diff options
author | John Snow <jsnow@redhat.com> | 2021-09-15 12:29:37 -0400 |
---|---|---|
committer | John Snow <jsnow@redhat.com> | 2021-09-27 12:10:29 -0400 |
commit | 774c64a58d45da54a344947e7ed26814db04cc68 (patch) | |
tree | 6a4cf85c59ac9d6419ab3133813f9f5e0dad0fbd /python | |
parent | 50e533061f30e69d618643c9513b6797019023d1 (diff) |
python/aqmp: add AsyncProtocol.accept() method
It's a little messier than connect, because it wasn't designed to accept
*precisely one* connection. Such is life.
Signed-off-by: John Snow <jsnow@redhat.com>
Reviewed-by: Eric Blake <eblake@redhat.com>
Message-id: 20210915162955.333025-10-jsnow@redhat.com
Signed-off-by: John Snow <jsnow@redhat.com>
Diffstat (limited to 'python')
-rw-r--r-- | python/qemu/aqmp/protocol.py | 89 |
1 files changed, 85 insertions, 4 deletions
diff --git a/python/qemu/aqmp/protocol.py b/python/qemu/aqmp/protocol.py index 1dfd12895d..62c26ede5a 100644 --- a/python/qemu/aqmp/protocol.py +++ b/python/qemu/aqmp/protocol.py @@ -245,6 +245,24 @@ class AsyncProtocol(Generic[T]): @upper_half @require(Runstate.IDLE) + async def accept(self, address: Union[str, Tuple[str, int]], + ssl: Optional[SSLContext] = None) -> None: + """ + Accept a connection and begin processing message queues. + + If this call fails, `runstate` is guaranteed to be set back to `IDLE`. + + :param address: + Address to listen to; UNIX socket path or TCP address/port. + :param ssl: SSL context to use, if any. + + :raise StateError: When the `Runstate` is not `IDLE`. + :raise ConnectError: If a connection could not be accepted. + """ + await self._new_session(address, ssl, accept=True) + + @upper_half + @require(Runstate.IDLE) async def connect(self, address: Union[str, Tuple[str, int]], ssl: Optional[SSLContext] = None) -> None: """ @@ -308,7 +326,8 @@ class AsyncProtocol(Generic[T]): @upper_half async def _new_session(self, address: Union[str, Tuple[str, int]], - ssl: Optional[SSLContext] = None) -> None: + ssl: Optional[SSLContext] = None, + accept: bool = False) -> None: """ Establish a new connection and initialize the session. @@ -317,9 +336,10 @@ class AsyncProtocol(Generic[T]): to be set back to `IDLE`. :param address: - Address to connect to; + Address to connect to/listen on; UNIX socket path or TCP address/port. :param ssl: SSL context to use, if any. + :param accept: Accept a connection instead of connecting when `True`. :raise ConnectError: When a connection or session cannot be established. @@ -333,7 +353,7 @@ class AsyncProtocol(Generic[T]): try: phase = "connection" - await self._establish_connection(address, ssl) + await self._establish_connection(address, ssl, accept) phase = "session" await self._establish_session() @@ -367,6 +387,7 @@ class AsyncProtocol(Generic[T]): self, address: Union[str, Tuple[str, int]], ssl: Optional[SSLContext] = None, + accept: bool = False ) -> None: """ Establish a new connection. @@ -375,6 +396,7 @@ class AsyncProtocol(Generic[T]): Address to connect to/listen on; UNIX socket path or TCP address/port. :param ssl: SSL context to use, if any. + :param accept: Accept a connection instead of connecting when `True`. """ assert self.runstate == Runstate.IDLE self._set_state(Runstate.CONNECTING) @@ -384,7 +406,66 @@ class AsyncProtocol(Generic[T]): # otherwise yield. await asyncio.sleep(0) - await self._do_connect(address, ssl) + if accept: + await self._do_accept(address, ssl) + else: + await self._do_connect(address, ssl) + + @upper_half + async def _do_accept(self, address: Union[str, Tuple[str, int]], + ssl: Optional[SSLContext] = None) -> None: + """ + Acting as the transport server, accept a single connection. + + :param address: + Address to listen on; UNIX socket path or TCP address/port. + :param ssl: SSL context to use, if any. + + :raise OSError: For stream-related errors. + """ + self.logger.debug("Awaiting connection on %s ...", address) + connected = asyncio.Event() + server: Optional[asyncio.AbstractServer] = None + + async def _client_connected_cb(reader: asyncio.StreamReader, + writer: asyncio.StreamWriter) -> None: + """Used to accept a single incoming connection, see below.""" + nonlocal server + nonlocal connected + + # A connection has been accepted; stop listening for new ones. + assert server is not None + server.close() + await server.wait_closed() + server = None + + # Register this client as being connected + self._reader, self._writer = (reader, writer) + + # Signal back: We've accepted a client! + connected.set() + + if isinstance(address, tuple): + coro = asyncio.start_server( + _client_connected_cb, + host=address[0], + port=address[1], + ssl=ssl, + backlog=1, + ) + else: + coro = asyncio.start_unix_server( + _client_connected_cb, + path=address, + ssl=ssl, + backlog=1, + ) + + server = await coro # Starts listening + await connected.wait() # Waits for the callback to fire (and finish) + assert server is None + + self.logger.debug("Connection accepted.") @upper_half async def _do_connect(self, address: Union[str, Tuple[str, int]], |