# SPDX-License-Identifier: LGPL-2.1+
+import asyncio
import base64
import contextlib
import datetime
import re
import resource
import shutil
+import socket
import subprocess
import sys
import tempfile
from mkosi.mounts import dissect_and_mount, mount_overlay, scandir_recursive
from mkosi.pager import page
from mkosi.remove import unlink_try_hard
-from mkosi.run import become_root, fork_and_wait, run, run_workspace_command, spawn
+from mkosi.run import (
+ MkosiAsyncioThread,
+ become_root,
+ fork_and_wait,
+ run,
+ run_workspace_command,
+ spawn,
+)
from mkosi.state import MkosiState
from mkosi.types import PathString
from mkosi.util import (
swtpm_proc.wait()
+@contextlib.contextmanager
+def vsock_notify_handler() -> Iterator[tuple[str, dict[str, str]]]:
+ """
+ This yields a vsock address and a dict that will be filled in with the notifications from the VM. The
+ dict should only be accessed after the context manager has been finalized.
+ """
+ with socket.socket(socket.AF_VSOCK, socket.SOCK_SEQPACKET) as vsock:
+ vsock.bind((socket.VMADDR_CID_ANY, -1))
+ vsock.listen()
+ vsock.setblocking(False)
+
+ messages = {}
+
+ async def notify() -> None:
+ loop = asyncio.get_running_loop()
+
+ try:
+ while True:
+ s, _ = await loop.sock_accept(vsock)
+
+ for msg in (await loop.sock_recv(s, 4096)).decode().split("\n"):
+ if not msg:
+ continue
+
+ k, _, v = msg.partition("=")
+ messages[k] = v
+
+ except asyncio.CancelledError:
+ pass
+
+ with MkosiAsyncioThread(notify()):
+ yield f"vsock:{socket.VMADDR_CID_HOST}:{vsock.getsockname()[1]}", messages
+
+
def run_qemu(args: MkosiArgs, config: MkosiConfig) -> None:
accel = "kvm" if config.qemu_kvm else "tcg"
cmdline += ["-drive", f"if=pflash,format=raw,readonly=on,file={firmware}"]
+ notifications: dict[str, str] = {}
+
with contextlib.ExitStack() as stack:
if fw_supports_sb:
ovmf_vars = stack.enter_context(tempfile.NamedTemporaryFile(prefix=".mkosi-", dir=tmp_dir()))
elif config.architecture == "aarch64":
cmdline += ["-device", "tpm-tis-device,tpmdev=tpm0"]
+ addr, notifications = stack.enter_context(vsock_notify_handler())
+ cmdline += ["-smbios", f"type=11,value=io.systemd.credential:vmm.notify_socket={addr}"]
+
cmdline += config.qemu_args
cmdline += args.cmdline
run(cmdline, stdin=sys.stdin, stdout=sys.stdout, env=os.environ, log=False)
+ if "EXIT_STATUS" in notifications:
+ raise subprocess.CalledProcessError(int(notifications["EXIT_STATUS"]), cmdline)
+
def run_ssh(args: MkosiArgs, config: MkosiConfig) -> None:
cmd = [
+import asyncio
+import asyncio.tasks
import ctypes
import ctypes.util
import logging
import multiprocessing
import os
import pwd
+import queue
import shlex
import shutil
import signal
import subprocess
import sys
import tempfile
+import threading
import traceback
from pathlib import Path
from types import TracebackType
-from typing import Any, Callable, Mapping, Optional, Sequence, Type, TypeVar
+from typing import Any, Awaitable, Callable, Mapping, Optional, Sequence, Type, TypeVar
from mkosi.log import ARG_DEBUG, ARG_DEBUG_SHELL, die
from mkosi.types import _FILE, CompletedProcess, PathString, Popen
if tmp.is_symlink():
resolve.unlink(missing_ok=True)
shutil.move(tmp, resolve)
+
+
+class MkosiAsyncioThread(threading.Thread):
+ """
+ The default threading.Thread() is not interruptable, so we make our own version by using the concurrency
+ feature in python that is interruptable, namely asyncio.
+
+ Additionally, we store the result of the coroutine in the result variable so it can be accessed easily
+ after the thread finishes.
+ """
+
+ def __init__(self, target: Awaitable[Any], *args: Any, **kwargs: Any) -> None:
+ self.target = target
+ self.loop: queue.SimpleQueue[asyncio.AbstractEventLoop] = queue.SimpleQueue()
+ super().__init__(*args, **kwargs)
+
+ def run(self) -> None:
+ async def wrapper() -> None:
+ self.loop.put(asyncio.get_running_loop())
+ await self.target
+
+ asyncio.run(wrapper())
+
+ def cancel(self) -> None:
+ loop = self.loop.get()
+
+ for task in asyncio.tasks.all_tasks(loop):
+ loop.call_soon_threadsafe(task.cancel)
+
+ def __enter__(self) -> "MkosiAsyncioThread":
+ self.start()
+ return self
+
+ def __exit__(
+ self,
+ type: Optional[Type[BaseException]],
+ value: Optional[BaseException],
+ traceback: Optional[TracebackType],
+ ) -> None:
+ self.cancel()
+ self.join()