diff options
Diffstat (limited to 'test/functional/test_framework/mininode.py')
-rwxr-xr-x | test/functional/test_framework/mininode.py | 34 |
1 files changed, 31 insertions, 3 deletions
diff --git a/test/functional/test_framework/mininode.py b/test/functional/test_framework/mininode.py index 9e92a70da1..724d418099 100755 --- a/test/functional/test_framework/mininode.py +++ b/test/functional/test_framework/mininode.py @@ -18,7 +18,7 @@ import logging import socket import struct import sys -from threading import RLock, Thread +import threading from test_framework.messages import * from test_framework.util import wait_until @@ -69,6 +69,10 @@ class P2PConnection(asyncore.dispatcher): sub-classed and the on_message() callback overridden.""" def __init__(self): + # All P2PConnections must be created before starting the NetworkThread. + # assert that the network thread is not running. + assert not network_thread_running() + super().__init__(map=mininode_socket_map) def peer_connect(self, dstaddr, dstport, net="regtest"): @@ -397,9 +401,12 @@ mininode_socket_map = dict() # and whenever adding anything to the send buffer (in send_message()). This # lock should be acquired in the thread running the test logic to synchronize # access to any data shared with the P2PInterface or P2PConnection. -mininode_lock = RLock() +mininode_lock = threading.RLock() + +class NetworkThread(threading.Thread): + def __init__(self): + super().__init__(name="NetworkThread") -class NetworkThread(Thread): def run(self): while mininode_socket_map: # We check for whether to disconnect outside of the asyncore @@ -412,3 +419,24 @@ class NetworkThread(Thread): [obj.handle_close() for obj in disconnected] asyncore.loop(0.1, use_poll=True, map=mininode_socket_map, count=1) logger.debug("Network thread closing") + +def network_thread_start(): + """Start the network thread.""" + # Only one network thread may run at a time + assert not network_thread_running() + + NetworkThread().start() + +def network_thread_running(): + """Return whether the network thread is running.""" + return any([thread.name == "NetworkThread" for thread in threading.enumerate()]) + +def network_thread_join(timeout=10): + """Wait timeout seconds for the network thread to terminate. + + Throw if the network thread doesn't terminate in timeout seconds.""" + network_threads = [thread for thread in threading.enumerate() if thread.name == "NetworkThread"] + assert len(network_threads) <= 1 + for thread in network_threads: + thread.join(timeout) + assert not thread.is_alive() |