diff options
Diffstat (limited to 'test/functional/test_framework/p2p.py')
-rwxr-xr-x | test/functional/test_framework/p2p.py | 681 |
1 files changed, 681 insertions, 0 deletions
diff --git a/test/functional/test_framework/p2p.py b/test/functional/test_framework/p2p.py new file mode 100755 index 0000000000..38c3c5551a --- /dev/null +++ b/test/functional/test_framework/p2p.py @@ -0,0 +1,681 @@ +#!/usr/bin/env python3 +# Copyright (c) 2010 ArtForz -- public domain half-a-node +# Copyright (c) 2012 Jeff Garzik +# Copyright (c) 2010-2020 The Bitcoin Core developers +# Distributed under the MIT software license, see the accompanying +# file COPYING or http://www.opensource.org/licenses/mit-license.php. +"""Bitcoin P2P network half-a-node. + +This python code was modified from ArtForz' public domain half-a-node, as +found in the mini-node branch of http://github.com/jgarzik/pynode. + +P2PConnection: A low-level connection object to a node's P2P interface +P2PInterface: A high-level interface object for communicating to a node over P2P +P2PDataStore: A p2p interface class that keeps a store of transactions and blocks + and can respond correctly to getdata and getheaders messages +P2PTxInvStore: A p2p interface class that inherits from P2PDataStore, and keeps + a count of how many times each txid has been announced.""" + +import asyncio +from collections import defaultdict +from io import BytesIO +import logging +import struct +import sys +import threading + +from test_framework.messages import ( + CBlockHeader, + MAX_HEADERS_RESULTS, + MIN_VERSION_SUPPORTED, + msg_addr, + msg_block, + MSG_BLOCK, + msg_blocktxn, + msg_cfcheckpt, + msg_cfheaders, + msg_cfilter, + msg_cmpctblock, + msg_feefilter, + msg_filteradd, + msg_filterclear, + msg_filterload, + msg_getaddr, + msg_getblocks, + msg_getblocktxn, + msg_getdata, + msg_getheaders, + msg_headers, + msg_inv, + msg_mempool, + msg_merkleblock, + msg_notfound, + msg_ping, + msg_pong, + msg_sendcmpct, + msg_sendheaders, + msg_tx, + MSG_TX, + MSG_TYPE_MASK, + msg_verack, + msg_version, + MSG_WTX, + msg_wtxidrelay, + NODE_NETWORK, + NODE_WITNESS, + sha256, +) +from test_framework.util import wait_until + +logger = logging.getLogger("TestFramework.p2p") + +MESSAGEMAP = { + b"addr": msg_addr, + b"block": msg_block, + b"blocktxn": msg_blocktxn, + b"cfcheckpt": msg_cfcheckpt, + b"cfheaders": msg_cfheaders, + b"cfilter": msg_cfilter, + b"cmpctblock": msg_cmpctblock, + b"feefilter": msg_feefilter, + b"filteradd": msg_filteradd, + b"filterclear": msg_filterclear, + b"filterload": msg_filterload, + b"getaddr": msg_getaddr, + b"getblocks": msg_getblocks, + b"getblocktxn": msg_getblocktxn, + b"getdata": msg_getdata, + b"getheaders": msg_getheaders, + b"headers": msg_headers, + b"inv": msg_inv, + b"mempool": msg_mempool, + b"merkleblock": msg_merkleblock, + b"notfound": msg_notfound, + b"ping": msg_ping, + b"pong": msg_pong, + b"sendcmpct": msg_sendcmpct, + b"sendheaders": msg_sendheaders, + b"tx": msg_tx, + b"verack": msg_verack, + b"version": msg_version, + b"wtxidrelay": msg_wtxidrelay, +} + +MAGIC_BYTES = { + "mainnet": b"\xf9\xbe\xb4\xd9", # mainnet + "testnet3": b"\x0b\x11\x09\x07", # testnet3 + "regtest": b"\xfa\xbf\xb5\xda", # regtest +} + + +class P2PConnection(asyncio.Protocol): + """A low-level connection object to a node's P2P interface. + + This class is responsible for: + + - opening and closing the TCP connection to the node + - reading bytes from and writing bytes to the socket + - deserializing and serializing the P2P message header + - logging messages as they are sent and received + + This class contains no logic for handing the P2P message payloads. It must be + sub-classed and the on_message() callback overridden.""" + + def __init__(self): + # The underlying transport of the connection. + # Should only call methods on this from the NetworkThread, c.f. call_soon_threadsafe + self._transport = None + + @property + def is_connected(self): + return self._transport is not None + + def peer_connect(self, dstaddr, dstport, *, net, timeout_factor): + assert not self.is_connected + self.timeout_factor = timeout_factor + self.dstaddr = dstaddr + self.dstport = dstport + # The initial message to send after the connection was made: + self.on_connection_send_msg = None + self.recvbuf = b"" + self.magic_bytes = MAGIC_BYTES[net] + logger.debug('Connecting to Bitcoin Node: %s:%d' % (self.dstaddr, self.dstport)) + + loop = NetworkThread.network_event_loop + conn_gen_unsafe = loop.create_connection(lambda: self, host=self.dstaddr, port=self.dstport) + conn_gen = lambda: loop.call_soon_threadsafe(loop.create_task, conn_gen_unsafe) + return conn_gen + + def peer_disconnect(self): + # Connection could have already been closed by other end. + NetworkThread.network_event_loop.call_soon_threadsafe(lambda: self._transport and self._transport.abort()) + + # Connection and disconnection methods + + def connection_made(self, transport): + """asyncio callback when a connection is opened.""" + assert not self._transport + logger.debug("Connected & Listening: %s:%d" % (self.dstaddr, self.dstport)) + self._transport = transport + if self.on_connection_send_msg: + self.send_message(self.on_connection_send_msg) + self.on_connection_send_msg = None # Never used again + self.on_open() + + def connection_lost(self, exc): + """asyncio callback when a connection is closed.""" + if exc: + logger.warning("Connection lost to {}:{} due to {}".format(self.dstaddr, self.dstport, exc)) + else: + logger.debug("Closed connection to: %s:%d" % (self.dstaddr, self.dstport)) + self._transport = None + self.recvbuf = b"" + self.on_close() + + # Socket read methods + + def data_received(self, t): + """asyncio callback when data is read from the socket.""" + if len(t) > 0: + self.recvbuf += t + self._on_data() + + def _on_data(self): + """Try to read P2P messages from the recv buffer. + + This method reads data from the buffer in a loop. It deserializes, + parses and verifies the P2P header, then passes the P2P payload to + the on_message callback for processing.""" + try: + while True: + if len(self.recvbuf) < 4: + return + if self.recvbuf[:4] != self.magic_bytes: + raise ValueError("magic bytes mismatch: {} != {}".format(repr(self.magic_bytes), repr(self.recvbuf))) + if len(self.recvbuf) < 4 + 12 + 4 + 4: + return + msgtype = self.recvbuf[4:4+12].split(b"\x00", 1)[0] + msglen = struct.unpack("<i", self.recvbuf[4+12:4+12+4])[0] + checksum = self.recvbuf[4+12+4:4+12+4+4] + if len(self.recvbuf) < 4 + 12 + 4 + 4 + msglen: + return + msg = self.recvbuf[4+12+4+4:4+12+4+4+msglen] + th = sha256(msg) + h = sha256(th) + if checksum != h[:4]: + raise ValueError("got bad checksum " + repr(self.recvbuf)) + self.recvbuf = self.recvbuf[4+12+4+4+msglen:] + if msgtype not in MESSAGEMAP: + raise ValueError("Received unknown msgtype from %s:%d: '%s' %s" % (self.dstaddr, self.dstport, msgtype, repr(msg))) + f = BytesIO(msg) + t = MESSAGEMAP[msgtype]() + t.deserialize(f) + self._log_message("receive", t) + self.on_message(t) + except Exception as e: + logger.exception('Error reading message:', repr(e)) + raise + + def on_message(self, message): + """Callback for processing a P2P payload. Must be overridden by derived class.""" + raise NotImplementedError + + # Socket write methods + + def send_message(self, message): + """Send a P2P message over the socket. + + This method takes a P2P payload, builds the P2P header and adds + the message to the send buffer to be sent over the socket.""" + tmsg = self.build_message(message) + self._log_message("send", message) + return self.send_raw_message(tmsg) + + def send_raw_message(self, raw_message_bytes): + if not self.is_connected: + raise IOError('Not connected') + + def maybe_write(): + if not self._transport: + return + if self._transport.is_closing(): + return + self._transport.write(raw_message_bytes) + NetworkThread.network_event_loop.call_soon_threadsafe(maybe_write) + + # Class utility methods + + def build_message(self, message): + """Build a serialized P2P message""" + msgtype = message.msgtype + data = message.serialize() + tmsg = self.magic_bytes + tmsg += msgtype + tmsg += b"\x00" * (12 - len(msgtype)) + tmsg += struct.pack("<I", len(data)) + th = sha256(data) + h = sha256(th) + tmsg += h[:4] + tmsg += data + return tmsg + + def _log_message(self, direction, msg): + """Logs a message being sent or received over the connection.""" + if direction == "send": + log_message = "Send message to " + elif direction == "receive": + log_message = "Received message from " + log_message += "%s:%d: %s" % (self.dstaddr, self.dstport, repr(msg)[:500]) + if len(log_message) > 500: + log_message += "... (msg truncated)" + logger.debug(log_message) + + +class P2PInterface(P2PConnection): + """A high-level P2P interface class for communicating with a Bitcoin node. + + This class provides high-level callbacks for processing P2P message + payloads, as well as convenience methods for interacting with the + node over P2P. + + Individual testcases should subclass this and override the on_* methods + if they want to alter message handling behaviour.""" + def __init__(self): + super().__init__() + + # Track number of messages of each type received. + # Should be read-only in a test. + self.message_count = defaultdict(int) + + # Track the most recent message of each type. + # To wait for a message to be received, pop that message from + # this and use wait_until. + self.last_message = {} + + # A count of the number of ping messages we've sent to the node + self.ping_counter = 1 + + # The network services received from the peer + self.nServices = 0 + + def peer_connect(self, *args, services=NODE_NETWORK|NODE_WITNESS, send_version=True, **kwargs): + create_conn = super().peer_connect(*args, **kwargs) + + if send_version: + # Send a version msg + vt = msg_version() + vt.nServices = services + vt.addrTo.ip = self.dstaddr + vt.addrTo.port = self.dstport + vt.addrFrom.ip = "0.0.0.0" + vt.addrFrom.port = 0 + self.on_connection_send_msg = vt # Will be sent soon after connection_made + + return create_conn + + # Message receiving methods + + def on_message(self, message): + """Receive message and dispatch message to appropriate callback. + + We keep a count of how many of each message type has been received + and the most recent message of each type.""" + with p2p_lock: + try: + msgtype = message.msgtype.decode('ascii') + self.message_count[msgtype] += 1 + self.last_message[msgtype] = message + getattr(self, 'on_' + msgtype)(message) + except: + print("ERROR delivering %s (%s)" % (repr(message), sys.exc_info()[0])) + raise + + # Callback methods. Can be overridden by subclasses in individual test + # cases to provide custom message handling behaviour. + + def on_open(self): + pass + + def on_close(self): + pass + + def on_addr(self, message): pass + def on_block(self, message): pass + def on_blocktxn(self, message): pass + def on_cfcheckpt(self, message): pass + def on_cfheaders(self, message): pass + def on_cfilter(self, message): pass + def on_cmpctblock(self, message): pass + def on_feefilter(self, message): pass + def on_filteradd(self, message): pass + def on_filterclear(self, message): pass + def on_filterload(self, message): pass + def on_getaddr(self, message): pass + def on_getblocks(self, message): pass + def on_getblocktxn(self, message): pass + def on_getdata(self, message): pass + def on_getheaders(self, message): pass + def on_headers(self, message): pass + def on_mempool(self, message): pass + def on_merkleblock(self, message): pass + def on_notfound(self, message): pass + def on_pong(self, message): pass + def on_sendcmpct(self, message): pass + def on_sendheaders(self, message): pass + def on_tx(self, message): pass + def on_wtxidrelay(self, message): pass + + def on_inv(self, message): + want = msg_getdata() + for i in message.inv: + if i.type != 0: + want.inv.append(i) + if len(want.inv): + self.send_message(want) + + def on_ping(self, message): + self.send_message(msg_pong(message.nonce)) + + def on_verack(self, message): + pass + + def on_version(self, message): + assert message.nVersion >= MIN_VERSION_SUPPORTED, "Version {} received. Test framework only supports versions greater than {}".format(message.nVersion, MIN_VERSION_SUPPORTED) + if message.nVersion >= 70016: + self.send_message(msg_wtxidrelay()) + self.send_message(msg_verack()) + self.nServices = message.nServices + + # Connection helper methods + + def wait_until(self, test_function_in, *, timeout=60, check_connected=True): + def test_function(): + if check_connected: + assert self.is_connected + return test_function_in() + + wait_until(test_function, timeout=timeout, lock=p2p_lock, timeout_factor=self.timeout_factor) + + def wait_for_disconnect(self, timeout=60): + test_function = lambda: not self.is_connected + self.wait_until(test_function, timeout=timeout, check_connected=False) + + # Message receiving helper methods + + def wait_for_tx(self, txid, timeout=60): + def test_function(): + if not self.last_message.get('tx'): + return False + return self.last_message['tx'].tx.rehash() == txid + + self.wait_until(test_function, timeout=timeout) + + def wait_for_block(self, blockhash, timeout=60): + def test_function(): + return self.last_message.get("block") and self.last_message["block"].block.rehash() == blockhash + + self.wait_until(test_function, timeout=timeout) + + def wait_for_header(self, blockhash, timeout=60): + def test_function(): + last_headers = self.last_message.get('headers') + if not last_headers: + return False + return last_headers.headers[0].rehash() == int(blockhash, 16) + + self.wait_until(test_function, timeout=timeout) + + def wait_for_merkleblock(self, blockhash, timeout=60): + def test_function(): + last_filtered_block = self.last_message.get('merkleblock') + if not last_filtered_block: + return False + return last_filtered_block.merkleblock.header.rehash() == int(blockhash, 16) + + self.wait_until(test_function, timeout=timeout) + + def wait_for_getdata(self, hash_list, timeout=60): + """Waits for a getdata message. + + The object hashes in the inventory vector must match the provided hash_list.""" + def test_function(): + last_data = self.last_message.get("getdata") + if not last_data: + return False + return [x.hash for x in last_data.inv] == hash_list + + self.wait_until(test_function, timeout=timeout) + + def wait_for_getheaders(self, timeout=60): + """Waits for a getheaders message. + + Receiving any getheaders message will satisfy the predicate. the last_message["getheaders"] + value must be explicitly cleared before calling this method, or this will return + immediately with success. TODO: change this method to take a hash value and only + return true if the correct block header has been requested.""" + def test_function(): + return self.last_message.get("getheaders") + + self.wait_until(test_function, timeout=timeout) + + def wait_for_inv(self, expected_inv, timeout=60): + """Waits for an INV message and checks that the first inv object in the message was as expected.""" + if len(expected_inv) > 1: + raise NotImplementedError("wait_for_inv() will only verify the first inv object") + + def test_function(): + return self.last_message.get("inv") and \ + self.last_message["inv"].inv[0].type == expected_inv[0].type and \ + self.last_message["inv"].inv[0].hash == expected_inv[0].hash + + self.wait_until(test_function, timeout=timeout) + + def wait_for_verack(self, timeout=60): + def test_function(): + return "verack" in self.last_message + + self.wait_until(test_function, timeout=timeout) + + # Message sending helper functions + + def send_and_ping(self, message, timeout=60): + self.send_message(message) + self.sync_with_ping(timeout=timeout) + + # Sync up with the node + def sync_with_ping(self, timeout=60): + self.send_message(msg_ping(nonce=self.ping_counter)) + + def test_function(): + return self.last_message.get("pong") and self.last_message["pong"].nonce == self.ping_counter + + self.wait_until(test_function, timeout=timeout) + self.ping_counter += 1 + + +# One lock for synchronizing all data access between the network event loop (see +# NetworkThread below) and the thread running the test logic. For simplicity, +# P2PConnection acquires this lock whenever delivering a message to a P2PInterface. +# This lock should be acquired in the thread running the test logic to synchronize +# access to any data shared with the P2PInterface or P2PConnection. +p2p_lock = threading.Lock() + + +class NetworkThread(threading.Thread): + network_event_loop = None + + def __init__(self): + super().__init__(name="NetworkThread") + # There is only one event loop and no more than one thread must be created + assert not self.network_event_loop + + NetworkThread.network_event_loop = asyncio.new_event_loop() + + def run(self): + """Start the network thread.""" + self.network_event_loop.run_forever() + + def close(self, timeout=10): + """Close the connections and network event loop.""" + self.network_event_loop.call_soon_threadsafe(self.network_event_loop.stop) + wait_until(lambda: not self.network_event_loop.is_running(), timeout=timeout) + self.network_event_loop.close() + self.join(timeout) + # Safe to remove event loop. + NetworkThread.network_event_loop = None + +class P2PDataStore(P2PInterface): + """A P2P data store class. + + Keeps a block and transaction store and responds correctly to getdata and getheaders requests.""" + + def __init__(self): + super().__init__() + # store of blocks. key is block hash, value is a CBlock object + self.block_store = {} + self.last_block_hash = '' + # store of txs. key is txid, value is a CTransaction object + self.tx_store = {} + self.getdata_requests = [] + + def on_getdata(self, message): + """Check for the tx/block in our stores and if found, reply with an inv message.""" + for inv in message.inv: + self.getdata_requests.append(inv.hash) + if (inv.type & MSG_TYPE_MASK) == MSG_TX and inv.hash in self.tx_store.keys(): + self.send_message(msg_tx(self.tx_store[inv.hash])) + elif (inv.type & MSG_TYPE_MASK) == MSG_BLOCK and inv.hash in self.block_store.keys(): + self.send_message(msg_block(self.block_store[inv.hash])) + else: + logger.debug('getdata message type {} received.'.format(hex(inv.type))) + + def on_getheaders(self, message): + """Search back through our block store for the locator, and reply with a headers message if found.""" + + locator, hash_stop = message.locator, message.hashstop + + # Assume that the most recent block added is the tip + if not self.block_store: + return + + headers_list = [self.block_store[self.last_block_hash]] + while headers_list[-1].sha256 not in locator.vHave: + # Walk back through the block store, adding headers to headers_list + # as we go. + prev_block_hash = headers_list[-1].hashPrevBlock + if prev_block_hash in self.block_store: + prev_block_header = CBlockHeader(self.block_store[prev_block_hash]) + headers_list.append(prev_block_header) + if prev_block_header.sha256 == hash_stop: + # if this is the hashstop header, stop here + break + else: + logger.debug('block hash {} not found in block store'.format(hex(prev_block_hash))) + break + + # Truncate the list if there are too many headers + headers_list = headers_list[:-MAX_HEADERS_RESULTS - 1:-1] + response = msg_headers(headers_list) + + if response is not None: + self.send_message(response) + + def send_blocks_and_test(self, blocks, node, *, success=True, force_send=False, reject_reason=None, expect_disconnect=False, timeout=60): + """Send blocks to test node and test whether the tip advances. + + - add all blocks to our block_store + - send a headers message for the final block + - the on_getheaders handler will ensure that any getheaders are responded to + - if force_send is False: wait for getdata for each of the blocks. The on_getdata handler will + ensure that any getdata messages are responded to. Otherwise send the full block unsolicited. + - if success is True: assert that the node's tip advances to the most recent block + - if success is False: assert that the node's tip doesn't advance + - if reject_reason is set: assert that the correct reject message is logged""" + + with p2p_lock: + for block in blocks: + self.block_store[block.sha256] = block + self.last_block_hash = block.sha256 + + reject_reason = [reject_reason] if reject_reason else [] + with node.assert_debug_log(expected_msgs=reject_reason): + if force_send: + for b in blocks: + self.send_message(msg_block(block=b)) + else: + self.send_message(msg_headers([CBlockHeader(block) for block in blocks])) + self.wait_until( + lambda: blocks[-1].sha256 in self.getdata_requests, + timeout=timeout, + check_connected=success, + ) + + if expect_disconnect: + self.wait_for_disconnect(timeout=timeout) + else: + self.sync_with_ping(timeout=timeout) + + if success: + self.wait_until(lambda: node.getbestblockhash() == blocks[-1].hash, timeout=timeout) + else: + assert node.getbestblockhash() != blocks[-1].hash + + def send_txs_and_test(self, txs, node, *, success=True, expect_disconnect=False, reject_reason=None): + """Send txs to test node and test whether they're accepted to the mempool. + + - add all txs to our tx_store + - send tx messages for all txs + - if success is True/False: assert that the txs are/are not accepted to the mempool + - if expect_disconnect is True: Skip the sync with ping + - if reject_reason is set: assert that the correct reject message is logged.""" + + with p2p_lock: + for tx in txs: + self.tx_store[tx.sha256] = tx + + reject_reason = [reject_reason] if reject_reason else [] + with node.assert_debug_log(expected_msgs=reject_reason): + for tx in txs: + self.send_message(msg_tx(tx)) + + if expect_disconnect: + self.wait_for_disconnect() + else: + self.sync_with_ping() + + raw_mempool = node.getrawmempool() + if success: + # Check that all txs are now in the mempool + for tx in txs: + assert tx.hash in raw_mempool, "{} not found in mempool".format(tx.hash) + else: + # Check that none of the txs are now in the mempool + for tx in txs: + assert tx.hash not in raw_mempool, "{} tx found in mempool".format(tx.hash) + +class P2PTxInvStore(P2PInterface): + """A P2PInterface which stores a count of how many times each txid has been announced.""" + def __init__(self): + super().__init__() + self.tx_invs_received = defaultdict(int) + + def on_inv(self, message): + super().on_inv(message) # Send getdata in response. + # Store how many times invs have been received for each tx. + for i in message.inv: + if (i.type == MSG_TX) or (i.type == MSG_WTX): + # save txid + self.tx_invs_received[i.hash] += 1 + + def get_invs(self): + with p2p_lock: + return list(self.tx_invs_received.keys()) + + def wait_for_broadcast(self, txns, timeout=60): + """Waits for the txns (list of txids) to complete initial broadcast. + The mempool should mark unbroadcast=False for these transactions. + """ + # Wait until invs have been received (and getdatas sent) for each txid. + self.wait_until(lambda: set(self.tx_invs_received.keys()) == set([int(tx, 16) for tx in txns]), timeout=timeout) + # Flush messages and wait for the getdatas to be processed + self.sync_with_ping() |