diff options
Diffstat (limited to 'test/functional/test_framework/mininode.py')
-rwxr-xr-x | test/functional/test_framework/mininode.py | 111 |
1 files changed, 78 insertions, 33 deletions
diff --git a/test/functional/test_framework/mininode.py b/test/functional/test_framework/mininode.py index a9e669fea9..31cec66ee7 100755 --- a/test/functional/test_framework/mininode.py +++ b/test/functional/test_framework/mininode.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # Copyright (c) 2010 ArtForz -- public domain half-a-node # Copyright (c) 2012 Jeff Garzik -# Copyright (c) 2010-2019 The Bitcoin Core developers +# Copyright (c) 2010-2020 The Bitcoin Core developers # Distributed under the MIT software license, see the accompanying # file COPYING or http://www.opensource.org/licenses/mit-license.php. """Bitcoin P2P network half-a-node. @@ -12,7 +12,10 @@ found in the mini-node branch of http://github.com/jgarzik/pynode. P2PConnection: A low-level connection object to a node's P2P interface P2PInterface: A high-level interface object for communicating to a node over P2P P2PDataStore: A p2p interface class that keeps a store of transactions and blocks - and can respond correctly to getdata and getheaders messages""" + and can respond correctly to getdata and getheaders messages +P2PTxInvStore: A p2p interface class that inherits from P2PDataStore, and keeps + a count of how many times each txid has been announced.""" + import asyncio from collections import defaultdict from io import BytesIO @@ -30,6 +33,9 @@ from test_framework.messages import ( msg_blocktxn, msg_cmpctblock, msg_feefilter, + msg_filteradd, + msg_filterclear, + msg_filterload, msg_getaddr, msg_getblocks, msg_getblocktxn, @@ -38,6 +44,7 @@ from test_framework.messages import ( msg_headers, msg_inv, msg_mempool, + msg_merkleblock, msg_notfound, msg_ping, msg_pong, @@ -62,6 +69,9 @@ MESSAGEMAP = { b"blocktxn": msg_blocktxn, b"cmpctblock": msg_cmpctblock, b"feefilter": msg_feefilter, + b"filteradd": msg_filteradd, + b"filterclear": msg_filterclear, + b"filterload": msg_filterload, b"getaddr": msg_getaddr, b"getblocks": msg_getblocks, b"getblocktxn": msg_getblocktxn, @@ -70,6 +80,7 @@ MESSAGEMAP = { b"headers": msg_headers, b"inv": msg_inv, b"mempool": msg_mempool, + b"merkleblock": msg_merkleblock, b"notfound": msg_notfound, b"ping": msg_ping, b"pong": msg_pong, @@ -109,8 +120,9 @@ class P2PConnection(asyncio.Protocol): def is_connected(self): return self._transport is not None - def peer_connect(self, dstaddr, dstport, *, net): + def peer_connect(self, dstaddr, dstport, *, net, factor): assert not self.is_connected + self.factor = factor self.dstaddr = dstaddr self.dstport = dstport # The initial message to send after the connection was made: @@ -172,7 +184,7 @@ class P2PConnection(asyncio.Protocol): raise ValueError("magic bytes mismatch: {} != {}".format(repr(self.magic_bytes), repr(self.recvbuf))) if len(self.recvbuf) < 4 + 12 + 4 + 4: return - command = self.recvbuf[4:4+12].split(b"\x00", 1)[0] + msgtype = self.recvbuf[4:4+12].split(b"\x00", 1)[0] msglen = struct.unpack("<i", self.recvbuf[4+12:4+12+4])[0] checksum = self.recvbuf[4+12+4:4+12+4+4] if len(self.recvbuf) < 4 + 12 + 4 + 4 + msglen: @@ -183,10 +195,10 @@ class P2PConnection(asyncio.Protocol): if checksum != h[:4]: raise ValueError("got bad checksum " + repr(self.recvbuf)) self.recvbuf = self.recvbuf[4+12+4+4+msglen:] - if command not in MESSAGEMAP: - raise ValueError("Received unknown command from %s:%d: '%s' %s" % (self.dstaddr, self.dstport, command, repr(msg))) + if msgtype not in MESSAGEMAP: + raise ValueError("Received unknown msgtype from %s:%d: '%s' %s" % (self.dstaddr, self.dstport, msgtype, repr(msg))) f = BytesIO(msg) - t = MESSAGEMAP[command]() + t = MESSAGEMAP[msgtype]() t.deserialize(f) self._log_message("receive", t) self.on_message(t) @@ -225,11 +237,11 @@ class P2PConnection(asyncio.Protocol): def build_message(self, message): """Build a serialized P2P message""" - command = message.command + msgtype = message.msgtype data = message.serialize() tmsg = self.magic_bytes - tmsg += command - tmsg += b"\x00" * (12 - len(command)) + tmsg += msgtype + tmsg += b"\x00" * (12 - len(msgtype)) tmsg += struct.pack("<I", len(data)) th = sha256(data) h = sha256(th) @@ -296,10 +308,10 @@ class P2PInterface(P2PConnection): and the most recent message of each type.""" with mininode_lock: try: - command = message.command.decode('ascii') - self.message_count[command] += 1 - self.last_message[command] = message - getattr(self, 'on_' + command)(message) + msgtype = message.msgtype.decode('ascii') + self.message_count[msgtype] += 1 + self.last_message[msgtype] = message + getattr(self, 'on_' + msgtype)(message) except: print("ERROR delivering %s (%s)" % (repr(message), sys.exc_info()[0])) raise @@ -318,6 +330,9 @@ class P2PInterface(P2PConnection): def on_blocktxn(self, message): pass def on_cmpctblock(self, message): pass def on_feefilter(self, message): pass + def on_filteradd(self, message): pass + def on_filterclear(self, message): pass + def on_filterload(self, message): pass def on_getaddr(self, message): pass def on_getblocks(self, message): pass def on_getblocktxn(self, message): pass @@ -325,9 +340,9 @@ class P2PInterface(P2PConnection): def on_getheaders(self, message): pass def on_headers(self, message): pass def on_mempool(self, message): pass + def on_merkleblock(self, message): pass def on_notfound(self, message): pass def on_pong(self, message): pass - def on_reject(self, message): pass def on_sendcmpct(self, message): pass def on_sendheaders(self, message): pass def on_tx(self, message): pass @@ -353,9 +368,12 @@ class P2PInterface(P2PConnection): # Connection helper methods + def wait_until(self, test_function, timeout): + wait_until(test_function, timeout=timeout, lock=mininode_lock, factor=self.factor) + def wait_for_disconnect(self, timeout=60): test_function = lambda: not self.is_connected - wait_until(test_function, timeout=timeout, lock=mininode_lock) + self.wait_until(test_function, timeout=timeout) # Message receiving helper methods @@ -366,14 +384,14 @@ class P2PInterface(P2PConnection): return False return self.last_message['tx'].tx.rehash() == txid - wait_until(test_function, timeout=timeout, lock=mininode_lock) + self.wait_until(test_function, timeout=timeout) def wait_for_block(self, blockhash, timeout=60): def test_function(): assert self.is_connected return self.last_message.get("block") and self.last_message["block"].block.rehash() == blockhash - wait_until(test_function, timeout=timeout, lock=mininode_lock) + self.wait_until(test_function, timeout=timeout) def wait_for_header(self, blockhash, timeout=60): def test_function(): @@ -381,23 +399,33 @@ class P2PInterface(P2PConnection): last_headers = self.last_message.get('headers') if not last_headers: return False - return last_headers.headers[0].rehash() == blockhash + return last_headers.headers[0].rehash() == int(blockhash, 16) - wait_until(test_function, timeout=timeout, lock=mininode_lock) + self.wait_until(test_function, timeout=timeout) - def wait_for_getdata(self, timeout=60): + def wait_for_merkleblock(self, blockhash, timeout=60): + def test_function(): + assert self.is_connected + last_filtered_block = self.last_message.get('merkleblock') + if not last_filtered_block: + return False + return last_filtered_block.merkleblock.header.rehash() == int(blockhash, 16) + + self.wait_until(test_function, timeout=timeout) + + def wait_for_getdata(self, hash_list, timeout=60): """Waits for a getdata message. - Receiving any getdata message will satisfy the predicate. the last_message["getdata"] - value must be explicitly cleared before calling this method, or this will return - immediately with success. TODO: change this method to take a hash value and only - return true if the correct block/tx has been requested.""" + The object hashes in the inventory vector must match the provided hash_list.""" def test_function(): assert self.is_connected - return self.last_message.get("getdata") + last_data = self.last_message.get("getdata") + if not last_data: + return False + return [x.hash for x in last_data.inv] == hash_list - wait_until(test_function, timeout=timeout, lock=mininode_lock) + self.wait_until(test_function, timeout=timeout) def wait_for_getheaders(self, timeout=60): """Waits for a getheaders message. @@ -411,7 +439,7 @@ class P2PInterface(P2PConnection): assert self.is_connected return self.last_message.get("getheaders") - wait_until(test_function, timeout=timeout, lock=mininode_lock) + self.wait_until(test_function, timeout=timeout) def wait_for_inv(self, expected_inv, timeout=60): """Waits for an INV message and checks that the first inv object in the message was as expected.""" @@ -424,13 +452,13 @@ class P2PInterface(P2PConnection): self.last_message["inv"].inv[0].type == expected_inv[0].type and \ self.last_message["inv"].inv[0].hash == expected_inv[0].hash - wait_until(test_function, timeout=timeout, lock=mininode_lock) + self.wait_until(test_function, timeout=timeout) def wait_for_verack(self, timeout=60): def test_function(): return self.message_count["verack"] - wait_until(test_function, timeout=timeout, lock=mininode_lock) + self.wait_until(test_function, timeout=timeout) # Message sending helper functions @@ -446,7 +474,7 @@ class P2PInterface(P2PConnection): assert self.is_connected return self.last_message.get("pong") and self.last_message["pong"].nonce == self.ping_counter - wait_until(test_function, timeout=timeout, lock=mininode_lock) + self.wait_until(test_function, timeout=timeout) self.ping_counter += 1 @@ -562,7 +590,7 @@ class P2PDataStore(P2PInterface): self.send_message(msg_block(block=b)) else: self.send_message(msg_headers([CBlockHeader(block) for block in blocks])) - wait_until(lambda: blocks[-1].sha256 in self.getdata_requests, timeout=timeout, lock=mininode_lock) + self.wait_until(lambda: blocks[-1].sha256 in self.getdata_requests, timeout=timeout) if expect_disconnect: self.wait_for_disconnect(timeout=timeout) @@ -570,7 +598,7 @@ class P2PDataStore(P2PInterface): self.sync_with_ping(timeout=timeout) if success: - wait_until(lambda: node.getbestblockhash() == blocks[-1].hash, timeout=timeout) + self.wait_until(lambda: node.getbestblockhash() == blocks[-1].hash, timeout=timeout) else: assert node.getbestblockhash() != blocks[-1].hash @@ -606,3 +634,20 @@ class P2PDataStore(P2PInterface): # Check that none of the txs are now in the mempool for tx in txs: assert tx.hash not in raw_mempool, "{} tx found in mempool".format(tx.hash) + +class P2PTxInvStore(P2PInterface): + """A P2PInterface which stores a count of how many times each txid has been announced.""" + def __init__(self): + super().__init__() + self.tx_invs_received = defaultdict(int) + + def on_inv(self, message): + # Store how many times invs have been received for each tx. + for i in message.inv: + if i.type == MSG_TX: + # save txid + self.tx_invs_received[i.hash] += 1 + + def get_invs(self): + with mininode_lock: + return list(self.tx_invs_received.keys()) |