From: Daan De Meyer Date: Mon, 7 Apr 2025 10:53:01 +0000 (+0200) Subject: qemu: Make notify handling code more generic X-Git-Tag: v26~268 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=988c855b060151595c97b9824c8217fa3b3f8fce;p=thirdparty%2Fmkosi.git qemu: Make notify handling code more generic Let's pass in a queue to the target function of AsyncioThread() to which it can write messages. Let's also keep track of all sent messages in AsyncioThread for easy access later and add some utility functions to process new messages. We also split out notify() from vsock_notify_handler() and make it non-vsock specific by passing in the socket as an argument. --- diff --git a/mkosi/qemu.py b/mkosi/qemu.py index 0c60148dd..f138eca8a 100644 --- a/mkosi/qemu.py +++ b/mkosi/qemu.py @@ -6,11 +6,13 @@ import dataclasses import enum import errno import fcntl +import functools import hashlib import io import json import logging import os +import queue import random import resource import shutil @@ -449,59 +451,50 @@ def start_virtiofsd( proc.terminate() -@contextlib.contextmanager -def vsock_notify_handler() -> Iterator[tuple[str, dict[str, str]]]: - """ - This yields a vsock address and a dict that will be filled in with the notifications from the VM. The - dict should only be accessed after the context manager has been finalized. - """ - with socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM) as vsock: - vsock.bind((socket.VMADDR_CID_ANY, socket.VMADDR_PORT_ANY)) - vsock.listen() - vsock.setblocking(False) +async def notify(messages: queue.SimpleQueue[tuple[str, str]], *, sock: socket.socket) -> None: + import asyncio - num_messages = 0 - num_bytes = 0 - messages = {} + loop = asyncio.get_running_loop() + num_messages = 0 + num_bytes = 0 - async def notify() -> None: - nonlocal num_messages - nonlocal num_bytes - - import asyncio + try: + while True: + s, _ = await loop.sock_accept(sock) - loop = asyncio.get_running_loop() + num_messages += 1 - while True: - s, _ = await loop.sock_accept(vsock) + with s: + data = [] + try: + while buf := await loop.sock_recv(s, 4096): + data.append(buf) + except ConnectionResetError: + logging.debug("notify listener connection reset by peer") - num_messages += 1 + for msg in b"".join(data).decode().split("\n"): + if not msg: + continue - with s: - data = [] - try: - while buf := await loop.sock_recv(s, 4096): - data.append(buf) - except ConnectionResetError: - logging.debug("vsock notify listener connection reset by peer") + num_bytes += len(msg) + k, _, v = msg.partition("=") + messages.put((k, v)) + except asyncio.CancelledError: + logging.debug(f"Received {num_messages} notify messages totalling {format_bytes(num_bytes)} bytes") - for msg in b"".join(data).decode().split("\n"): - if not msg: - continue - num_bytes += len(msg) - k, _, v = msg.partition("=") - messages[k] = v +@contextlib.contextmanager +def vsock_notify_handler() -> Iterator[tuple[str, AsyncioThread[tuple[str, str]]]]: + """ + This yields a vsock address and an object that will be filled in with the notifications from the VM. + """ + with socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM) as vsock: + vsock.bind((socket.VMADDR_CID_ANY, socket.VMADDR_PORT_ANY)) + vsock.listen() + vsock.setblocking(False) - with AsyncioThread(notify()): - try: - yield f"vsock-stream:{socket.VMADDR_CID_HOST}:{vsock.getsockname()[1]}", messages - finally: - logging.debug( - f"Received {num_messages} notify messages totalling {format_bytes(num_bytes)} bytes" - ) - for k, v in messages.items(): - logging.debug(f"- {k}={v}") + with AsyncioThread(functools.partial(notify, sock=vsock)) as thread: + yield f"vsock-stream:{socket.VMADDR_CID_HOST}:{vsock.getsockname()[1]}", thread @contextlib.contextmanager @@ -1306,7 +1299,8 @@ def run_qemu(args: Args, config: Config) -> None: if firmware.is_uefi(): assert ovmf cmdline += ["-drive", f"if=pflash,format={ovmf.format},readonly=on,file={ovmf.firmware}"] - notifications: dict[str, str] = {} + + notify: Optional[AsyncioThread[tuple[str, str]]] = None with contextlib.ExitStack() as stack: if firmware.is_uefi(): @@ -1495,7 +1489,7 @@ def run_qemu(args: Args, config: Config) -> None: cmdline += ["-device", "tpm-tis-device,tpmdev=tpm0"] if QemuDeviceNode.vhost_vsock in qemu_device_fds: - addr, notifications = stack.enter_context(vsock_notify_handler()) + addr, notify = stack.enter_context(vsock_notify_handler()) credentials["vmm.notify_socket"] = addr if config.forward_journal: @@ -1604,7 +1598,7 @@ def run_qemu(args: Args, config: Config) -> None: proc.send_signal(signal.SIGCONT) - if status := int(notifications.get("EXIT_STATUS", 0)): + if notify and (status := int({k: v for k, v in notify.process()}.get("EXIT_STATUS", "0"))) != 0: raise subprocess.CalledProcessError(status, cmdline) diff --git a/mkosi/run.py b/mkosi/run.py index 422006d88..e20f82cc3 100644 --- a/mkosi/run.py +++ b/mkosi/run.py @@ -16,7 +16,7 @@ from collections.abc import Awaitable, Collection, Iterator, Mapping, Sequence from contextlib import AbstractContextManager from pathlib import Path from types import TracebackType -from typing import TYPE_CHECKING, Any, Callable, NoReturn, Optional, Protocol +from typing import TYPE_CHECKING, Any, Callable, Generic, NoReturn, Optional, Protocol, TypeVar from mkosi.log import ARG_DEBUG, ARG_DEBUG_SANDBOX, ARG_DEBUG_SHELL, die from mkosi.sandbox import acquire_privileges, joinpath, umask @@ -33,6 +33,9 @@ else: Popen = subprocess.Popen +T = TypeVar("T") + + def ensure_exc_info() -> tuple[type[BaseException], BaseException, TracebackType]: exctype, exc, tb = sys.exc_info() assert exctype @@ -310,7 +313,7 @@ def find_binary( return None -class AsyncioThread(threading.Thread): +class AsyncioThread(threading.Thread, Generic[T]): """ The default threading.Thread() is not interruptible, so we make our own version by using the concurrency feature in python that is interruptible, namely asyncio. @@ -319,12 +322,16 @@ class AsyncioThread(threading.Thread): exception was raised before. """ - def __init__(self, target: Awaitable[Any], *args: Any, **kwargs: Any) -> None: + def __init__( + self, target: Callable[[queue.SimpleQueue[T]], Awaitable[Any]], *args: Any, **kwargs: Any + ) -> None: import asyncio self.target = target self.loop: queue.SimpleQueue[asyncio.AbstractEventLoop] = queue.SimpleQueue() self.exc: queue.SimpleQueue[BaseException] = queue.SimpleQueue() + self.queue: queue.SimpleQueue[T] = queue.SimpleQueue() + self.messages: list[T] = [] super().__init__(*args, **kwargs) def run(self) -> None: @@ -332,7 +339,7 @@ class AsyncioThread(threading.Thread): async def wrapper() -> None: self.loop.put(asyncio.get_running_loop()) - await self.target + await self.target(self.queue) try: asyncio.run(wrapper()) @@ -341,6 +348,16 @@ class AsyncioThread(threading.Thread): except BaseException as e: self.exc.put(e) + def process(self) -> list[T]: + while not self.queue.empty(): + self.messages += [self.queue.get()] + + return self.messages + + def wait_for(self, expected: T) -> None: + while (message := self.queue.get()) != expected: + self.messages += [message] + def cancel(self) -> None: import asyncio.tasks @@ -349,7 +366,7 @@ class AsyncioThread(threading.Thread): for task in asyncio.tasks.all_tasks(loop): loop.call_soon_threadsafe(task.cancel) - def __enter__(self) -> "AsyncioThread": + def __enter__(self) -> "AsyncioThread[T]": self.start() return self @@ -361,6 +378,7 @@ class AsyncioThread(threading.Thread): ) -> None: self.cancel() self.join() + self.process() if type is None: try: