import contextlib
import errno
import functools
+import inspect
+import linecache
import logging
+import multiprocessing
import os
import queue
import shlex
from collections.abc import Awaitable, Collection, Iterator, Mapping, Sequence
from contextlib import AbstractContextManager
from pathlib import Path
-from types import TracebackType
-from typing import TYPE_CHECKING, Any, Callable, Generic, NoReturn, Protocol, TypeVar
+from types import FrameType, TracebackType
+from typing import TYPE_CHECKING, Any, Callable, Generic, NoReturn, Protocol, TypeVar, cast
+import mkosi
import mkosi.sandbox
from mkosi.log import ARG_DEBUG, ARG_DEBUG_SANDBOX, ARG_DEBUG_SHELL, die
from mkosi.sandbox import acquire_privileges, joinpath, umask
return (exctype, exc, tb)
+# Make sure the line cache is primed as the mkosi source files might become inaccessible if we sandbox
+# ourselves.
+for path in Path(mkosi.__file__).parent.rglob("*.py"):
+ linecache.getlines(os.fspath(path))
+
+
+def excepthook(frames: Sequence[FrameType]) -> None:
+ exctype, exc, tb = ensure_exc_info()
+
+ for frame in frames:
+ tb = TracebackType(tb, frame, frame.f_lasti, frame.f_lineno)
+
+ exc.__traceback__ = tb
+
+ # sys.excepthook() triggers linecache.checkcache() which evicts cache entries for source files that are
+ # inaccessible. Temporarily disable it to preserve our primed linecache entries for which we might not
+ # be able to access the files anymore (because we're sandboxed).
+ checkcache = linecache.checkcache
+ linecache.checkcache = lambda filename=None: None # type: ignore[assignment]
+ try:
+ sys.excepthook(exctype, exc, tb)
+ finally:
+ linecache.checkcache = checkcache
+
+
@contextlib.contextmanager
-def uncaught_exception_handler(exit: Callable[[int], NoReturn] = sys.exit) -> Iterator[None]:
+def uncaught_exception_handler(
+ *,
+ exit: Callable[[int], NoReturn] = sys.exit,
+ fork: bool = False,
+ proceed: bool = False,
+) -> Iterator[None]:
+
+ frames = []
+ if fork:
+ # If we're using uncaught_exception_handler() in a forked process, we want to show the full
+ # stacktrace including the parent's frames as we can access the parent's frames in the fork but we
+ # can't pickle the fork's frames and send them to the parent. To do that we gather the frames here
+ # and then stitch them into the exception traceback in excepthook().
+
+ # Start from the caller's frame.
+ frame = inspect.currentframe()
+ while frame is not None:
+ frames.append(frame)
+ frame = frame.f_back
+
+ # Skip uncaught_exception_handler() itself and the wrapping context manager.
+ frames = frames[2:]
+
rc = 0
try:
yield
rc = e.code if isinstance(e.code, int) else 1
if ARG_DEBUG.get():
- sys.excepthook(*ensure_exc_info())
+ excepthook(frames)
except KeyboardInterrupt:
rc = 1
if ARG_DEBUG.get():
- sys.excepthook(*ensure_exc_info())
+ excepthook(frames)
else:
logging.error("Interrupted")
except subprocess.CalledProcessError as e:
and str(e.cmd[0]) not in ("self", "ssh", "systemd-nspawn")
and "qemu-system" not in str(e.cmd[0])
):
- sys.excepthook(*ensure_exc_info())
+ excepthook(frames)
except BaseException:
- sys.excepthook(*ensure_exc_info())
+ excepthook(frames)
rc = 1
finally:
- sys.stdout.flush()
- sys.stderr.flush()
- exit(rc)
-
-
-def fork_and_wait(target: Callable[..., None], *args: Any, **kwargs: Any) -> None:
- pid = os.fork()
- if pid == 0:
- with uncaught_exception_handler(exit=os._exit):
- target(*args, **kwargs)
-
- try:
- _, status = os.waitpid(pid, 0)
- except KeyboardInterrupt:
- os.kill(pid, signal.SIGINT)
- _, status = os.waitpid(pid, 0)
- except BaseException:
- os.kill(pid, signal.SIGTERM)
- _, status = os.waitpid(pid, 0)
-
- rc = os.waitstatus_to_exitcode(status)
-
- if rc != 0:
- raise subprocess.CalledProcessError(rc, ["self"])
+ if not proceed or rc != 0:
+ sys.stdout.flush()
+ sys.stderr.flush()
+ exit(rc)
def log_process_failure(sandbox: Sequence[str], cmdline: Sequence[str], returncode: int) -> None:
return contextlib.nullcontext([])
+def fork_and_wait(
+ target: Callable[..., T],
+ *args: Any,
+ sandbox: AbstractContextManager[Sequence[PathString]] = nosandbox(),
+ **kwargs: Any,
+) -> T:
+ parent, child = multiprocessing.Pipe(duplex=False)
+
+ with sandbox as sbx:
+ pid = os.fork()
+ if pid == 0:
+ with uncaught_exception_handler(exit=os._exit, fork=True):
+ parent.close()
+
+ if sbx:
+ mkosi.sandbox.main([os.fspath(s) for s in sbx])
+
+ child.send(target(*args, **kwargs))
+
+ child.close()
+
+ try:
+ _, status = os.waitpid(pid, 0)
+ except KeyboardInterrupt:
+ os.kill(pid, signal.SIGINT)
+ _, status = os.waitpid(pid, 0)
+ except BaseException:
+ os.kill(pid, signal.SIGTERM)
+ _, status = os.waitpid(pid, 0)
+
+ rc = os.waitstatus_to_exitcode(status)
+
+ if rc != 0:
+ parent.close()
+ raise subprocess.CalledProcessError(rc, ["self"])
+
+ result = parent.recv()
+ parent.close()
+
+ return cast(T, result)
+
+
def run(
cmdline: Sequence[PathString],
check: bool = True,
sandbox: list[str],
preexec: Callable[[], None] | None,
) -> None:
- if preexec:
- preexec()
+ with uncaught_exception_handler(exit=os._exit, fork=True, proceed=True):
+ if preexec:
+ preexec()
- if not sandbox:
- return
+ if not sandbox:
+ return
- # if we get here we should have neither a prefix nor a setup command to execute and so we can execute the
- # command directly.
+ # if we get here we should have neither a prefix nor a setup command to execute and so we can
+ # execute the command directly.
- # mkosi.sandbox.main() updates os.environ but the environment passed to Popen() is not yet in
- # effect by the time the preexec function is called. To get around that, we update the
- # environment ourselves here.
- os.environ.clear()
- os.environ.update(env)
- try:
+ # mkosi.sandbox.main() updates os.environ but the environment passed to Popen() is not yet in
+ # effect by the time the preexec function is called. To get around that, we update the
+ # environment ourselves here.
+ os.environ.clear()
+ os.environ.update(env)
mkosi.sandbox.main(sandbox)
- except Exception:
- sys.excepthook(*ensure_exc_info())
- os._exit(1)
-
- # Python does its own executable lookup in $PATH before executing the preexec function, and
- # hence before we have set up the sandbox which influences the lookup results. To get around
- # that, let's call execvp() ourselves inside the preexec() function, and not give Python the
- # chance to do it itself. This ensures we can do the proper executable lookup after setting
- # up the sandbox. If we can't find the executable, do nothing, and let Python do its own
- # search logic so it can return a proper error, which we cannot do from the preexec function.
- # Note that by doing this we also skip Python closing all open file descriptors except the
- # ones specified by the user in pass_fds, but since Python opens all file descriptors with
- # O_CLOEXEC anyway, we'll assume we're good and don't need to close open file descriptors
- # explicitly.
- if s := shutil.which(cmd[0]):
- os.execvp(s, cmd)
+
+ # Python does its own executable lookup in $PATH before executing the preexec function, and
+ # hence before we have set up the sandbox which influences the lookup results. To get around
+ # that, let's call execvp() ourselves inside the preexec() function, and not give Python the
+ # chance to do it itself. This ensures we can do the proper executable lookup after setting
+ # up the sandbox. If we can't find the executable, do nothing, and let Python do its own
+ # search logic so it can return a proper error, which we cannot do from the preexec function.
+ # Note that by doing this we also skip Python closing all open file descriptors except the
+ # ones specified by the user in pass_fds, but since Python opens all file descriptors with
+ # O_CLOEXEC anyway, we'll assume we're good and don't need to close open file descriptors
+ # explicitly.
+ if s := shutil.which(cmd[0]):
+ os.execvp(s, cmd)
@contextlib.contextmanager
*globs: str,
sandbox: AbstractContextManager[Sequence[PathString]] = nosandbox(),
) -> list[Path]:
- return [
- Path(s)
- for s in run(
- [
- "bash",
- "-c",
- rf"shopt -s nullglob && printf '%s\n' {' '.join(globs)} | xargs -r readlink -f",
- ],
- sandbox=sandbox,
- stdout=subprocess.PIPE,
- )
- .stdout.strip()
- .splitlines()
- ]
+
+ def child() -> list[Path]:
+ paths = flatten(Path("/").glob(glob) for glob in globs)
+ return [p.resolve() for p in paths]
+
+ return fork_and_wait(child, sandbox=sandbox)
def exists_in_sandbox(
path: PathString,
sandbox: AbstractContextManager[Sequence[PathString]] = nosandbox(),
) -> bool:
- return (
- run(
- ["bash", "-c", rf"test -e {path}"],
- sandbox=sandbox,
- check=False,
- ).returncode
- == 0
- )
+ return fork_and_wait(lambda: Path(path).exists(), sandbox=sandbox)
--- /dev/null
+# SPDX-License-Identifier: LGPL-2.1-or-later
+
+import contextlib
+import os
+import subprocess
+from pathlib import Path
+
+import pytest
+
+from mkosi.run import fork_and_wait
+
+
+def test_fork_and_wait_returns_value() -> None:
+ result = fork_and_wait(lambda: 42)
+ assert result == 42
+
+
+def test_fork_and_wait_returns_none() -> None:
+ result = fork_and_wait(lambda: None)
+ assert result is None
+
+
+def test_fork_and_wait_returns_string() -> None:
+ result = fork_and_wait(lambda: "hello world")
+ assert result == "hello world"
+
+
+def test_fork_and_wait_returns_complex_type() -> None:
+ result = fork_and_wait(lambda: {"key": [1, 2, 3], "nested": {"a": True}})
+ assert result == {"key": [1, 2, 3], "nested": {"a": True}}
+
+
+def test_fork_and_wait_passes_args() -> None:
+ def add(a: int, b: int) -> int:
+ return a + b
+
+ result = fork_and_wait(add, 3, 4)
+ assert result == 7
+
+
+def test_fork_and_wait_passes_kwargs() -> None:
+ def greet(name: str, greeting: str = "Hello") -> str:
+ return f"{greeting}, {name}!"
+
+ result = fork_and_wait(greet, "world", greeting="Hi")
+ assert result == "Hi, world!"
+
+
+def test_fork_and_wait_child_failure() -> None:
+ def fail() -> None:
+ raise RuntimeError("boom")
+
+ with pytest.raises(subprocess.CalledProcessError):
+ fork_and_wait(fail)
+
+
+def test_fork_and_wait_sandbox(tmp_path: Path) -> None:
+ (tmp_path / "abc").mkdir()
+
+ def exists() -> bool:
+ return Path("/abc").exists()
+
+ result = fork_and_wait(exists, sandbox=contextlib.nullcontext(["--bind", os.fspath(tmp_path), "/"]))
+ assert result