]> git.ipfire.org Git - thirdparty/mkosi.git/commitdiff
Use become_root_cmd() in copy_ephemeral()
authorDaan De Meyer <daan.j.demeyer@gmail.com>
Fri, 3 May 2024 22:44:57 +0000 (00:44 +0200)
committerDaan De Meyer <daan.j.demeyer@gmail.com>
Sat, 4 May 2024 10:47:57 +0000 (12:47 +0200)
mkosi/qemu.py

index d89dd3e0e2df3c72255437f73157572292292adb..7d4ae560c33f0c3d63b3d4505a8dd3ab3b57460f 100644 (file)
@@ -6,6 +6,7 @@ import contextlib
 import enum
 import errno
 import fcntl
+import functools
 import hashlib
 import json
 import logging
@@ -41,11 +42,11 @@ from mkosi.config import (
 from mkosi.log import ARG_DEBUG, die
 from mkosi.mounts import finalize_source_mounts
 from mkosi.partition import finalize_root, find_partitions
-from mkosi.run import SD_LISTEN_FDS_START, AsyncioThread, find_binary, fork_and_wait, kill, run, spawn
+from mkosi.run import SD_LISTEN_FDS_START, AsyncioThread, find_binary, kill, run, spawn
 from mkosi.sandbox import Mount
 from mkosi.tree import copy_tree, rmtree
 from mkosi.types import PathString
-from mkosi.user import INVOKING_USER, become_root, become_root_cmd
+from mkosi.user import INVOKING_USER, become_root_cmd
 from mkosi.util import StrEnum, flock, flock_or_die, try_or
 from mkosi.versioncomp import GenericVersion
 
@@ -511,28 +512,25 @@ def copy_ephemeral(config: Config, src: Path) -> Iterator[Path]:
     tmp = src.parent / f"{src.name}-{uuid.uuid4().hex[:16]}"
 
     try:
-        def copy() -> None:
-            if config.output_format == OutputFormat.directory:
-                become_root()
-
+        with flock(src):
             copy_tree(
                 src, tmp,
                 preserve=config.output_format == OutputFormat.directory,
                 use_subvolumes=config.use_subvolumes,
-                sandbox=config.sandbox,
+                sandbox=functools.partial(
+                    config.sandbox,
+                    setup=become_root_cmd() if config.output_format == OutputFormat.directory else [],
+                ),
             )
-
-        with flock(src):
-            fork_and_wait(copy)
         yield tmp
     finally:
-        def rm() -> None:
-            if config.output_format == OutputFormat.directory:
-                become_root()
-
-            rmtree(tmp, sandbox=config.sandbox)
-
-        fork_and_wait(rm)
+        rmtree(
+            tmp,
+            sandbox=functools.partial(
+                config.sandbox,
+                setup=become_root_cmd() if config.output_format == OutputFormat.directory else [],
+            ),
+        )
 
 
 def qemu_version(config: Config) -> GenericVersion: