aboutsummaryrefslogtreecommitdiff
path: root/test/functional/test_framework
diff options
context:
space:
mode:
authorJohn Newbery <john@johnnewbery.com>2017-12-08 10:50:24 -0500
committerJohn Newbery <john@johnnewbery.com>2017-12-11 09:16:44 -0500
commit5fc6e71d1994d58c25edebd8063555998752349a (patch)
tree8cee595bacfc27036769d25258b191e1e0fcdbbb /test/functional/test_framework
parentf60b4ad57912b78a96af08046a503f7905610a8c (diff)
[tests] Add network_thread_ utility functions.
Add network thread_start(), network_thread_running() and network_thread_join() utility functions in mininode.py and use network_thread_running() in network thread assertions.
Diffstat (limited to 'test/functional/test_framework')
-rwxr-xr-xtest/functional/test_framework/mininode.py27
1 files changed, 24 insertions, 3 deletions
diff --git a/test/functional/test_framework/mininode.py b/test/functional/test_framework/mininode.py
index 9e92a70da1..a00fc3d43c 100755
--- a/test/functional/test_framework/mininode.py
+++ b/test/functional/test_framework/mininode.py
@@ -18,7 +18,7 @@ import logging
import socket
import struct
import sys
-from threading import RLock, Thread
+import threading
from test_framework.messages import *
from test_framework.util import wait_until
@@ -397,9 +397,12 @@ mininode_socket_map = dict()
# and whenever adding anything to the send buffer (in send_message()). This
# lock should be acquired in the thread running the test logic to synchronize
# access to any data shared with the P2PInterface or P2PConnection.
-mininode_lock = RLock()
+mininode_lock = threading.RLock()
+
+class NetworkThread(threading.Thread):
+ def __init__(self):
+ super().__init__(name="NetworkThread")
-class NetworkThread(Thread):
def run(self):
while mininode_socket_map:
# We check for whether to disconnect outside of the asyncore
@@ -412,3 +415,21 @@ class NetworkThread(Thread):
[obj.handle_close() for obj in disconnected]
asyncore.loop(0.1, use_poll=True, map=mininode_socket_map, count=1)
logger.debug("Network thread closing")
+
+def network_thread_start():
+ """Start the network thread."""
+ NetworkThread().start()
+
+def network_thread_running():
+ """Return whether the network thread is running."""
+ return any([thread.name == "NetworkThread" for thread in threading.enumerate()])
+
+def network_thread_join(timeout=10):
+ """Wait timeout seconds for the network thread to terminate.
+
+ Throw if the network thread doesn't terminate in timeout seconds."""
+ network_threads = [thread for thread in threading.enumerate() if thread.name == "NetworkThread"]
+ assert len(network_threads) <= 1
+ for thread in network_threads:
+ thread.join(timeout)
+ assert not thread.is_alive()