]> git.ipfire.org Git - thirdparty/mkosi.git/commitdiff
Optionally return the inner pid from spawn()
authorDaan De Meyer <daan.j.demeyer@gmail.com>
Fri, 12 Apr 2024 14:21:27 +0000 (16:21 +0200)
committerDaan De Meyer <daan.j.demeyer@gmail.com>
Sat, 13 Apr 2024 17:22:02 +0000 (19:22 +0200)
bubblewrap does not support forwarding signals yet,
see https://github.com/containers/bubblewrap/pull/586. As a workaround,
we need to make sure we send our signals to the inner process. To
make this work, we create a pipe, pass it through to the subprocess,
and prefix with a bash command that writes its pid to the pipe before
exec-ing the actual command.

The other thing we get from this is that we can register the inner pid
as a scope which makes the systemctl status output for the scopes we
create a lot more useful.

mkosi/qemu.py
mkosi/run.py
mkosi/user.py

index 61326052468d4cbc229c3e7448d550bc920f406a..00ea410ac97c9326f8d8f1ae18582cc425faaa6c 100644 (file)
@@ -12,6 +12,7 @@ import logging
 import os
 import random
 import shutil
+import signal
 import socket
 import struct
 import subprocess
@@ -38,7 +39,7 @@ from mkosi.config import (
 from mkosi.log import ARG_DEBUG, die
 from mkosi.mounts import finalize_source_mounts
 from mkosi.partition import finalize_root, find_partitions
-from mkosi.run import SD_LISTEN_FDS_START, AsyncioThread, find_binary, fork_and_wait, run, spawn
+from mkosi.run import SD_LISTEN_FDS_START, AsyncioThread, find_binary, fork_and_wait, kill, run, spawn
 from mkosi.sandbox import Mount
 from mkosi.tree import copy_tree, rmtree
 from mkosi.types import PathString
@@ -274,15 +275,15 @@ def start_swtpm(config: Config) -> Iterator[Path]:
                 cmdline,
                 pass_fds=(sock.fileno(),),
                 sandbox=config.sandbox(mounts=[Mount(state, state)]),
-            ) as proc:
+            ) as (proc, innerpid):
                 allocate_scope(
                     config,
                     name=f"mkosi-swtpm-{config.machine_or_name()}",
-                    pid=proc.pid,
+                    pid=innerpid,
                     description=f"swtpm for {config.machine_or_name()}",
                 )
                 yield path
-                proc.terminate()
+                kill(proc, innerpid, signal.SIGTERM)
 
 
 def find_virtiofsd(*, tools: Path = Path("/")) -> Optional[Path]:
@@ -354,15 +355,15 @@ def start_virtiofsd(config: Config, directory: PathString, *, name: str, selinux
                 mounts=[Mount(directory, directory)],
                 options=["--uid", "0", "--gid", "0", "--cap-add", "all"],
             ),
-        ) as proc:
+        ) as (proc, innerpid):
             allocate_scope(
                 config,
                 name=f"mkosi-virtiofsd-{name}",
-                pid=proc.pid,
+                pid=innerpid,
                 description=f"virtiofsd for {directory}",
             )
             yield path
-            proc.terminate()
+            kill(proc, innerpid, signal.SIGTERM)
 
 
 @contextlib.contextmanager
@@ -442,15 +443,16 @@ def start_journal_remote(config: Config, sockfd: int) -> Iterator[None]:
         # If all logs go into a single file, disable compact mode to allow for journal files exceeding 4G.
         env={"SYSTEMD_JOURNAL_COMPACT": "0" if config.forward_journal.suffix == ".journal" else "1"},
         foreground=False,
-    ) as proc:
+    ) as (proc, innerpid):
         allocate_scope(
             config,
             name=f"mkosi-journal-remote-{config.machine_or_name()}",
-            pid=proc.pid,
+            pid=innerpid,
             description=f"mkosi systemd-journal-remote for {config.machine_or_name()}",
         )
         yield
-        proc.terminate()
+        kill(proc, innerpid, signal.SIGTERM)
+
 
 
 @contextlib.contextmanager
@@ -1097,7 +1099,7 @@ def run_qemu(args: Args, config: Config) -> None:
             log=False,
             foreground=True,
             sandbox=config.sandbox(network=True, devices=True, relaxed=True),
-        ) as qemu:
+        ) as (proc, innerpid):
             # We have to close these before we wait for qemu otherwise we'll deadlock as qemu will never exit.
             for fd in qemu_device_fds.values():
                 os.close(fd)
@@ -1106,12 +1108,12 @@ def run_qemu(args: Args, config: Config) -> None:
             allocate_scope(
                 config,
                 name=name,
-                pid=qemu.pid,
+                pid=innerpid,
                 description=f"mkosi Virtual Machine {name}",
             )
-            register_machine(config, qemu.pid, fname)
+            register_machine(config, innerpid, fname)
 
-            if qemu.wait() == 0 and (status := int(notifications.get("EXIT_STATUS", 0))):
+            if proc.wait() == 0 and (status := int(notifications.get("EXIT_STATUS", 0))):
                 raise subprocess.CalledProcessError(status, cmdline)
 
 
index af677791a42ffcd2608a388389bbfeafe2b57117..c9370595a8718cd8f971f03e31d31ac2f52ab080 100644 (file)
@@ -159,7 +159,8 @@ def run(
         preexec_fn=preexec_fn,
         success_exit_status=success_exit_status,
         sandbox=sandbox,
-    ) as process:
+        innerpid=False,
+    ) as (process, _):
         out, err = process.communicate(input)
 
     return CompletedProcess(cmdline, process.returncode, out, err)
@@ -182,7 +183,8 @@ def spawn(
     preexec_fn: Optional[Callable[[], None]] = None,
     success_exit_status: Sequence[int] = (0,),
     sandbox: AbstractContextManager[Sequence[PathString]] = contextlib.nullcontext([]),
-) -> Iterator[Popen]:
+    innerpid: bool = True,
+) -> Iterator[tuple[Popen, int]]:
     assert sorted(set(pass_fds)) == list(pass_fds)
 
     cmdline = [os.fspath(x) for x in cmdline]
@@ -274,6 +276,17 @@ def spawn(
             # command.
             prefix += ["sh", "-c", f"LISTEN_FDS={len(pass_fds)} LISTEN_PID=$$ exec $0 \"$@\""]
 
+        if prefix and innerpid:
+            r, w = os.pipe2(os.O_CLOEXEC)
+            # Make sure that the write end won't be overridden in preexec() when we're moving fds forward.
+            q = fcntl.fcntl(w, fcntl.F_DUPFD_CLOEXEC, SD_LISTEN_FDS_START + len(pass_fds) + 1)
+            os.close(w)
+            w = q
+            # dash doesn't support working with file descriptors higher than 9 so make sure we use bash.
+            prefix += ["bash", "-c", f"echo $$ >&{w} && exec {w}>&- && exec $0 \"$@\""]
+        else:
+            r, w = (None, None)
+
         try:
             with subprocess.Popen(
                 prefix + cmdline,
@@ -285,15 +298,24 @@ def spawn(
                 group=group,
                 # pass_fds only comes into effect after python has invoked the preexec function, so we make sure that
                 # pass_fds contains the file descriptors to keep open after we've done our transformation in preexec().
-                pass_fds=[SD_LISTEN_FDS_START + i for i in range(len(pass_fds))],
+                pass_fds=[SD_LISTEN_FDS_START + i for i in range(len(pass_fds))] + ([w] if w else []),
                 env=env,
                 cwd=cwd,
                 preexec_fn=preexec,
             ) as proc:
+                if w:
+                    os.close(w)
+                pid = proc.pid
                 try:
-                    yield proc
+                    if r:
+                        with open(r) as f:
+                            s = f.read()
+                            if s:
+                                pid = int(s)
+
+                    yield proc, pid
                 except BaseException:
-                    proc.terminate()
+                    kill(proc, pid, signal.SIGTERM)
                     raise
                 finally:
                     returncode = proc.wait()
@@ -342,6 +364,18 @@ def find_binary(*names: PathString, root: Path = Path("/")) -> Optional[Path]:
     return None
 
 
+def kill(process: Popen, innerpid: int, signal: int) -> None:
+    process.poll()
+    if process.returncode is not None:
+        return
+
+    try:
+        os.kill(innerpid, signal)
+    # Handle the race condition where the process might exit between us calling poll() and us calling os.kill().
+    except ProcessLookupError:
+        pass
+
+
 class AsyncioThread(threading.Thread):
     """
     The default threading.Thread() is not interruptable, so we make our own version by using the concurrency
index 6afbb13ad73f1c0a2571b391712d03a7b2784f17..4489c216671b877769e3e228d5d7737544f5db2c 100644 (file)
@@ -187,7 +187,11 @@ def become_root() -> None:
         # execute using flock so they don't execute before they can get a lock on the same temporary file, then we
         # unshare the user namespace and finally we unlock the temporary file, which allows the newuidmap and newgidmap
         # processes to execute. we then wait for the processes to finish before continuing.
-        with flock(lock) as fd, spawn(newuidmap) as uidmap, spawn(newgidmap) as gidmap:
+        with (
+            flock(lock) as fd,
+            spawn(newuidmap, innerpid=False) as (uidmap, _),
+            spawn(newgidmap, innerpid=False) as (gidmap, _)
+        ):
             unshare(CLONE_NEWUSER)
             fcntl.flock(fd, fcntl.LOCK_UN)
             uidmap.wait()