aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/qemu/aqmp/legacy.py7
-rw-r--r--python/qemu/aqmp/protocol.py381
-rw-r--r--python/tests/protocol.py45
-rwxr-xr-xscripts/qmp/qmp-shell-wrap2
4 files changed, 268 insertions, 167 deletions
diff --git a/python/qemu/aqmp/legacy.py b/python/qemu/aqmp/legacy.py
index 6baa5f3409..46026e9fdc 100644
--- a/python/qemu/aqmp/legacy.py
+++ b/python/qemu/aqmp/legacy.py
@@ -57,7 +57,7 @@ class QEMUMonitorProtocol(qemu.qmp.QEMUMonitorProtocol):
self._timeout: Optional[float] = None
if server:
- self._aqmp._bind_hack(address) # pylint: disable=protected-access
+ self._sync(self._aqmp.start_server(self._address))
_T = TypeVar('_T')
@@ -90,10 +90,7 @@ class QEMUMonitorProtocol(qemu.qmp.QEMUMonitorProtocol):
self._aqmp.await_greeting = True
self._aqmp.negotiate = True
- self._sync(
- self._aqmp.accept(self._address),
- timeout
- )
+ self._sync(self._aqmp.accept(), timeout)
ret = self._get_greeting()
assert ret is not None
diff --git a/python/qemu/aqmp/protocol.py b/python/qemu/aqmp/protocol.py
index 33358f5cd7..36fae57f27 100644
--- a/python/qemu/aqmp/protocol.py
+++ b/python/qemu/aqmp/protocol.py
@@ -10,12 +10,14 @@ In this package, it is used as the implementation for the `QMPClient`
class.
"""
+# It's all the docstrings ... ! It's long for a good reason ^_^;
+# pylint: disable=too-many-lines
+
import asyncio
from asyncio import StreamReader, StreamWriter
from enum import Enum
from functools import wraps
import logging
-import socket
from ssl import SSLContext
from typing import (
Any,
@@ -239,8 +241,9 @@ class AsyncProtocol(Generic[T]):
self._runstate = Runstate.IDLE
self._runstate_changed: Optional[asyncio.Event] = None
- # Workaround for bind()
- self._sock: Optional[socket.socket] = None
+ # Server state for start_server() and _incoming()
+ self._server: Optional[asyncio.AbstractServer] = None
+ self._accepted: Optional[asyncio.Event] = None
def __repr__(self) -> str:
cls_name = type(self).__name__
@@ -265,21 +268,90 @@ class AsyncProtocol(Generic[T]):
@upper_half
@require(Runstate.IDLE)
- async def accept(self, address: SocketAddrT,
- ssl: Optional[SSLContext] = None) -> None:
+ async def start_server_and_accept(
+ self, address: SocketAddrT,
+ ssl: Optional[SSLContext] = None
+ ) -> None:
"""
Accept a connection and begin processing message queues.
If this call fails, `runstate` is guaranteed to be set back to `IDLE`.
+ This method is precisely equivalent to calling `start_server()`
+ followed by `accept()`.
+
+ :param address:
+ Address to listen on; UNIX socket path or TCP address/port.
+ :param ssl: SSL context to use, if any.
+
+ :raise StateError: When the `Runstate` is not `IDLE`.
+ :raise ConnectError:
+ When a connection or session cannot be established.
+
+ This exception will wrap a more concrete one. In most cases,
+ the wrapped exception will be `OSError` or `EOFError`. If a
+ protocol-level failure occurs while establishing a new
+ session, the wrapped error may also be an `QMPError`.
+ """
+ await self.start_server(address, ssl)
+ await self.accept()
+ assert self.runstate == Runstate.RUNNING
+
+ @upper_half
+ @require(Runstate.IDLE)
+ async def start_server(self, address: SocketAddrT,
+ ssl: Optional[SSLContext] = None) -> None:
+ """
+ Start listening for an incoming connection, but do not wait for a peer.
+
+ This method starts listening for an incoming connection, but
+ does not block waiting for a peer. This call will return
+ immediately after binding and listening on a socket. A later
+ call to `accept()` must be made in order to finalize the
+ incoming connection.
:param address:
- Address to listen to; UNIX socket path or TCP address/port.
+ Address to listen on; UNIX socket path or TCP address/port.
:param ssl: SSL context to use, if any.
:raise StateError: When the `Runstate` is not `IDLE`.
- :raise ConnectError: If a connection could not be accepted.
+ :raise ConnectError:
+ When the server could not start listening on this address.
+
+ This exception will wrap a more concrete one. In most cases,
+ the wrapped exception will be `OSError`.
+ """
+ await self._session_guard(
+ self._do_start_server(address, ssl),
+ 'Failed to establish connection')
+ assert self.runstate == Runstate.CONNECTING
+
+ @upper_half
+ @require(Runstate.CONNECTING)
+ async def accept(self) -> None:
+ """
+ Accept an incoming connection and begin processing message queues.
+
+ If this call fails, `runstate` is guaranteed to be set back to `IDLE`.
+
+ :raise StateError: When the `Runstate` is not `CONNECTING`.
+ :raise QMPError: When `start_server()` was not called yet.
+ :raise ConnectError:
+ When a connection or session cannot be established.
+
+ This exception will wrap a more concrete one. In most cases,
+ the wrapped exception will be `OSError` or `EOFError`. If a
+ protocol-level failure occurs while establishing a new
+ session, the wrapped error may also be an `QMPError`.
"""
- await self._new_session(address, ssl, accept=True)
+ if self._accepted is None:
+ raise QMPError("Cannot call accept() before start_server().")
+ await self._session_guard(
+ self._do_accept(),
+ 'Failed to establish connection')
+ await self._session_guard(
+ self._establish_session(),
+ 'Failed to establish session')
+ assert self.runstate == Runstate.RUNNING
@upper_half
@require(Runstate.IDLE)
@@ -295,9 +367,21 @@ class AsyncProtocol(Generic[T]):
:param ssl: SSL context to use, if any.
:raise StateError: When the `Runstate` is not `IDLE`.
- :raise ConnectError: If a connection cannot be made to the server.
+ :raise ConnectError:
+ When a connection or session cannot be established.
+
+ This exception will wrap a more concrete one. In most cases,
+ the wrapped exception will be `OSError` or `EOFError`. If a
+ protocol-level failure occurs while establishing a new
+ session, the wrapped error may also be an `QMPError`.
"""
- await self._new_session(address, ssl)
+ await self._session_guard(
+ self._do_connect(address, ssl),
+ 'Failed to establish connection')
+ await self._session_guard(
+ self._establish_session(),
+ 'Failed to establish session')
+ assert self.runstate == Runstate.RUNNING
@upper_half
async def disconnect(self) -> None:
@@ -317,153 +401,146 @@ class AsyncProtocol(Generic[T]):
# Section: Session machinery
# --------------------------
- @property
- def _runstate_event(self) -> asyncio.Event:
- # asyncio.Event() objects should not be created prior to entrance into
- # an event loop, so we can ensure we create it in the correct context.
- # Create it on-demand *only* at the behest of an 'async def' method.
- if not self._runstate_changed:
- self._runstate_changed = asyncio.Event()
- return self._runstate_changed
-
- @upper_half
- @bottom_half
- def _set_state(self, state: Runstate) -> None:
- """
- Change the `Runstate` of the protocol connection.
-
- Signals the `runstate_changed` event.
- """
- if state == self._runstate:
- return
-
- self.logger.debug("Transitioning from '%s' to '%s'.",
- str(self._runstate), str(state))
- self._runstate = state
- self._runstate_event.set()
- self._runstate_event.clear()
-
- @upper_half
- async def _new_session(self,
- address: SocketAddrT,
- ssl: Optional[SSLContext] = None,
- accept: bool = False) -> None:
+ async def _session_guard(self, coro: Awaitable[None], emsg: str) -> None:
"""
- Establish a new connection and initialize the session.
+ Async guard function used to roll back to `IDLE` on any error.
- Connect or accept a new connection, then begin the protocol
- session machinery. If this call fails, `runstate` is guaranteed
- to be set back to `IDLE`.
+ On any Exception, the state machine will be reset back to
+ `IDLE`. Most Exceptions will be wrapped with `ConnectError`, but
+ `BaseException` events will be left alone (This includes
+ asyncio.CancelledError, even prior to Python 3.8).
- :param address:
- Address to connect to/listen on;
- UNIX socket path or TCP address/port.
- :param ssl: SSL context to use, if any.
- :param accept: Accept a connection instead of connecting when `True`.
+ :param error_message:
+ Human-readable string describing what connection phase failed.
+ :raise BaseException:
+ When `BaseException` occurs in the guarded block.
:raise ConnectError:
- When a connection or session cannot be established.
-
- This exception will wrap a more concrete one. In most cases,
- the wrapped exception will be `OSError` or `EOFError`. If a
- protocol-level failure occurs while establishing a new
- session, the wrapped error may also be an `QMPError`.
+ When any other error is encountered in the guarded block.
"""
- assert self.runstate == Runstate.IDLE
-
+ # Note: After Python 3.6 support is removed, this should be an
+ # @asynccontextmanager instead of accepting a callback.
try:
- phase = "connection"
- await self._establish_connection(address, ssl, accept)
-
- phase = "session"
- await self._establish_session()
-
+ await coro
except BaseException as err:
- emsg = f"Failed to establish {phase}"
self.logger.error("%s: %s", emsg, exception_summary(err))
self.logger.debug("%s:\n%s\n", emsg, pretty_traceback())
try:
- # Reset from CONNECTING back to IDLE.
+ # Reset the runstate back to IDLE.
await self.disconnect()
except:
- emsg = "Unexpected bottom half exception"
+ # We don't expect any Exceptions from the disconnect function
+ # here, because we failed to connect in the first place.
+ # The disconnect() function is intended to perform
+ # only cannot-fail cleanup here, but you never know.
+ emsg = (
+ "Unexpected bottom half exception. "
+ "This is a bug in the QMP library. "
+ "Please report it to <qemu-devel@nongnu.org> and "
+ "CC: John Snow <jsnow@redhat.com>."
+ )
self.logger.critical("%s:\n%s\n", emsg, pretty_traceback())
raise
+ # CancelledError is an Exception with special semantic meaning;
+ # We do NOT want to wrap it up under ConnectError.
# NB: CancelledError is not a BaseException before Python 3.8
if isinstance(err, asyncio.CancelledError):
raise
+ # Any other kind of error can be treated as some kind of connection
+ # failure broadly. Inspect the 'exc' field to explore the root
+ # cause in greater detail.
if isinstance(err, Exception):
raise ConnectError(emsg, err) from err
# Raise BaseExceptions un-wrapped, they're more important.
raise
- assert self.runstate == Runstate.RUNNING
+ @property
+ def _runstate_event(self) -> asyncio.Event:
+ # asyncio.Event() objects should not be created prior to entrance into
+ # an event loop, so we can ensure we create it in the correct context.
+ # Create it on-demand *only* at the behest of an 'async def' method.
+ if not self._runstate_changed:
+ self._runstate_changed = asyncio.Event()
+ return self._runstate_changed
@upper_half
- async def _establish_connection(
- self,
- address: SocketAddrT,
- ssl: Optional[SSLContext] = None,
- accept: bool = False
- ) -> None:
+ @bottom_half
+ def _set_state(self, state: Runstate) -> None:
"""
- Establish a new connection.
+ Change the `Runstate` of the protocol connection.
- :param address:
- Address to connect to/listen on;
- UNIX socket path or TCP address/port.
- :param ssl: SSL context to use, if any.
- :param accept: Accept a connection instead of connecting when `True`.
+ Signals the `runstate_changed` event.
"""
- assert self.runstate == Runstate.IDLE
- self._set_state(Runstate.CONNECTING)
-
- # Allow runstate watchers to witness 'CONNECTING' state; some
- # failures in the streaming layer are synchronous and will not
- # otherwise yield.
- await asyncio.sleep(0)
+ if state == self._runstate:
+ return
- if accept:
- await self._do_accept(address, ssl)
- else:
- await self._do_connect(address, ssl)
+ self.logger.debug("Transitioning from '%s' to '%s'.",
+ str(self._runstate), str(state))
+ self._runstate = state
+ self._runstate_event.set()
+ self._runstate_event.clear()
- def _bind_hack(self, address: Union[str, Tuple[str, int]]) -> None:
+ @bottom_half
+ async def _stop_server(self) -> None:
+ """
+ Stop listening for / accepting new incoming connections.
"""
- Used to create a socket in advance of accept().
+ if self._server is None:
+ return
- This is a workaround to ensure that we can guarantee timing of
- precisely when a socket exists to avoid a connection attempt
- bouncing off of nothing.
+ try:
+ self.logger.debug("Stopping server.")
+ self._server.close()
+ await self._server.wait_closed()
+ self.logger.debug("Server stopped.")
+ finally:
+ self._server = None
- Python 3.7+ adds a feature to separate the server creation and
- listening phases instead, and should be used instead of this
- hack.
+ @bottom_half # However, it does not run from the R/W tasks.
+ async def _incoming(self,
+ reader: asyncio.StreamReader,
+ writer: asyncio.StreamWriter) -> None:
"""
- if isinstance(address, tuple):
- family = socket.AF_INET
- else:
- family = socket.AF_UNIX
+ Accept an incoming connection and signal the upper_half.
- sock = socket.socket(family, socket.SOCK_STREAM)
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ This method does the minimum necessary to accept a single
+ incoming connection. It signals back to the upper_half ASAP so
+ that any errors during session initialization can occur
+ naturally in the caller's stack.
- try:
- sock.bind(address)
- except:
- sock.close()
- raise
+ :param reader: Incoming `asyncio.StreamReader`
+ :param writer: Incoming `asyncio.StreamWriter`
+ """
+ peer = writer.get_extra_info('peername', 'Unknown peer')
+ self.logger.debug("Incoming connection from %s", peer)
+
+ if self._reader or self._writer:
+ # Sadly, we can have more than one pending connection
+ # because of https://bugs.python.org/issue46715
+ # Close any extra connections we don't actually want.
+ self.logger.warning("Extraneous connection inadvertently accepted")
+ writer.close()
+ return
- self._sock = sock
+ # A connection has been accepted; stop listening for new ones.
+ assert self._accepted is not None
+ await self._stop_server()
+ self._reader, self._writer = (reader, writer)
+ self._accepted.set()
@upper_half
- async def _do_accept(self, address: SocketAddrT,
- ssl: Optional[SSLContext] = None) -> None:
+ async def _do_start_server(self, address: SocketAddrT,
+ ssl: Optional[SSLContext] = None) -> None:
"""
- Acting as the transport server, accept a single connection.
+ Start listening for an incoming connection, but do not wait for a peer.
+
+ This method starts listening for an incoming connection, but does not
+ block waiting for a peer. This call will return immediately after
+ binding and listening to a socket. A later call to accept() must be
+ made in order to finalize the incoming connection.
:param address:
Address to listen on; UNIX socket path or TCP address/port.
@@ -471,52 +548,54 @@ class AsyncProtocol(Generic[T]):
:raise OSError: For stream-related errors.
"""
- self.logger.debug("Awaiting connection on %s ...", address)
- connected = asyncio.Event()
- server: Optional[asyncio.AbstractServer] = None
-
- async def _client_connected_cb(reader: asyncio.StreamReader,
- writer: asyncio.StreamWriter) -> None:
- """Used to accept a single incoming connection, see below."""
- nonlocal server
- nonlocal connected
-
- # A connection has been accepted; stop listening for new ones.
- assert server is not None
- server.close()
- await server.wait_closed()
- server = None
-
- # Register this client as being connected
- self._reader, self._writer = (reader, writer)
+ assert self.runstate == Runstate.IDLE
+ self._set_state(Runstate.CONNECTING)
- # Signal back: We've accepted a client!
- connected.set()
+ self.logger.debug("Awaiting connection on %s ...", address)
+ self._accepted = asyncio.Event()
if isinstance(address, tuple):
coro = asyncio.start_server(
- _client_connected_cb,
- host=None if self._sock else address[0],
- port=None if self._sock else address[1],
+ self._incoming,
+ host=address[0],
+ port=address[1],
ssl=ssl,
backlog=1,
limit=self._limit,
- sock=self._sock,
)
else:
coro = asyncio.start_unix_server(
- _client_connected_cb,
- path=None if self._sock else address,
+ self._incoming,
+ path=address,
ssl=ssl,
backlog=1,
limit=self._limit,
- sock=self._sock,
)
- server = await coro # Starts listening
- await connected.wait() # Waits for the callback to fire (and finish)
- assert server is None
- self._sock = None
+ # Allow runstate watchers to witness 'CONNECTING' state; some
+ # failures in the streaming layer are synchronous and will not
+ # otherwise yield.
+ await asyncio.sleep(0)
+
+ # This will start the server (bind(2), listen(2)). It will also
+ # call accept(2) if we yield, but we don't block on that here.
+ self._server = await coro
+ self.logger.debug("Server listening on %s", address)
+
+ @upper_half
+ async def _do_accept(self) -> None:
+ """
+ Wait for and accept an incoming connection.
+
+ Requires that we have not yet accepted an incoming connection
+ from the upper_half, but it's OK if the server is no longer
+ running because the bottom_half has already accepted the
+ connection.
+ """
+ assert self._accepted is not None
+ await self._accepted.wait()
+ assert self._server is None
+ self._accepted = None
self.logger.debug("Connection accepted.")
@@ -532,6 +611,14 @@ class AsyncProtocol(Generic[T]):
:raise OSError: For stream-related errors.
"""
+ assert self.runstate == Runstate.IDLE
+ self._set_state(Runstate.CONNECTING)
+
+ # Allow runstate watchers to witness 'CONNECTING' state; some
+ # failures in the streaming layer are synchronous and will not
+ # otherwise yield.
+ await asyncio.sleep(0)
+
self.logger.debug("Connecting to %s ...", address)
if isinstance(address, tuple):
@@ -644,6 +731,7 @@ class AsyncProtocol(Generic[T]):
self._reader = None
self._writer = None
+ self._accepted = None
# NB: _runstate_changed cannot be cleared because we still need it to
# send the final runstate changed event ...!
@@ -667,6 +755,9 @@ class AsyncProtocol(Generic[T]):
def _done(task: Optional['asyncio.Future[Any]']) -> bool:
return task is not None and task.done()
+ # If the server is running, stop it.
+ await self._stop_server()
+
# Are we already in an error pathway? If either of the tasks are
# already done, or if we have no tasks but a reader/writer; we
# must be.
diff --git a/python/tests/protocol.py b/python/tests/protocol.py
index 5cd7938be3..d6849ad306 100644
--- a/python/tests/protocol.py
+++ b/python/tests/protocol.py
@@ -41,12 +41,25 @@ class NullProtocol(AsyncProtocol[None]):
self.trigger_input = asyncio.Event()
await super()._establish_session()
- async def _do_accept(self, address, ssl=None):
- if not self.fake_session:
- await super()._do_accept(address, ssl)
+ async def _do_start_server(self, address, ssl=None):
+ if self.fake_session:
+ self._accepted = asyncio.Event()
+ self._set_state(Runstate.CONNECTING)
+ await asyncio.sleep(0)
+ else:
+ await super()._do_start_server(address, ssl)
+
+ async def _do_accept(self):
+ if self.fake_session:
+ self._accepted = None
+ else:
+ await super()._do_accept()
async def _do_connect(self, address, ssl=None):
- if not self.fake_session:
+ if self.fake_session:
+ self._set_state(Runstate.CONNECTING)
+ await asyncio.sleep(0)
+ else:
await super()._do_connect(address, ssl)
async def _do_recv(self) -> None:
@@ -413,14 +426,14 @@ class Accept(Connect):
assert family in ('INET', 'UNIX')
if family == 'INET':
- await self.proto.accept(('example.com', 1))
+ await self.proto.start_server_and_accept(('example.com', 1))
elif family == 'UNIX':
- await self.proto.accept('/dev/null')
+ await self.proto.start_server_and_accept('/dev/null')
async def _hanging_connection(self):
with TemporaryDirectory(suffix='.aqmp') as tmpdir:
sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
- await self.proto.accept(sock)
+ await self.proto.start_server_and_accept(sock)
class FakeSession(TestBase):
@@ -449,13 +462,13 @@ class FakeSession(TestBase):
@TestBase.async_test
async def testFakeAccept(self):
"""Test the full state lifecycle (via accept) with a no-op session."""
- await self.proto.accept('/not/a/real/path')
+ await self.proto.start_server_and_accept('/not/a/real/path')
self.assertEqual(self.proto.runstate, Runstate.RUNNING)
@TestBase.async_test
async def testFakeRecv(self):
"""Test receiving a fake/null message."""
- await self.proto.accept('/not/a/real/path')
+ await self.proto.start_server_and_accept('/not/a/real/path')
logname = self.proto.logger.name
with self.assertLogs(logname, level='DEBUG') as context:
@@ -471,7 +484,7 @@ class FakeSession(TestBase):
@TestBase.async_test
async def testFakeSend(self):
"""Test sending a fake/null message."""
- await self.proto.accept('/not/a/real/path')
+ await self.proto.start_server_and_accept('/not/a/real/path')
logname = self.proto.logger.name
with self.assertLogs(logname, level='DEBUG') as context:
@@ -493,7 +506,7 @@ class FakeSession(TestBase):
):
with self.assertRaises(StateError) as context:
if accept:
- await self.proto.accept('/not/a/real/path')
+ await self.proto.start_server_and_accept('/not/a/real/path')
else:
await self.proto.connect('/not/a/real/path')
@@ -504,7 +517,7 @@ class FakeSession(TestBase):
@TestBase.async_test
async def testAcceptRequireRunning(self):
"""Test that accept() cannot be called when Runstate=RUNNING"""
- await self.proto.accept('/not/a/real/path')
+ await self.proto.start_server_and_accept('/not/a/real/path')
await self._prod_session_api(
Runstate.RUNNING,
@@ -515,7 +528,7 @@ class FakeSession(TestBase):
@TestBase.async_test
async def testConnectRequireRunning(self):
"""Test that connect() cannot be called when Runstate=RUNNING"""
- await self.proto.accept('/not/a/real/path')
+ await self.proto.start_server_and_accept('/not/a/real/path')
await self._prod_session_api(
Runstate.RUNNING,
@@ -526,7 +539,7 @@ class FakeSession(TestBase):
@TestBase.async_test
async def testAcceptRequireDisconnecting(self):
"""Test that accept() cannot be called when Runstate=DISCONNECTING"""
- await self.proto.accept('/not/a/real/path')
+ await self.proto.start_server_and_accept('/not/a/real/path')
# Cheat: force a disconnect.
await self.proto.simulate_disconnect()
@@ -541,7 +554,7 @@ class FakeSession(TestBase):
@TestBase.async_test
async def testConnectRequireDisconnecting(self):
"""Test that connect() cannot be called when Runstate=DISCONNECTING"""
- await self.proto.accept('/not/a/real/path')
+ await self.proto.start_server_and_accept('/not/a/real/path')
# Cheat: force a disconnect.
await self.proto.simulate_disconnect()
@@ -576,7 +589,7 @@ class SimpleSession(TestBase):
async def testSmoke(self):
with TemporaryDirectory(suffix='.aqmp') as tmpdir:
sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
- server_task = create_task(self.server.accept(sock))
+ server_task = create_task(self.server.start_server_and_accept(sock))
# give the server a chance to start listening [...]
await asyncio.sleep(0)
diff --git a/scripts/qmp/qmp-shell-wrap b/scripts/qmp/qmp-shell-wrap
index 9e94da114f..66846e36d1 100755
--- a/scripts/qmp/qmp-shell-wrap
+++ b/scripts/qmp/qmp-shell-wrap
@@ -4,7 +4,7 @@ import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'python'))
-from qemu.qmp import qmp_shell
+from qemu.aqmp import qmp_shell
if __name__ == '__main__':