diff options
Diffstat (limited to 'python/tests/protocol.py')
-rw-r--r-- | python/tests/protocol.py | 45 |
1 files changed, 29 insertions, 16 deletions
diff --git a/python/tests/protocol.py b/python/tests/protocol.py index 5cd7938be3..d6849ad306 100644 --- a/python/tests/protocol.py +++ b/python/tests/protocol.py @@ -41,12 +41,25 @@ class NullProtocol(AsyncProtocol[None]): self.trigger_input = asyncio.Event() await super()._establish_session() - async def _do_accept(self, address, ssl=None): - if not self.fake_session: - await super()._do_accept(address, ssl) + async def _do_start_server(self, address, ssl=None): + if self.fake_session: + self._accepted = asyncio.Event() + self._set_state(Runstate.CONNECTING) + await asyncio.sleep(0) + else: + await super()._do_start_server(address, ssl) + + async def _do_accept(self): + if self.fake_session: + self._accepted = None + else: + await super()._do_accept() async def _do_connect(self, address, ssl=None): - if not self.fake_session: + if self.fake_session: + self._set_state(Runstate.CONNECTING) + await asyncio.sleep(0) + else: await super()._do_connect(address, ssl) async def _do_recv(self) -> None: @@ -413,14 +426,14 @@ class Accept(Connect): assert family in ('INET', 'UNIX') if family == 'INET': - await self.proto.accept(('example.com', 1)) + await self.proto.start_server_and_accept(('example.com', 1)) elif family == 'UNIX': - await self.proto.accept('/dev/null') + await self.proto.start_server_and_accept('/dev/null') async def _hanging_connection(self): with TemporaryDirectory(suffix='.aqmp') as tmpdir: sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock") - await self.proto.accept(sock) + await self.proto.start_server_and_accept(sock) class FakeSession(TestBase): @@ -449,13 +462,13 @@ class FakeSession(TestBase): @TestBase.async_test async def testFakeAccept(self): """Test the full state lifecycle (via accept) with a no-op session.""" - await self.proto.accept('/not/a/real/path') + await self.proto.start_server_and_accept('/not/a/real/path') self.assertEqual(self.proto.runstate, Runstate.RUNNING) @TestBase.async_test async def testFakeRecv(self): """Test receiving a fake/null message.""" - await self.proto.accept('/not/a/real/path') + await self.proto.start_server_and_accept('/not/a/real/path') logname = self.proto.logger.name with self.assertLogs(logname, level='DEBUG') as context: @@ -471,7 +484,7 @@ class FakeSession(TestBase): @TestBase.async_test async def testFakeSend(self): """Test sending a fake/null message.""" - await self.proto.accept('/not/a/real/path') + await self.proto.start_server_and_accept('/not/a/real/path') logname = self.proto.logger.name with self.assertLogs(logname, level='DEBUG') as context: @@ -493,7 +506,7 @@ class FakeSession(TestBase): ): with self.assertRaises(StateError) as context: if accept: - await self.proto.accept('/not/a/real/path') + await self.proto.start_server_and_accept('/not/a/real/path') else: await self.proto.connect('/not/a/real/path') @@ -504,7 +517,7 @@ class FakeSession(TestBase): @TestBase.async_test async def testAcceptRequireRunning(self): """Test that accept() cannot be called when Runstate=RUNNING""" - await self.proto.accept('/not/a/real/path') + await self.proto.start_server_and_accept('/not/a/real/path') await self._prod_session_api( Runstate.RUNNING, @@ -515,7 +528,7 @@ class FakeSession(TestBase): @TestBase.async_test async def testConnectRequireRunning(self): """Test that connect() cannot be called when Runstate=RUNNING""" - await self.proto.accept('/not/a/real/path') + await self.proto.start_server_and_accept('/not/a/real/path') await self._prod_session_api( Runstate.RUNNING, @@ -526,7 +539,7 @@ class FakeSession(TestBase): @TestBase.async_test async def testAcceptRequireDisconnecting(self): """Test that accept() cannot be called when Runstate=DISCONNECTING""" - await self.proto.accept('/not/a/real/path') + await self.proto.start_server_and_accept('/not/a/real/path') # Cheat: force a disconnect. await self.proto.simulate_disconnect() @@ -541,7 +554,7 @@ class FakeSession(TestBase): @TestBase.async_test async def testConnectRequireDisconnecting(self): """Test that connect() cannot be called when Runstate=DISCONNECTING""" - await self.proto.accept('/not/a/real/path') + await self.proto.start_server_and_accept('/not/a/real/path') # Cheat: force a disconnect. await self.proto.simulate_disconnect() @@ -576,7 +589,7 @@ class SimpleSession(TestBase): async def testSmoke(self): with TemporaryDirectory(suffix='.aqmp') as tmpdir: sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock") - server_task = create_task(self.server.accept(sock)) + server_task = create_task(self.server.start_server_and_accept(sock)) # give the server a chance to start listening [...] await asyncio.sleep(0) |