diff options
author | fanquake <fanquake@gmail.com> | 2021-10-21 09:36:07 +0800 |
---|---|---|
committer | fanquake <fanquake@gmail.com> | 2021-10-21 09:36:07 +0800 |
commit | 07f0a61ef711a2f75ded3d73545bfabdf2a64fef (patch) | |
tree | ba26a74a12af78d5761464523e6c43dc21025c69 /src/minisketch/doc/gen_params.sage | |
parent | 4229f71bf8c25eb7a95b16a5f9ae1d6c362a3265 (diff) | |
parent | b6487dc4ef47ec9ea894eceac25f37d0b806f8aa (diff) |
Merge commit 'b6487dc4ef47ec9ea894eceac25f37d0b806f8aa' as 'src/minisketch'
Diffstat (limited to 'src/minisketch/doc/gen_params.sage')
-rwxr-xr-x | src/minisketch/doc/gen_params.sage | 333 |
1 files changed, 333 insertions, 0 deletions
diff --git a/src/minisketch/doc/gen_params.sage b/src/minisketch/doc/gen_params.sage new file mode 100755 index 0000000000..1cf036adb4 --- /dev/null +++ b/src/minisketch/doc/gen_params.sage @@ -0,0 +1,333 @@ +#!/usr/bin/env sage +r""" +Generate finite field parameters for minisketch. + +This script selects the finite fields used by minisketch + for various sizes and generates the required tables for + the implementation. + +The output (after formatting) can be found in src/fields/*.cpp. + +""" +B.<b> = GF(2) +P.<p> = B[] + +def apply_map(m, v): + r = 0 + i = 0 + while v != 0: + if (v & 1): + r ^^= m[i] + i += 1 + v >>= 1 + return r + +def recurse_moduli(acc, maxweight, maxdegree): + for pos in range(maxweight, maxdegree + 1, 1): + poly = acc + p^pos + if maxweight == 1: + if poly.is_irreducible(): + return (pos, poly) + else: + (deg, ret) = recurse_moduli(poly, maxweight - 1, pos - 1) + if ret is not None: + return (pos, ret) + return (None, None) + +def compute_moduli(bits): + # Return all optimal irreducible polynomials for GF(2^bits) + # The result is a list of tuples (weight, degree of second-highest nonzero coefficient, polynomial) + maxdegree = bits - 1 + result = [] + for weight in range(1, bits, 2): + deg, res = None, None + while True: + ret = recurse_moduli(p^bits + 1, weight, maxdegree) + if ret[0] is not None: + (deg, res) = ret + maxdegree = deg - 1 + else: + break + if res is not None: + result.append((weight + 2, deg, res)) + return result + +def bits_to_int(vals): + ret = 0 + base = 1 + for val in vals: + ret += Integer(val) * base + base *= 2 + return ret + +def sqr_table(f, bits, n=1): + ret = [] + for i in range(bits): + ret.append((f^(2^n*i)).integer_representation()) + return ret + +# Compute x**(2**n) +def pow2(x, n): + for i in range(n): + x = x**2 + return x + +def qrt_table(F, f, bits): + # Table for solving x2 + x = a + # This implements the technique from https://www.raco.cat/index.php/PublicacionsMatematiques/article/viewFile/37927/40412, Lemma 1 + for i in range(bits): + if (f**i).trace() != 0: + u = f**i + ret = [] + for i in range(0, bits): + d = f^i + y = sum(pow2(d, j) * sum(pow2(u, k) for k in range(j)) for j in range(1, bits)) + ret.append(y.integer_representation() ^^ (y.integer_representation() & 1)) + return ret + +def conv_tables(F, NF, bits): + # Generate a F(2) linear projection that maps elements from one field + # to an isomorphic field with a different modulus. + f = F.gen() + fp = f.minimal_polynomial() + assert(fp == F.modulus()) + nfp = fp.change_ring(NF) + nf = sorted(nfp.roots(multiplicities=False))[0] + ret = [] + matrepr = [[B(0) for x in range(bits)] for y in range(bits)] + for i in range(bits): + val = (nf**i).integer_representation() + ret.append(val) + for j in range(bits): + matrepr[j][i] = B((val >> j) & 1) + mat = Matrix(matrepr).inverse().transpose() + ret2 = [] + for i in range(bits): + ret2.append(bits_to_int(mat[i])) + + for t in range(100): + f1a = F.random_element() + f1b = F.random_element() + f1r = f1a * f1b + f2a = NF.fetch_int(apply_map(ret, f1a.integer_representation())) + f2b = NF.fetch_int(apply_map(ret, f1b.integer_representation())) + f2r = NF.fetch_int(apply_map(ret, f1r.integer_representation())) + f2s = f2a * f2b + assert(f2r == f2s) + + for t in range(100): + f2a = NF.random_element() + f2b = NF.random_element() + f2r = f2a * f2b + f1a = F.fetch_int(apply_map(ret2, f2a.integer_representation())) + f1b = F.fetch_int(apply_map(ret2, f2b.integer_representation())) + f1r = F.fetch_int(apply_map(ret2, f2r.integer_representation())) + f1s = f1a * f1b + assert(f1r == f1s) + + return (ret, ret2) + +def fmt(i,typ): + if i == 0: + return "0" + else: + return "0x%x" % i + +def lintranstype(typ, bits, maxtbl): + gsize = min(maxtbl, bits) + array_size = (bits + gsize - 1) // gsize + bits_list = [] + total = 0 + for i in range(array_size): + rsize = (bits - total + array_size - i - 1) // (array_size - i) + total += rsize + bits_list.append(rsize) + return "RecLinTrans<%s, %s>" % (typ, ", ".join("%i" % x for x in bits_list)) + +INT=0 +CLMUL=1 +CLMUL_TRI=2 +MD=3 + +def print_modulus_md(mod): + ret = "" + pos = mod.degree() + for c in reversed(list(mod)): + if c: + if ret: + ret += " + " + if pos == 0: + ret += "1" + elif pos == 1: + ret += "x" + else: + ret += "x<sup>%i</sup>" % pos + pos -= 1 + return ret + +def pick_modulus(bits, style): + # Choose the lexicographicly-first lowest-weight modulus + # optionally subject to implementation specific constraints. + moduli = compute_moduli(bits) + if style == INT or style == MD: + multi_sqr = False + need_trans = False + elif style == CLMUL: + # Fast CLMUL reduction requires that bits + the highest + # set bit are less than 66. + moduli = list(filter((lambda x: bits+x[1] <= 66), moduli)) + moduli + multi_sqr = True + need_trans = True + if not moduli or moduli[0][2].change_ring(ZZ)(2) == 3 + 2**bits: + # For modulus 3, CLMUL_TRI is obviously better. + return None + elif style == CLMUL_TRI: + moduli = list(filter(lambda x: bits+x[1] <= 66, moduli)) + moduli + moduli = list(filter(lambda x: x[0] == 3, moduli)) + multi_sqr = True + need_trans = True + else: + assert(False) + if not moduli: + return None + return moduli[0][2] + +def print_result(bits, style): + if style == INT: + multi_sqr = False + need_trans = False + table_id = "%i" % bits + elif style == MD: + pass + elif style == CLMUL: + multi_sqr = True + need_trans = True + table_id = "%i" % bits + elif style == CLMUL_TRI: + multi_sqr = True + need_trans = True + table_id = "TRI%i" % bits + else: + assert(False) + + nmodulus = pick_modulus(bits, INT) + modulus = pick_modulus(bits, style) + if modulus is None: + return + + if style == MD: + print("* *%s*" % print_modulus_md(modulus)) + return + + if bits > 32: + typ = "uint64_t" + elif bits > 16: + typ = "uint32_t" + elif bits > 8: + typ = "uint16_t" + else: + typ = "uint8_t" + + ttyp = lintranstype(typ, bits, 4) + rtyp = lintranstype(typ, bits, 6) + + F.<f> = GF(2**bits, modulus=modulus) + + include_table = True + if style != INT and style != CLMUL: + cmodulus = pick_modulus(bits, CLMUL) + if cmodulus == modulus: + include_table = False + table_id = "%i" % bits + + if include_table: + print("typedef %s StatTable%s;" % (rtyp, table_id)) + rtyp = "StatTable%s" % table_id + if (style == INT): + print("typedef %s DynTable%s;" % (ttyp, table_id)) + ttyp = "DynTable%s" % table_id + + if need_trans: + if modulus != nmodulus: + # If the bitstream modulus is not the best modulus for + # this implementation a conversion table will be needed. + ctyp = rtyp + NF.<nf> = GF(2**bits, modulus=nmodulus) + ctables = conv_tables(NF, F, bits) + loadtbl = "&LOAD_TABLE_%s" % table_id + savetbl = "&SAVE_TABLE_%s" % table_id + if include_table: + print("constexpr %s LOAD_TABLE_%s({%s});" % (ctyp, table_id, ", ".join([fmt(x,typ) for x in ctables[0]]))) + print("constexpr %s SAVE_TABLE_%s({%s});" % (ctyp, table_id, ", ".join([fmt(x,typ) for x in ctables[1]]))) + else: + ctyp = "IdTrans" + loadtbl = "&ID_TRANS" + savetbl = "&ID_TRANS" + else: + assert(modulus == nmodulus) + + if include_table: + print("constexpr %s SQR_TABLE_%s({%s});" % (rtyp, table_id, ", ".join([fmt(x,typ) for x in sqr_table(f, bits, 1)]))) + if multi_sqr: + # Repeated squaring is a linearised polynomial so in F(2^n) it is + # F(2) linear and can be computed by a simple bit-matrix. + # Repeated squaring is especially useful in powering ladders such as + # for inversion. + # When certain repeated squaring tables are not in use, use the QRT + # table instead to make the C++ compiler happy (it always has the + # same type). + sqr2 = "&QRT_TABLE_%s" % table_id + sqr4 = "&QRT_TABLE_%s" % table_id + sqr8 = "&QRT_TABLE_%s" % table_id + sqr16 = "&QRT_TABLE_%s" % table_id + if ((bits - 1) >= 4): + if include_table: + print("constexpr %s SQR2_TABLE_%s({%s});" % (rtyp, table_id, ", ".join([fmt(x,typ) for x in sqr_table(f, bits, 2)]))) + sqr2 = "&SQR2_TABLE_%s" % table_id + if ((bits - 1) >= 8): + if include_table: + print("constexpr %s SQR4_TABLE_%s({%s});" % (rtyp, table_id, ", ".join([fmt(x,typ) for x in sqr_table(f, bits, 4)]))) + sqr4 = "&SQR4_TABLE_%s" % table_id + if ((bits - 1) >= 16): + if include_table: + print("constexpr %s SQR8_TABLE_%s({%s});" % (rtyp, table_id, ", ".join([fmt(x,typ) for x in sqr_table(f, bits, 8)]))) + sqr8 = "&SQR8_TABLE_%s" % table_id + if ((bits - 1) >= 32): + if include_table: + print("constexpr %s SQR16_TABLE_%s({%s});" % (rtyp, table_id, ", ".join([fmt(x,typ) for x in sqr_table(f, bits, 16)]))) + sqr16 = "&SQR16_TABLE_%s" % table_id + if include_table: + print("constexpr %s QRT_TABLE_%s({%s});" % (rtyp, table_id, ", ".join([fmt(x,typ) for x in qrt_table(F, f, bits)]))) + + modulus_weight = modulus.hamming_weight() + modulus_degree = (modulus - p**bits).degree() + modulus_int = (modulus - p**bits).change_ring(ZZ)(2) + + lfsr = "" + + if style == INT: + print("typedef Field<%s, %i, %i, %s, %s, &SQR_TABLE_%s, &QRT_TABLE_%s%s> Field%i;" % (typ, bits, modulus_int, rtyp, ttyp, table_id, table_id, lfsr, bits)) + elif style == CLMUL: + print("typedef Field<%s, %i, %i, %s, &SQR_TABLE_%s, %s, %s, %s, %s, &QRT_TABLE_%s, %s, %s, %s%s> Field%i;" % (typ, bits, modulus_int, rtyp, table_id, sqr2, sqr4, sqr8, sqr16, table_id, ctyp, loadtbl, savetbl, lfsr, bits)) + elif style == CLMUL_TRI: + print("typedef FieldTri<%s, %i, %i, %s, &SQR_TABLE_%s, %s, %s, %s, %s, &QRT_TABLE_%s, %s, %s, %s> FieldTri%i;" % (typ, bits, modulus_degree, rtyp, table_id, sqr2, sqr4, sqr8, sqr16, table_id, ctyp, loadtbl, savetbl, bits)) + else: + assert(False) + +for bits in range(2, 65): + print("#ifdef ENABLE_FIELD_INT_%i" % bits) + print("// %i bit field" % bits) + print_result(bits, INT) + print("#endif") + print("") + +for bits in range(2, 65): + print("#ifdef ENABLE_FIELD_INT_%i" % bits) + print("// %i bit field" % bits) + print_result(bits, CLMUL) + print_result(bits, CLMUL_TRI) + print("#endif") + print("") + +for bits in range(2, 65): + print_result(bits, MD) |