]> git.ipfire.org Git - thirdparty/mkosi.git/commitdiff
qemu: Make notify handling code more generic
authorDaan De Meyer <daan.j.demeyer@gmail.com>
Mon, 7 Apr 2025 10:53:01 +0000 (12:53 +0200)
committerDaan De Meyer <daan.j.demeyer@gmail.com>
Tue, 8 Apr 2025 08:20:17 +0000 (10:20 +0200)
Let's pass in a queue to the target function of AsyncioThread() to which
it can write messages. Let's also keep track of all sent messages in
AsyncioThread for easy access later and add some utility functions to
process new messages.

We also split out notify() from vsock_notify_handler() and make it non-vsock
specific by passing in the socket as an argument.

mkosi/qemu.py
mkosi/run.py

index 0c60148dd0845e0f04762458d69aed98ae305a10..f138eca8ae04a2b2d90df76e8e837bf7a9a3cf26 100644 (file)
@@ -6,11 +6,13 @@ import dataclasses
 import enum
 import errno
 import fcntl
+import functools
 import hashlib
 import io
 import json
 import logging
 import os
+import queue
 import random
 import resource
 import shutil
@@ -449,59 +451,50 @@ def start_virtiofsd(
             proc.terminate()
 
 
-@contextlib.contextmanager
-def vsock_notify_handler() -> Iterator[tuple[str, dict[str, str]]]:
-    """
-    This yields a vsock address and a dict that will be filled in with the notifications from the VM. The
-    dict should only be accessed after the context manager has been finalized.
-    """
-    with socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM) as vsock:
-        vsock.bind((socket.VMADDR_CID_ANY, socket.VMADDR_PORT_ANY))
-        vsock.listen()
-        vsock.setblocking(False)
+async def notify(messages: queue.SimpleQueue[tuple[str, str]], *, sock: socket.socket) -> None:
+    import asyncio
 
-        num_messages = 0
-        num_bytes = 0
-        messages = {}
+    loop = asyncio.get_running_loop()
+    num_messages = 0
+    num_bytes = 0
 
-        async def notify() -> None:
-            nonlocal num_messages
-            nonlocal num_bytes
-
-            import asyncio
+    try:
+        while True:
+            s, _ = await loop.sock_accept(sock)
 
-            loop = asyncio.get_running_loop()
+            num_messages += 1
 
-            while True:
-                s, _ = await loop.sock_accept(vsock)
+            with s:
+                data = []
+                try:
+                    while buf := await loop.sock_recv(s, 4096):
+                        data.append(buf)
+                except ConnectionResetError:
+                    logging.debug("notify listener connection reset by peer")
 
-                num_messages += 1
+            for msg in b"".join(data).decode().split("\n"):
+                if not msg:
+                    continue
 
-                with s:
-                    data = []
-                    try:
-                        while buf := await loop.sock_recv(s, 4096):
-                            data.append(buf)
-                    except ConnectionResetError:
-                        logging.debug("vsock notify listener connection reset by peer")
+                num_bytes += len(msg)
+                k, _, v = msg.partition("=")
+                messages.put((k, v))
+    except asyncio.CancelledError:
+        logging.debug(f"Received {num_messages} notify messages totalling {format_bytes(num_bytes)} bytes")
 
-                for msg in b"".join(data).decode().split("\n"):
-                    if not msg:
-                        continue
 
-                    num_bytes += len(msg)
-                    k, _, v = msg.partition("=")
-                    messages[k] = v
+@contextlib.contextmanager
+def vsock_notify_handler() -> Iterator[tuple[str, AsyncioThread[tuple[str, str]]]]:
+    """
+    This yields a vsock address and an object that will be filled in with the notifications from the VM.
+    """
+    with socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM) as vsock:
+        vsock.bind((socket.VMADDR_CID_ANY, socket.VMADDR_PORT_ANY))
+        vsock.listen()
+        vsock.setblocking(False)
 
-        with AsyncioThread(notify()):
-            try:
-                yield f"vsock-stream:{socket.VMADDR_CID_HOST}:{vsock.getsockname()[1]}", messages
-            finally:
-                logging.debug(
-                    f"Received {num_messages} notify messages totalling {format_bytes(num_bytes)} bytes"
-                )
-                for k, v in messages.items():
-                    logging.debug(f"- {k}={v}")
+        with AsyncioThread(functools.partial(notify, sock=vsock)) as thread:
+            yield f"vsock-stream:{socket.VMADDR_CID_HOST}:{vsock.getsockname()[1]}", thread
 
 
 @contextlib.contextmanager
@@ -1306,7 +1299,8 @@ def run_qemu(args: Args, config: Config) -> None:
     if firmware.is_uefi():
         assert ovmf
         cmdline += ["-drive", f"if=pflash,format={ovmf.format},readonly=on,file={ovmf.firmware}"]
-    notifications: dict[str, str] = {}
+
+    notify: Optional[AsyncioThread[tuple[str, str]]] = None
 
     with contextlib.ExitStack() as stack:
         if firmware.is_uefi():
@@ -1495,7 +1489,7 @@ def run_qemu(args: Args, config: Config) -> None:
                 cmdline += ["-device", "tpm-tis-device,tpmdev=tpm0"]
 
         if QemuDeviceNode.vhost_vsock in qemu_device_fds:
-            addr, notifications = stack.enter_context(vsock_notify_handler())
+            addr, notify = stack.enter_context(vsock_notify_handler())
             credentials["vmm.notify_socket"] = addr
 
         if config.forward_journal:
@@ -1604,7 +1598,7 @@ def run_qemu(args: Args, config: Config) -> None:
 
             proc.send_signal(signal.SIGCONT)
 
-        if status := int(notifications.get("EXIT_STATUS", 0)):
+        if notify and (status := int({k: v for k, v in notify.process()}.get("EXIT_STATUS", "0"))) != 0:
             raise subprocess.CalledProcessError(status, cmdline)
 
 
index 422006d889802182d7e2f1734b2c342318583e7b..e20f82cc3812f6f1da4d61ff777632cb33381f1e 100644 (file)
@@ -16,7 +16,7 @@ from collections.abc import Awaitable, Collection, Iterator, Mapping, Sequence
 from contextlib import AbstractContextManager
 from pathlib import Path
 from types import TracebackType
-from typing import TYPE_CHECKING, Any, Callable, NoReturn, Optional, Protocol
+from typing import TYPE_CHECKING, Any, Callable, Generic, NoReturn, Optional, Protocol, TypeVar
 
 from mkosi.log import ARG_DEBUG, ARG_DEBUG_SANDBOX, ARG_DEBUG_SHELL, die
 from mkosi.sandbox import acquire_privileges, joinpath, umask
@@ -33,6 +33,9 @@ else:
     Popen = subprocess.Popen
 
 
+T = TypeVar("T")
+
+
 def ensure_exc_info() -> tuple[type[BaseException], BaseException, TracebackType]:
     exctype, exc, tb = sys.exc_info()
     assert exctype
@@ -310,7 +313,7 @@ def find_binary(
     return None
 
 
-class AsyncioThread(threading.Thread):
+class AsyncioThread(threading.Thread, Generic[T]):
     """
     The default threading.Thread() is not interruptible, so we make our own version by using the concurrency
     feature in python that is interruptible, namely asyncio.
@@ -319,12 +322,16 @@ class AsyncioThread(threading.Thread):
     exception was raised before.
     """
 
-    def __init__(self, target: Awaitable[Any], *args: Any, **kwargs: Any) -> None:
+    def __init__(
+        self, target: Callable[[queue.SimpleQueue[T]], Awaitable[Any]], *args: Any, **kwargs: Any
+    ) -> None:
         import asyncio
 
         self.target = target
         self.loop: queue.SimpleQueue[asyncio.AbstractEventLoop] = queue.SimpleQueue()
         self.exc: queue.SimpleQueue[BaseException] = queue.SimpleQueue()
+        self.queue: queue.SimpleQueue[T] = queue.SimpleQueue()
+        self.messages: list[T] = []
         super().__init__(*args, **kwargs)
 
     def run(self) -> None:
@@ -332,7 +339,7 @@ class AsyncioThread(threading.Thread):
 
         async def wrapper() -> None:
             self.loop.put(asyncio.get_running_loop())
-            await self.target
+            await self.target(self.queue)
 
         try:
             asyncio.run(wrapper())
@@ -341,6 +348,16 @@ class AsyncioThread(threading.Thread):
         except BaseException as e:
             self.exc.put(e)
 
+    def process(self) -> list[T]:
+        while not self.queue.empty():
+            self.messages += [self.queue.get()]
+
+        return self.messages
+
+    def wait_for(self, expected: T) -> None:
+        while (message := self.queue.get()) != expected:
+            self.messages += [message]
+
     def cancel(self) -> None:
         import asyncio.tasks
 
@@ -349,7 +366,7 @@ class AsyncioThread(threading.Thread):
         for task in asyncio.tasks.all_tasks(loop):
             loop.call_soon_threadsafe(task.cancel)
 
-    def __enter__(self) -> "AsyncioThread":
+    def __enter__(self) -> "AsyncioThread[T]":
         self.start()
         return self
 
@@ -361,6 +378,7 @@ class AsyncioThread(threading.Thread):
     ) -> None:
         self.cancel()
         self.join()
+        self.process()
 
         if type is None:
             try: