diff options
Diffstat (limited to 'bip-0327/reference.py')
-rw-r--r-- | bip-0327/reference.py | 23 |
1 files changed, 12 insertions, 11 deletions
diff --git a/bip-0327/reference.py b/bip-0327/reference.py index edf6e76..17831c5 100644 --- a/bip-0327/reference.py +++ b/bip-0327/reference.py @@ -317,7 +317,7 @@ SessionContext = NamedTuple('SessionContext', [('aggnonce', bytes), ('is_xonly', List[bool]), ('msg', bytes)]) -def key_agg_and_tweak(pubkeys: List[PlainPk], tweaks: List[bytes], is_xonly: List[bool]): +def key_agg_and_tweak(pubkeys: List[PlainPk], tweaks: List[bytes], is_xonly: List[bool]) -> KeyAggContext: if len(tweaks) != len(is_xonly): raise ValueError('The `tweaks` and `is_xonly` arrays must have the same length.') keyagg_ctx = key_agg(pubkeys) @@ -440,8 +440,6 @@ def partial_sig_verify_internal(psig: bytes, pubnonce: bytes, pk: bytes, session Re_s_ = point_add(R_s1, point_mul(R_s2, b)) Re_s = Re_s_ if has_even_y(R) else point_negate(Re_s_) P = cpoint(pk) - if P is None: - return False a = get_session_key_agg_coeff(session_ctx, P) g = 1 if has_even_y(Q) else n - 1 g_ = g * gacc % n @@ -523,7 +521,7 @@ def test_key_agg_vectors() -> None: assert get_xonly_pk(key_agg(pubkeys)) == expected - for i, test_case in enumerate(error_test_cases): + for test_case in error_test_cases: exception, except_fn = get_error_details(test_case) pubkeys = [X[i] for i in test_case["key_indices"]] @@ -572,7 +570,7 @@ def test_nonce_agg_vectors() -> None: expected = bytes.fromhex(test_case["expected"]) assert nonce_agg(pubnonces) == expected - for i, test_case in enumerate(error_test_cases): + for test_case in error_test_cases: exception, except_fn = get_error_details(test_case) pubnonces = [pnonce[i] for i in test_case["pnonce_indices"]] assert_raises(exception, lambda: nonce_agg(pubnonces), except_fn) @@ -598,7 +596,10 @@ def test_sign_verify_vectors() -> None: aggnonces = fromhex_all(test_data["aggnonces"]) # The aggregate of the first three elements of pnonce is at index 0 - assert(aggnonces[0] == nonce_agg([pnonce[0], pnonce[1], pnonce[2]])) + assert (aggnonces[0] == nonce_agg([pnonce[0], pnonce[1], pnonce[2]])) + # The aggregate of the first and fourth elements of pnonce is at index 1, + # which is the infinity point encoded as a zeroed 33-byte array + assert (aggnonces[1] == nonce_agg([pnonce[0], pnonce[3]])) msgs = fromhex_all(test_data["msgs"]) @@ -626,7 +627,7 @@ def test_sign_verify_vectors() -> None: assert sign(secnonce_tmp, sk, session_ctx) == expected assert partial_sig_verify(expected, pubnonces, pubkeys, [], [], msg, signer_index) - for i, test_case in enumerate(sign_error_test_cases): + for test_case in sign_error_test_cases: exception, except_fn = get_error_details(test_case) pubkeys = [X[i] for i in test_case["key_indices"]] @@ -646,7 +647,7 @@ def test_sign_verify_vectors() -> None: assert not partial_sig_verify(sig, pubnonces, pubkeys, [], [], msg, signer_index) - for i, test_case in enumerate(verify_error_test_cases): + for test_case in verify_error_test_cases: exception, except_fn = get_error_details(test_case) sig = bytes.fromhex(test_case["sig"]) @@ -702,7 +703,7 @@ def test_tweak_vectors() -> None: assert sign(secnonce_tmp, sk, session_ctx) == expected assert partial_sig_verify(expected, pubnonces, pubkeys, tweaks, is_xonly, msg, signer_index) - for i, test_case in enumerate(error_test_cases): + for test_case in error_test_cases: exception, except_fn = get_error_details(test_case) pubkeys = [X[i] for i in test_case["key_indices"]] @@ -747,7 +748,7 @@ def test_det_sign_vectors() -> None: session_ctx = SessionContext(aggnonce, pubkeys, tweaks, is_xonly, msg) assert partial_sig_verify_internal(psig, pubnonce, pubkeys[signer_index], session_ctx) - for i, test_case in enumerate(error_test_cases): + for test_case in error_test_cases: exception, except_fn = get_error_details(test_case) pubkeys = [X[i] for i in test_case["key_indices"]] @@ -796,7 +797,7 @@ def test_sig_agg_vectors() -> None: aggpk = get_xonly_pk(key_agg_and_tweak(pubkeys, tweaks, is_xonly)) assert schnorr_verify(msg, aggpk, sig) - for i, test_case in enumerate(error_test_cases): + for test_case in error_test_cases: exception, except_fn = get_error_details(test_case) pubnonces = [pnonce[i] for i in test_case["nonce_indices"]] |