diff options
Diffstat (limited to 'bip-0324/reference.py')
-rw-r--r-- | bip-0324/reference.py | 150 |
1 files changed, 112 insertions, 38 deletions
diff --git a/bip-0324/reference.py b/bip-0324/reference.py index e07731b..f02c44a 100644 --- a/bip-0324/reference.py +++ b/bip-0324/reference.py @@ -1,3 +1,5 @@ +"""Reference implementation for the cryptographic aspects of BIP-324""" + import sys import random import hashlib @@ -70,7 +72,7 @@ class FE: self.den = (a.den * b.num) % FE.SIZE else: self.num = (a * b.den) % FE.SIZE - self.den = a.num + self.den = b.num else: b = b % FE.SIZE assert b != 0 @@ -85,8 +87,7 @@ class FE: """Compute the sum of two field elements (second may be int).""" if isinstance(a, FE): return FE(self.num * a.den + self.den * a.num, self.den * a.den) - else: - return FE(self.num + self.den * a, self.den) + return FE(self.num + self.den * a, self.den) def __radd__(self, a): """Compute the sum of an integer and a field element.""" @@ -96,8 +97,7 @@ class FE: """Compute the difference of two field elements (second may be int).""" if isinstance(a, FE): return FE(self.num * a.den - self.den * a.num, self.den * a.den) - else: - return FE(self.num - self.den * a, self.den) + return FE(self.num - self.den * a, self.den) def __rsub__(self, a): """Compute the difference between an integer and a field element.""" @@ -107,8 +107,7 @@ class FE: """Compute the product of two field elements (second may be int).""" if isinstance(a, FE): return FE(self.num * a.num, self.den * a.den) - else: - return FE(self.num * a, self.den) + return FE(self.num * a, self.den) def __rmul__(self, a): """Compute the product of an integer with a field element.""" @@ -140,15 +139,57 @@ class FE: def sqrt(self): """Compute the square root of a field element. - Due to the fact that our modulus is of the form (p % 4) == 3, the Tonelli-Shanks - algorithm (https://en.wikipedia.org/wiki/Tonelli-Shanks_algorithm) is simply - raising the argument to the power (p + 3) / 4.""" + Due to the fact that our modulus p is of the form p = 3 (mod 4), the + Tonelli-Shanks algorithm (https://en.wikipedia.org/wiki/Tonelli-Shanks_algorithm) + is simply raising the argument to the power (p + 1) / 4. + + To see why: p-1 = 0 (mod 2), so 2 divides the order of the multiplicative group, + and thus only half of the non-zero field elements are squares. An element a is + a (nonzero) square when Euler's criterion, a^((p-1)/2) = 1 (mod p), holds. We're + looking for x such that x^2 = a (mod p). Given a^((p-1)/2) = 1 (mod p), that is + equivalent to x^2 = a^(1 + (p-1)/2) (mod p). As (1 + (p-1)/2) is even, this is + equivalent to x = a^((1 + (p-1)/2)/2) (mod p), or x = a^((p+1)/4) (mod p).""" v = int(self) s = pow(v, (FE.SIZE + 1) // 4, FE.SIZE) if s**2 % FE.SIZE == v: return FE(s) return None + def sqrts(self): + """Compute all square roots of a field element, if any.""" + s = self.sqrt() + if s is None: + return [] + return [FE(s), -FE(s)] + + # The cube roots of 1 (mod p). + CBRT1 = [ + 1, + 0x851695d49a83f8ef919bb86153cbcb16630fb68aed0a766a3ec693d68e6afa40, + 0x7ae96a2b657c07106e64479eac3434e99cf0497512f58995c1396c28719501ee + ] + + + def cbrts(self): + """Compute all cube roots of a field element, if any. + + Due to the fact that our modulus p is of the form p = 7 (mod 9), one cube root + can always be computed by raising to the power (p + 2) / 9. The other roots + (if any) can be found by multiplying with the two non-trivial cube roots of 1. + + To see why: p-1 = 0 (mod 3), so 3 divides the order of the multiplicative group, + and thus only 1/3 of the non-zero field elements are cubes. An element a is a + (nonzero) cube when a^((p-1)/3) = 1 (mod p). We're looking for x such that + x^3 = a (mod p). Given a^((p-1)/3) = 1 (mod p), that is equivalent to + x^3 = a^(1 + (p-1)/3) (mod p). As (1 + (p-1)/3) is a multiple of 3, this is + equivalent to x = a^((1 + (p-1)/3)/3) (mod p), or x = a^((p+2)/9) (mod p).""" + v = int(self) + c = pow(v, (FE.SIZE + 2) // 9, FE.SIZE) + + if pow(c, 3, FE.SIZE) == v: + return [FE(c * f) for f in FE.CBRT1] + return [] + def is_square(self): """Determine if this field element has a square root.""" # Compute the Jacobi symbol of (self / p). Since our modulus is prime, this @@ -161,7 +202,7 @@ class FE: while n & 1 == 0: n >>= 1 r = k & 7 - t ^= (r == 3 or r == 5) + t ^= (r in (3, 5)) n, k = k, n t ^= (n & k & 3 == 3) n = n % k @@ -172,8 +213,7 @@ class FE: """Check whether two field elements are equal (second may be an int).""" if isinstance(a, FE): return (self.num * a.den - self.den * a.num) % FE.SIZE == 0 - else: - return (self.num - self.den * a) % FE.SIZE == 0 + return (self.num - self.den * a) % FE.SIZE == 0 def to_bytes(self): """Convert a field element to 32-byte big endian encoding.""" @@ -187,6 +227,16 @@ class FE: return None return FE(v) + def __str__(self): + """Convert this field element to a string.""" + return f"{int(self):064x}" + + def __repr__(self): + """Get a string representation of this field element.""" + return f"FE(0x{int(self):x})" + +assert all(pow(c, 3, FE.SIZE) == 1 for c in FE.CBRT1) + class GE: """Objects of this class represent points (group elements) on the secp256k1 curve. @@ -221,12 +271,11 @@ class GE: x3 = l**2 - self.x - a.x y3 = l * (self.x - x3) - self.y return GE(x3, y3) - elif self.y == a.y: + if self.y == a.y: # Adding point to itself return self.double() - else: - # Adding point to its negation - return None + # Adding point to its negation + return None def __radd__(self, a): """Add infinity to a point.""" @@ -260,13 +309,21 @@ class GE: """Determine whether the provided field element is a valid X coordinate.""" return (FE(x)**3 + 7).is_square() + def __str__(self): + """Convert this group element to a string.""" + return f"({self.x},{self.y})" + + def __repr__(self): + """Get a string representation for this group element.""" + return f"GE(0x{int(self.x)},0x{int(self.y)})" + SECP256K1_G = GE( 0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8) ### ElligatorSwift -# Precomputed constant square root of -3 modulo p. +# Precomputed constant square root of -3 (mod p). MINUS_3_SQRT = FE(-3).sqrt() def xswiftec(u, t): @@ -292,7 +349,7 @@ def xswiftec_inv(x, u, case): if case & 2 == 0: if GE.is_valid_x(-x - u): return None - v = x if case & 1 == 0 else -x - u + v = x s = -(u**3 + 7) / (u**2 + u*v + v**2) else: s = x - u @@ -301,17 +358,16 @@ def xswiftec_inv(x, u, case): r = (-s * (4 * (u**3 + 7) + 3 * s * u**2)).sqrt() if r is None: return None - if case & 1: - if r == 0: - return None - r = -r + if case & 1 and r == 0: + return None v = (-u + r / s) / 2 w = s.sqrt() if w is None: return None - if case & 4: - w = -w - return w * (u * (MINUS_3_SQRT - 1) / 2 - v) + if case & 5 == 0: return -w * (u * (1 - MINUS_3_SQRT) / 2 + v) + if case & 5 == 1: return w * (u * (1 + MINUS_3_SQRT) / 2 + v) + if case & 5 == 4: return w * (u * (1 - MINUS_3_SQRT) / 2 + v) + if case & 5 == 5: return -w * (u * (1 + MINUS_3_SQRT) / 2 + v) def xelligatorswift(x): """Given a field element X on the curve, find (u, t) that encode them.""" @@ -328,12 +384,17 @@ def ellswift_create(): u, t = xelligatorswift((priv * SECP256K1_G).x) return priv.to_bytes(32, 'big'), u.to_bytes() + t.to_bytes() +def ellswift_decode(ellswift): + """Convert ellswift encoded X coordinate to 32-byte xonly format.""" + u = FE(int.from_bytes(ellswift[:32], 'big')) + t = FE(int.from_bytes(ellswift[32:], 'big')) + return xswiftec(u, t).to_bytes() + def ellswift_ecdh_xonly(pubkey_theirs, privkey): """Compute X coordinate of shared ECDH point between elswift pubkey and privkey.""" - u = FE(int.from_bytes(pubkey_theirs[:32], 'big')) - t = FE(int.from_bytes(pubkey_theirs[32:], 'big')) d = int.from_bytes(privkey, 'big') - return (d * GE.lift_x(xswiftec(u, t))).x.to_bytes() + pub = ellswift_decode(pubkey_theirs) + return (d * GE.lift_x(FE.from_bytes(pub))).x.to_bytes() ### Poly1305 @@ -402,7 +463,7 @@ def chacha20_block(key, nonce, cnt): for i in range(3): init[13 + i] = int.from_bytes(nonce[4 * i:4 * (i+1)], 'little') # Perform 20 rounds. - state = [v for v in init] + state = list(init) for _ in range(10): chacha20_doubleround(state) # Add initial values back into state. @@ -459,6 +520,7 @@ class FSChaCha20Poly1305: self.packet_counter = 0 def crypt(self, aad, text, is_decrypt): + """Encrypt or decrypt the specified (plain/cipher)text.""" nonce = ((self.packet_counter % REKEY_INTERVAL).to_bytes(4, 'little') + (self.packet_counter // REKEY_INTERVAL).to_bytes(8, 'little')) if is_decrypt: @@ -474,12 +536,14 @@ class FSChaCha20Poly1305: self.packet_counter += 1 return ret - def decrypt(self, aad, ciphertext): - return self.crypt(aad, ciphertext, True) - def encrypt(self, aad, plaintext): + """Encrypt the specified plaintext with provided AAD.""" return self.crypt(aad, plaintext, False) + def decrypt(self, aad, ciphertext): + """Decrypt the specified ciphertext with provided AAD.""" + return self.crypt(aad, ciphertext, True) + class FSChaCha20: """Rekeying wrapper stream cipher around ChaCha20.""" @@ -491,6 +555,7 @@ class FSChaCha20: self.keystream = b'' def get_keystream_bytes(self, nbytes): + """Generate nbytes keystream bytes.""" while len(self.keystream) < nbytes: nonce = ((0).to_bytes(4, 'little') + (self.chunk_counter // REKEY_INTERVAL).to_bytes(8, 'little')) @@ -501,6 +566,7 @@ class FSChaCha20: return ret def crypt(self, chunk): + """Encrypt or decypt chunk.""" ks = self.get_keystream_bytes(len(chunk)) ret = bytes([ks[i] ^ chunk[i] for i in range(len(chunk))]) if ((self.chunk_counter + 1) % REKEY_INTERVAL) == 0: @@ -509,6 +575,15 @@ class FSChaCha20: self.chunk_counter += 1 return ret + def encrypt(self, chunk): + """Encrypt chunk.""" + return self.crypt(chunk) + + def decrypt(self, chunk): + """Decrypt chunk.""" + return self.crypt(chunk) + + ### Shared secret computation def v2_ecdh(priv, ellswift_theirs, ellswift_ours, initiating): @@ -519,10 +594,9 @@ def v2_ecdh(priv, ellswift_theirs, ellswift_ours, initiating): # Initiating, place our public key encoding first. return TaggedHash("bip324_ellswift_xonly_ecdh", ellswift_ours + ellswift_theirs + ecdh_point_x32) - else: - # Responding, place their public key encoding first. - return TaggedHash("bip324_ellswift_xonly_ecdh", - ellswift_theirs + ellswift_ours + ecdh_point_x32) + # Responding, place their public key encoding first. + return TaggedHash("bip324_ellswift_xonly_ecdh", + ellswift_theirs + ellswift_ours + ecdh_point_x32) ### Key derivation @@ -571,5 +645,5 @@ def v2_enc_packet(peer, contents, aad=b'', ignore=False): header = (ignore << IGNORE_BIT_POS).to_bytes(HEADER_LEN, 'little') plaintext = header + contents aead_ciphertext = peer['send_P'].encrypt(aad, plaintext) - enc_plaintext_len = peer['send_L'].crypt(len(contents).to_bytes(LENGTH_FIELD_LEN, 'little')) + enc_plaintext_len = peer['send_L'].encrypt(len(contents).to_bytes(LENGTH_FIELD_LEN, 'little')) return enc_plaintext_len + aead_ciphertext |