import enum
import errno
import fcntl
+import functools
import hashlib
import io
import json
import logging
import os
+import queue
import random
import resource
import shutil
proc.terminate()
-@contextlib.contextmanager
-def vsock_notify_handler() -> Iterator[tuple[str, dict[str, str]]]:
- """
- This yields a vsock address and a dict that will be filled in with the notifications from the VM. The
- dict should only be accessed after the context manager has been finalized.
- """
- with socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM) as vsock:
- vsock.bind((socket.VMADDR_CID_ANY, socket.VMADDR_PORT_ANY))
- vsock.listen()
- vsock.setblocking(False)
+async def notify(messages: queue.SimpleQueue[tuple[str, str]], *, sock: socket.socket) -> None:
+ import asyncio
- num_messages = 0
- num_bytes = 0
- messages = {}
+ loop = asyncio.get_running_loop()
+ num_messages = 0
+ num_bytes = 0
- async def notify() -> None:
- nonlocal num_messages
- nonlocal num_bytes
-
- import asyncio
+ try:
+ while True:
+ s, _ = await loop.sock_accept(sock)
- loop = asyncio.get_running_loop()
+ num_messages += 1
- while True:
- s, _ = await loop.sock_accept(vsock)
+ with s:
+ data = []
+ try:
+ while buf := await loop.sock_recv(s, 4096):
+ data.append(buf)
+ except ConnectionResetError:
+ logging.debug("notify listener connection reset by peer")
- num_messages += 1
+ for msg in b"".join(data).decode().split("\n"):
+ if not msg:
+ continue
- with s:
- data = []
- try:
- while buf := await loop.sock_recv(s, 4096):
- data.append(buf)
- except ConnectionResetError:
- logging.debug("vsock notify listener connection reset by peer")
+ num_bytes += len(msg)
+ k, _, v = msg.partition("=")
+ messages.put((k, v))
+ except asyncio.CancelledError:
+ logging.debug(f"Received {num_messages} notify messages totalling {format_bytes(num_bytes)} bytes")
- for msg in b"".join(data).decode().split("\n"):
- if not msg:
- continue
- num_bytes += len(msg)
- k, _, v = msg.partition("=")
- messages[k] = v
+@contextlib.contextmanager
+def vsock_notify_handler() -> Iterator[tuple[str, AsyncioThread[tuple[str, str]]]]:
+ """
+ This yields a vsock address and an object that will be filled in with the notifications from the VM.
+ """
+ with socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM) as vsock:
+ vsock.bind((socket.VMADDR_CID_ANY, socket.VMADDR_PORT_ANY))
+ vsock.listen()
+ vsock.setblocking(False)
- with AsyncioThread(notify()):
- try:
- yield f"vsock-stream:{socket.VMADDR_CID_HOST}:{vsock.getsockname()[1]}", messages
- finally:
- logging.debug(
- f"Received {num_messages} notify messages totalling {format_bytes(num_bytes)} bytes"
- )
- for k, v in messages.items():
- logging.debug(f"- {k}={v}")
+ with AsyncioThread(functools.partial(notify, sock=vsock)) as thread:
+ yield f"vsock-stream:{socket.VMADDR_CID_HOST}:{vsock.getsockname()[1]}", thread
@contextlib.contextmanager
if firmware.is_uefi():
assert ovmf
cmdline += ["-drive", f"if=pflash,format={ovmf.format},readonly=on,file={ovmf.firmware}"]
- notifications: dict[str, str] = {}
+
+ notify: Optional[AsyncioThread[tuple[str, str]]] = None
with contextlib.ExitStack() as stack:
if firmware.is_uefi():
cmdline += ["-device", "tpm-tis-device,tpmdev=tpm0"]
if QemuDeviceNode.vhost_vsock in qemu_device_fds:
- addr, notifications = stack.enter_context(vsock_notify_handler())
+ addr, notify = stack.enter_context(vsock_notify_handler())
credentials["vmm.notify_socket"] = addr
if config.forward_journal:
proc.send_signal(signal.SIGCONT)
- if status := int(notifications.get("EXIT_STATUS", 0)):
+ if notify and (status := int({k: v for k, v in notify.process()}.get("EXIT_STATUS", "0"))) != 0:
raise subprocess.CalledProcessError(status, cmdline)
from contextlib import AbstractContextManager
from pathlib import Path
from types import TracebackType
-from typing import TYPE_CHECKING, Any, Callable, NoReturn, Optional, Protocol
+from typing import TYPE_CHECKING, Any, Callable, Generic, NoReturn, Optional, Protocol, TypeVar
from mkosi.log import ARG_DEBUG, ARG_DEBUG_SANDBOX, ARG_DEBUG_SHELL, die
from mkosi.sandbox import acquire_privileges, joinpath, umask
Popen = subprocess.Popen
+T = TypeVar("T")
+
+
def ensure_exc_info() -> tuple[type[BaseException], BaseException, TracebackType]:
exctype, exc, tb = sys.exc_info()
assert exctype
return None
-class AsyncioThread(threading.Thread):
+class AsyncioThread(threading.Thread, Generic[T]):
"""
The default threading.Thread() is not interruptible, so we make our own version by using the concurrency
feature in python that is interruptible, namely asyncio.
exception was raised before.
"""
- def __init__(self, target: Awaitable[Any], *args: Any, **kwargs: Any) -> None:
+ def __init__(
+ self, target: Callable[[queue.SimpleQueue[T]], Awaitable[Any]], *args: Any, **kwargs: Any
+ ) -> None:
import asyncio
self.target = target
self.loop: queue.SimpleQueue[asyncio.AbstractEventLoop] = queue.SimpleQueue()
self.exc: queue.SimpleQueue[BaseException] = queue.SimpleQueue()
+ self.queue: queue.SimpleQueue[T] = queue.SimpleQueue()
+ self.messages: list[T] = []
super().__init__(*args, **kwargs)
def run(self) -> None:
async def wrapper() -> None:
self.loop.put(asyncio.get_running_loop())
- await self.target
+ await self.target(self.queue)
try:
asyncio.run(wrapper())
except BaseException as e:
self.exc.put(e)
+ def process(self) -> list[T]:
+ while not self.queue.empty():
+ self.messages += [self.queue.get()]
+
+ return self.messages
+
+ def wait_for(self, expected: T) -> None:
+ while (message := self.queue.get()) != expected:
+ self.messages += [message]
+
def cancel(self) -> None:
import asyncio.tasks
for task in asyncio.tasks.all_tasks(loop):
loop.call_soon_threadsafe(task.cancel)
- def __enter__(self) -> "AsyncioThread":
+ def __enter__(self) -> "AsyncioThread[T]":
self.start()
return self
) -> None:
self.cancel()
self.join()
+ self.process()
if type is None:
try: