]> git.ipfire.org Git - thirdparty/mkosi.git/commitdiff
Make sandbox_cmd() return a context manager
authorDaan De Meyer <daan.j.demeyer@gmail.com>
Sat, 13 Apr 2024 08:50:58 +0000 (10:50 +0200)
committerDaan De Meyer <daan.j.demeyer@gmail.com>
Sat, 13 Apr 2024 17:22:00 +0000 (19:22 +0200)
This allows us to get rid of the shell hack to create and clean up
a subdirectory of /var/tmp. To avoid having to change every callsite
to use with(), we pass in a context manager directly into run() and
spawn().

Because we don't return a list anymore from sandbox_cmd(), we add an
extra "extra" argument to allow appending extra commands to the sandbox.

mkosi/__init__.py
mkosi/config.py
mkosi/context.py
mkosi/installer/apt.py
mkosi/installer/dnf.py
mkosi/installer/pacman.py
mkosi/installer/zypper.py
mkosi/run.py
mkosi/sandbox.py

index edae1bc5581a78d74e203ef4f54123d944870674..0c6cb4c3806c87f296b6d6952b076a165b79b0b9 100644 (file)
@@ -20,6 +20,7 @@ import tempfile
 import textwrap
 import uuid
 from collections.abc import Iterator, Mapping, Sequence
+from contextlib import AbstractContextManager
 from pathlib import Path
 from typing import Optional, Union, cast
 
@@ -408,7 +409,7 @@ def mkosi_as_caller() -> tuple[str, ...]:
 def finalize_host_scripts(
     context: Context,
     helpers: Mapping[str, Sequence[PathString]] = {},
-) -> contextlib.AbstractContextManager[Path]:
+) -> AbstractContextManager[Path]:
     scripts: dict[str, Sequence[PathString]] = {}
     for binary in ("useradd", "groupadd"):
         if find_binary(binary, root=context.config.tools()):
@@ -588,7 +589,8 @@ def run_prepare_scripts(context: Context, build: bool) -> None:
                         ],
                         options=["--dir", "/work/src", "--chdir", "/work/src"],
                         scripts=hd,
-                    ) + (chroot if script.suffix == ".chroot" else []),
+                        extra=chroot if script.suffix == ".chroot" else [],
+                    )
                 )
 
 
@@ -670,7 +672,8 @@ def run_build_scripts(context: Context) -> None:
                         ],
                         options=["--dir", "/work/src", "--chdir", "/work/src"],
                         scripts=hd,
-                    ) + (chroot if script.suffix == ".chroot" else []),
+                        extra=chroot if script.suffix == ".chroot" else [],
+                    ),
                 )
 
     if context.want_local_repo() and context.config.output_format != OutputFormat.none:
@@ -736,7 +739,8 @@ def run_postinst_scripts(context: Context) -> None:
                         ],
                         options=["--dir", "/work/src", "--chdir", "/work/src"],
                         scripts=hd,
-                    ) + (chroot if script.suffix == ".chroot" else []),
+                        extra=chroot if script.suffix == ".chroot" else [],
+                    ),
                 )
 
 
@@ -796,7 +800,8 @@ def run_finalize_scripts(context: Context) -> None:
                         ],
                         options=["--dir", "/work/src", "--chdir", "/work/src"],
                         scripts=hd,
-                    ) + (chroot if script.suffix == ".chroot" else []),
+                        extra=chroot if script.suffix == ".chroot" else [],
+                    ),
                 )
 
 
@@ -1458,7 +1463,8 @@ def grub_bios_setup(context: Context, partitions: Sequence[Partition]) -> None:
                     Mount(context.staging, context.staging),
                     Mount(mountinfo.name, mountinfo.name),
                 ],
-            ) + ["sh", "-c", f"mount --bind {mountinfo.name} /proc/$$/mountinfo && exec $0 \"$@\""],
+                extra=["sh", "-c", f"mount --bind {mountinfo.name} /proc/$$/mountinfo && exec $0 \"$@\""],
+            ),
         )
 
 
@@ -2494,7 +2500,8 @@ def calculate_signature(context: Context) -> None:
             sandbox=context.sandbox(
                 mounts=mounts,
                 options=options,
-            ) + ["setpriv", f"--reuid={INVOKING_USER.uid}", f"--regid={INVOKING_USER.gid}", "--clear-groups"]
+                extra=["setpriv", f"--reuid={INVOKING_USER.uid}", f"--regid={INVOKING_USER.gid}", "--clear-groups"],
+            )
         )
 
 
@@ -3523,7 +3530,7 @@ def copy_repository_metadata(context: Context) -> None:
                 with umask(~0o755):
                     dst.mkdir(parents=True, exist_ok=True)
 
-                def sandbox(*, mounts: Sequence[Mount] = ()) -> list[PathString]:
+                def sandbox(*, mounts: Sequence[Mount] = ()) -> AbstractContextManager[list[PathString]]:
                     return context.sandbox(mounts=[*mounts, *exclude])
 
                 copy_tree(
index 06009c41886f9f0a3564a9cfccccecca0fb48ade..e91bdd32d9cc8de94eb13d4170da04cc8819e88e 100644 (file)
@@ -27,6 +27,7 @@ import textwrap
 import typing
 import uuid
 from collections.abc import Collection, Iterable, Iterator, Sequence
+from contextlib import AbstractContextManager
 from pathlib import Path
 from typing import Any, Callable, Optional, TypeVar, Union, cast
 
@@ -1685,7 +1686,8 @@ class Config:
         scripts: Optional[Path] = None,
         mounts: Sequence[Mount] = (),
         options: Sequence[PathString] = (),
-    ) -> list[PathString]:
+        extra: Sequence[PathString] = (),
+    ) -> AbstractContextManager[list[PathString]]:
         mounts = [
             *[Mount(d, d, ro=True) for d in self.extra_search_paths if not relaxed and not self.tools_tree],
             *([Mount(p, "/proxy.cacert", ro=True)] if (p := self.proxy_peer_certificate) else []),
@@ -1702,6 +1704,7 @@ class Config:
             tools=tools or self.tools(),
             mounts=mounts,
             options=options,
+            extra=extra,
         )
 
 
index 81ce10dd592865f3ddab88a72671ee54140d8431..4c3a4fb8c2941a0a199006e4c9c8fd65c56eeecd 100644 (file)
@@ -1,6 +1,7 @@
 # SPDX-License-Identifier: LGPL-2.1+
 
 from collections.abc import Sequence
+from contextlib import AbstractContextManager
 from pathlib import Path
 from typing import Optional
 
@@ -75,7 +76,16 @@ class Context:
         scripts: Optional[Path] = None,
         mounts: Sequence[Mount] = (),
         options: Sequence[PathString] = (),
-    ) -> list[PathString]:
+        extra: Sequence[PathString] = (),
+    ) -> AbstractContextManager[list[PathString]]:
+        if (self.pkgmngr / "usr").exists():
+            extra = [
+                "sh",
+                "-c",
+                f"mount -t overlay -o lowerdir={self.pkgmngr / 'usr'}:/usr overlayfs /usr && exec $0 \"$@\"",
+                *extra,
+            ]
+
         return self.config.sandbox(
             network=network,
             devices=devices,
@@ -95,13 +105,7 @@ class Context:
                 "--cap-add", "ALL",
                 *options,
             ],
-        ) + (
-            [
-                "sh",
-                "-c",
-                f"mount -t overlay -o lowerdir={self.pkgmngr / 'usr'}:/usr overlayfs /usr && exec $0 \"$@\"",
-            ] if (self.pkgmngr / "usr").exists() else []
+            extra=extra,
         )
-
     def want_local_repo(self) -> bool:
         return any(self.packages.iterdir())
index 2b80173ec81821838f7e06ab0fafd2cd61b27725..bbea6351553509897cc3b95d4b911c89b96d42a4 100644 (file)
@@ -202,7 +202,8 @@ class Apt(PackageManager):
                         network=True,
                         mounts=[Mount(context.root, "/buildroot"), *cls.mounts(context), *sources, *mounts],
                         options=["--dir", "/work/src", "--chdir", "/work/src"],
-                    ) + (apivfs_cmd() if apivfs else [])
+                        extra=apivfs_cmd() if apivfs else []
+                    )
                 ),
                 env=context.config.environment | cls.finalize_environment(context),
                 stdout=stdout,
index b49c371e4087290e4021ebd8fc088ad627511a61..74a2123f9daad3882b3e1a45f1a47e5c2549c1be 100644 (file)
@@ -179,7 +179,8 @@ class Dnf(PackageManager):
                             network=True,
                             mounts=[Mount(context.root, "/buildroot"), *cls.mounts(context), *sources],
                             options=["--dir", "/work/src", "--chdir", "/work/src"],
-                        ) + (apivfs_cmd() if apivfs else [])
+                            extra=apivfs_cmd() if apivfs else [],
+                        )
                     ),
                     env=context.config.environment | cls.finalize_environment(context),
                     stdout=stdout,
index cbff8daabd8d0e1468a30ce8d9502ed5a436a826..9b68ce3851f499d053c2383de9de6f3bbeab7bb9 100644 (file)
@@ -162,7 +162,8 @@ class Pacman(PackageManager):
                         network=True,
                         mounts=[Mount(context.root, "/buildroot"), *cls.mounts(context), *sources],
                         options=["--dir", "/work/src", "--chdir", "/work/src"],
-                    ) + (apivfs_cmd() if apivfs else [])
+                        extra=apivfs_cmd() if apivfs else [],
+                    )
                 ),
                 env=context.config.environment | cls.finalize_environment(context),
                 stdout=stdout,
index a8fdf06d5bc6f1cdc6fb148655a7f4df46463594..f37c721ac14b9d5c093b0ee53d9e7da0f41e6863 100644 (file)
@@ -134,7 +134,8 @@ class Zypper(PackageManager):
                         network=True,
                         mounts=[Mount(context.root, "/buildroot"), *cls.mounts(context), *sources],
                         options=["--dir", "/work/src", "--chdir", "/work/src"],
-                    ) + (apivfs_cmd() if apivfs else [])
+                        extra=apivfs_cmd() if apivfs else [],
+                    )
                 ),
                 env=context.config.environment | cls.finalize_environment(context),
                 stdout=stdout,
index 8a7bd7e10626a0008927fb63e6680983bd0a41a8..af677791a42ffcd2608a388389bbfeafe2b57117 100644 (file)
@@ -15,6 +15,7 @@ import subprocess
 import sys
 import threading
 from collections.abc import Awaitable, Collection, Iterator, Mapping, Sequence
+from contextlib import AbstractContextManager
 from pathlib import Path
 from types import TracebackType
 from typing import Any, Callable, NoReturn, Optional
@@ -137,7 +138,7 @@ def run(
     foreground: bool = True,
     preexec_fn: Optional[Callable[[], None]] = None,
     success_exit_status: Sequence[int] = (0,),
-    sandbox: Sequence[PathString] = (),
+    sandbox: AbstractContextManager[Sequence[PathString]] = contextlib.nullcontext([]),
 ) -> CompletedProcess:
     if input is not None:
         assert stdin is None  # stdin and input cannot be specified together
@@ -180,11 +181,10 @@ def spawn(
     foreground: bool = False,
     preexec_fn: Optional[Callable[[], None]] = None,
     success_exit_status: Sequence[int] = (0,),
-    sandbox: Sequence[PathString] = (),
+    sandbox: AbstractContextManager[Sequence[PathString]] = contextlib.nullcontext([]),
 ) -> Iterator[Popen]:
     assert sorted(set(pass_fds)) == list(pass_fds)
 
-    sandbox = [os.fspath(x) for x in sandbox]
     cmdline = [os.fspath(x) for x in cmdline]
 
     if ARG_DEBUG.get():
@@ -247,72 +247,78 @@ def spawn(
             # expect it to pick.
             assert nfd == SD_LISTEN_FDS_START + i
 
-    # First, check if the sandbox works at all before executing the command.
-    if sandbox and (rc := subprocess.run(sandbox + ["true"]).returncode) != 0:
-        log_process_failure(sandbox, cmdline, rc)
-        raise subprocess.CalledProcessError(rc, sandbox + cmdline)
+    with sandbox as sbx:
+        prefix = [os.fspath(x) for x in sbx]
 
-    if subprocess.run(sandbox + ["sh", "-c", f"command -v {cmdline[0]}"], stdout=subprocess.DEVNULL).returncode != 0:
-        die(f"{cmdline[0]} not found.", hint=f"Is {cmdline[0]} installed on the host system?")
+        # First, check if the sandbox works at all before executing the command.
+        if prefix and (rc := subprocess.run(prefix + ["true"]).returncode) != 0:
+            log_process_failure(prefix, cmdline, rc)
+            raise subprocess.CalledProcessError(rc, prefix + cmdline)
 
-    if (
-        foreground and
-        sandbox and
-        subprocess.run(sandbox + ["sh", "-c", "command -v setpgid"], stdout=subprocess.DEVNULL).returncode == 0
-    ):
-        sandbox += ["setpgid", "--foreground", "--"]
+        if subprocess.run(
+            prefix + ["sh", "-c", f"command -v {cmdline[0]}"],
+            stdout=subprocess.DEVNULL,
+        ).returncode != 0:
+            die(f"{cmdline[0]} not found.", hint=f"Is {cmdline[0]} installed on the host system?")
 
-    if pass_fds:
-        # We don't know the PID before we start the process and we can't modify the environment in preexec_fn so we
-        # have to spawn a temporary shell to set the necessary environment variables before spawning the actual
-        # command.
-        sandbox += ["sh", "-c", f"LISTEN_FDS={len(pass_fds)} LISTEN_PID=$$ exec $0 \"$@\""]
+        if (
+            foreground and
+            prefix and
+            subprocess.run(prefix + ["sh", "-c", "command -v setpgid"], stdout=subprocess.DEVNULL).returncode == 0
+        ):
+            prefix += ["setpgid", "--foreground", "--"]
 
-    try:
-        with subprocess.Popen(
-            sandbox + cmdline,
-            stdin=stdin,
-            stdout=stdout,
-            stderr=stderr,
-            text=True,
-            user=user,
-            group=group,
-            # pass_fds only comes into effect after python has invoked the preexec function, so we make sure that
-            # pass_fds contains the file descriptors to keep open after we've done our transformation in preexec().
-            pass_fds=[SD_LISTEN_FDS_START + i for i in range(len(pass_fds))],
-            env=env,
-            cwd=cwd,
-            preexec_fn=preexec,
-        ) as proc:
-            try:
-                yield proc
-            except BaseException:
-                proc.terminate()
-                raise
-            finally:
-                returncode = proc.wait()
-
-            if check and returncode not in success_exit_status:
-                if log:
-                    log_process_failure(sandbox, cmdline, returncode)
-                if ARG_DEBUG_SHELL.get():
-                    subprocess.run(
-                        [*sandbox, "bash"],
-                        check=False,
-                        stdin=sys.stdin,
-                        text=True,
-                        user=user,
-                        group=group,
-                        env=env,
-                        cwd=cwd,
-                        preexec_fn=preexec,
-                    )
-                raise subprocess.CalledProcessError(returncode, cmdline)
-    except FileNotFoundError as e:
-        die(f"{e.filename} not found.")
-    finally:
-        if foreground:
-            make_foreground_process(new_process_group=False)
+        if pass_fds:
+            # We don't know the PID before we start the process and we can't modify the environment in preexec_fn so we
+            # have to spawn a temporary shell to set the necessary environment variables before spawning the actual
+            # command.
+            prefix += ["sh", "-c", f"LISTEN_FDS={len(pass_fds)} LISTEN_PID=$$ exec $0 \"$@\""]
+
+        try:
+            with subprocess.Popen(
+                prefix + cmdline,
+                stdin=stdin,
+                stdout=stdout,
+                stderr=stderr,
+                text=True,
+                user=user,
+                group=group,
+                # pass_fds only comes into effect after python has invoked the preexec function, so we make sure that
+                # pass_fds contains the file descriptors to keep open after we've done our transformation in preexec().
+                pass_fds=[SD_LISTEN_FDS_START + i for i in range(len(pass_fds))],
+                env=env,
+                cwd=cwd,
+                preexec_fn=preexec,
+            ) as proc:
+                try:
+                    yield proc
+                except BaseException:
+                    proc.terminate()
+                    raise
+                finally:
+                    returncode = proc.wait()
+
+                if check and returncode not in success_exit_status:
+                    if log:
+                        log_process_failure(prefix, cmdline, returncode)
+                    if ARG_DEBUG_SHELL.get():
+                        subprocess.run(
+                            [*prefix, "bash"],
+                            check=False,
+                            stdin=sys.stdin,
+                            text=True,
+                            user=user,
+                            group=group,
+                            env=env,
+                            cwd=cwd,
+                            preexec_fn=preexec,
+                        )
+                    raise subprocess.CalledProcessError(returncode, cmdline)
+        except FileNotFoundError as e:
+            die(f"{e.filename} not found.")
+        finally:
+            if foreground:
+                make_foreground_process(new_process_group=False)
 
 
 def find_binary(*names: PathString, root: Path = Path("/")) -> Optional[Path]:
index 35bd1f42ca2f3f0c0f5b3c24e7cefd82efafe7e9..68fdc8f0070abae17c20ea33e2b981454e435c38 100644 (file)
@@ -1,9 +1,12 @@
 # SPDX-License-Identifier: LGPL-2.1+
+import contextlib
 import enum
 import logging
 import os
+import shutil
 import uuid
-from collections.abc import Sequence
+from collections.abc import Iterator, Sequence
+from contextlib import AbstractContextManager
 from pathlib import Path
 from typing import NamedTuple, Optional, Protocol
 
@@ -40,11 +43,11 @@ class Mount(NamedTuple):
 
 
 class SandboxProtocol(Protocol):
-    def __call__(self, *, mounts: Sequence[Mount] = ()) -> list[PathString]: ...
+    def __call__(self, *, mounts: Sequence[Mount] = ()) -> AbstractContextManager[list[PathString]]: ...
 
 
-def nosandbox(*, mounts: Sequence[Mount] = ()) -> list[PathString]:
-    return []
+def nosandbox(*, mounts: Sequence[Mount] = ()) -> AbstractContextManager[list[PathString]]:
+    return contextlib.nullcontext([])
 
 
 # https://github.com/torvalds/linux/blob/master/include/uapi/linux/capability.h
@@ -118,6 +121,7 @@ def finalize_mounts(mounts: Sequence[Mount]) -> list[PathString]:
     return flatten(m.options() for m in mounts)
 
 
+@contextlib.contextmanager
 def sandbox_cmd(
     *,
     network: bool = False,
@@ -127,16 +131,14 @@ def sandbox_cmd(
     relaxed: bool = False,
     mounts: Sequence[Mount] = (),
     options: Sequence[PathString] = (),
-) -> list[PathString]:
+    extra: Sequence[PathString] = (),
+) -> Iterator[list[PathString]]:
     cmdline: list[PathString] = []
     mounts = list(mounts)
 
     if not relaxed:
-        # We want to use an empty subdirectory in the host's temporary directory as the sandbox's /var/tmp. To make
-        # sure it only gets created when we run the sandboxed command and cleaned up when the sandboxed command exits,
-        # we create it using shell.
+        # We want to use an empty subdirectory in the host's temporary directory as the sandbox's /var/tmp.
         vartmp = Path(os.getenv("TMPDIR", "/var/tmp")) / f"mkosi-var-tmp-{uuid.uuid4().hex[:16]}"
-        cmdline += ["sh", "-c", f"trap 'rm -rf {vartmp}' EXIT && mkdir --mode 1777 {vartmp} && $0 \"$@\""]
     else:
         vartmp = None
 
@@ -238,9 +240,16 @@ def sandbox_cmd(
         ops += ["chmod 755 /etc"]
     ops += ["exec $0 \"$@\""]
 
-    cmdline += ["sh", "-c", " && ".join(ops)]
+    cmdline += ["sh", "-c", " && ".join(ops), *extra]
 
-    return cmdline
+    if vartmp:
+        vartmp.mkdir(mode=0o1777)
+
+    try:
+        yield cmdline
+    finally:
+        if vartmp:
+            shutil.rmtree(vartmp)
 
 
 def apivfs_cmd() -> list[PathString]: