summaryrefslogtreecommitdiff
path: root/bip-0324/reference.py
diff options
context:
space:
mode:
Diffstat (limited to 'bip-0324/reference.py')
-rw-r--r--bip-0324/reference.py649
1 files changed, 649 insertions, 0 deletions
diff --git a/bip-0324/reference.py b/bip-0324/reference.py
new file mode 100644
index 0000000..f02c44a
--- /dev/null
+++ b/bip-0324/reference.py
@@ -0,0 +1,649 @@
+"""Reference implementation for the cryptographic aspects of BIP-324"""
+
+import sys
+import random
+import hashlib
+import hmac
+
+### BIP-340 tagged hash
+
+def TaggedHash(tag, data):
+ """Compute BIP-340 tagged hash with specified tag string of data."""
+ ss = hashlib.sha256(tag.encode('utf-8')).digest()
+ ss += ss
+ ss += data
+ return hashlib.sha256(ss).digest()
+
+### HKDF-SHA256
+
+def hmac_sha256(key, data):
+ """Compute HMAC-SHA256 from specified byte arrays key and data."""
+ return hmac.new(key, data, hashlib.sha256).digest()
+
+def hkdf_sha256(length, ikm, salt, info):
+ """Derive a key using HKDF-SHA256."""
+ if len(salt) == 0:
+ salt = bytes([0] * 32)
+ prk = hmac_sha256(salt, ikm)
+ t = b""
+ okm = b""
+ for i in range((length + 32 - 1) // 32):
+ t = hmac_sha256(prk, t + info + bytes([i + 1]))
+ okm += t
+ return okm[:length]
+
+### secp256k1 field/group elements
+
+def modinv(a, n):
+ """Compute the modular inverse of a modulo n using the extended Euclidean
+ Algorithm. See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm#Modular_integers.
+ """
+ a = a % n
+ if a == 0:
+ return 0
+ if sys.hexversion >= 0x3080000:
+ # More efficient version available in Python 3.8.
+ return pow(a, -1, n)
+ t1, t2 = 0, 1
+ r1, r2 = n, a
+ while r2 != 0:
+ q = r1 // r2
+ t1, t2 = t2, t1 - q * t2
+ r1, r2 = r2, r1 - q * r2
+ if r1 > 1:
+ return None
+ if t1 < 0:
+ t1 += n
+ return t1
+
+class FE:
+ """Objects of this class represent elements of the field GF(2**256 - 2**32 - 977).
+
+ They are represented internally in numerator / denominator form, in order to delay inversions.
+ """
+
+ SIZE = 2**256 - 2**32 - 977
+
+ def __init__(self, a=0, b=1):
+ """Initialize an FE as a/b; both a and b can be ints or field elements."""
+ if isinstance(b, FE):
+ if isinstance(a, FE):
+ self.num = (a.num * b.den) % FE.SIZE
+ self.den = (a.den * b.num) % FE.SIZE
+ else:
+ self.num = (a * b.den) % FE.SIZE
+ self.den = b.num
+ else:
+ b = b % FE.SIZE
+ assert b != 0
+ if isinstance(a, FE):
+ self.num = a.num
+ self.den = (a.den * b) % FE.SIZE
+ else:
+ self.num = a % FE.SIZE
+ self.den = b
+
+ def __add__(self, a):
+ """Compute the sum of two field elements (second may be int)."""
+ if isinstance(a, FE):
+ return FE(self.num * a.den + self.den * a.num, self.den * a.den)
+ return FE(self.num + self.den * a, self.den)
+
+ def __radd__(self, a):
+ """Compute the sum of an integer and a field element."""
+ return FE(self.num + self.den * a, self.den)
+
+ def __sub__(self, a):
+ """Compute the difference of two field elements (second may be int)."""
+ if isinstance(a, FE):
+ return FE(self.num * a.den - self.den * a.num, self.den * a.den)
+ return FE(self.num - self.den * a, self.den)
+
+ def __rsub__(self, a):
+ """Compute the difference between an integer and a field element."""
+ return FE(self.den * a - self.num, self.den)
+
+ def __mul__(self, a):
+ """Compute the product of two field elements (second may be int)."""
+ if isinstance(a, FE):
+ return FE(self.num * a.num, self.den * a.den)
+ return FE(self.num * a, self.den)
+
+ def __rmul__(self, a):
+ """Compute the product of an integer with a field element."""
+ return FE(self.num * a, self.den)
+
+ def __truediv__(self, a):
+ """Compute the ratio of two field elements (second may be int)."""
+ return FE(self, a)
+
+ def __rtruediv__(self, a):
+ """Compute the ratio of an integer and a field element."""
+ return FE(a, self)
+
+ def __pow__(self, a):
+ """Raise a field element to a (positive) integer power."""
+ return FE(pow(self.num, a, FE.SIZE), pow(self.den, a, FE.SIZE))
+
+ def __neg__(self):
+ """Negate a field element."""
+ return FE(-self.num, self.den)
+
+ def __int__(self):
+ """Convert a field element to an integer. The result is cached."""
+ if self.den != 1:
+ self.num = (self.num * modinv(self.den, FE.SIZE)) % FE.SIZE
+ self.den = 1
+ return self.num
+
+ def sqrt(self):
+ """Compute the square root of a field element.
+
+ Due to the fact that our modulus p is of the form p = 3 (mod 4), the
+ Tonelli-Shanks algorithm (https://en.wikipedia.org/wiki/Tonelli-Shanks_algorithm)
+ is simply raising the argument to the power (p + 1) / 4.
+
+ To see why: p-1 = 0 (mod 2), so 2 divides the order of the multiplicative group,
+ and thus only half of the non-zero field elements are squares. An element a is
+ a (nonzero) square when Euler's criterion, a^((p-1)/2) = 1 (mod p), holds. We're
+ looking for x such that x^2 = a (mod p). Given a^((p-1)/2) = 1 (mod p), that is
+ equivalent to x^2 = a^(1 + (p-1)/2) (mod p). As (1 + (p-1)/2) is even, this is
+ equivalent to x = a^((1 + (p-1)/2)/2) (mod p), or x = a^((p+1)/4) (mod p)."""
+ v = int(self)
+ s = pow(v, (FE.SIZE + 1) // 4, FE.SIZE)
+ if s**2 % FE.SIZE == v:
+ return FE(s)
+ return None
+
+ def sqrts(self):
+ """Compute all square roots of a field element, if any."""
+ s = self.sqrt()
+ if s is None:
+ return []
+ return [FE(s), -FE(s)]
+
+ # The cube roots of 1 (mod p).
+ CBRT1 = [
+ 1,
+ 0x851695d49a83f8ef919bb86153cbcb16630fb68aed0a766a3ec693d68e6afa40,
+ 0x7ae96a2b657c07106e64479eac3434e99cf0497512f58995c1396c28719501ee
+ ]
+
+
+ def cbrts(self):
+ """Compute all cube roots of a field element, if any.
+
+ Due to the fact that our modulus p is of the form p = 7 (mod 9), one cube root
+ can always be computed by raising to the power (p + 2) / 9. The other roots
+ (if any) can be found by multiplying with the two non-trivial cube roots of 1.
+
+ To see why: p-1 = 0 (mod 3), so 3 divides the order of the multiplicative group,
+ and thus only 1/3 of the non-zero field elements are cubes. An element a is a
+ (nonzero) cube when a^((p-1)/3) = 1 (mod p). We're looking for x such that
+ x^3 = a (mod p). Given a^((p-1)/3) = 1 (mod p), that is equivalent to
+ x^3 = a^(1 + (p-1)/3) (mod p). As (1 + (p-1)/3) is a multiple of 3, this is
+ equivalent to x = a^((1 + (p-1)/3)/3) (mod p), or x = a^((p+2)/9) (mod p)."""
+ v = int(self)
+ c = pow(v, (FE.SIZE + 2) // 9, FE.SIZE)
+
+ if pow(c, 3, FE.SIZE) == v:
+ return [FE(c * f) for f in FE.CBRT1]
+ return []
+
+ def is_square(self):
+ """Determine if this field element has a square root."""
+ # Compute the Jacobi symbol of (self / p). Since our modulus is prime, this
+ # is the same as the Legendre symbol, which determines quadratic residuosity.
+ # See https://en.wikipedia.org/wiki/Jacobi_symbol for the algorithm.
+ n, k, t = (self.num * self.den) % FE.SIZE, FE.SIZE, 0
+ if n == 0:
+ return True
+ while n != 0:
+ while n & 1 == 0:
+ n >>= 1
+ r = k & 7
+ t ^= (r in (3, 5))
+ n, k = k, n
+ t ^= (n & k & 3 == 3)
+ n = n % k
+ assert k == 1
+ return not t
+
+ def __eq__(self, a):
+ """Check whether two field elements are equal (second may be an int)."""
+ if isinstance(a, FE):
+ return (self.num * a.den - self.den * a.num) % FE.SIZE == 0
+ return (self.num - self.den * a) % FE.SIZE == 0
+
+ def to_bytes(self):
+ """Convert a field element to 32-byte big endian encoding."""
+ return int(self).to_bytes(32, 'big')
+
+ @staticmethod
+ def from_bytes(b):
+ """Convert a 32-byte big endian encoding of a field element to an FE."""
+ v = int.from_bytes(b, 'big')
+ if v >= FE.SIZE:
+ return None
+ return FE(v)
+
+ def __str__(self):
+ """Convert this field element to a string."""
+ return f"{int(self):064x}"
+
+ def __repr__(self):
+ """Get a string representation of this field element."""
+ return f"FE(0x{int(self):x})"
+
+assert all(pow(c, 3, FE.SIZE) == 1 for c in FE.CBRT1)
+
+class GE:
+ """Objects of this class represent points (group elements) on the secp256k1 curve.
+
+ The point at infinity is represented as None."""
+
+ ORDER = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
+ ORDER_HALF = ORDER // 2
+
+ def __init__(self, x, y):
+ """Initialize a group element with specified x and y coordinates (must be on curve)."""
+ fx = FE(x)
+ fy = FE(y)
+ assert fy**2 == fx**3 + 7
+ self.x = fx
+ self.y = fy
+
+ def double(self):
+ """Compute the double of a point."""
+ l = 3 * self.x**2 / (2 * self.y)
+ x3 = l**2 - 2 * self.x
+ y3 = l * (self.x - x3) - self.y
+ return GE(x3, y3)
+
+ def __add__(self, a):
+ """Add two points, or a point and infinity, together."""
+ if a is None:
+ # Adding point at infinity
+ return self
+ if self.x != a.x:
+ # Adding distinct x coordinates
+ l = (a.y - self.y) / (a.x - self.x)
+ x3 = l**2 - self.x - a.x
+ y3 = l * (self.x - x3) - self.y
+ return GE(x3, y3)
+ if self.y == a.y:
+ # Adding point to itself
+ return self.double()
+ # Adding point to its negation
+ return None
+
+ def __radd__(self, a):
+ """Add infinity to a point."""
+ assert a is None
+ return self
+
+ def __mul__(self, a):
+ """Multiply a point with an integer (scalar multiplication)."""
+ r = None
+ for i in range(a.bit_length() - 1, -1, -1):
+ if r is not None:
+ r = r.double()
+ if (a >> i) & 1:
+ r += self
+ return r
+
+ def __rmul__(self, a):
+ """Multiply an integer with a point (scalar multiplication)."""
+ return self * a
+
+ @staticmethod
+ def lift_x(x):
+ """Take an FE, and return the point with that as X coordinate, and square Y."""
+ y = (FE(x)**3 + 7).sqrt()
+ if y is None:
+ return None
+ return GE(x, y)
+
+ @staticmethod
+ def is_valid_x(x):
+ """Determine whether the provided field element is a valid X coordinate."""
+ return (FE(x)**3 + 7).is_square()
+
+ def __str__(self):
+ """Convert this group element to a string."""
+ return f"({self.x},{self.y})"
+
+ def __repr__(self):
+ """Get a string representation for this group element."""
+ return f"GE(0x{int(self.x)},0x{int(self.y)})"
+
+SECP256K1_G = GE(
+ 0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798,
+ 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8)
+
+### ElligatorSwift
+
+# Precomputed constant square root of -3 (mod p).
+MINUS_3_SQRT = FE(-3).sqrt()
+
+def xswiftec(u, t):
+ """Decode field elements (u, t) to an X coordinate on the curve."""
+ if u == 0:
+ u = FE(1)
+ if t == 0:
+ t = FE(1)
+ if u**3 + t**2 + 7 == 0:
+ t = 2 * t
+ X = (u**3 + 7 - t**2) / (2 * t)
+ Y = (X + t) / (MINUS_3_SQRT * u)
+ for x in (u + 4 * Y**2, (-X / Y - u) / 2, (X / Y - u) / 2):
+ if GE.is_valid_x(x):
+ return x
+ assert False
+
+def xswiftec_inv(x, u, case):
+ """Given x and u, find t such that xswiftec(u, t) = x, or return None.
+
+ Case selects which of the up to 8 results to return."""
+
+ if case & 2 == 0:
+ if GE.is_valid_x(-x - u):
+ return None
+ v = x
+ s = -(u**3 + 7) / (u**2 + u*v + v**2)
+ else:
+ s = x - u
+ if s == 0:
+ return None
+ r = (-s * (4 * (u**3 + 7) + 3 * s * u**2)).sqrt()
+ if r is None:
+ return None
+ if case & 1 and r == 0:
+ return None
+ v = (-u + r / s) / 2
+ w = s.sqrt()
+ if w is None:
+ return None
+ if case & 5 == 0: return -w * (u * (1 - MINUS_3_SQRT) / 2 + v)
+ if case & 5 == 1: return w * (u * (1 + MINUS_3_SQRT) / 2 + v)
+ if case & 5 == 4: return w * (u * (1 - MINUS_3_SQRT) / 2 + v)
+ if case & 5 == 5: return -w * (u * (1 + MINUS_3_SQRT) / 2 + v)
+
+def xelligatorswift(x):
+ """Given a field element X on the curve, find (u, t) that encode them."""
+ while True:
+ u = FE(random.randrange(1, GE.ORDER))
+ case = random.randrange(0, 8)
+ t = xswiftec_inv(x, u, case)
+ if t is not None:
+ return u, t
+
+def ellswift_create():
+ """Generate a (privkey, ellswift_pubkey) pair."""
+ priv = random.randrange(1, GE.ORDER)
+ u, t = xelligatorswift((priv * SECP256K1_G).x)
+ return priv.to_bytes(32, 'big'), u.to_bytes() + t.to_bytes()
+
+def ellswift_decode(ellswift):
+ """Convert ellswift encoded X coordinate to 32-byte xonly format."""
+ u = FE(int.from_bytes(ellswift[:32], 'big'))
+ t = FE(int.from_bytes(ellswift[32:], 'big'))
+ return xswiftec(u, t).to_bytes()
+
+def ellswift_ecdh_xonly(pubkey_theirs, privkey):
+ """Compute X coordinate of shared ECDH point between elswift pubkey and privkey."""
+ d = int.from_bytes(privkey, 'big')
+ pub = ellswift_decode(pubkey_theirs)
+ return (d * GE.lift_x(FE.from_bytes(pub))).x.to_bytes()
+
+### Poly1305
+
+class Poly1305:
+ """Class representing a running poly1305 computation."""
+ MODULUS = 2**130 - 5
+
+ def __init__(self, key):
+ self.r = int.from_bytes(key[:16], 'little') & 0xffffffc0ffffffc0ffffffc0fffffff
+ self.s = int.from_bytes(key[16:], 'little')
+ self.acc = 0
+
+ def add(self, msg, length=None, pad=False):
+ """Add a message of any length. Input so far must be a multiple of 16 bytes."""
+ length = len(msg) if length is None else length
+ for i in range((length + 15) // 16):
+ chunk = msg[i * 16:i * 16 + min(16, length - i * 16)]
+ val = int.from_bytes(chunk, 'little') + 256**(16 if pad else len(chunk))
+ self.acc = (self.r * (self.acc + val)) % Poly1305.MODULUS
+ return self
+
+ def tag(self):
+ """Compute the poly1305 tag."""
+ return ((self.acc + self.s) & 0xffffffffffffffffffffffffffffffff).to_bytes(16, 'little')
+
+### ChaCha20
+
+CHACHA20_INDICES = (
+ (0, 4, 8, 12), (1, 5, 9, 13), (2, 6, 10, 14), (3, 7, 11, 15),
+ (0, 5, 10, 15), (1, 6, 11, 12), (2, 7, 8, 13), (3, 4, 9, 14)
+)
+
+CHACHA20_CONSTANTS = (0x61707865, 0x3320646e, 0x79622d32, 0x6b206574)
+
+def rotl32(v, bits):
+ """Rotate the 32-bit value v left by bits bits."""
+ return ((v << bits) & 0xffffffff) | (v >> (32 - bits))
+
+def chacha20_doubleround(s):
+ """Apply a ChaCha20 double round to 16-element state array s.
+
+ See https://cr.yp.to/chacha/chacha-20080128.pdf and https://tools.ietf.org/html/rfc8439
+ """
+ for a, b, c, d in CHACHA20_INDICES:
+ s[a] = (s[a] + s[b]) & 0xffffffff
+ s[d] = rotl32(s[d] ^ s[a], 16)
+ s[c] = (s[c] + s[d]) & 0xffffffff
+ s[b] = rotl32(s[b] ^ s[c], 12)
+ s[a] = (s[a] + s[b]) & 0xffffffff
+ s[d] = rotl32(s[d] ^ s[a], 8)
+ s[c] = (s[c] + s[d]) & 0xffffffff
+ s[b] = rotl32(s[b] ^ s[c], 7)
+
+def chacha20_block(key, nonce, cnt):
+ """Compute the 64-byte output of the ChaCha20 block function.
+
+ Takes as input a 32-byte key, 12-byte nonce, and 32-bit integer counter.
+ """
+ # Initial state.
+ init = [0 for _ in range(16)]
+ for i in range(4):
+ init[i] = CHACHA20_CONSTANTS[i]
+ for i in range(8):
+ init[4 + i] = int.from_bytes(key[4 * i:4 * (i+1)], 'little')
+ init[12] = cnt
+ for i in range(3):
+ init[13 + i] = int.from_bytes(nonce[4 * i:4 * (i+1)], 'little')
+ # Perform 20 rounds.
+ state = list(init)
+ for _ in range(10):
+ chacha20_doubleround(state)
+ # Add initial values back into state.
+ for i in range(16):
+ state[i] = (state[i] + init[i]) & 0xffffffff
+ # Produce byte output
+ return b''.join(state[i].to_bytes(4, 'little') for i in range(16))
+
+### ChaCha20Poly1305
+
+def aead_chacha20_poly1305_encrypt(key, nonce, aad, plaintext):
+ """Encrypt a plaintext using ChaCha20Poly1305."""
+ ret = bytearray()
+ msg_len = len(plaintext)
+ for i in range((msg_len + 63) // 64):
+ now = min(64, msg_len - 64 * i)
+ keystream = chacha20_block(key, nonce, i + 1)
+ for j in range(now):
+ ret.append(plaintext[j + 64 * i] ^ keystream[j])
+ poly1305 = Poly1305(chacha20_block(key, nonce, 0)[:32])
+ poly1305.add(aad, pad=True).add(ret, pad=True)
+ poly1305.add(len(aad).to_bytes(8, 'little') + msg_len.to_bytes(8, 'little'))
+ ret += poly1305.tag()
+ return bytes(ret)
+
+def aead_chacha20_poly1305_decrypt(key, nonce, aad, ciphertext):
+ """Decrypt a ChaCha20Poly1305 ciphertext."""
+ if len(ciphertext) < 16:
+ return None
+ msg_len = len(ciphertext) - 16
+ poly1305 = Poly1305(chacha20_block(key, nonce, 0)[:32])
+ poly1305.add(aad, pad=True)
+ poly1305.add(ciphertext, length=msg_len, pad=True)
+ poly1305.add(len(aad).to_bytes(8, 'little') + msg_len.to_bytes(8, 'little'))
+ if ciphertext[-16:] != poly1305.tag():
+ return None
+ ret = bytearray()
+ for i in range((msg_len + 63) // 64):
+ now = min(64, msg_len - 64 * i)
+ keystream = chacha20_block(key, nonce, i + 1)
+ for j in range(now):
+ ret.append(ciphertext[j + 64 * i] ^ keystream[j])
+ return bytes(ret)
+
+### FSChaCha20{,Poly1305}
+
+REKEY_INTERVAL = 224 # packets
+
+class FSChaCha20Poly1305:
+ """Rekeying wrapper AEAD around ChaCha20Poly1305."""
+
+ def __init__(self, initial_key):
+ self.key = initial_key
+ self.packet_counter = 0
+
+ def crypt(self, aad, text, is_decrypt):
+ """Encrypt or decrypt the specified (plain/cipher)text."""
+ nonce = ((self.packet_counter % REKEY_INTERVAL).to_bytes(4, 'little') +
+ (self.packet_counter // REKEY_INTERVAL).to_bytes(8, 'little'))
+ if is_decrypt:
+ ret = aead_chacha20_poly1305_decrypt(self.key, nonce, aad, text)
+ else:
+ ret = aead_chacha20_poly1305_encrypt(self.key, nonce, aad, text)
+ if (self.packet_counter + 1) % REKEY_INTERVAL == 0:
+ rekey_nonce = b"\xFF\xFF\xFF\xFF" + nonce[4:]
+ newkey1 = aead_chacha20_poly1305_encrypt(self.key, rekey_nonce, b"", b"\x00" * 32)[:32]
+ newkey2 = chacha20_block(self.key, rekey_nonce, 1)[:32]
+ assert newkey1 == newkey2
+ self.key = newkey1
+ self.packet_counter += 1
+ return ret
+
+ def encrypt(self, aad, plaintext):
+ """Encrypt the specified plaintext with provided AAD."""
+ return self.crypt(aad, plaintext, False)
+
+ def decrypt(self, aad, ciphertext):
+ """Decrypt the specified ciphertext with provided AAD."""
+ return self.crypt(aad, ciphertext, True)
+
+
+class FSChaCha20:
+ """Rekeying wrapper stream cipher around ChaCha20."""
+
+ def __init__(self, initial_key):
+ self.key = initial_key
+ self.block_counter = 0
+ self.chunk_counter = 0
+ self.keystream = b''
+
+ def get_keystream_bytes(self, nbytes):
+ """Generate nbytes keystream bytes."""
+ while len(self.keystream) < nbytes:
+ nonce = ((0).to_bytes(4, 'little') +
+ (self.chunk_counter // REKEY_INTERVAL).to_bytes(8, 'little'))
+ self.keystream += chacha20_block(self.key, nonce, self.block_counter)
+ self.block_counter += 1
+ ret = self.keystream[:nbytes]
+ self.keystream = self.keystream[nbytes:]
+ return ret
+
+ def crypt(self, chunk):
+ """Encrypt or decypt chunk."""
+ ks = self.get_keystream_bytes(len(chunk))
+ ret = bytes([ks[i] ^ chunk[i] for i in range(len(chunk))])
+ if ((self.chunk_counter + 1) % REKEY_INTERVAL) == 0:
+ self.key = self.get_keystream_bytes(32)
+ self.block_counter = 0
+ self.chunk_counter += 1
+ return ret
+
+ def encrypt(self, chunk):
+ """Encrypt chunk."""
+ return self.crypt(chunk)
+
+ def decrypt(self, chunk):
+ """Decrypt chunk."""
+ return self.crypt(chunk)
+
+
+### Shared secret computation
+
+def v2_ecdh(priv, ellswift_theirs, ellswift_ours, initiating):
+ """Compute BIP324 shared secret."""
+
+ ecdh_point_x32 = ellswift_ecdh_xonly(ellswift_theirs, priv)
+ if initiating:
+ # Initiating, place our public key encoding first.
+ return TaggedHash("bip324_ellswift_xonly_ecdh",
+ ellswift_ours + ellswift_theirs + ecdh_point_x32)
+ # Responding, place their public key encoding first.
+ return TaggedHash("bip324_ellswift_xonly_ecdh",
+ ellswift_theirs + ellswift_ours + ecdh_point_x32)
+
+### Key derivation
+
+NETWORK_MAGIC = b'\xf9\xbe\xb4\xd9'
+
+def initialize_v2_transport(ecdh_secret, initiating):
+ """Return a peer object with various BIP324 derived keys and ciphers."""
+
+ peer = {}
+ salt = b'bitcoin_v2_shared_secret' + NETWORK_MAGIC
+ for name, length in (
+ ('initiator_L', 32), ('initiator_P', 32), ('responder_L', 32), ('responder_P', 32),
+ ('garbage_terminators', 32), ('session_id', 32)):
+ peer[name] = hkdf_sha256(
+ salt=salt, ikm=ecdh_secret, info=name.encode('utf-8'), length=length)
+ peer['initiator_garbage_terminator'] = peer['garbage_terminators'][:16]
+ peer['responder_garbage_terminator'] = peer['garbage_terminators'][16:]
+ del peer['garbage_terminators']
+ if initiating:
+ peer['send_L'] = FSChaCha20(peer['initiator_L'])
+ peer['send_P'] = FSChaCha20Poly1305(peer['initiator_P'])
+ peer['send_garbage_terminator'] = peer['initiator_garbage_terminator']
+ peer['recv_L'] = FSChaCha20(peer['responder_L'])
+ peer['recv_P'] = FSChaCha20Poly1305(peer['responder_P'])
+ peer['recv_garbage_terminator'] = peer['responder_garbage_terminator']
+ else:
+ peer['send_L'] = FSChaCha20(peer['responder_L'])
+ peer['send_P'] = FSChaCha20Poly1305(peer['responder_P'])
+ peer['send_garbage_terminator'] = peer['responder_garbage_terminator']
+ peer['recv_L'] = FSChaCha20(peer['initiator_L'])
+ peer['recv_P'] = FSChaCha20Poly1305(peer['initiator_P'])
+ peer['recv_garbage_terminator'] = peer['initiator_garbage_terminator']
+
+ return peer
+
+### Packet encryption
+
+LENGTH_FIELD_LEN = 3
+HEADER_LEN = 1
+IGNORE_BIT_POS = 7
+
+def v2_enc_packet(peer, contents, aad=b'', ignore=False):
+ """Encrypt a BIP324 packet."""
+
+ assert len(contents) <= 2**24 - 1
+ header = (ignore << IGNORE_BIT_POS).to_bytes(HEADER_LEN, 'little')
+ plaintext = header + contents
+ aead_ciphertext = peer['send_P'].encrypt(aad, plaintext)
+ enc_plaintext_len = peer['send_L'].encrypt(len(contents).to_bytes(LENGTH_FIELD_LEN, 'little'))
+ return enc_plaintext_len + aead_ciphertext