]> git.ipfire.org Git - thirdparty/mkosi.git/commitdiff
Propagate systemd exit status from the VM
authorDaan De Meyer <daan.j.demeyer@gmail.com>
Thu, 11 May 2023 08:45:54 +0000 (10:45 +0200)
committerDaan De Meyer <daan.j.demeyer@gmail.com>
Thu, 11 May 2023 10:11:13 +0000 (12:11 +0200)
Let's make use of the new vmm.notify_socket credential to fetch
systemd's exit status from the VM and propagate it as our own exit
status, just like already happens automatically for containers with
systemd-nspawn.

mkosi/__init__.py
mkosi/run.py

index 4055530dbc975fc5938e11b24e9dff59a519a430..fa44f73ddf9624908bd97a4d49bb865be33ddd69 100644 (file)
@@ -1,5 +1,6 @@
 # SPDX-License-Identifier: LGPL-2.1+
 
+import asyncio
 import base64
 import contextlib
 import datetime
@@ -13,6 +14,7 @@ import os
 import re
 import resource
 import shutil
+import socket
 import subprocess
 import sys
 import tempfile
@@ -36,7 +38,14 @@ from mkosi.manifest import Manifest
 from mkosi.mounts import dissect_and_mount, mount_overlay, scandir_recursive
 from mkosi.pager import page
 from mkosi.remove import unlink_try_hard
-from mkosi.run import become_root, fork_and_wait, run, run_workspace_command, spawn
+from mkosi.run import (
+    MkosiAsyncioThread,
+    become_root,
+    fork_and_wait,
+    run,
+    run_workspace_command,
+    spawn,
+)
 from mkosi.state import MkosiState
 from mkosi.types import PathString
 from mkosi.util import (
@@ -2024,6 +2033,40 @@ def start_swtpm() -> Iterator[Optional[Path]]:
             swtpm_proc.wait()
 
 
+@contextlib.contextmanager
+def vsock_notify_handler() -> Iterator[tuple[str, dict[str, str]]]:
+    """
+    This yields a vsock address and a dict that will be filled in with the notifications from the VM. The
+    dict should only be accessed after the context manager has been finalized.
+    """
+    with socket.socket(socket.AF_VSOCK, socket.SOCK_SEQPACKET) as vsock:
+        vsock.bind((socket.VMADDR_CID_ANY, -1))
+        vsock.listen()
+        vsock.setblocking(False)
+
+        messages = {}
+
+        async def notify() -> None:
+            loop = asyncio.get_running_loop()
+
+            try:
+                while True:
+                    s, _ = await loop.sock_accept(vsock)
+
+                    for msg in (await loop.sock_recv(s, 4096)).decode().split("\n"):
+                        if not msg:
+                            continue
+
+                        k, _, v = msg.partition("=")
+                        messages[k] = v
+
+            except asyncio.CancelledError:
+                pass
+
+        with MkosiAsyncioThread(notify()):
+            yield f"vsock:{socket.VMADDR_CID_HOST}:{vsock.getsockname()[1]}", messages
+
+
 def run_qemu(args: MkosiArgs, config: MkosiConfig) -> None:
     accel = "kvm" if config.qemu_kvm else "tcg"
 
@@ -2075,6 +2118,8 @@ def run_qemu(args: MkosiArgs, config: MkosiConfig) -> None:
 
     cmdline += ["-drive", f"if=pflash,format=raw,readonly=on,file={firmware}"]
 
+    notifications: dict[str, str] = {}
+
     with contextlib.ExitStack() as stack:
         if fw_supports_sb:
             ovmf_vars = stack.enter_context(tempfile.NamedTemporaryFile(prefix=".mkosi-", dir=tmp_dir()))
@@ -2127,11 +2172,17 @@ def run_qemu(args: MkosiArgs, config: MkosiConfig) -> None:
             elif config.architecture == "aarch64":
                 cmdline += ["-device", "tpm-tis-device,tpmdev=tpm0"]
 
+        addr, notifications = stack.enter_context(vsock_notify_handler())
+        cmdline += ["-smbios", f"type=11,value=io.systemd.credential:vmm.notify_socket={addr}"]
+
         cmdline += config.qemu_args
         cmdline += args.cmdline
 
         run(cmdline, stdin=sys.stdin, stdout=sys.stdout, env=os.environ, log=False)
 
+    if "EXIT_STATUS" in notifications:
+        raise subprocess.CalledProcessError(int(notifications["EXIT_STATUS"]), cmdline)
+
 
 def run_ssh(args: MkosiArgs, config: MkosiConfig) -> None:
     cmd = [
index 4fc79ea250f199b5f91063b59aaeb548c368de94..c6db1515e0ea260167f27a637a071148e21dc354 100644 (file)
@@ -1,19 +1,23 @@
+import asyncio
+import asyncio.tasks
 import ctypes
 import ctypes.util
 import logging
 import multiprocessing
 import os
 import pwd
+import queue
 import shlex
 import shutil
 import signal
 import subprocess
 import sys
 import tempfile
+import threading
 import traceback
 from pathlib import Path
 from types import TracebackType
-from typing import Any, Callable, Mapping, Optional, Sequence, Type, TypeVar
+from typing import Any, Awaitable, Callable, Mapping, Optional, Sequence, Type, TypeVar
 
 from mkosi.log import ARG_DEBUG, ARG_DEBUG_SHELL, die
 from mkosi.types import _FILE, CompletedProcess, PathString, Popen
@@ -402,3 +406,44 @@ def run_workspace_command(
             if tmp.is_symlink():
                 resolve.unlink(missing_ok=True)
                 shutil.move(tmp, resolve)
+
+
+class MkosiAsyncioThread(threading.Thread):
+    """
+    The default threading.Thread() is not interruptable, so we make our own version by using the concurrency
+    feature in python that is interruptable, namely asyncio.
+
+    Additionally, we store the result of the coroutine in the result variable so it can be accessed easily
+    after the thread finishes.
+    """
+
+    def __init__(self, target: Awaitable[Any], *args: Any, **kwargs: Any) -> None:
+        self.target = target
+        self.loop: queue.SimpleQueue[asyncio.AbstractEventLoop] = queue.SimpleQueue()
+        super().__init__(*args, **kwargs)
+
+    def run(self) -> None:
+        async def wrapper() -> None:
+            self.loop.put(asyncio.get_running_loop())
+            await self.target
+
+        asyncio.run(wrapper())
+
+    def cancel(self) -> None:
+        loop = self.loop.get()
+
+        for task in asyncio.tasks.all_tasks(loop):
+            loop.call_soon_threadsafe(task.cancel)
+
+    def __enter__(self) -> "MkosiAsyncioThread":
+        self.start()
+        return self
+
+    def __exit__(
+        self,
+        type: Optional[Type[BaseException]],
+        value: Optional[BaseException],
+        traceback: Optional[TracebackType],
+    ) -> None:
+        self.cancel()
+        self.join()