aboutsummaryrefslogtreecommitdiff
path: root/contrib/seeds/asmap.py
blob: 7a605d0b9e2dff1250d3e9fbb516944e0862cab1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
# Copyright (c) 2022 Pieter Wuille
# Distributed under the MIT software license, see the accompanying
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.

"""
This module provides the ASNEntry and ASMap classes.
"""

import copy
import ipaddress
import random
import unittest
from enum import Enum
from functools import total_ordering
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union, overload

def net_to_prefix(net: Union[ipaddress.IPv4Network,ipaddress.IPv6Network]) -> List[bool]:
    """
    Convert an IPv4 or IPv6 network to a prefix represented as a list of bits.

    IPv4 ranges are remapped to their IPv4-mapped IPv6 range (::ffff:0:0/96).
    """
    num_bits = net.prefixlen
    netrange = int.from_bytes(net.network_address.packed, 'big')

    # Map an IPv4 prefix into IPv6 space.
    if isinstance(net, ipaddress.IPv4Network):
        num_bits += 96
        netrange += 0xffff00000000

    # Strip unused bottom bits.
    assert (netrange & ((1 << (128 - num_bits)) - 1)) == 0
    return [((netrange >> (127 - i)) & 1) != 0 for i in range(num_bits)]

def prefix_to_net(prefix: List[bool]) -> Union[ipaddress.IPv4Network,ipaddress.IPv6Network]:
    """The reverse operation of net_to_prefix."""
    # Convert to number
    netrange = sum(b << (127 - i) for i, b in enumerate(prefix))
    num_bits = len(prefix)
    assert num_bits <= 128

    # Return IPv4 range if in ::ffff:0:0/96
    if num_bits >= 96 and (netrange >> 32) == 0xffff:
        return ipaddress.IPv4Network((netrange & 0xffffffff, num_bits - 96), True)

    # Return IPv6 range otherwise.
    return ipaddress.IPv6Network((netrange, num_bits), True)

# Shortcut for (prefix, ASN) entries.
ASNEntry = Tuple[List[bool], int]

# Shortcut for (prefix, old ASN, new ASN) entries.
ASNDiff = Tuple[List[bool], int, int]

class _VarLenCoder:
    """
    A class representing a custom variable-length binary encoder/decoder for
    integers. Each object represents a different coder, with different parameters
    minval and clsbits.

    The encoding is easiest to describe using an example. Let's say minval=100 and
    clsbits=[4,2,2,3]. In that case:
    - x in [100..115]: encoded as [0] + [4-bit BE encoding of (x-100)].
    - x in [116..119]: encoded as [1,0] + [2-bit BE encoding of (x-116)].
    - x in [120..123]: encoded as [1,1,0] + [2-bit BE encoding of (x-120)].
    - x in [124..131]: encoded as [1,1,1] + [3-bit BE encoding of (x-124)].

    In general, every number is encoded as:
    - First, k "1"-bits, where k is the class the number falls in (there is one class
      per element of clsbits).
    - Then, a "0"-bit, unless k is the highest class, in which case there is nothing.
    - Lastly, clsbits[k] bits encoding in big endian the position in its class that
      number falls into.
    - Every class k consists of 2^clsbits[k] consecutive integers. k=0 starts at minval,
      other classes start one past the last element of the class before it.
    """

    def __init__(self, minval: int, clsbits: List[int]):
        """Construct a new _VarLenCoder."""
        self._minval = minval
        self._clsbits = clsbits
        self._maxval = minval + sum(1 << b for b in clsbits) - 1

    def can_encode(self, val: int) -> bool:
        """Check whether value val is in the range this coder supports."""
        return self._minval <= val <= self._maxval

    def encode(self, val: int, ret: List[int]) -> None:
        """Append encoding of val onto integer list ret."""

        assert self._minval <= val <= self._maxval
        val -= self._minval
        bits = 0
        for k, bits in enumerate(self._clsbits):
            if val >> bits:
                # If the value will not fit in class k, subtract its range from v,
                # emit a "1" bit and continue with the next class.
                val -= 1 << bits
                ret.append(1)
            else:
                if k + 1 < len(self._clsbits):
                    # Unless we're in the last class, emit a "0" bit.
                    ret.append(0)
                break
        # And then encode v (now the position within the class) in big endian.
        ret.extend((val >> (bits - 1 - b)) & 1 for b in range(bits))

    def encode_size(self, val: int) -> int:
        """Compute how many bits are needed to encode val."""
        assert self._minval <= val <= self._maxval
        val -= self._minval
        ret = 0
        bits = 0
        for k, bits in enumerate(self._clsbits):
            if val >> bits:
                val -= 1 << bits
                ret += 1
            else:
                ret += k + 1 < len(self._clsbits)
                break
        return ret + bits

    def decode(self, stream, bitpos) -> Tuple[int,int]:
        """Decode a number starting at bitpos in stream, returning value and new bitpos."""
        val = self._minval
        bits = 0
        for k, bits in enumerate(self._clsbits):
            bit = 0
            if k + 1 < len(self._clsbits):
                bit = stream[bitpos]
                bitpos += 1
            if not bit:
                break
            val += 1 << bits
        for i in range(bits):
            bit = stream[bitpos]
            bitpos += 1
            val += bit << (bits - 1 - i)
        return val, bitpos

# Variable-length encoders used in the binary asmap format.
_CODER_INS = _VarLenCoder(0, [0, 0, 1])
_CODER_ASN = _VarLenCoder(1, list(range(15, 25)))
_CODER_MATCH = _VarLenCoder(2, list(range(1, 9)))
_CODER_JUMP = _VarLenCoder(17, list(range(5, 31)))

class _Instruction(Enum):
    """One instruction in the binary asmap format."""
    # A return instruction, encoded as [0], returns a constant ASN. It is followed by
    # an integer using the ASN encoding.
    RETURN = 0
    # A jump instruction, encoded as [1,0] inspects the next unused bit in the input
    # and either continues execution (if 0), or skips a specified number of bits (if 1).
    # It is followed by an integer, and then two subprograms. The integer uses jump encoding
    # and corresponds to the length of the first subprogram (so it can be skipped).
    JUMP = 1
    # A match instruction, encoded as [1,1,0] inspects 1 or more of the next unused bits
    # in the input with its argument. If they all match, execution continues. If they do
    # not, failure is returned. If a default instruction has been executed before, instead
    # of failure the default instruction's argument is returned. It is followed by an
    # integer in match encoding, and a subprogram. That value is at least 2 bits and at
    # most 9 bits. An n-bit value signifies matching (n-1) bits in the input with the lower
    # (n-1) bits in the match value.
    MATCH = 2
    # A default instruction, encoded as [1,1,1] sets the default variable to its argument,
    # and continues execution. It is followed by an integer in ASN encoding, and a subprogram.
    DEFAULT = 3
    # Not an actual instruction, but a way to encode the empty program that fails. In the
    # encoder, it is used more generally to represent the failure case inside MATCH instructions,
    # which may (if used inside the context of a DEFAULT instruction) actually correspond to
    # a succesful return. In this usage, they're always converted to an actual MATCH or RETURN
    # before the top level is reached (see make_default below).
    END = 4

class _BinNode:
    """A class representing a (node of) the parsed binary asmap format."""

    @overload
    def __init__(self, ins: _Instruction): ...
    @overload
    def __init__(self, ins: _Instruction, arg1: int): ...
    @overload
    def __init__(self, ins: _Instruction, arg1: "_BinNode", arg2: "_BinNode"): ...
    @overload
    def __init__(self, ins: _Instruction, arg1: int, arg2: "_BinNode"): ...

    def __init__(self, ins: _Instruction, arg1=None, arg2=None):
        """
        Construct a new asmap node. Possibilities are:
        - _BinNode(_Instruction.RETURN, asn)
        - _BinNode(_Instruction.JUMP, node_0, node_1)
        - _BinNode(_Instruction.MATCH, val, node)
        - _BinNode(_Instruction.DEFAULT, asn, node)
        - _BinNode(_Instruction.END)
        """
        self.ins = ins
        self.arg1 = arg1
        self.arg2 = arg2
        if ins == _Instruction.RETURN:
            assert isinstance(arg1, int)
            assert arg2 is None
            self.size = _CODER_INS.encode_size(ins.value) + _CODER_ASN.encode_size(arg1)
        elif ins == _Instruction.JUMP:
            assert isinstance(arg1, _BinNode)
            assert isinstance(arg2, _BinNode)
            self.size = (_CODER_INS.encode_size(ins.value) + _CODER_JUMP.encode_size(arg1.size) +
                         arg1.size + arg2.size)
        elif ins == _Instruction.DEFAULT:
            assert isinstance(arg1, int)
            assert isinstance(arg2, _BinNode)
            self.size = _CODER_INS.encode_size(ins.value) + _CODER_ASN.encode_size(arg1) + arg2.size
        elif ins == _Instruction.MATCH:
            assert isinstance(arg1, int)
            assert isinstance(arg2, _BinNode)
            self.size = (_CODER_INS.encode_size(ins.value) + _CODER_MATCH.encode_size(arg1)
                         + arg2.size)
        elif ins == _Instruction.END:
            assert arg1 is None
            assert arg2 is None
            self.size = 0
        else:
            assert False

    @staticmethod
    def make_end() -> "_BinNode":
        """Constructor for a _BinNode with just an END instruction."""
        return _BinNode(_Instruction.END)

    @staticmethod
    def make_leaf(val: int) -> "_BinNode":
        """Constructor for a _BinNode of just a RETURN instruction."""
        assert val is not None and val > 0
        return _BinNode(_Instruction.RETURN, val)

    @staticmethod
    def make_branch(node0: "_BinNode", node1: "_BinNode") -> "_BinNode":
        """
        Construct a _BinNode corresponding to running either the node0 or node1 subprogram,
        based on the next input bit. It exploits shortcuts that are possible in the encoding,
        and uses either a JUMP, MATCH, or END instruction.
        """
        if node0.ins == _Instruction.END and node1.ins == _Instruction.END:
            return node0
        if node0.ins == _Instruction.END:
            if node1.ins == _Instruction.MATCH and node1.arg1 <= 0xFF:
                return _BinNode(node1.ins, node1.arg1 + (1 << node1.arg1.bit_length()), node1.arg2)
            return _BinNode(_Instruction.MATCH, 3, node1)
        if node1.ins == _Instruction.END:
            if node0.ins == _Instruction.MATCH and node0.arg1 <= 0xFF:
                return _BinNode(node0.ins, node0.arg1 + (1 << (node0.arg1.bit_length() - 1)),
                                node0.arg2)
            return _BinNode(_Instruction.MATCH, 2, node0)
        return _BinNode(_Instruction.JUMP, node0, node1)

    @staticmethod
    def make_default(val: int, sub: "_BinNode") -> "_BinNode":
        """
        Construct a _BinNode that corresponds to the specified subprogram, with the specified
        default value. It exploits shortcuts that are possible in the encoding, and will use
        either a DEFAULT or a RETURN instruction."""
        assert val is not None and val > 0
        if sub.ins == _Instruction.END:
            return _BinNode(_Instruction.RETURN, val)
        if sub.ins in (_Instruction.RETURN, _Instruction.DEFAULT):
            return sub
        return _BinNode(_Instruction.DEFAULT, val, sub)

@total_ordering
class ASMap:
    """
    A class whose objects represent a mapping from subnets to ASNs.

    Internally the mapping is stored as a binary trie, but can be converted
    from/to a list of ASNEntry objects, and from/to the binary asmap file format.

    In the trie representation, nodes are represented as bare lists for efficiency
    and ease of manipulation:
    - [0] means an unassigned subnet (no ASN mapping for it is present)
    - [int] means a subnet mapped entirely to the specified ASN.
    - [node,node] means a subnet whose lower half and upper half have different
    -             mappings, represented by new trie nodes.
    """

    def update(self, prefix: List[bool], asn: int) -> None:
        """Update this ASMap object to map prefix to the specified asn."""
        assert asn == 0 or _CODER_ASN.can_encode(asn)

        def recurse(node: List, offset: int) -> None:
            if offset == len(prefix):
                # Reached the end of prefix; overwrite this node.
                node.clear()
                node.append(asn)
                return
            if len(node) == 1:
                # Need to descend into a leaf node; split it up.
                oldasn = node[0]
                node.clear()
                node.append([oldasn])
                node.append([oldasn])
            # Descend into the node.
            recurse(node[prefix[offset]], offset + 1)
            # If the result is two identical leaf children, merge them.
            if len(node[0]) == 1 and len(node[1]) == 1 and node[0] == node[1]:
                oldasn = node[0][0]
                node.clear()
                node.append(oldasn)
        recurse(self._trie, 0)

    def update_multi(self, entries: List[Tuple[List[bool], int]]) -> None:
        """Apply multiple update operations, where longer prefixes take precedence."""
        entries.sort(key=lambda entry: len(entry[0]))
        for prefix, asn in entries:
            self.update(prefix, asn)

    def _set_trie(self, trie) -> None:
        """Set trie directly. Internal use only."""
        def recurse(node: List) -> None:
            if len(node) < 2:
                return
            recurse(node[0])
            recurse(node[1])
            if len(node[0]) == 2:
                return
            if node[0] == node[1]:
                if len(node[0]) == 0:
                    node.clear()
                else:
                    asn = node[0][0]
                    node.clear()
                    node.append(asn)
        recurse(trie)
        self._trie = trie

    def __init__(self, entries: Optional[Iterable[ASNEntry]] = None) -> None:
        """Construct an ASMap object from an optional list of entries."""
        self._trie = [0]
        if entries is not None:
            def entry_key(entry):
                """Sort function that places shorter prefixes first."""
                prefix, asn = entry
                return len(prefix), prefix, asn
            for prefix, asn in sorted(entries, key=entry_key):
                self.update(prefix, asn)

    def lookup(self, prefix: List[bool]) -> Optional[int]:
        """Look up a prefix. Returns ASN, or 0 if unassigned, or None if indeterminate."""
        node = self._trie
        for bit in prefix:
            if len(node) == 1:
                break
            node = node[bit]
        if len(node) == 1:
            return node[0]
        return None

    def _to_entries_flat(self, fill: bool = False) -> List[ASNEntry]:
        """Convert an ASMap object to a list of non-overlapping (prefix, asn) objects."""
        prefix : List[bool] = []

        def recurse(node: List) -> List[ASNEntry]:
            ret = []
            if len(node) == 1:
                if node[0] > 0:
                    ret = [(list(prefix), node[0])]
            elif len(node) == 2:
                prefix.append(False)
                ret = recurse(node[0])
                prefix[-1] = True
                ret += recurse(node[1])
                prefix.pop()
                if fill and len(ret) > 1:
                    asns = set(x[1] for x in ret)
                    if len(asns) == 1:
                        ret = [(list(prefix), list(asns)[0])]
            return ret
        return recurse(self._trie)

    def _to_entries_minimal(self, fill: bool = False) -> List[ASNEntry]:
        """Convert a trie to a minimal list of ASNEntry objects, exploiting overlap."""
        prefix : List[bool] = []

        def recurse(node: List) -> (Tuple[Dict[Optional[int], List[ASNEntry]], bool]):
            if len(node) == 1 and node[0] == 0:
                return {None if fill else 0: []}, True
            if len(node) == 1:
                return {node[0]: [], None: [(list(prefix), node[0])]}, False
            ret: Dict[Optional[int], List[ASNEntry]] = {}
            prefix.append(False)
            left, lhole = recurse(node[0])
            prefix[-1] = True
            right, rhole = recurse(node[1])
            prefix.pop()
            hole = not fill and (lhole or rhole)
            def candidate(ctx: Optional[int], res0: Optional[List[ASNEntry]],
                    res1: Optional[List[ASNEntry]]):
                if res0 is not None and res1 is not None:
                    if ctx not in ret or len(res0) + len(res1) < len(ret[ctx]):
                        ret[ctx] = res0 + res1
            for ctx in set(left) | set(right):
                candidate(ctx, left.get(ctx), right.get(ctx))
                candidate(ctx, left.get(None), right.get(ctx))
                candidate(ctx, left.get(ctx), right.get(None))
            if not hole:
                for ctx in list(ret):
                    if ctx is not None:
                        candidate(None, [(list(prefix), ctx)], ret[ctx])
            if None in ret:
                ret = {ctx:entries for ctx, entries in ret.items()
                       if ctx is None or len(entries) < len(ret[None])}
            if hole:
                ret = {ctx:entries for ctx, entries in ret.items() if ctx is None or ctx == 0}
            return ret, hole
        res, _ = recurse(self._trie)
        return res[0] if 0 in res else res[None]

    def __str__(self) -> str:
        """Convert this ASMap object to a string containing Python code constructing it."""
        return f"ASMap({self._trie})"

    def to_entries(self, overlapping: bool = True, fill: bool = False) -> List[ASNEntry]:
        """
        Convert the mappings in this ASMap object to a list of ASNEntry objects.

        Arguments:
            overlapping: Permit the subnets in the resulting ASNEntry to overlap.
                         Setting this can result in a shorter list.
            fill:        Permit the resulting ASNEntry objects to cover subnets that
                         are unassigned in this ASMap object. Setting this can
                         result in a shorter list.
        """
        if overlapping:
            return self._to_entries_minimal(fill)
        return self._to_entries_flat(fill)

    @staticmethod
    def from_random(num_leaves: int = 10, max_asn: int = 6,
                    unassigned_prob: float = 0.5) -> "ASMap":
        """
        Construct a random ASMap object, with specified:
         - Number of leaves in its trie (at least 1)
         - Maximum ASN value (at least 1)
         - Probability for leaf nodes to be unassigned

        The number of leaves in the resulting object may be less than what is
        requested. This method is mostly intended for testing.
        """
        assert num_leaves >= 1
        assert max_asn >= 1 or unassigned_prob == 1
        assert _CODER_ASN.can_encode(max_asn)
        assert 0.0 <= unassigned_prob <= 1.0
        trie: List = []
        leaves = [trie]
        ret = ASMap()
        for i in range(1, num_leaves):
            idx = random.randrange(i)
            leaf = leaves[idx]
            lastleaf = leaves.pop()
            if idx + 1 < i:
                leaves[idx] = lastleaf
            leaf.append([])
            leaf.append([])
            leaves.append(leaf[0])
            leaves.append(leaf[1])
        for leaf in leaves:
            if random.random() >= unassigned_prob:
                leaf.append(random.randrange(1, max_asn + 1))
            else:
                leaf.append(0)
        #pylint: disable=protected-access
        ret._set_trie(trie)
        return ret

    def _to_binnode(self, fill: bool = False) -> _BinNode:
        """Convert a trie to a _BinNode object."""
        def recurse(node: List) -> Tuple[Dict[Optional[int], _BinNode], bool]:
            if len(node) == 1 and node[0] == 0:
                return {(None if fill else 0): _BinNode.make_end()}, True
            if len(node) == 1:
                return {None: _BinNode.make_leaf(node[0]), node[0]: _BinNode.make_end()}, False
            ret: Dict[Optional[int], _BinNode] = {}
            left, lhole = recurse(node[0])
            right, rhole = recurse(node[1])
            hole = (lhole or rhole) and not fill

            def candidate(ctx: Optional[int], arg1, arg2, func: Callable):
                if arg1 is not None and arg2 is not None:
                    cand = func(arg1, arg2)
                    if ctx not in ret or cand.size < ret[ctx].size:
                        ret[ctx] = cand

            for ctx in set(left) | set(right):
                candidate(ctx, left.get(ctx), right.get(ctx), _BinNode.make_branch)
                candidate(ctx, left.get(None), right.get(ctx), _BinNode.make_branch)
                candidate(ctx, left.get(ctx), right.get(None), _BinNode.make_branch)
            if not hole:
                for ctx in set(ret) - set([None]):
                    candidate(None, ctx, ret[ctx], _BinNode.make_default)
            if None in ret:
                ret = {ctx:enc for ctx, enc in ret.items()
                       if ctx is None or enc.size < ret[None].size}
            if hole:
                ret = {ctx:enc for ctx, enc in ret.items() if ctx is None or ctx == 0}
            return ret, hole
        res, _ = recurse(self._trie)
        return res[0] if 0 in res else res[None]

    @staticmethod
    def _from_binnode(binnode: _BinNode) -> "ASMap":
        """Construct an ASMap object from a _BinNode. Internal use only."""
        def recurse(node: _BinNode, default: int) -> List:
            if node.ins == _Instruction.RETURN:
                return [node.arg1]
            if node.ins == _Instruction.JUMP:
                return [recurse(node.arg1, default), recurse(node.arg2, default)]
            if node.ins == _Instruction.MATCH:
                val = node.arg1
                sub = recurse(node.arg2, default)
                while val >= 2:
                    bit = val & 1
                    val >>= 1
                    if bit:
                        sub = [[default], sub]
                    else:
                        sub = [sub, [default]]
                return sub
            assert node.ins == _Instruction.DEFAULT
            return recurse(node.arg2, node.arg1)
        ret = ASMap()
        if binnode.ins != _Instruction.END:
            #pylint: disable=protected-access
            ret._set_trie(recurse(binnode, 0))
        return ret

    def to_binary(self, fill: bool = False) -> bytes:
        """
        Convert this ASMap object to binary.

        Argument:
            fill: permit the resulting binary encoder to contain mappers for
                  unassigned subnets in this ASMap object. Doing so may
                  reduce the size of the encoding.
        Returns:
            A bytes object with the encoding of this ASMap object.
        """
        bits: List[int] = []

        def recurse(node: _BinNode) -> None:
            _CODER_INS.encode(node.ins.value, bits)
            if node.ins == _Instruction.RETURN:
                _CODER_ASN.encode(node.arg1, bits)
            elif node.ins == _Instruction.JUMP:
                _CODER_JUMP.encode(node.arg1.size, bits)
                recurse(node.arg1)
                recurse(node.arg2)
            elif node.ins == _Instruction.DEFAULT:
                _CODER_ASN.encode(node.arg1, bits)
                recurse(node.arg2)
            else:
                assert node.ins == _Instruction.MATCH
                _CODER_MATCH.encode(node.arg1, bits)
                recurse(node.arg2)

        binnode = self._to_binnode(fill)
        if binnode.ins != _Instruction.END:
            recurse(binnode)

        val = 0
        nbits = 0
        ret = []
        for bit in bits:
            val += (bit << nbits)
            nbits += 1
            if nbits == 8:
                ret.append(val)
                val = 0
                nbits = 0
        if nbits:
            ret.append(val)
        return bytes(ret)

    @staticmethod
    def from_binary(bindata: bytes) -> Optional["ASMap"]:
        """Decode an ASMap object from the provided binary encoding."""

        bits: List[int] = []
        for byte in bindata:
            bits.extend((byte >> i) & 1 for i in range(8))

        def recurse(bitpos: int) -> Tuple[_BinNode, int]:
            insval, bitpos = _CODER_INS.decode(bits, bitpos)
            ins = _Instruction(insval)
            if ins == _Instruction.RETURN:
                asn, bitpos = _CODER_ASN.decode(bits, bitpos)
                return _BinNode(ins, asn), bitpos
            if ins == _Instruction.JUMP:
                jump, bitpos = _CODER_JUMP.decode(bits, bitpos)
                left, bitpos1 = recurse(bitpos)
                if bitpos1 != bitpos + jump:
                    raise ValueError("Inconsistent jump")
                right, bitpos = recurse(bitpos1)
                return _BinNode(ins, left, right), bitpos
            if ins == _Instruction.MATCH:
                match, bitpos = _CODER_MATCH.decode(bits, bitpos)
                sub, bitpos = recurse(bitpos)
                return _BinNode(ins, match, sub), bitpos
            assert ins == _Instruction.DEFAULT
            asn, bitpos = _CODER_ASN.decode(bits, bitpos)
            sub, bitpos = recurse(bitpos)
            return _BinNode(ins, asn, sub), bitpos

        if len(bits) == 0:
            binnode = _BinNode(_Instruction.END)
        else:
            try:
                binnode, bitpos = recurse(0)
            except (ValueError, IndexError):
                return None
            if bitpos < len(bits) - 7:
                return None
            if not all(bit == 0 for bit in bits[bitpos:]):
                return None

        return ASMap._from_binnode(binnode)

    def __lt__(self, other: "ASMap") -> bool:
        return self._trie < other._trie

    def __eq__(self, other: object) -> bool:
        if isinstance(other, ASMap):
            return self._trie == other._trie
        return False

    def extends(self, req: "ASMap") -> bool:
        """Determine whether this matches req for all subranges where req is assigned."""
        def recurse(actual: List, require: List) -> bool:
            if len(require) == 1 and require[0] == 0:
                return True
            if len(require) == 1:
                if len(actual) == 1:
                    return bool(require[0] == actual[0])
                return recurse(actual[0], require) and recurse(actual[1], require)
            if len(actual) == 2:
                return recurse(actual[0], require[0]) and recurse(actual[1], require[1])
            return recurse(actual, require[0]) and recurse(actual, require[1])
        assert isinstance(req, ASMap)
        #pylint: disable=protected-access
        return recurse(self._trie, req._trie)

    def diff(self, other: "ASMap") -> List[ASNDiff]:
        """Compute the diff from self to other."""
        prefix: List[bool] = []
        ret: List[ASNDiff] = []

        def recurse(old_node: List, new_node: List):
            if len(old_node) == 1 and len(new_node) == 1:
                if old_node[0] != new_node[0]:
                    ret.append((list(prefix), old_node[0], new_node[0]))
            else:
                old_left: List = old_node if len(old_node) == 1 else old_node[0]
                old_right: List = old_node if len(old_node) == 1 else old_node[1]
                new_left: List = new_node if len(new_node) == 1 else new_node[0]
                new_right: List = new_node if len(new_node) == 1 else new_node[1]
                prefix.append(False)
                recurse(old_left, new_left)
                prefix[-1] = True
                recurse(old_right, new_right)
                prefix.pop()
        assert isinstance(other, ASMap)
        #pylint: disable=protected-access
        recurse(self._trie, other._trie)
        return ret

    def __copy__(self) -> "ASMap":
        """Construct a copy of this ASMap object. Its state will not be shared."""
        ret = ASMap()
        #pylint: disable=protected-access
        ret._set_trie(copy.deepcopy(self._trie))
        return ret

    def __deepcopy__(self, _) -> "ASMap":
        # ASMap objects do not allow sharing of the _trie member, so we don't need the memoization.
        return self.__copy__()


class TestASMap(unittest.TestCase):
    """Unit tests for this module."""

    def test_ipv6_prefix_roundtrips(self) -> None:
        """Test that random IPv6 network ranges roundtrip through prefix encoding."""
        for _ in range(20):
            net_bits = random.getrandbits(128)
            for prefix_len in range(0, 129):
                masked_bits = (net_bits >> (128 - prefix_len)) << (128 - prefix_len)
                net = ipaddress.IPv6Network((masked_bits.to_bytes(16, 'big'), prefix_len))
                prefix = net_to_prefix(net)
                self.assertTrue(len(prefix) <= 128)
                net2 = prefix_to_net(prefix)
                self.assertEqual(net, net2)

    def test_ipv4_prefix_roundtrips(self) -> None:
        """Test that random IPv4 network ranges roundtrip through prefix encoding."""
        for _ in range(100):
            net_bits = random.getrandbits(32)
            for prefix_len in range(0, 33):
                masked_bits = (net_bits >> (32 - prefix_len)) << (32 - prefix_len)
                net = ipaddress.IPv4Network((masked_bits.to_bytes(4, 'big'), prefix_len))
                prefix = net_to_prefix(net)
                self.assertTrue(32 <= len(prefix) <= 128)
                net2 = prefix_to_net(prefix)
                self.assertEqual(net, net2)

    def test_asmap_roundtrips(self) -> None:
        """Test case that verifies random ASMap objects roundtrip to/from entries/binary."""
        # Iterate over the number of leaves the random test ASMap objects have.
        for leaves in range(1, 20):
            # Iterate over the number of bits in the AS numbers used.
            for asnbits in range(0, 24):
                # Iterate over the probability that leaves are unassigned.
                for pct in range(101):
                    # Construct a random ASMap object according to the above parameters.
                    asmap = ASMap.from_random(num_leaves=leaves, max_asn=1 + (1 << asnbits),
                                              unassigned_prob=0.01 * pct)
                    # Run tests for to_entries and construction from those entries, both
                    # for overlapping and non-overlapping ones.
                    for overlapping in [False, True]:
                        entries = asmap.to_entries(overlapping=overlapping, fill=False)
                        random.shuffle(entries)
                        asmap2 = ASMap(entries)
                        assert asmap2 is not None
                        self.assertEqual(asmap2, asmap)
                        entries = asmap.to_entries(overlapping=overlapping, fill=True)
                        random.shuffle(entries)
                        asmap2 = ASMap(entries)
                        assert asmap2 is not None
                        self.assertTrue(asmap2.extends(asmap))

                    # Run tests for to_binary and construction from binary.
                    enc = asmap.to_binary(fill=False)
                    asmap3 = ASMap.from_binary(enc)
                    assert asmap3 is not None
                    self.assertEqual(asmap3, asmap)
                    enc = asmap.to_binary(fill=True)
                    asmap3 = ASMap.from_binary(enc)
                    assert asmap3 is not None
                    self.assertTrue(asmap3.extends(asmap))

    def test_patching(self) -> None:
        """Test behavior of update, lookup, extends, and diff."""
        #pylint: disable=too-many-locals,too-many-nested-blocks
        # Iterate over the number of leaves the random test ASMap objects have.
        for leaves in range(1, 20):
            # Iterate over the number of bits in the AS numbers used.
            for asnbits in range(0, 10):
                # Iterate over the probability that leaves are unassigned.
                for pct in range(0, 101):
                    # Construct a random ASMap object according to the above parameters.
                    asmap = ASMap.from_random(num_leaves=leaves, max_asn=1 + (1 << asnbits),
                                              unassigned_prob=0.01 * pct)
                    # Make a copy of that asmap object to which patches will be applied.
                    # It starts off being equal to asmap.
                    patched = copy.copy(asmap)
                    # Keep a list of patches performed.
                    patches: List[ASNEntry] = []
                    # Initially there cannot be any difference.
                    self.assertEqual(asmap.diff(patched), [])
                    # Make 5 patches, each building on top of the previous ones.
                    for _ in range(0, 5):
                        # Construct a random path and new ASN to assign it to, apply it to patched,
                        # and remember it in patches.
                        pathlen = random.randrange(5)
                        path = [random.getrandbits(1) != 0 for _ in range(pathlen)]
                        newasn = random.randrange(1 + (1 << asnbits))
                        patched.update(path, newasn)
                        patches = [(path, newasn)] + patches

                        # Compute the diff, and whether asmap extends patched, and the other way
                        # around.
                        diff = asmap.diff(patched)
                        self.assertEqual(asmap == patched, len(diff) == 0)
                        extends = asmap.extends(patched)
                        back_extends = patched.extends(asmap)
                        # Determine whether those extends results are consistent with the diff
                        # result.
                        self.assertEqual(extends, all(d[2] == 0 for d in diff))
                        self.assertEqual(back_extends, all(d[1] == 0 for d in diff))
                        # For every diff found:
                        for path, old_asn, new_asn in diff:
                            # Verify asmap and patched actually differ there.
                            self.assertTrue(old_asn != new_asn)
                            self.assertEqual(asmap.lookup(path), old_asn)
                            self.assertEqual(patched.lookup(path), new_asn)
                            for _ in range(2):
                                # Extend the path far enough that it's smaller than any mapped
                                # range, and check the lookup holds there too.
                                spec_path = list(path)
                                while len(spec_path) < 32:
                                    spec_path.append(random.getrandbits(1) != 0)
                                self.assertEqual(asmap.lookup(spec_path), old_asn)
                                self.assertEqual(patched.lookup(spec_path), new_asn)
                                # Search through the list of performed patches to find the last one
                                # applying to the extended path (note that patches is in reverse
                                # order, so the first match should work).
                                found = False
                                for patch_path, patch_asn in patches:
                                    if spec_path[:len(patch_path)] == patch_path:
                                        # When found, it must match whatever the result was patched
                                        # to.
                                        self.assertEqual(new_asn, patch_asn)
                                        found = True
                                        break
                                # And such a patch must exist.
                                self.assertTrue(found)

if __name__ == '__main__':
    unittest.main()