]> git.ipfire.org Git - thirdparty/systemd.git/commitdiff
ukify: Add UkifyConfig
authorJörg Behrmann <behrmann@physik.fu-berlin.de>
Mon, 14 Oct 2024 07:58:05 +0000 (09:58 +0200)
committerJörg Behrmann <behrmann@physik.fu-berlin.de>
Mon, 14 Oct 2024 07:59:25 +0000 (09:59 +0200)
Using a dataclass instead of an argparse namespace to pass around the parsed
options allows to track the types properly.

src/ukify/ukify.py

index accc977ac239ba2bbd172b5f3d941b6b986f3358..ebb4985cfcb4408eba55053965646b92e99db852 100755 (executable)
@@ -26,6 +26,7 @@ import contextlib
 import dataclasses
 import datetime
 import fnmatch
+import inspect
 import itertools
 import json
 import os
@@ -48,6 +49,7 @@ from typing import (
     IO,
     Any,
     Callable,
+    Literal,
     Optional,
     TypeVar,
     Union,
@@ -230,6 +232,50 @@ def maybe_decompress(filename: Union[str, Path]) -> bytes:
     raise NotImplementedError(f'unknown file format (starts with {start!r})')
 
 
+@dataclasses.dataclass
+class UkifyConfig:
+    all: bool
+    cmdline: Union[str, Path, None]
+    devicetree: Path
+    efi_arch: str
+    initrd: list[Path]
+    join_profiles: list[Path]
+    json: Union[Literal['pretty'], Literal['short'], Literal['off']]
+    linux: Optional[Path]
+    measure: bool
+    microcode: Path
+    os_release: Union[str, Path, None]
+    output: Optional[str]
+    pcr_banks: list[str]
+    pcr_private_keys: list[str]
+    pcr_public_keys: list[Path]
+    pcrpkey: Optional[Path]
+    phase_path_groups: Optional[list[str]]
+    profile: Union[str, Path, None]
+    sb_cert: Path
+    sb_cert_name: Optional[str]
+    sb_cert_validity: int
+    sb_certdir: Path
+    sb_key: Optional[Path]
+    sbat: Optional[list[str]]
+    sections: list['Section']
+    sections_by_name: dict[str, 'Section']
+    sign_kernel: bool
+    signing_engine: Optional[str]
+    signtool: Optional[type['SignTool']]
+    splash: Optional[Path]
+    stub: Path
+    summary: bool
+    tools: list[Path]
+    uname: Optional[str]
+    verb: str
+    files: list[str] = dataclasses.field(default_factory=list)
+
+    @classmethod
+    def from_namespace(cls, ns: argparse.Namespace) -> 'UkifyConfig':
+        return cls(**{k: v for k, v in vars(ns).items() if k in inspect.signature(cls).parameters})
+
+
 class Uname:
     # This class is here purely as a namespace for the functions
 
@@ -243,7 +289,7 @@ class Uname:
     TEXT_PATTERN = rb'Linux version (?P<version>\d\.\S+) \('
 
     @classmethod
-    def scrape_x86(cls, filename: Path, opts: Optional[argparse.Namespace] = None) -> str:
+    def scrape_x86(cls, filename: Path, opts: Optional[UkifyConfig] = None) -> str:
         # Based on https://gitlab.archlinux.org/archlinux/mkinitcpio/mkinitcpio/-/blob/master/functions#L136
         # and https://docs.kernel.org/arch/x86/boot.html#the-real-mode-kernel-header
         with open(filename, 'rb') as f:
@@ -263,7 +309,7 @@ class Uname:
         return m.group('version')
 
     @classmethod
-    def scrape_elf(cls, filename: Path, opts: Optional[argparse.Namespace] = None) -> str:
+    def scrape_elf(cls, filename: Path, opts: Optional[UkifyConfig] = None) -> str:
         readelf = find_tool('readelf', opts=opts)
 
         cmd = [
@@ -285,7 +331,7 @@ class Uname:
         return text.rstrip('\0')
 
     @classmethod
-    def scrape_generic(cls, filename: Path, opts: Optional[argparse.Namespace] = None) -> str:
+    def scrape_generic(cls, filename: Path, opts: Optional[UkifyConfig] = None) -> str:
         # import libarchive
         # libarchive-c fails with
         # ArchiveError: Unrecognized archive format (errno=84, retcode=-30, archive_p=94705420454656)
@@ -299,7 +345,7 @@ class Uname:
         return m.group('version').decode()
 
     @classmethod
-    def scrape(cls, filename: Path, opts: Optional[argparse.Namespace] = None) -> Optional[str]:
+    def scrape(cls, filename: Path, opts: Optional[UkifyConfig] = None) -> Optional[str]:
         for func in (cls.scrape_x86, cls.scrape_elf, cls.scrape_generic):
             try:
                 version = func(filename, opts=opts)
@@ -406,17 +452,17 @@ class UKI:
 
 class SignTool:
     @staticmethod
-    def sign(input_f: str, output_f: str, opts: argparse.Namespace) -> None:
+    def sign(input_f: str, output_f: str, opts: UkifyConfig) -> None:
         raise NotImplementedError()
 
     @staticmethod
-    def verify(opts: argparse.Namespace) -> bool:
+    def verify(opts: UkifyConfig) -> bool:
         raise NotImplementedError()
 
 
 class PeSign(SignTool):
     @staticmethod
-    def sign(input_f: str, output_f: str, opts: argparse.Namespace) -> None:
+    def sign(input_f: str, output_f: str, opts: UkifyConfig) -> None:
         assert opts.sb_certdir is not None
         assert opts.sb_cert_name is not None
 
@@ -435,7 +481,7 @@ class PeSign(SignTool):
         subprocess.check_call(cmd)
 
     @staticmethod
-    def verify(opts: argparse.Namespace) -> bool:
+    def verify(opts: UkifyConfig) -> bool:
         assert opts.linux is not None
 
         tool = find_tool('pesign', opts=opts)
@@ -449,7 +495,7 @@ class PeSign(SignTool):
 
 class SbSign(SignTool):
     @staticmethod
-    def sign(input_f: str, output_f: str, opts: argparse.Namespace) -> None:
+    def sign(input_f: str, output_f: str, opts: UkifyConfig) -> None:
         assert opts.sb_key is not None
         assert opts.sb_cert is not None
 
@@ -467,7 +513,7 @@ class SbSign(SignTool):
         subprocess.check_call(cmd)
 
     @staticmethod
-    def verify(opts: argparse.Namespace) -> bool:
+    def verify(opts: UkifyConfig) -> bool:
         assert opts.linux is not None
 
         tool = find_tool('sbverify', opts=opts)
@@ -507,7 +553,7 @@ def parse_phase_paths(s: str) -> list[str]:
     return paths
 
 
-def check_splash(filename: Optional[str]) -> None:
+def check_splash(filename: Optional[Path]) -> None:
     if filename is None:
         return
 
@@ -521,7 +567,7 @@ def check_splash(filename: Optional[str]) -> None:
     print(f'Splash image {filename} is {img.width}×{img.height} pixels')
 
 
-def check_inputs(opts: argparse.Namespace) -> None:
+def check_inputs(opts: UkifyConfig) -> None:
     for name, value in vars(opts).items():
         if name in {'output', 'tools'}:
             continue
@@ -537,9 +583,9 @@ def check_inputs(opts: argparse.Namespace) -> None:
     check_splash(opts.splash)
 
 
-def check_cert_and_keys_nonexistent(opts: argparse.Namespace) -> None:
+def check_cert_and_keys_nonexistent(opts: UkifyConfig) -> None:
     # Raise if any of the keys and certs are found on disk
-    paths = itertools.chain(
+    paths: Iterator[Union[str, Path, None]] = itertools.chain(
         (opts.sb_key, opts.sb_cert),
         *((priv_key, pub_key) for priv_key, pub_key, _ in key_path_groups(opts)),
     )
@@ -551,14 +597,14 @@ def check_cert_and_keys_nonexistent(opts: argparse.Namespace) -> None:
 def find_tool(
     name: str,
     fallback: Optional[str] = None,
-    opts: Optional[argparse.Namespace] = None,
+    opts: Optional[UkifyConfig] = None,
     msg: str = 'Tool {name} not installed!',
 ) -> Union[str, Path]:
     if opts and opts.tools:
         for d in opts.tools:
             tool = d / name
             if tool.exists():
-                return cast(Path, tool)
+                return tool
 
     if shutil.which(name) is not None:
         return name
@@ -579,18 +625,19 @@ def combine_signatures(pcrsigs: list[dict[str, str]]) -> str:
     return json.dumps(combined)
 
 
-def key_path_groups(opts: argparse.Namespace) -> Iterator[tuple[str, Optional[Path], Optional[str]]]:
+def key_path_groups(opts: UkifyConfig) -> Iterator[tuple[str, Optional[Path], Optional[str]]]:
     if not opts.pcr_private_keys:
         return
 
     n_priv = len(opts.pcr_private_keys)
-    pub_keys = opts.pcr_public_keys or [None] * n_priv
-    pp_groups = opts.phase_path_groups or [None] * n_priv
+    pub_keys = opts.pcr_public_keys or []
+    pp_groups = opts.phase_path_groups or []
 
-    yield from zip(
+    yield from itertools.zip_longest(
         opts.pcr_private_keys,
-        pub_keys,
-        pp_groups,
+        pub_keys[:n_priv],
+        pp_groups[:n_priv],
+        fillvalue=None,
     )
 
 
@@ -598,7 +645,7 @@ def pe_strip_section_name(name: bytes) -> str:
     return name.rstrip(b'\x00').decode()
 
 
-def call_systemd_measure(uki: UKI, opts: argparse.Namespace, profile_start: int = 0) -> None:
+def call_systemd_measure(uki: UKI, opts: UkifyConfig, profile_start: int = 0) -> None:
     measure_tool = find_tool(
         'systemd-measure',
         '/usr/lib/systemd/systemd-measure',
@@ -884,7 +931,9 @@ uki-addon,1,UKI Addon,addon,1,https://www.freedesktop.org/software/systemd/man/l
 """
 
 
-def make_uki(opts: argparse.Namespace) -> None:
+def make_uki(opts: UkifyConfig) -> None:
+    assert opts.output is not None
+
     # kernel payload signing
 
     sign_args_present = opts.sb_key or opts.sb_cert_name
@@ -892,16 +941,17 @@ def make_uki(opts: argparse.Namespace) -> None:
     linux = opts.linux
 
     if sign_args_present:
+        assert opts.linux is not None
         assert opts.signtool is not None
 
-        if not sign_kernel and opts.linux is not None:
+        if not sign_kernel:
             # figure out if we should sign the kernel
             sign_kernel = opts.signtool.verify(opts)
 
         if sign_kernel:
             linux_signed = tempfile.NamedTemporaryFile(prefix='linux-signed')
             linux = Path(linux_signed.name)
-            opts.signtool.sign(opts.linux, linux, opts=opts)
+            opts.signtool.sign(os.fspath(opts.linux), os.fspath(linux), opts=opts)
 
     if opts.uname is None and opts.linux is not None:
         print('Kernel version not specified, starting autodetection 😖.')
@@ -910,7 +960,7 @@ def make_uki(opts: argparse.Namespace) -> None:
     uki = UKI(opts.stub)
     initrd = join_initrds(opts.initrd)
 
-    pcrpkey = opts.pcrpkey
+    pcrpkey: Union[bytes, Path, None] = opts.pcrpkey
     if pcrpkey is None:
         if opts.pcr_public_keys and len(opts.pcr_public_keys) == 1:
             pcrpkey = opts.pcr_public_keys[0]
@@ -1032,15 +1082,15 @@ def make_uki(opts: argparse.Namespace) -> None:
         if names.count('.profile') > 1:
             raise ValueError(f'Profile PE binary {profile} contains multiple .profile sections')
 
-        for section in pe.sections:
-            n = pe_strip_section_name(section.Name)
+        for pesection in pe.sections:
+            n = pe_strip_section_name(pesection.Name)
 
             if n not in to_import:
                 continue
 
-            print(f"Copying section '{n}' from '{profile}': {section.Misc_VirtualSize} bytes")
+            print(f"Copying section '{n}' from '{profile}': {pesection.Misc_VirtualSize} bytes")
             uki.add_section(
-                Section.create(n, section.get_data(length=section.Misc_VirtualSize), measure=True)
+                Section.create(n, pesection.get_data(length=pesection.Misc_VirtualSize), measure=True)
             )
 
         call_systemd_measure(uki, opts=opts, profile_start=prev_len)
@@ -1059,7 +1109,7 @@ def make_uki(opts: argparse.Namespace) -> None:
 
     if sign_args_present:
         assert opts.signtool is not None
-        opts.signtool.sign(unsigned_output, opts.output, opts)
+        opts.signtool.sign(os.fspath(unsigned_output), os.fspath(opts.output), opts)
 
         # We end up with no executable bits, let's reapply them
         os.umask(umask := os.umask(0))
@@ -1161,7 +1211,7 @@ def generate_priv_pub_key_pair(keylength: int = 2048) -> tuple[bytes, bytes]:
     return priv_key_pem, pub_key_pem
 
 
-def generate_keys(opts: argparse.Namespace) -> None:
+def generate_keys(opts: UkifyConfig) -> None:
     work = False
 
     # This will generate keys and certificates and write them to the paths that
@@ -1200,7 +1250,7 @@ def generate_keys(opts: argparse.Namespace) -> None:
 
 
 def inspect_section(
-    opts: argparse.Namespace,
+    opts: UkifyConfig,
     section: pefile.SectionStructure,
 ) -> tuple[str, Optional[dict[str, Union[int, str]]]]:
     name = pe_strip_section_name(section.Name)
@@ -1243,7 +1293,7 @@ def inspect_section(
     return name, struct
 
 
-def inspect_sections(opts: argparse.Namespace) -> None:
+def inspect_sections(opts: UkifyConfig) -> None:
     indent = 4 if opts.json == 'pretty' else None
 
     for file in opts.files:
@@ -1935,7 +1985,7 @@ def parse_args(args: Optional[list[str]] = None) -> argparse.Namespace:
 
 
 def main() -> None:
-    opts = parse_args()
+    opts = UkifyConfig.from_namespace(parse_args())
     if opts.summary:
         # TODO: replace pprint() with some fancy formatting.
         pprint.pprint(vars(opts))