aboutsummaryrefslogtreecommitdiff
path: root/src/secp256k1/sage/group_prover.sage
blob: b200bfeae3d1c6c32fd3fc7de6160ebc6bbb12e9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
# This code supports verifying group implementations which have branches
# or conditional statements (like cmovs), by allowing each execution path
# to independently set assumptions on input or intermediary variables.
#
# The general approach is:
# * A constraint is a tuple of two sets of symbolic expressions:
#   the first of which are required to evaluate to zero, the second of which
#   are required to evaluate to nonzero.
#   - A constraint is said to be conflicting if any of its nonzero expressions
#     is in the ideal with basis the zero expressions (in other words: when the
#     zero expressions imply that one of the nonzero expressions are zero).
# * There is a list of laws that describe the intended behaviour, including
#   laws for addition and doubling. Each law is called with the symbolic point
#   coordinates as arguments, and returns:
#   - A constraint describing the assumptions under which it is applicable,
#     called "assumeLaw"
#   - A constraint describing the requirements of the law, called "require"
# * Implementations are transliterated into functions that operate as well on
#   algebraic input points, and are called once per combination of branches
#   executed. Each execution returns:
#   - A constraint describing the assumptions this implementation requires
#     (such as Z1=1), called "assumeFormula"
#   - A constraint describing the assumptions this specific branch requires,
#     but which is by construction guaranteed to cover the entire space by
#     merging the results from all branches, called "assumeBranch"
#   - The result of the computation
# * All combinations of laws with implementation branches are tried, and:
#   - If the combination of assumeLaw, assumeFormula, and assumeBranch results
#     in a conflict, it means this law does not apply to this branch, and it is
#     skipped.
#   - For others, we try to prove the require constraints hold, assuming the
#     information in assumeLaw + assumeFormula + assumeBranch, and if this does
#     not succeed, we fail.
#     + To prove an expression is zero, we check whether it belongs to the
#       ideal with the assumed zero expressions as basis. This test is exact.
#     + To prove an expression is nonzero, we check whether each of its
#       factors is contained in the set of nonzero assumptions' factors.
#       This test is not exact, so various combinations of original and
#       reduced expressions' factors are tried.
#   - If we succeed, we print out the assumptions from assumeFormula that
#     weren't implied by assumeLaw already. Those from assumeBranch are skipped,
#     as we assume that all constraints in it are complementary with each other.
#
# Based on the sage verification scripts used in the Explicit-Formulas Database
# by Tanja Lange and others, see https://hyperelliptic.org/EFD

class fastfrac:
  """Fractions over rings."""

  def __init__(self,R,top,bot=1):
    """Construct a fractional, given a ring, a numerator, and denominator."""
    self.R = R
    if parent(top) == ZZ or parent(top) == R:
      self.top = R(top)
      self.bot = R(bot)
    elif top.__class__ == fastfrac:
      self.top = top.top
      self.bot = top.bot * bot
    else:
      self.top = R(numerator(top))
      self.bot = R(denominator(top)) * bot

  def iszero(self,I):
    """Return whether this fraction is zero given an ideal."""
    return self.top in I and self.bot not in I

  def reduce(self,assumeZero):
    zero = self.R.ideal(list(map(numerator, assumeZero)))
    return fastfrac(self.R, zero.reduce(self.top)) / fastfrac(self.R, zero.reduce(self.bot))

  def __add__(self,other):
    """Add two fractions."""
    if parent(other) == ZZ:
      return fastfrac(self.R,self.top + self.bot * other,self.bot)
    if other.__class__ == fastfrac:
      return fastfrac(self.R,self.top * other.bot + self.bot * other.top,self.bot * other.bot)
    return NotImplemented

  def __sub__(self,other):
    """Subtract two fractions."""
    if parent(other) == ZZ:
      return fastfrac(self.R,self.top - self.bot * other,self.bot)
    if other.__class__ == fastfrac:
      return fastfrac(self.R,self.top * other.bot - self.bot * other.top,self.bot * other.bot)
    return NotImplemented

  def __neg__(self):
    """Return the negation of a fraction."""
    return fastfrac(self.R,-self.top,self.bot)

  def __mul__(self,other):
    """Multiply two fractions."""
    if parent(other) == ZZ:
      return fastfrac(self.R,self.top * other,self.bot)
    if other.__class__ == fastfrac:
      return fastfrac(self.R,self.top * other.top,self.bot * other.bot)
    return NotImplemented

  def __rmul__(self,other):
    """Multiply something else with a fraction."""
    return self.__mul__(other)

  def __truediv__(self,other):
    """Divide two fractions."""
    if parent(other) == ZZ:
      return fastfrac(self.R,self.top,self.bot * other)
    if other.__class__ == fastfrac:
      return fastfrac(self.R,self.top * other.bot,self.bot * other.top)
    return NotImplemented

  # Compatibility wrapper for Sage versions based on Python 2
  def __div__(self,other):
     """Divide two fractions."""
     return self.__truediv__(other)

  def __pow__(self,other):
    """Compute a power of a fraction."""
    if parent(other) == ZZ:
      if other < 0:
        # Negative powers require flipping top and bottom
        return fastfrac(self.R,self.bot ^ (-other),self.top ^ (-other))
      else:
        return fastfrac(self.R,self.top ^ other,self.bot ^ other)
    return NotImplemented

  def __str__(self):
    return "fastfrac((" + str(self.top) + ") / (" + str(self.bot) + "))"
  def __repr__(self):
    return "%s" % self

  def numerator(self):
    return self.top

class constraints:
  """A set of constraints, consisting of zero and nonzero expressions.

  Constraints can either be used to express knowledge or a requirement.

  Both the fields zero and nonzero are maps from expressions to description
  strings. The expressions that are the keys in zero are required to be zero,
  and the expressions that are the keys in nonzero are required to be nonzero.

  Note that (a != 0) and (b != 0) is the same as (a*b != 0), so all keys in
  nonzero could be multiplied into a single key. This is often much less
  efficient to work with though, so we keep them separate inside the
  constraints. This allows higher-level code to do fast checks on the individual
  nonzero elements, or combine them if needed for stronger checks.

  We can't multiply the different zero elements, as it would suffice for one of
  the factors to be zero, instead of all of them. Instead, the zero elements are
  typically combined into an ideal first.
  """

  def __init__(self, **kwargs):
    if 'zero' in kwargs:
      self.zero = dict(kwargs['zero'])
    else:
      self.zero = dict()
    if 'nonzero' in kwargs:
      self.nonzero = dict(kwargs['nonzero'])
    else:
      self.nonzero = dict()

  def negate(self):
    return constraints(zero=self.nonzero, nonzero=self.zero)

  def __add__(self, other):
    zero = self.zero.copy()
    zero.update(other.zero)
    nonzero = self.nonzero.copy()
    nonzero.update(other.nonzero)
    return constraints(zero=zero, nonzero=nonzero)

  def __str__(self):
    return "constraints(zero=%s,nonzero=%s)" % (self.zero, self.nonzero)

  def __repr__(self):
    return "%s" % self


def conflicts(R, con):
  """Check whether any of the passed non-zero assumptions is implied by the zero assumptions"""
  zero = R.ideal(list(map(numerator, con.zero)))
  if 1 in zero:
    return True
  # First a cheap check whether any of the individual nonzero terms conflict on
  # their own.
  for nonzero in con.nonzero:
    if nonzero.iszero(zero):
      return True
  # It can be the case that entries in the nonzero set do not individually
  # conflict with the zero set, but their combination does. For example, knowing
  # that either x or y is zero is equivalent to having x*y in the zero set.
  # Having x or y individually in the nonzero set is not a conflict, but both
  # simultaneously is, so that is the right thing to check for.
  if reduce(lambda a,b: a * b, con.nonzero, fastfrac(R, 1)).iszero(zero):
    return True
  return False


def get_nonzero_set(R, assume):
  """Calculate a simple set of nonzero expressions"""
  zero = R.ideal(list(map(numerator, assume.zero)))
  nonzero = set()
  for nz in map(numerator, assume.nonzero):
    for (f,n) in nz.factor():
      nonzero.add(f)
    rnz = zero.reduce(nz)
    for (f,n) in rnz.factor():
      nonzero.add(f)
  return nonzero


def prove_nonzero(R, exprs, assume):
  """Check whether an expression is provably nonzero, given assumptions"""
  zero = R.ideal(list(map(numerator, assume.zero)))
  nonzero = get_nonzero_set(R, assume)
  expl = set()
  ok = True
  for expr in exprs:
    if numerator(expr) in zero:
      return (False, [exprs[expr]])
  allexprs = reduce(lambda a,b: numerator(a)*numerator(b), exprs, 1)
  for (f, n) in allexprs.factor():
    if f not in nonzero:
      ok = False
  if ok:
    return (True, None)
  ok = True
  for (f, n) in zero.reduce(numerator(allexprs)).factor():
    if f not in nonzero:
      ok = False
  if ok:
    return (True, None)
  ok = True
  for expr in exprs:
    for (f,n) in numerator(expr).factor():
      if f not in nonzero:
        ok = False
  if ok:
    return (True, None)
  ok = True
  for expr in exprs:
    for (f,n) in zero.reduce(numerator(expr)).factor():
      if f not in nonzero:
        expl.add(exprs[expr])
  if expl:
    return (False, list(expl))
  else:
    return (True, None)


def prove_zero(R, exprs, assume):
  """Check whether all of the passed expressions are provably zero, given assumptions"""
  r, e = prove_nonzero(R, dict(map(lambda x: (fastfrac(R, x.bot, 1), exprs[x]), exprs)), assume)
  if not r:
    return (False, map(lambda x: "Possibly zero denominator: %s" % x, e))
  zero = R.ideal(list(map(numerator, assume.zero)))
  nonzero = prod(x for x in assume.nonzero)
  expl = []
  for expr in exprs:
    if not expr.iszero(zero):
      expl.append(exprs[expr])
  if not expl:
    return (True, None)
  return (False, expl)


def describe_extra(R, assume, assumeExtra):
  """Describe what assumptions are added, given existing assumptions"""
  zerox = assume.zero.copy()
  zerox.update(assumeExtra.zero)
  zero = R.ideal(list(map(numerator, assume.zero)))
  zeroextra = R.ideal(list(map(numerator, zerox)))
  nonzero = get_nonzero_set(R, assume)
  ret = set()
  # Iterate over the extra zero expressions
  for base in assumeExtra.zero:
    if base not in zero:
      add = []
      for (f, n) in numerator(base).factor():
        if f not in nonzero:
          add += ["%s" % f]
      if add:
        ret.add((" * ".join(add)) + " = 0 [%s]" % assumeExtra.zero[base])
  # Iterate over the extra nonzero expressions
  for nz in assumeExtra.nonzero:
    nzr = zeroextra.reduce(numerator(nz))
    if nzr not in zeroextra:
      for (f,n) in nzr.factor():
        if zeroextra.reduce(f) not in nonzero:
          ret.add("%s != 0" % zeroextra.reduce(f))
  return ", ".join(x for x in ret)


def check_symbolic(R, assumeLaw, assumeAssert, assumeBranch, require):
  """Check a set of zero and nonzero requirements, given a set of zero and nonzero assumptions"""
  assume = assumeLaw + assumeAssert + assumeBranch

  if conflicts(R, assume):
    # This formula does not apply
    return None

  describe = describe_extra(R, assumeLaw + assumeBranch, assumeAssert)

  ok, msg = prove_zero(R, require.zero, assume)
  if not ok:
    return "FAIL, %s fails (assuming %s)" % (str(msg), describe)

  res, expl = prove_nonzero(R, require.nonzero, assume)
  if not res:
    return "FAIL, %s fails (assuming %s)" % (str(expl), describe)

  if describe != "":
    return "OK (assuming %s)" % describe
  else:
    return "OK"


def concrete_verify(c):
  for k in c.zero:
    if k != 0:
      return (False, c.zero[k])
  for k in c.nonzero:
    if k == 0:
      return (False, c.nonzero[k])
  return (True, None)