aboutsummaryrefslogtreecommitdiff
path: root/test/functional/test_framework/mininode.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/functional/test_framework/mininode.py')
-rwxr-xr-xtest/functional/test_framework/mininode.py111
1 files changed, 78 insertions, 33 deletions
diff --git a/test/functional/test_framework/mininode.py b/test/functional/test_framework/mininode.py
index a9e669fea9..31cec66ee7 100755
--- a/test/functional/test_framework/mininode.py
+++ b/test/functional/test_framework/mininode.py
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
# Copyright (c) 2010 ArtForz -- public domain half-a-node
# Copyright (c) 2012 Jeff Garzik
-# Copyright (c) 2010-2019 The Bitcoin Core developers
+# Copyright (c) 2010-2020 The Bitcoin Core developers
# Distributed under the MIT software license, see the accompanying
# file COPYING or http://www.opensource.org/licenses/mit-license.php.
"""Bitcoin P2P network half-a-node.
@@ -12,7 +12,10 @@ found in the mini-node branch of http://github.com/jgarzik/pynode.
P2PConnection: A low-level connection object to a node's P2P interface
P2PInterface: A high-level interface object for communicating to a node over P2P
P2PDataStore: A p2p interface class that keeps a store of transactions and blocks
- and can respond correctly to getdata and getheaders messages"""
+ and can respond correctly to getdata and getheaders messages
+P2PTxInvStore: A p2p interface class that inherits from P2PDataStore, and keeps
+ a count of how many times each txid has been announced."""
+
import asyncio
from collections import defaultdict
from io import BytesIO
@@ -30,6 +33,9 @@ from test_framework.messages import (
msg_blocktxn,
msg_cmpctblock,
msg_feefilter,
+ msg_filteradd,
+ msg_filterclear,
+ msg_filterload,
msg_getaddr,
msg_getblocks,
msg_getblocktxn,
@@ -38,6 +44,7 @@ from test_framework.messages import (
msg_headers,
msg_inv,
msg_mempool,
+ msg_merkleblock,
msg_notfound,
msg_ping,
msg_pong,
@@ -62,6 +69,9 @@ MESSAGEMAP = {
b"blocktxn": msg_blocktxn,
b"cmpctblock": msg_cmpctblock,
b"feefilter": msg_feefilter,
+ b"filteradd": msg_filteradd,
+ b"filterclear": msg_filterclear,
+ b"filterload": msg_filterload,
b"getaddr": msg_getaddr,
b"getblocks": msg_getblocks,
b"getblocktxn": msg_getblocktxn,
@@ -70,6 +80,7 @@ MESSAGEMAP = {
b"headers": msg_headers,
b"inv": msg_inv,
b"mempool": msg_mempool,
+ b"merkleblock": msg_merkleblock,
b"notfound": msg_notfound,
b"ping": msg_ping,
b"pong": msg_pong,
@@ -109,8 +120,9 @@ class P2PConnection(asyncio.Protocol):
def is_connected(self):
return self._transport is not None
- def peer_connect(self, dstaddr, dstport, *, net):
+ def peer_connect(self, dstaddr, dstport, *, net, factor):
assert not self.is_connected
+ self.factor = factor
self.dstaddr = dstaddr
self.dstport = dstport
# The initial message to send after the connection was made:
@@ -172,7 +184,7 @@ class P2PConnection(asyncio.Protocol):
raise ValueError("magic bytes mismatch: {} != {}".format(repr(self.magic_bytes), repr(self.recvbuf)))
if len(self.recvbuf) < 4 + 12 + 4 + 4:
return
- command = self.recvbuf[4:4+12].split(b"\x00", 1)[0]
+ msgtype = self.recvbuf[4:4+12].split(b"\x00", 1)[0]
msglen = struct.unpack("<i", self.recvbuf[4+12:4+12+4])[0]
checksum = self.recvbuf[4+12+4:4+12+4+4]
if len(self.recvbuf) < 4 + 12 + 4 + 4 + msglen:
@@ -183,10 +195,10 @@ class P2PConnection(asyncio.Protocol):
if checksum != h[:4]:
raise ValueError("got bad checksum " + repr(self.recvbuf))
self.recvbuf = self.recvbuf[4+12+4+4+msglen:]
- if command not in MESSAGEMAP:
- raise ValueError("Received unknown command from %s:%d: '%s' %s" % (self.dstaddr, self.dstport, command, repr(msg)))
+ if msgtype not in MESSAGEMAP:
+ raise ValueError("Received unknown msgtype from %s:%d: '%s' %s" % (self.dstaddr, self.dstport, msgtype, repr(msg)))
f = BytesIO(msg)
- t = MESSAGEMAP[command]()
+ t = MESSAGEMAP[msgtype]()
t.deserialize(f)
self._log_message("receive", t)
self.on_message(t)
@@ -225,11 +237,11 @@ class P2PConnection(asyncio.Protocol):
def build_message(self, message):
"""Build a serialized P2P message"""
- command = message.command
+ msgtype = message.msgtype
data = message.serialize()
tmsg = self.magic_bytes
- tmsg += command
- tmsg += b"\x00" * (12 - len(command))
+ tmsg += msgtype
+ tmsg += b"\x00" * (12 - len(msgtype))
tmsg += struct.pack("<I", len(data))
th = sha256(data)
h = sha256(th)
@@ -296,10 +308,10 @@ class P2PInterface(P2PConnection):
and the most recent message of each type."""
with mininode_lock:
try:
- command = message.command.decode('ascii')
- self.message_count[command] += 1
- self.last_message[command] = message
- getattr(self, 'on_' + command)(message)
+ msgtype = message.msgtype.decode('ascii')
+ self.message_count[msgtype] += 1
+ self.last_message[msgtype] = message
+ getattr(self, 'on_' + msgtype)(message)
except:
print("ERROR delivering %s (%s)" % (repr(message), sys.exc_info()[0]))
raise
@@ -318,6 +330,9 @@ class P2PInterface(P2PConnection):
def on_blocktxn(self, message): pass
def on_cmpctblock(self, message): pass
def on_feefilter(self, message): pass
+ def on_filteradd(self, message): pass
+ def on_filterclear(self, message): pass
+ def on_filterload(self, message): pass
def on_getaddr(self, message): pass
def on_getblocks(self, message): pass
def on_getblocktxn(self, message): pass
@@ -325,9 +340,9 @@ class P2PInterface(P2PConnection):
def on_getheaders(self, message): pass
def on_headers(self, message): pass
def on_mempool(self, message): pass
+ def on_merkleblock(self, message): pass
def on_notfound(self, message): pass
def on_pong(self, message): pass
- def on_reject(self, message): pass
def on_sendcmpct(self, message): pass
def on_sendheaders(self, message): pass
def on_tx(self, message): pass
@@ -353,9 +368,12 @@ class P2PInterface(P2PConnection):
# Connection helper methods
+ def wait_until(self, test_function, timeout):
+ wait_until(test_function, timeout=timeout, lock=mininode_lock, factor=self.factor)
+
def wait_for_disconnect(self, timeout=60):
test_function = lambda: not self.is_connected
- wait_until(test_function, timeout=timeout, lock=mininode_lock)
+ self.wait_until(test_function, timeout=timeout)
# Message receiving helper methods
@@ -366,14 +384,14 @@ class P2PInterface(P2PConnection):
return False
return self.last_message['tx'].tx.rehash() == txid
- wait_until(test_function, timeout=timeout, lock=mininode_lock)
+ self.wait_until(test_function, timeout=timeout)
def wait_for_block(self, blockhash, timeout=60):
def test_function():
assert self.is_connected
return self.last_message.get("block") and self.last_message["block"].block.rehash() == blockhash
- wait_until(test_function, timeout=timeout, lock=mininode_lock)
+ self.wait_until(test_function, timeout=timeout)
def wait_for_header(self, blockhash, timeout=60):
def test_function():
@@ -381,23 +399,33 @@ class P2PInterface(P2PConnection):
last_headers = self.last_message.get('headers')
if not last_headers:
return False
- return last_headers.headers[0].rehash() == blockhash
+ return last_headers.headers[0].rehash() == int(blockhash, 16)
- wait_until(test_function, timeout=timeout, lock=mininode_lock)
+ self.wait_until(test_function, timeout=timeout)
- def wait_for_getdata(self, timeout=60):
+ def wait_for_merkleblock(self, blockhash, timeout=60):
+ def test_function():
+ assert self.is_connected
+ last_filtered_block = self.last_message.get('merkleblock')
+ if not last_filtered_block:
+ return False
+ return last_filtered_block.merkleblock.header.rehash() == int(blockhash, 16)
+
+ self.wait_until(test_function, timeout=timeout)
+
+ def wait_for_getdata(self, hash_list, timeout=60):
"""Waits for a getdata message.
- Receiving any getdata message will satisfy the predicate. the last_message["getdata"]
- value must be explicitly cleared before calling this method, or this will return
- immediately with success. TODO: change this method to take a hash value and only
- return true if the correct block/tx has been requested."""
+ The object hashes in the inventory vector must match the provided hash_list."""
def test_function():
assert self.is_connected
- return self.last_message.get("getdata")
+ last_data = self.last_message.get("getdata")
+ if not last_data:
+ return False
+ return [x.hash for x in last_data.inv] == hash_list
- wait_until(test_function, timeout=timeout, lock=mininode_lock)
+ self.wait_until(test_function, timeout=timeout)
def wait_for_getheaders(self, timeout=60):
"""Waits for a getheaders message.
@@ -411,7 +439,7 @@ class P2PInterface(P2PConnection):
assert self.is_connected
return self.last_message.get("getheaders")
- wait_until(test_function, timeout=timeout, lock=mininode_lock)
+ self.wait_until(test_function, timeout=timeout)
def wait_for_inv(self, expected_inv, timeout=60):
"""Waits for an INV message and checks that the first inv object in the message was as expected."""
@@ -424,13 +452,13 @@ class P2PInterface(P2PConnection):
self.last_message["inv"].inv[0].type == expected_inv[0].type and \
self.last_message["inv"].inv[0].hash == expected_inv[0].hash
- wait_until(test_function, timeout=timeout, lock=mininode_lock)
+ self.wait_until(test_function, timeout=timeout)
def wait_for_verack(self, timeout=60):
def test_function():
return self.message_count["verack"]
- wait_until(test_function, timeout=timeout, lock=mininode_lock)
+ self.wait_until(test_function, timeout=timeout)
# Message sending helper functions
@@ -446,7 +474,7 @@ class P2PInterface(P2PConnection):
assert self.is_connected
return self.last_message.get("pong") and self.last_message["pong"].nonce == self.ping_counter
- wait_until(test_function, timeout=timeout, lock=mininode_lock)
+ self.wait_until(test_function, timeout=timeout)
self.ping_counter += 1
@@ -562,7 +590,7 @@ class P2PDataStore(P2PInterface):
self.send_message(msg_block(block=b))
else:
self.send_message(msg_headers([CBlockHeader(block) for block in blocks]))
- wait_until(lambda: blocks[-1].sha256 in self.getdata_requests, timeout=timeout, lock=mininode_lock)
+ self.wait_until(lambda: blocks[-1].sha256 in self.getdata_requests, timeout=timeout)
if expect_disconnect:
self.wait_for_disconnect(timeout=timeout)
@@ -570,7 +598,7 @@ class P2PDataStore(P2PInterface):
self.sync_with_ping(timeout=timeout)
if success:
- wait_until(lambda: node.getbestblockhash() == blocks[-1].hash, timeout=timeout)
+ self.wait_until(lambda: node.getbestblockhash() == blocks[-1].hash, timeout=timeout)
else:
assert node.getbestblockhash() != blocks[-1].hash
@@ -606,3 +634,20 @@ class P2PDataStore(P2PInterface):
# Check that none of the txs are now in the mempool
for tx in txs:
assert tx.hash not in raw_mempool, "{} tx found in mempool".format(tx.hash)
+
+class P2PTxInvStore(P2PInterface):
+ """A P2PInterface which stores a count of how many times each txid has been announced."""
+ def __init__(self):
+ super().__init__()
+ self.tx_invs_received = defaultdict(int)
+
+ def on_inv(self, message):
+ # Store how many times invs have been received for each tx.
+ for i in message.inv:
+ if i.type == MSG_TX:
+ # save txid
+ self.tx_invs_received[i.hash] += 1
+
+ def get_invs(self):
+ with mininode_lock:
+ return list(self.tx_invs_received.keys())