aboutsummaryrefslogtreecommitdiff
path: root/src/minisketch/doc/gen_params.sage
diff options
context:
space:
mode:
authorfanquake <fanquake@gmail.com>2021-11-12 09:25:40 +0800
committerfanquake <fanquake@gmail.com>2021-11-12 10:00:49 +0800
commitc1fb30633b6dcbf32db7d53c9f538019af80d6c5 (patch)
tree0fb479ab0aa27218dc34db18cd3a31ac72b1359a /src/minisketch/doc/gen_params.sage
parentbc03823e26d1bf0cca679ce9330f9c6a58b33faf (diff)
parent29173d6c6ca0cc3be9fa6bf2409a509ffea1a02a (diff)
downloadbitcoin-c1fb30633b6dcbf32db7d53c9f538019af80d6c5.tar.xz
Merge bitcoin/bitcoin#23114: Add minisketch subtree and integrate into build/test
29173d6c6ca0cc3be9fa6bf2409a509ffea1a02a ubsan: add minisketch exceptions (Cory Fields) 54b5e1aeab73953c1f12ec2c041572038f6f59da Add thin Minisketch wrapper to pick best implementation (Pieter Wuille) ee9dc71c1bc16205494f2a0aebe575a3c062ff52 Add basic minisketch tests (Pieter Wuille) 0659f12b131fc5915fe7a493306af197f4fb838b Add minisketch dependency (Gleb Naumenko) 0eb7928ab8d9dcb840e4965bfa81deb752b00dfa Add MSVC build configuration for libminisketch (Pieter Wuille) 8bc166d5b179205fc56855e2b462aa273a6f8661 build: add minisketch build file and include it (Cory Fields) b2904ceb85b4d440b1f4bbd716fcb601411cc2c9 build: add configure checks for minisketch (Cory Fields) b6487dc4ef47ec9ea894eceac25f37d0b806f8aa Squashed 'src/minisketch/' content from commit 89629eb2c7 (fanquake) Pull request description: This takes over #21859, which has [recently switched](https://github.com/bitcoin/bitcoin/pull/21859#issuecomment-921899200) to my integration branch. A few more build issues came up (and have been fixed) since, and after discussing with sipa it was decided I would open a PR to shepherd any final changes through. > This adds a `src/minisketch` subtree, taken from the master branch of https://github.com/sipa/minisketch, to prepare for Erlay implementation (see #21515). It gets configured for just supporting 32-bit fields (the only ones we're interested in in the context of Erlay), and some code on top is added: > * A very basic unit test (just to make sure compilation & running works; actual correctness checking is done through minisketch's own tests). > * A wrapper in `minisketchwrapper.{cpp,h}` that runs a benchmark to determine which field implementation to use. Only changes since my last update to the branch in the previous PR have been rebasing on master and fixing an issue with a header in an introduced file. ACKs for top commit: naumenkogs: ACK 29173d6c6ca0cc3be9fa6bf2409a509ffea1a02a Tree-SHA512: 1217d3228db1dd0de12c2919314e1c3626c18a416cf6291fec99d37e34fb6eec8e28d9e9fb935f8590273b8836cbadac313a15f05b4fd9f9d3024c8ce2c80d02
Diffstat (limited to 'src/minisketch/doc/gen_params.sage')
-rwxr-xr-xsrc/minisketch/doc/gen_params.sage333
1 files changed, 333 insertions, 0 deletions
diff --git a/src/minisketch/doc/gen_params.sage b/src/minisketch/doc/gen_params.sage
new file mode 100755
index 0000000000..1cf036adb4
--- /dev/null
+++ b/src/minisketch/doc/gen_params.sage
@@ -0,0 +1,333 @@
+#!/usr/bin/env sage
+r"""
+Generate finite field parameters for minisketch.
+
+This script selects the finite fields used by minisketch
+ for various sizes and generates the required tables for
+ the implementation.
+
+The output (after formatting) can be found in src/fields/*.cpp.
+
+"""
+B.<b> = GF(2)
+P.<p> = B[]
+
+def apply_map(m, v):
+ r = 0
+ i = 0
+ while v != 0:
+ if (v & 1):
+ r ^^= m[i]
+ i += 1
+ v >>= 1
+ return r
+
+def recurse_moduli(acc, maxweight, maxdegree):
+ for pos in range(maxweight, maxdegree + 1, 1):
+ poly = acc + p^pos
+ if maxweight == 1:
+ if poly.is_irreducible():
+ return (pos, poly)
+ else:
+ (deg, ret) = recurse_moduli(poly, maxweight - 1, pos - 1)
+ if ret is not None:
+ return (pos, ret)
+ return (None, None)
+
+def compute_moduli(bits):
+ # Return all optimal irreducible polynomials for GF(2^bits)
+ # The result is a list of tuples (weight, degree of second-highest nonzero coefficient, polynomial)
+ maxdegree = bits - 1
+ result = []
+ for weight in range(1, bits, 2):
+ deg, res = None, None
+ while True:
+ ret = recurse_moduli(p^bits + 1, weight, maxdegree)
+ if ret[0] is not None:
+ (deg, res) = ret
+ maxdegree = deg - 1
+ else:
+ break
+ if res is not None:
+ result.append((weight + 2, deg, res))
+ return result
+
+def bits_to_int(vals):
+ ret = 0
+ base = 1
+ for val in vals:
+ ret += Integer(val) * base
+ base *= 2
+ return ret
+
+def sqr_table(f, bits, n=1):
+ ret = []
+ for i in range(bits):
+ ret.append((f^(2^n*i)).integer_representation())
+ return ret
+
+# Compute x**(2**n)
+def pow2(x, n):
+ for i in range(n):
+ x = x**2
+ return x
+
+def qrt_table(F, f, bits):
+ # Table for solving x2 + x = a
+ # This implements the technique from https://www.raco.cat/index.php/PublicacionsMatematiques/article/viewFile/37927/40412, Lemma 1
+ for i in range(bits):
+ if (f**i).trace() != 0:
+ u = f**i
+ ret = []
+ for i in range(0, bits):
+ d = f^i
+ y = sum(pow2(d, j) * sum(pow2(u, k) for k in range(j)) for j in range(1, bits))
+ ret.append(y.integer_representation() ^^ (y.integer_representation() & 1))
+ return ret
+
+def conv_tables(F, NF, bits):
+ # Generate a F(2) linear projection that maps elements from one field
+ # to an isomorphic field with a different modulus.
+ f = F.gen()
+ fp = f.minimal_polynomial()
+ assert(fp == F.modulus())
+ nfp = fp.change_ring(NF)
+ nf = sorted(nfp.roots(multiplicities=False))[0]
+ ret = []
+ matrepr = [[B(0) for x in range(bits)] for y in range(bits)]
+ for i in range(bits):
+ val = (nf**i).integer_representation()
+ ret.append(val)
+ for j in range(bits):
+ matrepr[j][i] = B((val >> j) & 1)
+ mat = Matrix(matrepr).inverse().transpose()
+ ret2 = []
+ for i in range(bits):
+ ret2.append(bits_to_int(mat[i]))
+
+ for t in range(100):
+ f1a = F.random_element()
+ f1b = F.random_element()
+ f1r = f1a * f1b
+ f2a = NF.fetch_int(apply_map(ret, f1a.integer_representation()))
+ f2b = NF.fetch_int(apply_map(ret, f1b.integer_representation()))
+ f2r = NF.fetch_int(apply_map(ret, f1r.integer_representation()))
+ f2s = f2a * f2b
+ assert(f2r == f2s)
+
+ for t in range(100):
+ f2a = NF.random_element()
+ f2b = NF.random_element()
+ f2r = f2a * f2b
+ f1a = F.fetch_int(apply_map(ret2, f2a.integer_representation()))
+ f1b = F.fetch_int(apply_map(ret2, f2b.integer_representation()))
+ f1r = F.fetch_int(apply_map(ret2, f2r.integer_representation()))
+ f1s = f1a * f1b
+ assert(f1r == f1s)
+
+ return (ret, ret2)
+
+def fmt(i,typ):
+ if i == 0:
+ return "0"
+ else:
+ return "0x%x" % i
+
+def lintranstype(typ, bits, maxtbl):
+ gsize = min(maxtbl, bits)
+ array_size = (bits + gsize - 1) // gsize
+ bits_list = []
+ total = 0
+ for i in range(array_size):
+ rsize = (bits - total + array_size - i - 1) // (array_size - i)
+ total += rsize
+ bits_list.append(rsize)
+ return "RecLinTrans<%s, %s>" % (typ, ", ".join("%i" % x for x in bits_list))
+
+INT=0
+CLMUL=1
+CLMUL_TRI=2
+MD=3
+
+def print_modulus_md(mod):
+ ret = ""
+ pos = mod.degree()
+ for c in reversed(list(mod)):
+ if c:
+ if ret:
+ ret += " + "
+ if pos == 0:
+ ret += "1"
+ elif pos == 1:
+ ret += "x"
+ else:
+ ret += "x<sup>%i</sup>" % pos
+ pos -= 1
+ return ret
+
+def pick_modulus(bits, style):
+ # Choose the lexicographicly-first lowest-weight modulus
+ # optionally subject to implementation specific constraints.
+ moduli = compute_moduli(bits)
+ if style == INT or style == MD:
+ multi_sqr = False
+ need_trans = False
+ elif style == CLMUL:
+ # Fast CLMUL reduction requires that bits + the highest
+ # set bit are less than 66.
+ moduli = list(filter((lambda x: bits+x[1] <= 66), moduli)) + moduli
+ multi_sqr = True
+ need_trans = True
+ if not moduli or moduli[0][2].change_ring(ZZ)(2) == 3 + 2**bits:
+ # For modulus 3, CLMUL_TRI is obviously better.
+ return None
+ elif style == CLMUL_TRI:
+ moduli = list(filter(lambda x: bits+x[1] <= 66, moduli)) + moduli
+ moduli = list(filter(lambda x: x[0] == 3, moduli))
+ multi_sqr = True
+ need_trans = True
+ else:
+ assert(False)
+ if not moduli:
+ return None
+ return moduli[0][2]
+
+def print_result(bits, style):
+ if style == INT:
+ multi_sqr = False
+ need_trans = False
+ table_id = "%i" % bits
+ elif style == MD:
+ pass
+ elif style == CLMUL:
+ multi_sqr = True
+ need_trans = True
+ table_id = "%i" % bits
+ elif style == CLMUL_TRI:
+ multi_sqr = True
+ need_trans = True
+ table_id = "TRI%i" % bits
+ else:
+ assert(False)
+
+ nmodulus = pick_modulus(bits, INT)
+ modulus = pick_modulus(bits, style)
+ if modulus is None:
+ return
+
+ if style == MD:
+ print("* *%s*" % print_modulus_md(modulus))
+ return
+
+ if bits > 32:
+ typ = "uint64_t"
+ elif bits > 16:
+ typ = "uint32_t"
+ elif bits > 8:
+ typ = "uint16_t"
+ else:
+ typ = "uint8_t"
+
+ ttyp = lintranstype(typ, bits, 4)
+ rtyp = lintranstype(typ, bits, 6)
+
+ F.<f> = GF(2**bits, modulus=modulus)
+
+ include_table = True
+ if style != INT and style != CLMUL:
+ cmodulus = pick_modulus(bits, CLMUL)
+ if cmodulus == modulus:
+ include_table = False
+ table_id = "%i" % bits
+
+ if include_table:
+ print("typedef %s StatTable%s;" % (rtyp, table_id))
+ rtyp = "StatTable%s" % table_id
+ if (style == INT):
+ print("typedef %s DynTable%s;" % (ttyp, table_id))
+ ttyp = "DynTable%s" % table_id
+
+ if need_trans:
+ if modulus != nmodulus:
+ # If the bitstream modulus is not the best modulus for
+ # this implementation a conversion table will be needed.
+ ctyp = rtyp
+ NF.<nf> = GF(2**bits, modulus=nmodulus)
+ ctables = conv_tables(NF, F, bits)
+ loadtbl = "&LOAD_TABLE_%s" % table_id
+ savetbl = "&SAVE_TABLE_%s" % table_id
+ if include_table:
+ print("constexpr %s LOAD_TABLE_%s({%s});" % (ctyp, table_id, ", ".join([fmt(x,typ) for x in ctables[0]])))
+ print("constexpr %s SAVE_TABLE_%s({%s});" % (ctyp, table_id, ", ".join([fmt(x,typ) for x in ctables[1]])))
+ else:
+ ctyp = "IdTrans"
+ loadtbl = "&ID_TRANS"
+ savetbl = "&ID_TRANS"
+ else:
+ assert(modulus == nmodulus)
+
+ if include_table:
+ print("constexpr %s SQR_TABLE_%s({%s});" % (rtyp, table_id, ", ".join([fmt(x,typ) for x in sqr_table(f, bits, 1)])))
+ if multi_sqr:
+ # Repeated squaring is a linearised polynomial so in F(2^n) it is
+ # F(2) linear and can be computed by a simple bit-matrix.
+ # Repeated squaring is especially useful in powering ladders such as
+ # for inversion.
+ # When certain repeated squaring tables are not in use, use the QRT
+ # table instead to make the C++ compiler happy (it always has the
+ # same type).
+ sqr2 = "&QRT_TABLE_%s" % table_id
+ sqr4 = "&QRT_TABLE_%s" % table_id
+ sqr8 = "&QRT_TABLE_%s" % table_id
+ sqr16 = "&QRT_TABLE_%s" % table_id
+ if ((bits - 1) >= 4):
+ if include_table:
+ print("constexpr %s SQR2_TABLE_%s({%s});" % (rtyp, table_id, ", ".join([fmt(x,typ) for x in sqr_table(f, bits, 2)])))
+ sqr2 = "&SQR2_TABLE_%s" % table_id
+ if ((bits - 1) >= 8):
+ if include_table:
+ print("constexpr %s SQR4_TABLE_%s({%s});" % (rtyp, table_id, ", ".join([fmt(x,typ) for x in sqr_table(f, bits, 4)])))
+ sqr4 = "&SQR4_TABLE_%s" % table_id
+ if ((bits - 1) >= 16):
+ if include_table:
+ print("constexpr %s SQR8_TABLE_%s({%s});" % (rtyp, table_id, ", ".join([fmt(x,typ) for x in sqr_table(f, bits, 8)])))
+ sqr8 = "&SQR8_TABLE_%s" % table_id
+ if ((bits - 1) >= 32):
+ if include_table:
+ print("constexpr %s SQR16_TABLE_%s({%s});" % (rtyp, table_id, ", ".join([fmt(x,typ) for x in sqr_table(f, bits, 16)])))
+ sqr16 = "&SQR16_TABLE_%s" % table_id
+ if include_table:
+ print("constexpr %s QRT_TABLE_%s({%s});" % (rtyp, table_id, ", ".join([fmt(x,typ) for x in qrt_table(F, f, bits)])))
+
+ modulus_weight = modulus.hamming_weight()
+ modulus_degree = (modulus - p**bits).degree()
+ modulus_int = (modulus - p**bits).change_ring(ZZ)(2)
+
+ lfsr = ""
+
+ if style == INT:
+ print("typedef Field<%s, %i, %i, %s, %s, &SQR_TABLE_%s, &QRT_TABLE_%s%s> Field%i;" % (typ, bits, modulus_int, rtyp, ttyp, table_id, table_id, lfsr, bits))
+ elif style == CLMUL:
+ print("typedef Field<%s, %i, %i, %s, &SQR_TABLE_%s, %s, %s, %s, %s, &QRT_TABLE_%s, %s, %s, %s%s> Field%i;" % (typ, bits, modulus_int, rtyp, table_id, sqr2, sqr4, sqr8, sqr16, table_id, ctyp, loadtbl, savetbl, lfsr, bits))
+ elif style == CLMUL_TRI:
+ print("typedef FieldTri<%s, %i, %i, %s, &SQR_TABLE_%s, %s, %s, %s, %s, &QRT_TABLE_%s, %s, %s, %s> FieldTri%i;" % (typ, bits, modulus_degree, rtyp, table_id, sqr2, sqr4, sqr8, sqr16, table_id, ctyp, loadtbl, savetbl, bits))
+ else:
+ assert(False)
+
+for bits in range(2, 65):
+ print("#ifdef ENABLE_FIELD_INT_%i" % bits)
+ print("// %i bit field" % bits)
+ print_result(bits, INT)
+ print("#endif")
+ print("")
+
+for bits in range(2, 65):
+ print("#ifdef ENABLE_FIELD_INT_%i" % bits)
+ print("// %i bit field" % bits)
+ print_result(bits, CLMUL)
+ print_result(bits, CLMUL_TRI)
+ print("#endif")
+ print("")
+
+for bits in range(2, 65):
+ print_result(bits, MD)