diff options
Diffstat (limited to 'bip-0327/reference.py')
-rw-r--r-- | bip-0327/reference.py | 880 |
1 files changed, 880 insertions, 0 deletions
diff --git a/bip-0327/reference.py b/bip-0327/reference.py new file mode 100644 index 0000000..edf6e76 --- /dev/null +++ b/bip-0327/reference.py @@ -0,0 +1,880 @@ +# BIP327 reference implementation +# +# WARNING: This implementation is for demonstration purposes only and _not_ to +# be used in production environments. The code is vulnerable to timing attacks, +# for example. + +from typing import Any, List, Optional, Tuple, NewType, NamedTuple +import hashlib +import secrets +import time + +# +# The following helper functions were copied from the BIP-340 reference implementation: +# https://github.com/bitcoin/bips/blob/master/bip-0340/reference.py +# + +p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F +n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 + +# Points are tuples of X and Y coordinates and the point at infinity is +# represented by the None keyword. +G = (0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8) + +Point = Tuple[int, int] + +# This implementation can be sped up by storing the midstate after hashing +# tag_hash instead of rehashing it all the time. +def tagged_hash(tag: str, msg: bytes) -> bytes: + tag_hash = hashlib.sha256(tag.encode()).digest() + return hashlib.sha256(tag_hash + tag_hash + msg).digest() + +def is_infinite(P: Optional[Point]) -> bool: + return P is None + +def x(P: Point) -> int: + assert not is_infinite(P) + return P[0] + +def y(P: Point) -> int: + assert not is_infinite(P) + return P[1] + +def point_add(P1: Optional[Point], P2: Optional[Point]) -> Optional[Point]: + if P1 is None: + return P2 + if P2 is None: + return P1 + if (x(P1) == x(P2)) and (y(P1) != y(P2)): + return None + if P1 == P2: + lam = (3 * x(P1) * x(P1) * pow(2 * y(P1), p - 2, p)) % p + else: + lam = ((y(P2) - y(P1)) * pow(x(P2) - x(P1), p - 2, p)) % p + x3 = (lam * lam - x(P1) - x(P2)) % p + return (x3, (lam * (x(P1) - x3) - y(P1)) % p) + +def point_mul(P: Optional[Point], n: int) -> Optional[Point]: + R = None + for i in range(256): + if (n >> i) & 1: + R = point_add(R, P) + P = point_add(P, P) + return R + +def bytes_from_int(x: int) -> bytes: + return x.to_bytes(32, byteorder="big") + +def lift_x(b: bytes) -> Optional[Point]: + x = int_from_bytes(b) + if x >= p: + return None + y_sq = (pow(x, 3, p) + 7) % p + y = pow(y_sq, (p + 1) // 4, p) + if pow(y, 2, p) != y_sq: + return None + return (x, y if y & 1 == 0 else p-y) + +def int_from_bytes(b: bytes) -> int: + return int.from_bytes(b, byteorder="big") + +def has_even_y(P: Point) -> bool: + assert not is_infinite(P) + return y(P) % 2 == 0 + +def schnorr_verify(msg: bytes, pubkey: bytes, sig: bytes) -> bool: + if len(msg) != 32: + raise ValueError('The message must be a 32-byte array.') + if len(pubkey) != 32: + raise ValueError('The public key must be a 32-byte array.') + if len(sig) != 64: + raise ValueError('The signature must be a 64-byte array.') + P = lift_x(pubkey) + r = int_from_bytes(sig[0:32]) + s = int_from_bytes(sig[32:64]) + if (P is None) or (r >= p) or (s >= n): + return False + e = int_from_bytes(tagged_hash("BIP0340/challenge", sig[0:32] + pubkey + msg)) % n + R = point_add(point_mul(G, s), point_mul(P, n - e)) + if (R is None) or (not has_even_y(R)) or (x(R) != r): + return False + return True + +# +# End of helper functions copied from BIP-340 reference implementation. +# + +PlainPk = NewType('PlainPk', bytes) +XonlyPk = NewType('XonlyPk', bytes) + +# There are two types of exceptions that can be raised by this implementation: +# - ValueError for indicating that an input doesn't conform to some function +# precondition (e.g. an input array is the wrong length, a serialized +# representation doesn't have the correct format). +# - InvalidContributionError for indicating that a signer (or the +# aggregator) is misbehaving in the protocol. +# +# Assertions are used to (1) satisfy the type-checking system, and (2) check for +# inconvenient events that can't happen except with negligible probability (e.g. +# output of a hash function is 0) and can't be manually triggered by any +# signer. + +# This exception is raised if a party (signer or nonce aggregator) sends invalid +# values. Actual implementations should not crash when receiving invalid +# contributions. Instead, they should hold the offending party accountable. +class InvalidContributionError(Exception): + def __init__(self, signer, contrib): + self.signer = signer + # contrib is one of "pubkey", "pubnonce", "aggnonce", or "psig". + self.contrib = contrib + +infinity = None + +def xbytes(P: Point) -> bytes: + return bytes_from_int(x(P)) + +def cbytes(P: Point) -> bytes: + a = b'\x02' if has_even_y(P) else b'\x03' + return a + xbytes(P) + +def cbytes_ext(P: Optional[Point]) -> bytes: + if is_infinite(P): + return (0).to_bytes(33, byteorder='big') + assert P is not None + return cbytes(P) + +def point_negate(P: Optional[Point]) -> Optional[Point]: + if P is None: + return P + return (x(P), p - y(P)) + +def cpoint(x: bytes) -> Point: + if len(x) != 33: + raise ValueError('x is not a valid compressed point.') + P = lift_x(x[1:33]) + if P is None: + raise ValueError('x is not a valid compressed point.') + if x[0] == 2: + return P + elif x[0] == 3: + P = point_negate(P) + assert P is not None + return P + else: + raise ValueError('x is not a valid compressed point.') + +def cpoint_ext(x: bytes) -> Optional[Point]: + if x == (0).to_bytes(33, 'big'): + return None + else: + return cpoint(x) + +# Return the plain public key corresponding to a given secret key +def individual_pk(seckey: bytes) -> PlainPk: + d0 = int_from_bytes(seckey) + if not (1 <= d0 <= n - 1): + raise ValueError('The secret key must be an integer in the range 1..n-1.') + P = point_mul(G, d0) + assert P is not None + return PlainPk(cbytes(P)) + +def key_sort(pubkeys: List[PlainPk]) -> List[PlainPk]: + pubkeys.sort() + return pubkeys + +KeyAggContext = NamedTuple('KeyAggContext', [('Q', Point), + ('gacc', int), + ('tacc', int)]) + +def get_xonly_pk(keyagg_ctx: KeyAggContext) -> XonlyPk: + Q, _, _ = keyagg_ctx + return XonlyPk(xbytes(Q)) + +def key_agg(pubkeys: List[PlainPk]) -> KeyAggContext: + pk2 = get_second_key(pubkeys) + u = len(pubkeys) + Q = infinity + for i in range(u): + try: + P_i = cpoint(pubkeys[i]) + except ValueError: + raise InvalidContributionError(i, "pubkey") + a_i = key_agg_coeff_internal(pubkeys, pubkeys[i], pk2) + Q = point_add(Q, point_mul(P_i, a_i)) + # Q is not the point at infinity except with negligible probability. + assert(Q is not None) + gacc = 1 + tacc = 0 + return KeyAggContext(Q, gacc, tacc) + +def hash_keys(pubkeys: List[PlainPk]) -> bytes: + return tagged_hash('KeyAgg list', b''.join(pubkeys)) + +def get_second_key(pubkeys: List[PlainPk]) -> PlainPk: + u = len(pubkeys) + for j in range(1, u): + if pubkeys[j] != pubkeys[0]: + return pubkeys[j] + return PlainPk(b'\x00'*33) + +def key_agg_coeff(pubkeys: List[PlainPk], pk_: PlainPk) -> int: + pk2 = get_second_key(pubkeys) + return key_agg_coeff_internal(pubkeys, pk_, pk2) + +def key_agg_coeff_internal(pubkeys: List[PlainPk], pk_: PlainPk, pk2: PlainPk) -> int: + L = hash_keys(pubkeys) + if pk_ == pk2: + return 1 + return int_from_bytes(tagged_hash('KeyAgg coefficient', L + pk_)) % n + +def apply_tweak(keyagg_ctx: KeyAggContext, tweak: bytes, is_xonly: bool) -> KeyAggContext: + if len(tweak) != 32: + raise ValueError('The tweak must be a 32-byte array.') + Q, gacc, tacc = keyagg_ctx + if is_xonly and not has_even_y(Q): + g = n - 1 + else: + g = 1 + t = int_from_bytes(tweak) + if t >= n: + raise ValueError('The tweak must be less than n.') + Q_ = point_add(point_mul(Q, g), point_mul(G, t)) + if Q_ is None: + raise ValueError('The result of tweaking cannot be infinity.') + gacc_ = g * gacc % n + tacc_ = (t + g * tacc) % n + return KeyAggContext(Q_, gacc_, tacc_) + +def bytes_xor(a: bytes, b: bytes) -> bytes: + return bytes(x ^ y for x, y in zip(a, b)) + +def nonce_hash(rand: bytes, pk: PlainPk, aggpk: XonlyPk, i: int, msg_prefixed: bytes, extra_in: bytes) -> int: + buf = b'' + buf += rand + buf += len(pk).to_bytes(1, 'big') + buf += pk + buf += len(aggpk).to_bytes(1, 'big') + buf += aggpk + buf += msg_prefixed + buf += len(extra_in).to_bytes(4, 'big') + buf += extra_in + buf += i.to_bytes(1, 'big') + return int_from_bytes(tagged_hash('MuSig/nonce', buf)) + +def nonce_gen_internal(rand_: bytes, sk: Optional[bytes], pk: PlainPk, aggpk: Optional[XonlyPk], msg: Optional[bytes], extra_in: Optional[bytes]) -> Tuple[bytearray, bytes]: + if sk is not None: + rand = bytes_xor(sk, tagged_hash('MuSig/aux', rand_)) + else: + rand = rand_ + if aggpk is None: + aggpk = XonlyPk(b'') + if msg is None: + msg_prefixed = b'\x00' + else: + msg_prefixed = b'\x01' + msg_prefixed += len(msg).to_bytes(8, 'big') + msg_prefixed += msg + if extra_in is None: + extra_in = b'' + k_1 = nonce_hash(rand, pk, aggpk, 0, msg_prefixed, extra_in) % n + k_2 = nonce_hash(rand, pk, aggpk, 1, msg_prefixed, extra_in) % n + # k_1 == 0 or k_2 == 0 cannot occur except with negligible probability. + assert k_1 != 0 + assert k_2 != 0 + R_s1 = point_mul(G, k_1) + R_s2 = point_mul(G, k_2) + assert R_s1 is not None + assert R_s2 is not None + pubnonce = cbytes(R_s1) + cbytes(R_s2) + secnonce = bytearray(bytes_from_int(k_1) + bytes_from_int(k_2) + pk) + return secnonce, pubnonce + +def nonce_gen(sk: Optional[bytes], pk: PlainPk, aggpk: Optional[XonlyPk], msg: Optional[bytes], extra_in: Optional[bytes]) -> Tuple[bytearray, bytes]: + if sk is not None and len(sk) != 32: + raise ValueError('The optional byte array sk must have length 32.') + if aggpk is not None and len(aggpk) != 32: + raise ValueError('The optional byte array aggpk must have length 32.') + rand_ = secrets.token_bytes(32) + return nonce_gen_internal(rand_, sk, pk, aggpk, msg, extra_in) + +def nonce_agg(pubnonces: List[bytes]) -> bytes: + u = len(pubnonces) + aggnonce = b'' + for j in (1, 2): + R_j = infinity + for i in range(u): + try: + R_ij = cpoint(pubnonces[i][(j-1)*33:j*33]) + except ValueError: + raise InvalidContributionError(i, "pubnonce") + R_j = point_add(R_j, R_ij) + aggnonce += cbytes_ext(R_j) + return aggnonce + +SessionContext = NamedTuple('SessionContext', [('aggnonce', bytes), + ('pubkeys', List[PlainPk]), + ('tweaks', List[bytes]), + ('is_xonly', List[bool]), + ('msg', bytes)]) + +def key_agg_and_tweak(pubkeys: List[PlainPk], tweaks: List[bytes], is_xonly: List[bool]): + if len(tweaks) != len(is_xonly): + raise ValueError('The `tweaks` and `is_xonly` arrays must have the same length.') + keyagg_ctx = key_agg(pubkeys) + v = len(tweaks) + for i in range(v): + keyagg_ctx = apply_tweak(keyagg_ctx, tweaks[i], is_xonly[i]) + return keyagg_ctx + +def get_session_values(session_ctx: SessionContext) -> Tuple[Point, int, int, int, Point, int]: + (aggnonce, pubkeys, tweaks, is_xonly, msg) = session_ctx + Q, gacc, tacc = key_agg_and_tweak(pubkeys, tweaks, is_xonly) + b = int_from_bytes(tagged_hash('MuSig/noncecoef', aggnonce + xbytes(Q) + msg)) % n + try: + R_1 = cpoint_ext(aggnonce[0:33]) + R_2 = cpoint_ext(aggnonce[33:66]) + except ValueError: + # Nonce aggregator sent invalid nonces + raise InvalidContributionError(None, "aggnonce") + R_ = point_add(R_1, point_mul(R_2, b)) + R = R_ if not is_infinite(R_) else G + assert R is not None + e = int_from_bytes(tagged_hash('BIP0340/challenge', xbytes(R) + xbytes(Q) + msg)) % n + return (Q, gacc, tacc, b, R, e) + +def get_session_key_agg_coeff(session_ctx: SessionContext, P: Point) -> int: + (_, pubkeys, _, _, _) = session_ctx + pk = PlainPk(cbytes(P)) + if pk not in pubkeys: + raise ValueError('The signer\'s pubkey must be included in the list of pubkeys.') + return key_agg_coeff(pubkeys, pk) + +def sign(secnonce: bytearray, sk: bytes, session_ctx: SessionContext) -> bytes: + (Q, gacc, _, b, R, e) = get_session_values(session_ctx) + k_1_ = int_from_bytes(secnonce[0:32]) + k_2_ = int_from_bytes(secnonce[32:64]) + # Overwrite the secnonce argument with zeros such that subsequent calls of + # sign with the same secnonce raise a ValueError. + secnonce[:64] = bytearray(b'\x00'*64) + if not 0 < k_1_ < n: + raise ValueError('first secnonce value is out of range.') + if not 0 < k_2_ < n: + raise ValueError('second secnonce value is out of range.') + k_1 = k_1_ if has_even_y(R) else n - k_1_ + k_2 = k_2_ if has_even_y(R) else n - k_2_ + d_ = int_from_bytes(sk) + if not 0 < d_ < n: + raise ValueError('secret key value is out of range.') + P = point_mul(G, d_) + assert P is not None + pk = cbytes(P) + if not pk == secnonce[64:97]: + raise ValueError('Public key does not match nonce_gen argument') + a = get_session_key_agg_coeff(session_ctx, P) + g = 1 if has_even_y(Q) else n - 1 + d = g * gacc * d_ % n + s = (k_1 + b * k_2 + e * a * d) % n + psig = bytes_from_int(s) + R_s1 = point_mul(G, k_1_) + R_s2 = point_mul(G, k_2_) + assert R_s1 is not None + assert R_s2 is not None + pubnonce = cbytes(R_s1) + cbytes(R_s2) + # Optional correctness check. The result of signing should pass signature verification. + assert partial_sig_verify_internal(psig, pubnonce, pk, session_ctx) + return psig + +def det_nonce_hash(sk_: bytes, aggothernonce: bytes, aggpk: bytes, msg: bytes, i: int) -> int: + buf = b'' + buf += sk_ + buf += aggothernonce + buf += aggpk + buf += len(msg).to_bytes(8, 'big') + buf += msg + buf += i.to_bytes(1, 'big') + return int_from_bytes(tagged_hash('MuSig/deterministic/nonce', buf)) + +def deterministic_sign(sk: bytes, aggothernonce: bytes, pubkeys: List[PlainPk], tweaks: List[bytes], is_xonly: List[bool], msg: bytes, rand: Optional[bytes]) -> Tuple[bytes, bytes]: + if rand is not None: + sk_ = bytes_xor(sk, tagged_hash('MuSig/aux', rand)) + else: + sk_ = sk + aggpk = get_xonly_pk(key_agg_and_tweak(pubkeys, tweaks, is_xonly)) + + k_1 = det_nonce_hash(sk_, aggothernonce, aggpk, msg, 0) % n + k_2 = det_nonce_hash(sk_, aggothernonce, aggpk, msg, 1) % n + # k_1 == 0 or k_2 == 0 cannot occur except with negligible probability. + assert k_1 != 0 + assert k_2 != 0 + + R_s1 = point_mul(G, k_1) + R_s2 = point_mul(G, k_2) + assert R_s1 is not None + assert R_s2 is not None + pubnonce = cbytes(R_s1) + cbytes(R_s2) + secnonce = bytearray(bytes_from_int(k_1) + bytes_from_int(k_2) + individual_pk(sk)) + try: + aggnonce = nonce_agg([pubnonce, aggothernonce]) + except Exception: + raise InvalidContributionError(None, "aggothernonce") + session_ctx = SessionContext(aggnonce, pubkeys, tweaks, is_xonly, msg) + psig = sign(secnonce, sk, session_ctx) + return (pubnonce, psig) + +def partial_sig_verify(psig: bytes, pubnonces: List[bytes], pubkeys: List[PlainPk], tweaks: List[bytes], is_xonly: List[bool], msg: bytes, i: int) -> bool: + if len(pubnonces) != len(pubkeys): + raise ValueError('The `pubnonces` and `pubkeys` arrays must have the same length.') + if len(tweaks) != len(is_xonly): + raise ValueError('The `tweaks` and `is_xonly` arrays must have the same length.') + aggnonce = nonce_agg(pubnonces) + session_ctx = SessionContext(aggnonce, pubkeys, tweaks, is_xonly, msg) + return partial_sig_verify_internal(psig, pubnonces[i], pubkeys[i], session_ctx) + +def partial_sig_verify_internal(psig: bytes, pubnonce: bytes, pk: bytes, session_ctx: SessionContext) -> bool: + (Q, gacc, _, b, R, e) = get_session_values(session_ctx) + s = int_from_bytes(psig) + if s >= n: + return False + R_s1 = cpoint(pubnonce[0:33]) + R_s2 = cpoint(pubnonce[33:66]) + Re_s_ = point_add(R_s1, point_mul(R_s2, b)) + Re_s = Re_s_ if has_even_y(R) else point_negate(Re_s_) + P = cpoint(pk) + if P is None: + return False + a = get_session_key_agg_coeff(session_ctx, P) + g = 1 if has_even_y(Q) else n - 1 + g_ = g * gacc % n + return point_mul(G, s) == point_add(Re_s, point_mul(P, e * a * g_ % n)) + +def partial_sig_agg(psigs: List[bytes], session_ctx: SessionContext) -> bytes: + (Q, _, tacc, _, R, e) = get_session_values(session_ctx) + s = 0 + u = len(psigs) + for i in range(u): + s_i = int_from_bytes(psigs[i]) + if s_i >= n: + raise InvalidContributionError(i, "psig") + s = (s + s_i) % n + g = 1 if has_even_y(Q) else n - 1 + s = (s + e * g * tacc) % n + return xbytes(R) + bytes_from_int(s) +# +# The following code is only used for testing. +# + +import json +import os +import sys + +def fromhex_all(l): + return [bytes.fromhex(l_i) for l_i in l] + +# Check that calling `try_fn` raises a `exception`. If `exception` is raised, +# examine it with `except_fn`. +def assert_raises(exception, try_fn, except_fn): + raised = False + try: + try_fn() + except exception as e: + raised = True + assert(except_fn(e)) + except BaseException: + raise AssertionError("Wrong exception raised in a test.") + if not raised: + raise AssertionError("Exception was _not_ raised in a test where it was required.") + +def get_error_details(test_case): + error = test_case["error"] + if error["type"] == "invalid_contribution": + exception = InvalidContributionError + if "contrib" in error: + except_fn = lambda e: e.signer == error["signer"] and e.contrib == error["contrib"] + else: + except_fn = lambda e: e.signer == error["signer"] + elif error["type"] == "value": + exception = ValueError + except_fn = lambda e: str(e) == error["message"] + else: + raise RuntimeError(f"Invalid error type: {error['type']}") + return exception, except_fn + +def test_key_sort_vectors() -> None: + with open(os.path.join(sys.path[0], 'vectors', 'key_sort_vectors.json')) as f: + test_data = json.load(f) + + X = fromhex_all(test_data["pubkeys"]) + X_sorted = fromhex_all(test_data["sorted_pubkeys"]) + + assert key_sort(X) == X_sorted + +def test_key_agg_vectors() -> None: + with open(os.path.join(sys.path[0], 'vectors', 'key_agg_vectors.json')) as f: + test_data = json.load(f) + + X = fromhex_all(test_data["pubkeys"]) + T = fromhex_all(test_data["tweaks"]) + valid_test_cases = test_data["valid_test_cases"] + error_test_cases = test_data["error_test_cases"] + + for test_case in valid_test_cases: + pubkeys = [X[i] for i in test_case["key_indices"]] + expected = bytes.fromhex(test_case["expected"]) + + assert get_xonly_pk(key_agg(pubkeys)) == expected + + for i, test_case in enumerate(error_test_cases): + exception, except_fn = get_error_details(test_case) + + pubkeys = [X[i] for i in test_case["key_indices"]] + tweaks = [T[i] for i in test_case["tweak_indices"]] + is_xonly = test_case["is_xonly"] + + assert_raises(exception, lambda: key_agg_and_tweak(pubkeys, tweaks, is_xonly), except_fn) + +def test_nonce_gen_vectors() -> None: + with open(os.path.join(sys.path[0], 'vectors', 'nonce_gen_vectors.json')) as f: + test_data = json.load(f) + + for test_case in test_data["test_cases"]: + def get_value(key) -> bytes: + return bytes.fromhex(test_case[key]) + + def get_value_maybe(key) -> Optional[bytes]: + if test_case[key] is not None: + return get_value(key) + else: + return None + + rand_ = get_value("rand_") + sk = get_value_maybe("sk") + pk = PlainPk(get_value("pk")) + aggpk = get_value_maybe("aggpk") + if aggpk is not None: + aggpk = XonlyPk(aggpk) + msg = get_value_maybe("msg") + extra_in = get_value_maybe("extra_in") + expected_secnonce = get_value("expected_secnonce") + expected_pubnonce = get_value("expected_pubnonce") + + assert nonce_gen_internal(rand_, sk, pk, aggpk, msg, extra_in) == (expected_secnonce, expected_pubnonce) + +def test_nonce_agg_vectors() -> None: + with open(os.path.join(sys.path[0], 'vectors', 'nonce_agg_vectors.json')) as f: + test_data = json.load(f) + + pnonce = fromhex_all(test_data["pnonces"]) + valid_test_cases = test_data["valid_test_cases"] + error_test_cases = test_data["error_test_cases"] + + for test_case in valid_test_cases: + pubnonces = [pnonce[i] for i in test_case["pnonce_indices"]] + expected = bytes.fromhex(test_case["expected"]) + assert nonce_agg(pubnonces) == expected + + for i, test_case in enumerate(error_test_cases): + exception, except_fn = get_error_details(test_case) + pubnonces = [pnonce[i] for i in test_case["pnonce_indices"]] + assert_raises(exception, lambda: nonce_agg(pubnonces), except_fn) + +def test_sign_verify_vectors() -> None: + with open(os.path.join(sys.path[0], 'vectors', 'sign_verify_vectors.json')) as f: + test_data = json.load(f) + + sk = bytes.fromhex(test_data["sk"]) + X = fromhex_all(test_data["pubkeys"]) + # The public key corresponding to sk is at index 0 + assert X[0] == individual_pk(sk) + + secnonces = fromhex_all(test_data["secnonces"]) + pnonce = fromhex_all(test_data["pnonces"]) + # The public nonce corresponding to secnonces[0] is at index 0 + k_1 = int_from_bytes(secnonces[0][0:32]) + k_2 = int_from_bytes(secnonces[0][32:64]) + R_s1 = point_mul(G, k_1) + R_s2 = point_mul(G, k_2) + assert R_s1 is not None and R_s2 is not None + assert pnonce[0] == cbytes(R_s1) + cbytes(R_s2) + + aggnonces = fromhex_all(test_data["aggnonces"]) + # The aggregate of the first three elements of pnonce is at index 0 + assert(aggnonces[0] == nonce_agg([pnonce[0], pnonce[1], pnonce[2]])) + + msgs = fromhex_all(test_data["msgs"]) + + valid_test_cases = test_data["valid_test_cases"] + sign_error_test_cases = test_data["sign_error_test_cases"] + verify_fail_test_cases = test_data["verify_fail_test_cases"] + verify_error_test_cases = test_data["verify_error_test_cases"] + + for test_case in valid_test_cases: + pubkeys = [X[i] for i in test_case["key_indices"]] + pubnonces = [pnonce[i] for i in test_case["nonce_indices"]] + aggnonce = aggnonces[test_case["aggnonce_index"]] + # Make sure that pubnonces and aggnonce in the test vector are + # consistent + assert nonce_agg(pubnonces) == aggnonce + msg = msgs[test_case["msg_index"]] + signer_index = test_case["signer_index"] + expected = bytes.fromhex(test_case["expected"]) + + session_ctx = SessionContext(aggnonce, pubkeys, [], [], msg) + # WARNING: An actual implementation should _not_ copy the secnonce. + # Reusing the secnonce, as we do here for testing purposes, can leak the + # secret key. + secnonce_tmp = bytearray(secnonces[0]) + assert sign(secnonce_tmp, sk, session_ctx) == expected + assert partial_sig_verify(expected, pubnonces, pubkeys, [], [], msg, signer_index) + + for i, test_case in enumerate(sign_error_test_cases): + exception, except_fn = get_error_details(test_case) + + pubkeys = [X[i] for i in test_case["key_indices"]] + aggnonce = aggnonces[test_case["aggnonce_index"]] + msg = msgs[test_case["msg_index"]] + secnonce = bytearray(secnonces[test_case["secnonce_index"]]) + + session_ctx = SessionContext(aggnonce, pubkeys, [], [], msg) + assert_raises(exception, lambda: sign(secnonce, sk, session_ctx), except_fn) + + for test_case in verify_fail_test_cases: + sig = bytes.fromhex(test_case["sig"]) + pubkeys = [X[i] for i in test_case["key_indices"]] + pubnonces = [pnonce[i] for i in test_case["nonce_indices"]] + msg = msgs[test_case["msg_index"]] + signer_index = test_case["signer_index"] + + assert not partial_sig_verify(sig, pubnonces, pubkeys, [], [], msg, signer_index) + + for i, test_case in enumerate(verify_error_test_cases): + exception, except_fn = get_error_details(test_case) + + sig = bytes.fromhex(test_case["sig"]) + pubkeys = [X[i] for i in test_case["key_indices"]] + pubnonces = [pnonce[i] for i in test_case["nonce_indices"]] + msg = msgs[test_case["msg_index"]] + signer_index = test_case["signer_index"] + + assert_raises(exception, lambda: partial_sig_verify(sig, pubnonces, pubkeys, [], [], msg, signer_index), except_fn) + +def test_tweak_vectors() -> None: + with open(os.path.join(sys.path[0], 'vectors', 'tweak_vectors.json')) as f: + test_data = json.load(f) + + sk = bytes.fromhex(test_data["sk"]) + X = fromhex_all(test_data["pubkeys"]) + # The public key corresponding to sk is at index 0 + assert X[0] == individual_pk(sk) + + secnonce = bytearray(bytes.fromhex(test_data["secnonce"])) + pnonce = fromhex_all(test_data["pnonces"]) + # The public nonce corresponding to secnonce is at index 0 + k_1 = int_from_bytes(secnonce[0:32]) + k_2 = int_from_bytes(secnonce[32:64]) + R_s1 = point_mul(G, k_1) + R_s2 = point_mul(G, k_2) + assert R_s1 is not None and R_s2 is not None + assert pnonce[0] == cbytes(R_s1) + cbytes(R_s2) + + aggnonce = bytes.fromhex(test_data["aggnonce"]) + # The aggnonce is the aggregate of the first three elements of pnonce + assert(aggnonce == nonce_agg([pnonce[0], pnonce[1], pnonce[2]])) + + tweak = fromhex_all(test_data["tweaks"]) + msg = bytes.fromhex(test_data["msg"]) + + valid_test_cases = test_data["valid_test_cases"] + error_test_cases = test_data["error_test_cases"] + + for test_case in valid_test_cases: + pubkeys = [X[i] for i in test_case["key_indices"]] + pubnonces = [pnonce[i] for i in test_case["nonce_indices"]] + tweaks = [tweak[i] for i in test_case["tweak_indices"]] + is_xonly = test_case["is_xonly"] + signer_index = test_case["signer_index"] + expected = bytes.fromhex(test_case["expected"]) + + session_ctx = SessionContext(aggnonce, pubkeys, tweaks, is_xonly, msg) + secnonce_tmp = bytearray(secnonce) + # WARNING: An actual implementation should _not_ copy the secnonce. + # Reusing the secnonce, as we do here for testing purposes, can leak the + # secret key. + assert sign(secnonce_tmp, sk, session_ctx) == expected + assert partial_sig_verify(expected, pubnonces, pubkeys, tweaks, is_xonly, msg, signer_index) + + for i, test_case in enumerate(error_test_cases): + exception, except_fn = get_error_details(test_case) + + pubkeys = [X[i] for i in test_case["key_indices"]] + pubnonces = [pnonce[i] for i in test_case["nonce_indices"]] + tweaks = [tweak[i] for i in test_case["tweak_indices"]] + is_xonly = test_case["is_xonly"] + signer_index = test_case["signer_index"] + + session_ctx = SessionContext(aggnonce, pubkeys, tweaks, is_xonly, msg) + assert_raises(exception, lambda: sign(secnonce, sk, session_ctx), except_fn) + +def test_det_sign_vectors() -> None: + with open(os.path.join(sys.path[0], 'vectors', 'det_sign_vectors.json')) as f: + test_data = json.load(f) + + sk = bytes.fromhex(test_data["sk"]) + X = fromhex_all(test_data["pubkeys"]) + # The public key corresponding to sk is at index 0 + assert X[0] == individual_pk(sk) + + msgs = fromhex_all(test_data["msgs"]) + + valid_test_cases = test_data["valid_test_cases"] + error_test_cases = test_data["error_test_cases"] + + for test_case in valid_test_cases: + pubkeys = [X[i] for i in test_case["key_indices"]] + aggothernonce = bytes.fromhex(test_case["aggothernonce"]) + tweaks = fromhex_all(test_case["tweaks"]) + is_xonly = test_case["is_xonly"] + msg = msgs[test_case["msg_index"]] + signer_index = test_case["signer_index"] + rand = bytes.fromhex(test_case["rand"]) if test_case["rand"] is not None else None + expected = fromhex_all(test_case["expected"]) + + pubnonce, psig = deterministic_sign(sk, aggothernonce, pubkeys, tweaks, is_xonly, msg, rand) + assert pubnonce == expected[0] + assert psig == expected[1] + + pubnonces = [aggothernonce, pubnonce] + aggnonce = nonce_agg(pubnonces) + session_ctx = SessionContext(aggnonce, pubkeys, tweaks, is_xonly, msg) + assert partial_sig_verify_internal(psig, pubnonce, pubkeys[signer_index], session_ctx) + + for i, test_case in enumerate(error_test_cases): + exception, except_fn = get_error_details(test_case) + + pubkeys = [X[i] for i in test_case["key_indices"]] + aggothernonce = bytes.fromhex(test_case["aggothernonce"]) + tweaks = fromhex_all(test_case["tweaks"]) + is_xonly = test_case["is_xonly"] + msg = msgs[test_case["msg_index"]] + signer_index = test_case["signer_index"] + rand = bytes.fromhex(test_case["rand"]) if test_case["rand"] is not None else None + + try_fn = lambda: deterministic_sign(sk, aggothernonce, pubkeys, tweaks, is_xonly, msg, rand) + assert_raises(exception, try_fn, except_fn) + +def test_sig_agg_vectors() -> None: + with open(os.path.join(sys.path[0], 'vectors', 'sig_agg_vectors.json')) as f: + test_data = json.load(f) + + X = fromhex_all(test_data["pubkeys"]) + + # These nonces are only required if the tested API takes the individual + # nonces and not the aggregate nonce. + pnonce = fromhex_all(test_data["pnonces"]) + + tweak = fromhex_all(test_data["tweaks"]) + psig = fromhex_all(test_data["psigs"]) + + msg = bytes.fromhex(test_data["msg"]) + + valid_test_cases = test_data["valid_test_cases"] + error_test_cases = test_data["error_test_cases"] + + for test_case in valid_test_cases: + pubnonces = [pnonce[i] for i in test_case["nonce_indices"]] + aggnonce = bytes.fromhex(test_case["aggnonce"]) + assert aggnonce == nonce_agg(pubnonces) + + pubkeys = [X[i] for i in test_case["key_indices"]] + tweaks = [tweak[i] for i in test_case["tweak_indices"]] + is_xonly = test_case["is_xonly"] + psigs = [psig[i] for i in test_case["psig_indices"]] + expected = bytes.fromhex(test_case["expected"]) + + session_ctx = SessionContext(aggnonce, pubkeys, tweaks, is_xonly, msg) + sig = partial_sig_agg(psigs, session_ctx) + assert sig == expected + aggpk = get_xonly_pk(key_agg_and_tweak(pubkeys, tweaks, is_xonly)) + assert schnorr_verify(msg, aggpk, sig) + + for i, test_case in enumerate(error_test_cases): + exception, except_fn = get_error_details(test_case) + + pubnonces = [pnonce[i] for i in test_case["nonce_indices"]] + aggnonce = nonce_agg(pubnonces) + + pubkeys = [X[i] for i in test_case["key_indices"]] + tweaks = [tweak[i] for i in test_case["tweak_indices"]] + is_xonly = test_case["is_xonly"] + psigs = [psig[i] for i in test_case["psig_indices"]] + + session_ctx = SessionContext(aggnonce, pubkeys, tweaks, is_xonly, msg) + assert_raises(exception, lambda: partial_sig_agg(psigs, session_ctx), except_fn) + +def test_sign_and_verify_random(iters: int) -> None: + for i in range(iters): + sk_1 = secrets.token_bytes(32) + sk_2 = secrets.token_bytes(32) + pk_1 = individual_pk(sk_1) + pk_2 = individual_pk(sk_2) + pubkeys = [pk_1, pk_2] + + # In this example, the message and aggregate pubkey are known + # before nonce generation, so they can be passed into the nonce + # generation function as a defense-in-depth measure to protect + # against nonce reuse. + # + # If these values are not known when nonce_gen is called, empty + # byte arrays can be passed in for the corresponding arguments + # instead. + msg = secrets.token_bytes(32) + v = secrets.randbelow(4) + tweaks = [secrets.token_bytes(32) for _ in range(v)] + is_xonly = [secrets.choice([False, True]) for _ in range(v)] + aggpk = get_xonly_pk(key_agg_and_tweak(pubkeys, tweaks, is_xonly)) + + # Use a non-repeating counter for extra_in + secnonce_1, pubnonce_1 = nonce_gen(sk_1, pk_1, aggpk, msg, i.to_bytes(4, 'big')) + + # On even iterations use regular signing algorithm for signer 2, + # otherwise use deterministic signing algorithm + if i % 2 == 0: + # Use a clock for extra_in + t = time.clock_gettime_ns(time.CLOCK_MONOTONIC) + secnonce_2, pubnonce_2 = nonce_gen(sk_2, pk_2, aggpk, msg, t.to_bytes(8, 'big')) + else: + aggothernonce = nonce_agg([pubnonce_1]) + rand = secrets.token_bytes(32) + pubnonce_2, psig_2 = deterministic_sign(sk_2, aggothernonce, pubkeys, tweaks, is_xonly, msg, rand) + + pubnonces = [pubnonce_1, pubnonce_2] + aggnonce = nonce_agg(pubnonces) + + session_ctx = SessionContext(aggnonce, pubkeys, tweaks, is_xonly, msg) + psig_1 = sign(secnonce_1, sk_1, session_ctx) + assert partial_sig_verify(psig_1, pubnonces, pubkeys, tweaks, is_xonly, msg, 0) + # An exception is thrown if secnonce_1 is accidentally reused + assert_raises(ValueError, lambda: sign(secnonce_1, sk_1, session_ctx), lambda e: True) + + # Wrong signer index + assert not partial_sig_verify(psig_1, pubnonces, pubkeys, tweaks, is_xonly, msg, 1) + + # Wrong message + assert not partial_sig_verify(psig_1, pubnonces, pubkeys, tweaks, is_xonly, secrets.token_bytes(32), 0) + + if i % 2 == 0: + psig_2 = sign(secnonce_2, sk_2, session_ctx) + assert partial_sig_verify(psig_2, pubnonces, pubkeys, tweaks, is_xonly, msg, 1) + + sig = partial_sig_agg([psig_1, psig_2], session_ctx) + assert schnorr_verify(msg, aggpk, sig) + +if __name__ == '__main__': + test_key_sort_vectors() + test_key_agg_vectors() + test_nonce_gen_vectors() + test_nonce_agg_vectors() + test_sign_verify_vectors() + test_tweak_vectors() + test_det_sign_vectors() + test_sig_agg_vectors() + test_sign_and_verify_random(6) |