aboutsummaryrefslogtreecommitdiff
path: root/test/functional/test_framework
diff options
context:
space:
mode:
authorAmiti Uttarwar <amiti@uttarwar.org>2020-01-30 18:52:25 -0800
committerAmiti Uttarwar <amiti@uttarwar.org>2020-04-23 14:42:25 -0700
commit6851502472d3625416f0e7796e9f2a0379d14d49 (patch)
tree9dd95ba8351e778f1bf03aa3eb4b3629a098db40 /test/functional/test_framework
parentdc1da48dc5e5526215561311c184a8cbc345ecdc (diff)
[refactor/test] Extract P2PTxInvStore into test framework
Diffstat (limited to 'test/functional/test_framework')
-rwxr-xr-xtest/functional/test_framework/mininode.py22
1 files changed, 21 insertions, 1 deletions
diff --git a/test/functional/test_framework/mininode.py b/test/functional/test_framework/mininode.py
index a9e669fea9..a76a25d128 100755
--- a/test/functional/test_framework/mininode.py
+++ b/test/functional/test_framework/mininode.py
@@ -12,7 +12,10 @@ found in the mini-node branch of http://github.com/jgarzik/pynode.
P2PConnection: A low-level connection object to a node's P2P interface
P2PInterface: A high-level interface object for communicating to a node over P2P
P2PDataStore: A p2p interface class that keeps a store of transactions and blocks
- and can respond correctly to getdata and getheaders messages"""
+ and can respond correctly to getdata and getheaders messages
+P2PTxInvStore: A p2p interface class that inherits from P2PDataStore, and keeps
+ a count of how many times each txid has been announced."""
+
import asyncio
from collections import defaultdict
from io import BytesIO
@@ -606,3 +609,20 @@ class P2PDataStore(P2PInterface):
# Check that none of the txs are now in the mempool
for tx in txs:
assert tx.hash not in raw_mempool, "{} tx found in mempool".format(tx.hash)
+
+class P2PTxInvStore(P2PInterface):
+ """A P2PInterface which stores a count of how many times each txid has been announced."""
+ def __init__(self):
+ super().__init__()
+ self.tx_invs_received = defaultdict(int)
+
+ def on_inv(self, message):
+ # Store how many times invs have been received for each tx.
+ for i in message.inv:
+ if i.type == MSG_TX:
+ # save txid
+ self.tx_invs_received[i.hash] += 1
+
+ def get_invs(self):
+ with mininode_lock:
+ return list(self.tx_invs_received.keys())