diff options
Diffstat (limited to 'bip-0324/gen_test_vectors.py')
-rw-r--r-- | bip-0324/gen_test_vectors.py | 418 |
1 files changed, 418 insertions, 0 deletions
diff --git a/bip-0324/gen_test_vectors.py b/bip-0324/gen_test_vectors.py new file mode 100644 index 0000000..05b30a8 --- /dev/null +++ b/bip-0324/gen_test_vectors.py @@ -0,0 +1,418 @@ +"""Generate the BIP-0324 test vectors.""" + +import csv +import hashlib +import os +import sys +from reference import ( + FE, + GE, + MINUS_3_SQRT, + hkdf_sha256, + SECP256K1_G, + ellswift_decode, + ellswift_ecdh_xonly, + xswiftec_inv, + xswiftec, + v2_ecdh, + initialize_v2_transport, + v2_enc_packet +) + +FILENAME_PACKET_TEST = os.path.join(sys.path[0], 'packet_encoding_test_vectors.csv') +FILENAME_XSWIFTEC_INV_TEST = os.path.join(sys.path[0], 'xswiftec_inv_test_vectors.csv') +FILENAME_ELLSWIFT_DECODE_TEST = os.path.join(sys.path[0], 'ellswift_decode_test_vectors.csv') + +def xswiftec_flagged(u, t, simplified=False): + """A variant of xswiftec which also returns 'flags', describing conditions encountered.""" + flags = [] + if u == 0: + flags.append("u%p=0") + u = FE(1) + if t == 0: + flags.append("t%p=0") + t = FE(1) + if u**3 + t**2 + 7 == 0: + flags.append("(u'^3+t'^2+7)%p=0") + t = 2 * t + X = (u**3 + 7 - t**2) / (2 * t) + Y = (X + t) / (MINUS_3_SQRT * u) + if X == 0: + if not simplified: + flags.append("(u'^3-t'^2+7)%p=0") + x3 = u + 4 * Y**2 + if GE.is_valid_x(x3): + flags.append("valid_x(x3)") + x2 = (-X / Y - u) / 2 + if GE.is_valid_x(x2): + flags.append("valid_x(x2)") + x1 = (X / Y - u) / 2 + if GE.is_valid_x(x1): + flags.append("valid_x(x1)") + for x in (x3, x2, x1): + if GE.is_valid_x(x): + break + return x, flags + + +def ellswift_create_deterministic(seed, features): + """This is a variant of ellswift_create which doesn't use randomness. + + features is an integer selecting some properties of the result: + - (f & 3) == 0: only x1 is valid on decoding (see xswiftec{_flagged}) + - (f & 3) == 1: only x2 is valid on decoding + - (f & 3) == 2: only x3 is valid on decoding + - (f & 3) == 3: x1,x2,x3 are all valid on decoding + - (f & 4) == 4: u >= p + - (f & 8) == 8: u mod n == 0 + + Returns privkey, ellswift + """ + + cnt = 0 + while True: + sec = hkdf_sha256(32, seed, (cnt).to_bytes(4, 'little'), b"sec") + xval = (int.from_bytes(sec, 'big') * SECP256K1_G).x + cnt += 1 + if features & 8: + u = 0 + if features & 4: + u += FE.SIZE + else: + udat = hkdf_sha256(64, seed, (cnt).to_bytes(4, 'little'), b"u") + if features & 4: + u = FE.SIZE + 1 + int.from_bytes(udat, 'big') % (2**256 - FE.SIZE - 1) + else: + u = 1 + int.from_bytes(udat, 'big') % (FE.SIZE - 1) + case = hkdf_sha256(1, seed, (cnt).to_bytes(4, 'little'), b"case")[0] & 7 + coru = FE(u) + ((features & 8) == 8) + t = xswiftec_inv(xval, coru, case) + if t is None: + continue + assert xswiftec(FE(u), t) == xval + x2, flags = xswiftec_flagged(FE(u), t) + assert x2 == xval + have_x1 = "valid_x(x1)" in flags + have_x2 = "valid_x(x2)" in flags + have_x3 = "valid_x(x3)" in flags + if (features & 4) == 0 and not (have_x1 and not have_x2 and not have_x3): + continue + if (features & 4) == 1 and not (not have_x1 and have_x2 and not have_x3): + continue + if (features & 4) == 2 and not (not have_x1 and not have_x2 and have_x3): + continue + if (features & 4) == 3 and not (have_x1 and have_x2 and have_x3): + continue + return sec, u.to_bytes(32, 'big') + t.to_bytes() + +def ellswift_decode_flagged(ellswift, simplified=False): + """Decode a 64-byte ElligatorSwift encoded coordinate, returning byte array + flag string.""" + uv = int.from_bytes(ellswift[:32], 'big') + tv = int.from_bytes(ellswift[32:], 'big') + x, flags = xswiftec_flagged(FE(uv), FE(tv)) + if not simplified: + if uv >= FE.SIZE: + flags.append("u>=p") + if tv >= FE.SIZE: + flags.append("t>=p") + return int(x).to_bytes(32, 'big'), ";".join(flags) + +def random_fe_int(_, seed, i, p): + """Function to use in tuple_expand, generating a random integer in 0..p-1.""" + rng_out = hkdf_sha256(64, seed, i.to_bytes(4, 'little'), b"v%i_fe" % p) + return int.from_bytes(rng_out, 'big') % FE.SIZE + +def random_fe_int_high(_, seed, i, p): + """Function to use in tuple_expand, generating a random integer in p..2^256-1.""" + rng_out = hkdf_sha256(64, seed, i.to_bytes(4, 'little'), b"v%i_fe_high" % p) + return FE.SIZE + int.from_bytes(rng_out, 'big') % (2**256 - FE.SIZE) + +def fn_of(p_in, fn): + """Function to use in tuple_expand, to pick one variable in function of another.""" + def inner(vs, _seed, _i, p): + assert p != p_in + if isinstance(vs[p_in], int): + return fn(vs[p_in]) + return None + return inner + +def tuple_expand(out, tuplespec, prio, seed=None, cnt=1): + """Given a tuple specification, expand it cnt times, and add results to out. + + Expansion is defined recursively: + - If any of the spec elements is a list, each element of the list results + in an expansion (by replacing the list with its element). + - If any of the spec elements is a function, that function is invoked with + (spec, seed, expansion count, index in spec) as arguments. If the function + needs to wait for other indices to be expanded, it can return None. + + The output consists of (prio, expansion count, SHA256(result), result, seed) + tuples.""" + + def recurse(vs, seed, i, change_pos=None, change=None): + if change_pos is not None: + vs = list(vs) + vs[change_pos] = change + for p, v in enumerate(vs): + if v is None: + return + if isinstance(v, list): + for ve in v: + recurse(vs, seed, i, p, ve) + return + if callable(v): + res = v(vs, seed, i, p) + if res is not None: + recurse(vs, seed, i, p, res) + return + h = hashlib.sha256() + for v in vs: + h.update(int(v).to_bytes(32, 'big')) + out.append((prio, i, h.digest(), vs, seed)) + for i in range(cnt): + recurse(tuplespec, seed, i) + +def gen_ellswift_decode_cases(seed, simplified=False): + """Generate a set of interesting (ellswift, x, flags) ellswift decoding cases.""" + inputs = [] + + # Aggregate for use in tuple_expand, expanding to int in 0..p-1, and one in p..2^256-1. + RANDOM_VAL = [random_fe_int, random_fe_int_high] + # Aggregate for use in tuple_expand, expanding to integers which %p equal 0. + ZERO_VAL = [0, FE.SIZE] + # Helpers for constructing u and t values such that u^3+t^2+7=0 or u^3-t^2+7=0. + T_FOR_SUM_ZERO = fn_of(0, lambda u: (-FE(u)**3 - 7).sqrts()) + T_FOR_DIFF_ZERO = fn_of(0, lambda u: (FE(u)**3 + 7).sqrts()) + U_FOR_SUM_ZERO = fn_of(1, lambda t: (-FE(t)**2 - 7).cbrts()) + U_FOR_DIFF_ZERO = fn_of(1, lambda t: (FE(t)**2 - 7).cbrts()) + + tuple_expand(inputs, [RANDOM_VAL, RANDOM_VAL], 0, seed + b"random", 64) + tuple_expand(inputs, [RANDOM_VAL, T_FOR_SUM_ZERO], 1, seed + b"t=sqrt(-u^3-7)", 64) + tuple_expand(inputs, [U_FOR_SUM_ZERO, RANDOM_VAL], 1, seed + b"u=cbrt(-t^2-7)", 64) + tuple_expand(inputs, [RANDOM_VAL, T_FOR_DIFF_ZERO], 1, seed + b"t=sqrt(u^3+7)", 64) + tuple_expand(inputs, [U_FOR_DIFF_ZERO, RANDOM_VAL], 1, seed + b"u=cbrt(t^2-7)", 64) + tuple_expand(inputs, [ZERO_VAL, RANDOM_VAL], 2, seed + b"u=0", 64) + tuple_expand(inputs, [RANDOM_VAL, ZERO_VAL], 2, seed + b"t=0", 64) + tuple_expand(inputs, [ZERO_VAL, FE(8).sqrts()], 3, seed + b"u=0;t=sqrt(8)") + tuple_expand(inputs, [FE(-8).cbrts(), ZERO_VAL], 3, seed + b"t=0;u=cbrt(-8)") + tuple_expand(inputs, [FE(-6).cbrts(), ZERO_VAL], 3, seed + b"t=0;u=cbrt(-6)") + tuple_expand(inputs, [ZERO_VAL, ZERO_VAL], 3, seed + b"u=0;t=0") + # Unused. + tuple_expand(inputs, [ZERO_VAL, FE(-8).sqrts()], 4, seed + b"u=0;t=sqrt(-8)") + + seen = set() + cases = [] + for _prio, _cnt, _hash, vs, _seed in sorted(inputs): + inp = int(vs[0]).to_bytes(32, 'big') + int(vs[1]).to_bytes(32, 'big') + outp, flags = ellswift_decode_flagged(inp, simplified) + if flags not in seen: + cases.append((inp, outp, flags)) + seen.add(flags) + + return cases + +def gen_all_ellswift_decode_vectors(fil): + """Generate all xelligatorswift decoding test vectors.""" + + cases = gen_ellswift_decode_cases(b"") + writer = csv.DictWriter(fil, ["ellswift", "x", "comment"]) + writer.writeheader() + for val, x, flags in sorted(cases): + writer.writerow({"ellswift": val.hex(), "x": x.hex(), "comment": flags}) + +def xswiftec_inv_flagged(x, u, case): + """A variant of xswiftec_inv which also returns flags, describing conditions encountered.""" + + flags = [] + + if case & 2 == 0: + if GE.is_valid_x(-x - u): + flags.append("bad[valid_x(-x-u)]") + return None, flags + v = x if case & 1 == 0 else -x - u + if v == 0: + flags.append("info[v=0]") + s = -(u**3 + 7) / (u**2 + u*v + v**2) + assert s != 0 # would imply X=0 on curve + else: + s = x - u + if s == 0: + flags.append("bad[s=0]") + return None, flags + q = (-s * (4 * (u**3 + 7) + 3 * s * u**2)) + if q == 0: + flags.append("info[q=0]") + r = q.sqrt() + if r is None: + flags.append("bad[non_square(q)]") + return None, flags + if case & 1: + if r == 0: + flags.append("bad[r=0]") + return None, flags + r = -r + v = (-u + r / s) / 2 + if v == 0: + flags.append("info[v=0]") + w = s.sqrt() + assert w != 0 + if w is None: + flags.append("bad[non_square(s)]") + return None, flags + if case & 4: + w = -w + Y = w / 2 + assert Y != 0 + X = 2 * Y * (v + u / 2) + if X == 0: + flags.append("info[X=0]") + flags.append("ok") + return w * (u * (MINUS_3_SQRT - 1) / 2 - v), flags + +def xswiftec_inv_combo_flagged(x, u): + """Compute the aggregate results and flags from xswiftec_inv_flagged for case=0..7.""" + ts = [] + allflags = [] + for case in range(8): + t, flags = xswiftec_inv_flagged(x, u, case) + if t is not None: + assert x == xswiftec(u, t) + ts.append(t) + allflags.append(f"case{case}:{'&'.join(flags)}") + return ts, ";".join(allflags) + +def gen_all_xswiftec_inv_vectors(fil): + """Generate all xswiftec_inv test vectors.""" + + # Two constants used below. Compute them only once. + C1 = (FE(MINUS_3_SQRT) - 1) / 2 + C2 = (-FE(MINUS_3_SQRT) - 1) / 2 + # Helper functions that pick x and u with special properties. + TRIGGER_Q_ZERO = fn_of(1, lambda u: (FE(u)**3 + 28) / (FE(-3) * FE(u)**2)) + TRIGGER_DIVZERO_A = fn_of(1, lambda u: FE(u) * C1) + TRIGGER_DIVZERO_B = fn_of(1, lambda u: FE(u) * C2) + TRIGGER_V_ZERO = fn_of(1, lambda u: FE(-7) / FE(u)**2) + TRIGGER_X_ZERO = fn_of(0, lambda x: FE(-2) * FE(x)) + + inputs = [] + tuple_expand(inputs, [random_fe_int, random_fe_int], 0, b"uniform", 256) + tuple_expand(inputs, [TRIGGER_Q_ZERO, random_fe_int], 1, b"x=-(u^3+28)/(3*u^2)", 64) + tuple_expand(inputs, [TRIGGER_V_ZERO, random_fe_int], 1, b"x=-7/u^2", 512) + tuple_expand(inputs, [random_fe_int, fn_of(0, lambda x: x)], 2, b"u=x", 64) + tuple_expand(inputs, [random_fe_int, fn_of(0, lambda x: -FE(x))], 2, b"u=-x", 64) + # Unused. + tuple_expand(inputs, [TRIGGER_DIVZERO_A, random_fe_int], 3, b"x=u*(sqrt(-3)-1)/2", 64) + tuple_expand(inputs, [TRIGGER_DIVZERO_B, random_fe_int], 3, b"x=u*(-sqrt(-3)-1)/2", 64) + tuple_expand(inputs, [random_fe_int, TRIGGER_X_ZERO], 3, b"u=-2x", 64) + + seen = set() + cases = [] + for _prio, _cnt, _hash, vs, _seed in sorted(inputs): + x, u = FE(vs[0]), FE(vs[1]) + if u == 0: + continue + if not GE.is_valid_x(x): + continue + ts, flags = xswiftec_inv_combo_flagged(x, u) + if flags not in seen: + cases.append((int(u), int(x), ts, flags)) + seen.add(flags) + + writer = csv.DictWriter(fil, ["u", "x"] + [f"case{c}_t" for c in range(8)] + ["comment"]) + writer.writeheader() + for u, x, ts, flags in sorted(cases): + row = {"u": FE(u), "x": FE(x), "comment": flags} + for c in range(8): + if ts[c] is not None: + row[f"case{c}_t"] = FE(ts[c]) + writer.writerow(row) + +def gen_packet_encoding_vector(case): + """Given a dict case with specs, construct a packet_encoding test vector as a CSV line.""" + ikm = str(case).encode('utf-8') + in_initiating = case["init"] + in_ignore = int(case["ignore"]) + in_priv_ours, in_ellswift_ours = ellswift_create_deterministic(ikm, case["features"]) + mid_x_ours = (int.from_bytes(in_priv_ours, 'big') * SECP256K1_G).x.to_bytes() + assert mid_x_ours == ellswift_decode(in_ellswift_ours) + in_ellswift_theirs = case["theirs"] + in_contents = hkdf_sha256(case["contentlen"], ikm, b"contents", b"") + contents = in_contents * case["multiply"] + in_aad = hkdf_sha256(case["aadlen"], ikm, b"aad", b"") + mid_shared_secret = v2_ecdh(in_priv_ours, in_ellswift_theirs, in_ellswift_ours, in_initiating) + + peer = initialize_v2_transport(mid_shared_secret, in_initiating) + for _ in range(case["idx"]): + v2_enc_packet(peer, b"") + ciphertext = v2_enc_packet(peer, contents, in_aad, case["ignore"]) + long_msg = len(ciphertext) > 128 + + return { + "in_idx": case['idx'], + "in_priv_ours": in_priv_ours.hex(), + "in_ellswift_ours": in_ellswift_ours.hex(), + "in_ellswift_theirs": in_ellswift_theirs.hex(), + "in_initiating": int(in_initiating), + "in_contents": in_contents.hex(), + "in_multiply": case['multiply'], + "in_aad": in_aad.hex(), + "in_ignore": in_ignore, + "mid_x_ours": mid_x_ours.hex(), + "mid_x_theirs": ellswift_decode(in_ellswift_theirs).hex(), + "mid_x_shared": ellswift_ecdh_xonly(in_ellswift_theirs, in_priv_ours).hex(), + "mid_shared_secret": mid_shared_secret.hex(), + "mid_initiator_l": peer['initiator_L'].hex(), + "mid_initiator_p": peer['initiator_P'].hex(), + "mid_responder_l": peer['responder_L'].hex(), + "mid_responder_p": peer['responder_P'].hex(), + "mid_send_garbage_terminator": peer["send_garbage_terminator"].hex(), + "mid_recv_garbage_terminator": peer["recv_garbage_terminator"].hex(), + "out_session_id": peer["session_id"].hex(), + "out_ciphertext": "" if long_msg else ciphertext.hex(), + "out_ciphertext_endswith": ciphertext[-128:].hex() if long_msg else "" + } + +def gen_all_packet_encoding_vectors(fil): + """Return a list of CSV lines, one for each packet encoding vector.""" + + ellswift = gen_ellswift_decode_cases(b"simplified_", simplified=True) + ellswift.sort(key=lambda x: hashlib.sha256(b"simplified:" + x[0]).digest()) + + fields = [ + "in_idx", "in_priv_ours", "in_ellswift_ours", "in_ellswift_theirs", "in_initiating", + "in_contents", "in_multiply", "in_aad", "in_ignore", "mid_x_ours", "mid_x_theirs", + "mid_x_shared", "mid_shared_secret", "mid_initiator_l", "mid_initiator_p", + "mid_responder_l", "mid_responder_p", "mid_send_garbage_terminator", + "mid_recv_garbage_terminator", "out_session_id", "out_ciphertext", "out_ciphertext_endswith" + ] + + writer = csv.DictWriter(fil, fields) + writer.writeheader() + for case in [ + {"init": True, "contentlen": 1, "multiply": 1, "aadlen": 0, "ignore": False, "idx": 1, + "theirs": ellswift[0][0], "features": 0}, + {"init": False, "contentlen": 17, "multiply": 1, "aadlen": 0, "ignore": False, "idx": 999, + "theirs": ellswift[1][0], "features": 1}, + {"init": True, "contentlen": 63, "multiply": 1, "aadlen": 4095, "ignore": False, "idx": 0, + "theirs": ellswift[2][0], "features": 2}, + {"init": False, "contentlen": 128, "multiply": 1, "aadlen": 0, "ignore": True, "idx": 223, + "theirs": ellswift[3][0], "features": 3}, + {"init": True, "contentlen": 193, "multiply": 1, "aadlen": 0, "ignore": False, "idx": 448, + "theirs": ellswift[4][0], "features": 4}, + {"init": False, "contentlen": 41, "multiply": 97561, "aadlen": 0, "ignore": False, + "idx": 673, "theirs": ellswift[5][0], "features": 5}, + {"init": True, "contentlen": 241, "multiply": 69615, "aadlen": 0, "ignore": True, + "idx": 1024, "theirs": ellswift[6][0], "features": 6}, + ]: + writer.writerow(gen_packet_encoding_vector(case)) + +if __name__ == "__main__": + print(f"Generating {FILENAME_PACKET_TEST}...") + with open(FILENAME_PACKET_TEST, "w", encoding="utf-8") as fil_packet: + gen_all_packet_encoding_vectors(fil_packet) + print(f"Generating {FILENAME_XSWIFTEC_INV_TEST}...") + with open(FILENAME_XSWIFTEC_INV_TEST, "w", encoding="utf-8") as fil_xswiftec_inv: + gen_all_xswiftec_inv_vectors(fil_xswiftec_inv) + print(f"Generating {FILENAME_ELLSWIFT_DECODE_TEST}...") + with open(FILENAME_ELLSWIFT_DECODE_TEST, "w", encoding="utf-8") as fil_ellswift_decode: + gen_all_ellswift_decode_vectors(fil_ellswift_decode) |