summaryrefslogtreecommitdiff
path: root/bip-0324/gen_test_vectors.py
diff options
context:
space:
mode:
Diffstat (limited to 'bip-0324/gen_test_vectors.py')
-rw-r--r--bip-0324/gen_test_vectors.py418
1 files changed, 418 insertions, 0 deletions
diff --git a/bip-0324/gen_test_vectors.py b/bip-0324/gen_test_vectors.py
new file mode 100644
index 0000000..05b30a8
--- /dev/null
+++ b/bip-0324/gen_test_vectors.py
@@ -0,0 +1,418 @@
+"""Generate the BIP-0324 test vectors."""
+
+import csv
+import hashlib
+import os
+import sys
+from reference import (
+ FE,
+ GE,
+ MINUS_3_SQRT,
+ hkdf_sha256,
+ SECP256K1_G,
+ ellswift_decode,
+ ellswift_ecdh_xonly,
+ xswiftec_inv,
+ xswiftec,
+ v2_ecdh,
+ initialize_v2_transport,
+ v2_enc_packet
+)
+
+FILENAME_PACKET_TEST = os.path.join(sys.path[0], 'packet_encoding_test_vectors.csv')
+FILENAME_XSWIFTEC_INV_TEST = os.path.join(sys.path[0], 'xswiftec_inv_test_vectors.csv')
+FILENAME_ELLSWIFT_DECODE_TEST = os.path.join(sys.path[0], 'ellswift_decode_test_vectors.csv')
+
+def xswiftec_flagged(u, t, simplified=False):
+ """A variant of xswiftec which also returns 'flags', describing conditions encountered."""
+ flags = []
+ if u == 0:
+ flags.append("u%p=0")
+ u = FE(1)
+ if t == 0:
+ flags.append("t%p=0")
+ t = FE(1)
+ if u**3 + t**2 + 7 == 0:
+ flags.append("(u'^3+t'^2+7)%p=0")
+ t = 2 * t
+ X = (u**3 + 7 - t**2) / (2 * t)
+ Y = (X + t) / (MINUS_3_SQRT * u)
+ if X == 0:
+ if not simplified:
+ flags.append("(u'^3-t'^2+7)%p=0")
+ x3 = u + 4 * Y**2
+ if GE.is_valid_x(x3):
+ flags.append("valid_x(x3)")
+ x2 = (-X / Y - u) / 2
+ if GE.is_valid_x(x2):
+ flags.append("valid_x(x2)")
+ x1 = (X / Y - u) / 2
+ if GE.is_valid_x(x1):
+ flags.append("valid_x(x1)")
+ for x in (x3, x2, x1):
+ if GE.is_valid_x(x):
+ break
+ return x, flags
+
+
+def ellswift_create_deterministic(seed, features):
+ """This is a variant of ellswift_create which doesn't use randomness.
+
+ features is an integer selecting some properties of the result:
+ - (f & 3) == 0: only x1 is valid on decoding (see xswiftec{_flagged})
+ - (f & 3) == 1: only x2 is valid on decoding
+ - (f & 3) == 2: only x3 is valid on decoding
+ - (f & 3) == 3: x1,x2,x3 are all valid on decoding
+ - (f & 4) == 4: u >= p
+ - (f & 8) == 8: u mod n == 0
+
+ Returns privkey, ellswift
+ """
+
+ cnt = 0
+ while True:
+ sec = hkdf_sha256(32, seed, (cnt).to_bytes(4, 'little'), b"sec")
+ xval = (int.from_bytes(sec, 'big') * SECP256K1_G).x
+ cnt += 1
+ if features & 8:
+ u = 0
+ if features & 4:
+ u += FE.SIZE
+ else:
+ udat = hkdf_sha256(64, seed, (cnt).to_bytes(4, 'little'), b"u")
+ if features & 4:
+ u = FE.SIZE + 1 + int.from_bytes(udat, 'big') % (2**256 - FE.SIZE - 1)
+ else:
+ u = 1 + int.from_bytes(udat, 'big') % (FE.SIZE - 1)
+ case = hkdf_sha256(1, seed, (cnt).to_bytes(4, 'little'), b"case")[0] & 7
+ coru = FE(u) + ((features & 8) == 8)
+ t = xswiftec_inv(xval, coru, case)
+ if t is None:
+ continue
+ assert xswiftec(FE(u), t) == xval
+ x2, flags = xswiftec_flagged(FE(u), t)
+ assert x2 == xval
+ have_x1 = "valid_x(x1)" in flags
+ have_x2 = "valid_x(x2)" in flags
+ have_x3 = "valid_x(x3)" in flags
+ if (features & 4) == 0 and not (have_x1 and not have_x2 and not have_x3):
+ continue
+ if (features & 4) == 1 and not (not have_x1 and have_x2 and not have_x3):
+ continue
+ if (features & 4) == 2 and not (not have_x1 and not have_x2 and have_x3):
+ continue
+ if (features & 4) == 3 and not (have_x1 and have_x2 and have_x3):
+ continue
+ return sec, u.to_bytes(32, 'big') + t.to_bytes()
+
+def ellswift_decode_flagged(ellswift, simplified=False):
+ """Decode a 64-byte ElligatorSwift encoded coordinate, returning byte array + flag string."""
+ uv = int.from_bytes(ellswift[:32], 'big')
+ tv = int.from_bytes(ellswift[32:], 'big')
+ x, flags = xswiftec_flagged(FE(uv), FE(tv))
+ if not simplified:
+ if uv >= FE.SIZE:
+ flags.append("u>=p")
+ if tv >= FE.SIZE:
+ flags.append("t>=p")
+ return int(x).to_bytes(32, 'big'), ";".join(flags)
+
+def random_fe_int(_, seed, i, p):
+ """Function to use in tuple_expand, generating a random integer in 0..p-1."""
+ rng_out = hkdf_sha256(64, seed, i.to_bytes(4, 'little'), b"v%i_fe" % p)
+ return int.from_bytes(rng_out, 'big') % FE.SIZE
+
+def random_fe_int_high(_, seed, i, p):
+ """Function to use in tuple_expand, generating a random integer in p..2^256-1."""
+ rng_out = hkdf_sha256(64, seed, i.to_bytes(4, 'little'), b"v%i_fe_high" % p)
+ return FE.SIZE + int.from_bytes(rng_out, 'big') % (2**256 - FE.SIZE)
+
+def fn_of(p_in, fn):
+ """Function to use in tuple_expand, to pick one variable in function of another."""
+ def inner(vs, _seed, _i, p):
+ assert p != p_in
+ if isinstance(vs[p_in], int):
+ return fn(vs[p_in])
+ return None
+ return inner
+
+def tuple_expand(out, tuplespec, prio, seed=None, cnt=1):
+ """Given a tuple specification, expand it cnt times, and add results to out.
+
+ Expansion is defined recursively:
+ - If any of the spec elements is a list, each element of the list results
+ in an expansion (by replacing the list with its element).
+ - If any of the spec elements is a function, that function is invoked with
+ (spec, seed, expansion count, index in spec) as arguments. If the function
+ needs to wait for other indices to be expanded, it can return None.
+
+ The output consists of (prio, expansion count, SHA256(result), result, seed)
+ tuples."""
+
+ def recurse(vs, seed, i, change_pos=None, change=None):
+ if change_pos is not None:
+ vs = list(vs)
+ vs[change_pos] = change
+ for p, v in enumerate(vs):
+ if v is None:
+ return
+ if isinstance(v, list):
+ for ve in v:
+ recurse(vs, seed, i, p, ve)
+ return
+ if callable(v):
+ res = v(vs, seed, i, p)
+ if res is not None:
+ recurse(vs, seed, i, p, res)
+ return
+ h = hashlib.sha256()
+ for v in vs:
+ h.update(int(v).to_bytes(32, 'big'))
+ out.append((prio, i, h.digest(), vs, seed))
+ for i in range(cnt):
+ recurse(tuplespec, seed, i)
+
+def gen_ellswift_decode_cases(seed, simplified=False):
+ """Generate a set of interesting (ellswift, x, flags) ellswift decoding cases."""
+ inputs = []
+
+ # Aggregate for use in tuple_expand, expanding to int in 0..p-1, and one in p..2^256-1.
+ RANDOM_VAL = [random_fe_int, random_fe_int_high]
+ # Aggregate for use in tuple_expand, expanding to integers which %p equal 0.
+ ZERO_VAL = [0, FE.SIZE]
+ # Helpers for constructing u and t values such that u^3+t^2+7=0 or u^3-t^2+7=0.
+ T_FOR_SUM_ZERO = fn_of(0, lambda u: (-FE(u)**3 - 7).sqrts())
+ T_FOR_DIFF_ZERO = fn_of(0, lambda u: (FE(u)**3 + 7).sqrts())
+ U_FOR_SUM_ZERO = fn_of(1, lambda t: (-FE(t)**2 - 7).cbrts())
+ U_FOR_DIFF_ZERO = fn_of(1, lambda t: (FE(t)**2 - 7).cbrts())
+
+ tuple_expand(inputs, [RANDOM_VAL, RANDOM_VAL], 0, seed + b"random", 64)
+ tuple_expand(inputs, [RANDOM_VAL, T_FOR_SUM_ZERO], 1, seed + b"t=sqrt(-u^3-7)", 64)
+ tuple_expand(inputs, [U_FOR_SUM_ZERO, RANDOM_VAL], 1, seed + b"u=cbrt(-t^2-7)", 64)
+ tuple_expand(inputs, [RANDOM_VAL, T_FOR_DIFF_ZERO], 1, seed + b"t=sqrt(u^3+7)", 64)
+ tuple_expand(inputs, [U_FOR_DIFF_ZERO, RANDOM_VAL], 1, seed + b"u=cbrt(t^2-7)", 64)
+ tuple_expand(inputs, [ZERO_VAL, RANDOM_VAL], 2, seed + b"u=0", 64)
+ tuple_expand(inputs, [RANDOM_VAL, ZERO_VAL], 2, seed + b"t=0", 64)
+ tuple_expand(inputs, [ZERO_VAL, FE(8).sqrts()], 3, seed + b"u=0;t=sqrt(8)")
+ tuple_expand(inputs, [FE(-8).cbrts(), ZERO_VAL], 3, seed + b"t=0;u=cbrt(-8)")
+ tuple_expand(inputs, [FE(-6).cbrts(), ZERO_VAL], 3, seed + b"t=0;u=cbrt(-6)")
+ tuple_expand(inputs, [ZERO_VAL, ZERO_VAL], 3, seed + b"u=0;t=0")
+ # Unused.
+ tuple_expand(inputs, [ZERO_VAL, FE(-8).sqrts()], 4, seed + b"u=0;t=sqrt(-8)")
+
+ seen = set()
+ cases = []
+ for _prio, _cnt, _hash, vs, _seed in sorted(inputs):
+ inp = int(vs[0]).to_bytes(32, 'big') + int(vs[1]).to_bytes(32, 'big')
+ outp, flags = ellswift_decode_flagged(inp, simplified)
+ if flags not in seen:
+ cases.append((inp, outp, flags))
+ seen.add(flags)
+
+ return cases
+
+def gen_all_ellswift_decode_vectors(fil):
+ """Generate all xelligatorswift decoding test vectors."""
+
+ cases = gen_ellswift_decode_cases(b"")
+ writer = csv.DictWriter(fil, ["ellswift", "x", "comment"])
+ writer.writeheader()
+ for val, x, flags in sorted(cases):
+ writer.writerow({"ellswift": val.hex(), "x": x.hex(), "comment": flags})
+
+def xswiftec_inv_flagged(x, u, case):
+ """A variant of xswiftec_inv which also returns flags, describing conditions encountered."""
+
+ flags = []
+
+ if case & 2 == 0:
+ if GE.is_valid_x(-x - u):
+ flags.append("bad[valid_x(-x-u)]")
+ return None, flags
+ v = x if case & 1 == 0 else -x - u
+ if v == 0:
+ flags.append("info[v=0]")
+ s = -(u**3 + 7) / (u**2 + u*v + v**2)
+ assert s != 0 # would imply X=0 on curve
+ else:
+ s = x - u
+ if s == 0:
+ flags.append("bad[s=0]")
+ return None, flags
+ q = (-s * (4 * (u**3 + 7) + 3 * s * u**2))
+ if q == 0:
+ flags.append("info[q=0]")
+ r = q.sqrt()
+ if r is None:
+ flags.append("bad[non_square(q)]")
+ return None, flags
+ if case & 1:
+ if r == 0:
+ flags.append("bad[r=0]")
+ return None, flags
+ r = -r
+ v = (-u + r / s) / 2
+ if v == 0:
+ flags.append("info[v=0]")
+ w = s.sqrt()
+ assert w != 0
+ if w is None:
+ flags.append("bad[non_square(s)]")
+ return None, flags
+ if case & 4:
+ w = -w
+ Y = w / 2
+ assert Y != 0
+ X = 2 * Y * (v + u / 2)
+ if X == 0:
+ flags.append("info[X=0]")
+ flags.append("ok")
+ return w * (u * (MINUS_3_SQRT - 1) / 2 - v), flags
+
+def xswiftec_inv_combo_flagged(x, u):
+ """Compute the aggregate results and flags from xswiftec_inv_flagged for case=0..7."""
+ ts = []
+ allflags = []
+ for case in range(8):
+ t, flags = xswiftec_inv_flagged(x, u, case)
+ if t is not None:
+ assert x == xswiftec(u, t)
+ ts.append(t)
+ allflags.append(f"case{case}:{'&'.join(flags)}")
+ return ts, ";".join(allflags)
+
+def gen_all_xswiftec_inv_vectors(fil):
+ """Generate all xswiftec_inv test vectors."""
+
+ # Two constants used below. Compute them only once.
+ C1 = (FE(MINUS_3_SQRT) - 1) / 2
+ C2 = (-FE(MINUS_3_SQRT) - 1) / 2
+ # Helper functions that pick x and u with special properties.
+ TRIGGER_Q_ZERO = fn_of(1, lambda u: (FE(u)**3 + 28) / (FE(-3) * FE(u)**2))
+ TRIGGER_DIVZERO_A = fn_of(1, lambda u: FE(u) * C1)
+ TRIGGER_DIVZERO_B = fn_of(1, lambda u: FE(u) * C2)
+ TRIGGER_V_ZERO = fn_of(1, lambda u: FE(-7) / FE(u)**2)
+ TRIGGER_X_ZERO = fn_of(0, lambda x: FE(-2) * FE(x))
+
+ inputs = []
+ tuple_expand(inputs, [random_fe_int, random_fe_int], 0, b"uniform", 256)
+ tuple_expand(inputs, [TRIGGER_Q_ZERO, random_fe_int], 1, b"x=-(u^3+28)/(3*u^2)", 64)
+ tuple_expand(inputs, [TRIGGER_V_ZERO, random_fe_int], 1, b"x=-7/u^2", 512)
+ tuple_expand(inputs, [random_fe_int, fn_of(0, lambda x: x)], 2, b"u=x", 64)
+ tuple_expand(inputs, [random_fe_int, fn_of(0, lambda x: -FE(x))], 2, b"u=-x", 64)
+ # Unused.
+ tuple_expand(inputs, [TRIGGER_DIVZERO_A, random_fe_int], 3, b"x=u*(sqrt(-3)-1)/2", 64)
+ tuple_expand(inputs, [TRIGGER_DIVZERO_B, random_fe_int], 3, b"x=u*(-sqrt(-3)-1)/2", 64)
+ tuple_expand(inputs, [random_fe_int, TRIGGER_X_ZERO], 3, b"u=-2x", 64)
+
+ seen = set()
+ cases = []
+ for _prio, _cnt, _hash, vs, _seed in sorted(inputs):
+ x, u = FE(vs[0]), FE(vs[1])
+ if u == 0:
+ continue
+ if not GE.is_valid_x(x):
+ continue
+ ts, flags = xswiftec_inv_combo_flagged(x, u)
+ if flags not in seen:
+ cases.append((int(u), int(x), ts, flags))
+ seen.add(flags)
+
+ writer = csv.DictWriter(fil, ["u", "x"] + [f"case{c}_t" for c in range(8)] + ["comment"])
+ writer.writeheader()
+ for u, x, ts, flags in sorted(cases):
+ row = {"u": FE(u), "x": FE(x), "comment": flags}
+ for c in range(8):
+ if ts[c] is not None:
+ row[f"case{c}_t"] = FE(ts[c])
+ writer.writerow(row)
+
+def gen_packet_encoding_vector(case):
+ """Given a dict case with specs, construct a packet_encoding test vector as a CSV line."""
+ ikm = str(case).encode('utf-8')
+ in_initiating = case["init"]
+ in_ignore = int(case["ignore"])
+ in_priv_ours, in_ellswift_ours = ellswift_create_deterministic(ikm, case["features"])
+ mid_x_ours = (int.from_bytes(in_priv_ours, 'big') * SECP256K1_G).x.to_bytes()
+ assert mid_x_ours == ellswift_decode(in_ellswift_ours)
+ in_ellswift_theirs = case["theirs"]
+ in_contents = hkdf_sha256(case["contentlen"], ikm, b"contents", b"")
+ contents = in_contents * case["multiply"]
+ in_aad = hkdf_sha256(case["aadlen"], ikm, b"aad", b"")
+ mid_shared_secret = v2_ecdh(in_priv_ours, in_ellswift_theirs, in_ellswift_ours, in_initiating)
+
+ peer = initialize_v2_transport(mid_shared_secret, in_initiating)
+ for _ in range(case["idx"]):
+ v2_enc_packet(peer, b"")
+ ciphertext = v2_enc_packet(peer, contents, in_aad, case["ignore"])
+ long_msg = len(ciphertext) > 128
+
+ return {
+ "in_idx": case['idx'],
+ "in_priv_ours": in_priv_ours.hex(),
+ "in_ellswift_ours": in_ellswift_ours.hex(),
+ "in_ellswift_theirs": in_ellswift_theirs.hex(),
+ "in_initiating": int(in_initiating),
+ "in_contents": in_contents.hex(),
+ "in_multiply": case['multiply'],
+ "in_aad": in_aad.hex(),
+ "in_ignore": in_ignore,
+ "mid_x_ours": mid_x_ours.hex(),
+ "mid_x_theirs": ellswift_decode(in_ellswift_theirs).hex(),
+ "mid_x_shared": ellswift_ecdh_xonly(in_ellswift_theirs, in_priv_ours).hex(),
+ "mid_shared_secret": mid_shared_secret.hex(),
+ "mid_initiator_l": peer['initiator_L'].hex(),
+ "mid_initiator_p": peer['initiator_P'].hex(),
+ "mid_responder_l": peer['responder_L'].hex(),
+ "mid_responder_p": peer['responder_P'].hex(),
+ "mid_send_garbage_terminator": peer["send_garbage_terminator"].hex(),
+ "mid_recv_garbage_terminator": peer["recv_garbage_terminator"].hex(),
+ "out_session_id": peer["session_id"].hex(),
+ "out_ciphertext": "" if long_msg else ciphertext.hex(),
+ "out_ciphertext_endswith": ciphertext[-128:].hex() if long_msg else ""
+ }
+
+def gen_all_packet_encoding_vectors(fil):
+ """Return a list of CSV lines, one for each packet encoding vector."""
+
+ ellswift = gen_ellswift_decode_cases(b"simplified_", simplified=True)
+ ellswift.sort(key=lambda x: hashlib.sha256(b"simplified:" + x[0]).digest())
+
+ fields = [
+ "in_idx", "in_priv_ours", "in_ellswift_ours", "in_ellswift_theirs", "in_initiating",
+ "in_contents", "in_multiply", "in_aad", "in_ignore", "mid_x_ours", "mid_x_theirs",
+ "mid_x_shared", "mid_shared_secret", "mid_initiator_l", "mid_initiator_p",
+ "mid_responder_l", "mid_responder_p", "mid_send_garbage_terminator",
+ "mid_recv_garbage_terminator", "out_session_id", "out_ciphertext", "out_ciphertext_endswith"
+ ]
+
+ writer = csv.DictWriter(fil, fields)
+ writer.writeheader()
+ for case in [
+ {"init": True, "contentlen": 1, "multiply": 1, "aadlen": 0, "ignore": False, "idx": 1,
+ "theirs": ellswift[0][0], "features": 0},
+ {"init": False, "contentlen": 17, "multiply": 1, "aadlen": 0, "ignore": False, "idx": 999,
+ "theirs": ellswift[1][0], "features": 1},
+ {"init": True, "contentlen": 63, "multiply": 1, "aadlen": 4095, "ignore": False, "idx": 0,
+ "theirs": ellswift[2][0], "features": 2},
+ {"init": False, "contentlen": 128, "multiply": 1, "aadlen": 0, "ignore": True, "idx": 223,
+ "theirs": ellswift[3][0], "features": 3},
+ {"init": True, "contentlen": 193, "multiply": 1, "aadlen": 0, "ignore": False, "idx": 448,
+ "theirs": ellswift[4][0], "features": 4},
+ {"init": False, "contentlen": 41, "multiply": 97561, "aadlen": 0, "ignore": False,
+ "idx": 673, "theirs": ellswift[5][0], "features": 5},
+ {"init": True, "contentlen": 241, "multiply": 69615, "aadlen": 0, "ignore": True,
+ "idx": 1024, "theirs": ellswift[6][0], "features": 6},
+ ]:
+ writer.writerow(gen_packet_encoding_vector(case))
+
+if __name__ == "__main__":
+ print(f"Generating {FILENAME_PACKET_TEST}...")
+ with open(FILENAME_PACKET_TEST, "w", encoding="utf-8") as fil_packet:
+ gen_all_packet_encoding_vectors(fil_packet)
+ print(f"Generating {FILENAME_XSWIFTEC_INV_TEST}...")
+ with open(FILENAME_XSWIFTEC_INV_TEST, "w", encoding="utf-8") as fil_xswiftec_inv:
+ gen_all_xswiftec_inv_vectors(fil_xswiftec_inv)
+ print(f"Generating {FILENAME_ELLSWIFT_DECODE_TEST}...")
+ with open(FILENAME_ELLSWIFT_DECODE_TEST, "w", encoding="utf-8") as fil_ellswift_decode:
+ gen_all_ellswift_decode_vectors(fil_ellswift_decode)