]> git.ipfire.org Git - thirdparty/mkosi.git/commitdiff
Unify loading the current user info
authorGeorges Discry <georges@discry.be>
Thu, 20 Apr 2023 20:09:54 +0000 (22:09 +0200)
committerGeorges Discry <georges@discry.be>
Fri, 21 Apr 2023 12:31:36 +0000 (14:31 +0200)
There are a few places where mkosi wants to know who is the current user
invoking the script with sudo or pkexec, instead of root. Some
combinations of `$SUDO_UID`, `$PKEXEC_UID` and `$SUDO_USER` are used,
but without consistency. Furthermore, when pkexec is used, several
places depend on `bin/mkosi` setting the correct `$SUDO_`* environment
variables.

Those places now use the `mkosi.backend.current_user()` function.
Furthermore, that function returns an `InvokingUser` wrapping the
complete passwd entry from the `pwd` module instead of just the uid/gid.

mkosi/__init__.py
mkosi/backend.py
mkosi/run.py

index d3d066638d503685f16f1d6394e62ee6139453ed..36f3635bb3f6179a93dd378b96bc6c2bc80bf168 100644 (file)
@@ -33,7 +33,7 @@ from mkosi.backend import (
     MkosiState,
     OutputFormat,
     Verb,
-    current_user_uid_gid,
+    current_user,
     flatten,
     format_rlimit,
     is_dnf_distribution,
@@ -1981,7 +1981,7 @@ def run_shell(config: MkosiConfig) -> None:
         cmdline += ["--"]
         cmdline += config.cmdline
 
-    uid, _ = current_user_uid_gid()
+    uid = current_user().uid
 
     if config.output_format == OutputFormat.directory:
         acl_toggle_remove(config, config.output, uid, allow=False)
@@ -2381,9 +2381,7 @@ def prepend_to_environ_path(paths: Sequence[Path]) -> Iterator[None]:
 
 
 def expand_specifier(s: str) -> str:
-    user = os.getenv("SUDO_USER") or os.getenv("USER")
-    assert user is not None
-    return s.replace("%u", user)
+    return s.replace("%u", current_user().name)
 
 
 def needs_build(config: Union[argparse.Namespace, MkosiConfig]) -> bool:
index 5a70728e4a02e2cb598165af96048316b8115020..fef57f0c3698767a7969bcd7182bc0442f7f5b03 100644 (file)
@@ -471,10 +471,49 @@ def flatten(lists: Iterable[Iterable[T]]) -> list[T]:
     return list(itertools.chain.from_iterable(lists))
 
 
-def current_user_uid_gid() -> tuple[int, int]:
-    uid = int(os.getenv("SUDO_UID") or os.getenv("PKEXEC_UID") or os.getuid())
-    gid = pwd.getpwuid(uid).pw_gid
-    return uid, gid
+@dataclasses.dataclass
+class InvokingUser:
+    _pw: Optional[pwd.struct_passwd] = None
+
+    @classmethod
+    def for_uid(cls, uid: int) -> "InvokingUser":
+        return cls(pwd.getpwuid(uid))
+
+    @property
+    def uid(self) -> int:
+        if self._pw is not None:
+            return self._pw.pw_uid
+        return os.getuid()
+
+    @property
+    def gid(self) -> int:
+        if self._pw is not None:
+            return self._pw.pw_gid
+        return os.getgid()
+
+    @property
+    def name(self) -> str:
+        if self._pw is not None:
+            return self._pw.pw_name
+        return os.getlogin()
+
+    @property
+    def home(self) -> Path:
+        if self._pw is not None:
+            return Path(self._pw.pw_dir)
+        return Path.home()
+
+    def is_running_user(self) -> bool:
+        if self._pw is not None:
+            return self._pw.pw_uid == os.getuid()
+        return True
+
+
+def current_user() -> InvokingUser:
+    uid = os.getenv("SUDO_UID") or os.getenv("PKEXEC_UID")
+    if uid:
+        return InvokingUser.for_uid(int(uid))
+    return InvokingUser()
 
 
 @contextlib.contextmanager
index 90e935d54c5e0a0a54c119bc4e24a4c10173f57c..50871ffd045f19fc54e959ea0f7a7be7e7017196 100644 (file)
@@ -13,7 +13,7 @@ from pathlib import Path
 from types import TracebackType
 from typing import Any, Callable, Mapping, Optional, Sequence, Type, TypeVar
 
-from mkosi.backend import MkosiState, current_user_uid_gid
+from mkosi.backend import MkosiState, current_user
 from mkosi.log import ARG_DEBUG, MkosiPrinter, die
 from mkosi.types import _FILE, CompletedProcess, PathString, Popen
 
@@ -67,7 +67,9 @@ def become_root() -> tuple[int, int]:
     The function returns the UID-GID pair of the invoking user in the namespace (65436, 65436).
     """
     if os.getuid() == 0:
-        return current_user_uid_gid()
+        user = current_user()
+        return user.uid, user.gid
+
     subuid = read_subrange(Path("/etc/subuid"))
     subgid = read_subrange(Path("/etc/subgid"))