diff options
Diffstat (limited to 'bip-0352/reference.py')
-rwxr-xr-x | bip-0352/reference.py | 15 |
1 files changed, 11 insertions, 4 deletions
diff --git a/bip-0352/reference.py b/bip-0352/reference.py index 9f43695..b4eaf94 100755 --- a/bip-0352/reference.py +++ b/bip-0352/reference.py @@ -117,7 +117,7 @@ def decode_silent_payment_address(address: str, hrp: str = "tsp") -> Tuple[ECPub return B_scan, B_spend -def create_outputs(input_priv_keys: List[Tuple[ECKey, bool]], input_hash: bytes, recipients: List[str], hrp="tsp") -> List[str]: +def create_outputs(input_priv_keys: List[Tuple[ECKey, bool]], outpoints: List[COutPoint], recipients: List[str], hrp="tsp") -> List[str]: G = ECKey().set(1).get_pubkey() negated_keys = [] for key, is_xonly in input_priv_keys: @@ -127,6 +127,10 @@ def create_outputs(input_priv_keys: List[Tuple[ECKey, bool]], input_hash: bytes, negated_keys.append(k) a_sum = sum(negated_keys) + if not a_sum.valid: + # Input privkeys sum is zero -> fail + return [] + input_hash = get_input_hash(outpoints, a_sum * G) silent_payment_groups: Dict[ECPubKey, List[ECPubKey]] = {} for recipient in recipients: B_scan, B_m = decode_silent_payment_address(recipient, hrp=hrp) @@ -236,9 +240,8 @@ if __name__ == "__main__": sending_outputs = [] if (len(input_pub_keys) > 0): - A_sum = reduce(lambda x, y: x + y, input_pub_keys) - input_hash = get_input_hash([vin.outpoint for vin in vins], A_sum) - sending_outputs = create_outputs(input_priv_keys, input_hash, given["recipients"], hrp="sp") + outpoints = [vin.outpoint for vin in vins] + sending_outputs = create_outputs(input_priv_keys, outpoints, given["recipients"], hrp="sp") # Note: order doesn't matter for creating/finding the outputs. However, different orderings of the recipient addresses # will produce different generated outputs if sending to multiple silent payment addresses belonging to the @@ -297,6 +300,10 @@ if __name__ == "__main__": add_to_wallet = [] if (len(input_pub_keys) > 0): A_sum = reduce(lambda x, y: x + y, input_pub_keys) + if A_sum.get_bytes() is None: + # Input pubkeys sum is point at infinity -> skip tx + assert expected["outputs"] == [] + continue input_hash = get_input_hash([vin.outpoint for vin in vins], A_sum) pre_computed_labels = { (generate_label(b_scan, label) * G).get_bytes(False).hex(): generate_label(b_scan, label).hex() |