From f1b6430ecf947897cbe802abb0598d6423e4841a Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=B6rg=20Behrmann?= Date: Mon, 14 Oct 2024 09:58:05 +0200 Subject: [PATCH] ukify: Add UkifyConfig Using a dataclass instead of an argparse namespace to pass around the parsed options allows to track the types properly. --- src/ukify/ukify.py | 122 ++++++++++++++++++++++++++++++++------------- 1 file changed, 86 insertions(+), 36 deletions(-) diff --git a/src/ukify/ukify.py b/src/ukify/ukify.py index accc977ac23..ebb4985cfcb 100755 --- a/src/ukify/ukify.py +++ b/src/ukify/ukify.py @@ -26,6 +26,7 @@ import contextlib import dataclasses import datetime import fnmatch +import inspect import itertools import json import os @@ -48,6 +49,7 @@ from typing import ( IO, Any, Callable, + Literal, Optional, TypeVar, Union, @@ -230,6 +232,50 @@ def maybe_decompress(filename: Union[str, Path]) -> bytes: raise NotImplementedError(f'unknown file format (starts with {start!r})') +@dataclasses.dataclass +class UkifyConfig: + all: bool + cmdline: Union[str, Path, None] + devicetree: Path + efi_arch: str + initrd: list[Path] + join_profiles: list[Path] + json: Union[Literal['pretty'], Literal['short'], Literal['off']] + linux: Optional[Path] + measure: bool + microcode: Path + os_release: Union[str, Path, None] + output: Optional[str] + pcr_banks: list[str] + pcr_private_keys: list[str] + pcr_public_keys: list[Path] + pcrpkey: Optional[Path] + phase_path_groups: Optional[list[str]] + profile: Union[str, Path, None] + sb_cert: Path + sb_cert_name: Optional[str] + sb_cert_validity: int + sb_certdir: Path + sb_key: Optional[Path] + sbat: Optional[list[str]] + sections: list['Section'] + sections_by_name: dict[str, 'Section'] + sign_kernel: bool + signing_engine: Optional[str] + signtool: Optional[type['SignTool']] + splash: Optional[Path] + stub: Path + summary: bool + tools: list[Path] + uname: Optional[str] + verb: str + files: list[str] = dataclasses.field(default_factory=list) + + @classmethod + def from_namespace(cls, ns: argparse.Namespace) -> 'UkifyConfig': + return cls(**{k: v for k, v in vars(ns).items() if k in inspect.signature(cls).parameters}) + + class Uname: # This class is here purely as a namespace for the functions @@ -243,7 +289,7 @@ class Uname: TEXT_PATTERN = rb'Linux version (?P\d\.\S+) \(' @classmethod - def scrape_x86(cls, filename: Path, opts: Optional[argparse.Namespace] = None) -> str: + def scrape_x86(cls, filename: Path, opts: Optional[UkifyConfig] = None) -> str: # Based on https://gitlab.archlinux.org/archlinux/mkinitcpio/mkinitcpio/-/blob/master/functions#L136 # and https://docs.kernel.org/arch/x86/boot.html#the-real-mode-kernel-header with open(filename, 'rb') as f: @@ -263,7 +309,7 @@ class Uname: return m.group('version') @classmethod - def scrape_elf(cls, filename: Path, opts: Optional[argparse.Namespace] = None) -> str: + def scrape_elf(cls, filename: Path, opts: Optional[UkifyConfig] = None) -> str: readelf = find_tool('readelf', opts=opts) cmd = [ @@ -285,7 +331,7 @@ class Uname: return text.rstrip('\0') @classmethod - def scrape_generic(cls, filename: Path, opts: Optional[argparse.Namespace] = None) -> str: + def scrape_generic(cls, filename: Path, opts: Optional[UkifyConfig] = None) -> str: # import libarchive # libarchive-c fails with # ArchiveError: Unrecognized archive format (errno=84, retcode=-30, archive_p=94705420454656) @@ -299,7 +345,7 @@ class Uname: return m.group('version').decode() @classmethod - def scrape(cls, filename: Path, opts: Optional[argparse.Namespace] = None) -> Optional[str]: + def scrape(cls, filename: Path, opts: Optional[UkifyConfig] = None) -> Optional[str]: for func in (cls.scrape_x86, cls.scrape_elf, cls.scrape_generic): try: version = func(filename, opts=opts) @@ -406,17 +452,17 @@ class UKI: class SignTool: @staticmethod - def sign(input_f: str, output_f: str, opts: argparse.Namespace) -> None: + def sign(input_f: str, output_f: str, opts: UkifyConfig) -> None: raise NotImplementedError() @staticmethod - def verify(opts: argparse.Namespace) -> bool: + def verify(opts: UkifyConfig) -> bool: raise NotImplementedError() class PeSign(SignTool): @staticmethod - def sign(input_f: str, output_f: str, opts: argparse.Namespace) -> None: + def sign(input_f: str, output_f: str, opts: UkifyConfig) -> None: assert opts.sb_certdir is not None assert opts.sb_cert_name is not None @@ -435,7 +481,7 @@ class PeSign(SignTool): subprocess.check_call(cmd) @staticmethod - def verify(opts: argparse.Namespace) -> bool: + def verify(opts: UkifyConfig) -> bool: assert opts.linux is not None tool = find_tool('pesign', opts=opts) @@ -449,7 +495,7 @@ class PeSign(SignTool): class SbSign(SignTool): @staticmethod - def sign(input_f: str, output_f: str, opts: argparse.Namespace) -> None: + def sign(input_f: str, output_f: str, opts: UkifyConfig) -> None: assert opts.sb_key is not None assert opts.sb_cert is not None @@ -467,7 +513,7 @@ class SbSign(SignTool): subprocess.check_call(cmd) @staticmethod - def verify(opts: argparse.Namespace) -> bool: + def verify(opts: UkifyConfig) -> bool: assert opts.linux is not None tool = find_tool('sbverify', opts=opts) @@ -507,7 +553,7 @@ def parse_phase_paths(s: str) -> list[str]: return paths -def check_splash(filename: Optional[str]) -> None: +def check_splash(filename: Optional[Path]) -> None: if filename is None: return @@ -521,7 +567,7 @@ def check_splash(filename: Optional[str]) -> None: print(f'Splash image {filename} is {img.width}×{img.height} pixels') -def check_inputs(opts: argparse.Namespace) -> None: +def check_inputs(opts: UkifyConfig) -> None: for name, value in vars(opts).items(): if name in {'output', 'tools'}: continue @@ -537,9 +583,9 @@ def check_inputs(opts: argparse.Namespace) -> None: check_splash(opts.splash) -def check_cert_and_keys_nonexistent(opts: argparse.Namespace) -> None: +def check_cert_and_keys_nonexistent(opts: UkifyConfig) -> None: # Raise if any of the keys and certs are found on disk - paths = itertools.chain( + paths: Iterator[Union[str, Path, None]] = itertools.chain( (opts.sb_key, opts.sb_cert), *((priv_key, pub_key) for priv_key, pub_key, _ in key_path_groups(opts)), ) @@ -551,14 +597,14 @@ def check_cert_and_keys_nonexistent(opts: argparse.Namespace) -> None: def find_tool( name: str, fallback: Optional[str] = None, - opts: Optional[argparse.Namespace] = None, + opts: Optional[UkifyConfig] = None, msg: str = 'Tool {name} not installed!', ) -> Union[str, Path]: if opts and opts.tools: for d in opts.tools: tool = d / name if tool.exists(): - return cast(Path, tool) + return tool if shutil.which(name) is not None: return name @@ -579,18 +625,19 @@ def combine_signatures(pcrsigs: list[dict[str, str]]) -> str: return json.dumps(combined) -def key_path_groups(opts: argparse.Namespace) -> Iterator[tuple[str, Optional[Path], Optional[str]]]: +def key_path_groups(opts: UkifyConfig) -> Iterator[tuple[str, Optional[Path], Optional[str]]]: if not opts.pcr_private_keys: return n_priv = len(opts.pcr_private_keys) - pub_keys = opts.pcr_public_keys or [None] * n_priv - pp_groups = opts.phase_path_groups or [None] * n_priv + pub_keys = opts.pcr_public_keys or [] + pp_groups = opts.phase_path_groups or [] - yield from zip( + yield from itertools.zip_longest( opts.pcr_private_keys, - pub_keys, - pp_groups, + pub_keys[:n_priv], + pp_groups[:n_priv], + fillvalue=None, ) @@ -598,7 +645,7 @@ def pe_strip_section_name(name: bytes) -> str: return name.rstrip(b'\x00').decode() -def call_systemd_measure(uki: UKI, opts: argparse.Namespace, profile_start: int = 0) -> None: +def call_systemd_measure(uki: UKI, opts: UkifyConfig, profile_start: int = 0) -> None: measure_tool = find_tool( 'systemd-measure', '/usr/lib/systemd/systemd-measure', @@ -884,7 +931,9 @@ uki-addon,1,UKI Addon,addon,1,https://www.freedesktop.org/software/systemd/man/l """ -def make_uki(opts: argparse.Namespace) -> None: +def make_uki(opts: UkifyConfig) -> None: + assert opts.output is not None + # kernel payload signing sign_args_present = opts.sb_key or opts.sb_cert_name @@ -892,16 +941,17 @@ def make_uki(opts: argparse.Namespace) -> None: linux = opts.linux if sign_args_present: + assert opts.linux is not None assert opts.signtool is not None - if not sign_kernel and opts.linux is not None: + if not sign_kernel: # figure out if we should sign the kernel sign_kernel = opts.signtool.verify(opts) if sign_kernel: linux_signed = tempfile.NamedTemporaryFile(prefix='linux-signed') linux = Path(linux_signed.name) - opts.signtool.sign(opts.linux, linux, opts=opts) + opts.signtool.sign(os.fspath(opts.linux), os.fspath(linux), opts=opts) if opts.uname is None and opts.linux is not None: print('Kernel version not specified, starting autodetection 😖.') @@ -910,7 +960,7 @@ def make_uki(opts: argparse.Namespace) -> None: uki = UKI(opts.stub) initrd = join_initrds(opts.initrd) - pcrpkey = opts.pcrpkey + pcrpkey: Union[bytes, Path, None] = opts.pcrpkey if pcrpkey is None: if opts.pcr_public_keys and len(opts.pcr_public_keys) == 1: pcrpkey = opts.pcr_public_keys[0] @@ -1032,15 +1082,15 @@ def make_uki(opts: argparse.Namespace) -> None: if names.count('.profile') > 1: raise ValueError(f'Profile PE binary {profile} contains multiple .profile sections') - for section in pe.sections: - n = pe_strip_section_name(section.Name) + for pesection in pe.sections: + n = pe_strip_section_name(pesection.Name) if n not in to_import: continue - print(f"Copying section '{n}' from '{profile}': {section.Misc_VirtualSize} bytes") + print(f"Copying section '{n}' from '{profile}': {pesection.Misc_VirtualSize} bytes") uki.add_section( - Section.create(n, section.get_data(length=section.Misc_VirtualSize), measure=True) + Section.create(n, pesection.get_data(length=pesection.Misc_VirtualSize), measure=True) ) call_systemd_measure(uki, opts=opts, profile_start=prev_len) @@ -1059,7 +1109,7 @@ def make_uki(opts: argparse.Namespace) -> None: if sign_args_present: assert opts.signtool is not None - opts.signtool.sign(unsigned_output, opts.output, opts) + opts.signtool.sign(os.fspath(unsigned_output), os.fspath(opts.output), opts) # We end up with no executable bits, let's reapply them os.umask(umask := os.umask(0)) @@ -1161,7 +1211,7 @@ def generate_priv_pub_key_pair(keylength: int = 2048) -> tuple[bytes, bytes]: return priv_key_pem, pub_key_pem -def generate_keys(opts: argparse.Namespace) -> None: +def generate_keys(opts: UkifyConfig) -> None: work = False # This will generate keys and certificates and write them to the paths that @@ -1200,7 +1250,7 @@ def generate_keys(opts: argparse.Namespace) -> None: def inspect_section( - opts: argparse.Namespace, + opts: UkifyConfig, section: pefile.SectionStructure, ) -> tuple[str, Optional[dict[str, Union[int, str]]]]: name = pe_strip_section_name(section.Name) @@ -1243,7 +1293,7 @@ def inspect_section( return name, struct -def inspect_sections(opts: argparse.Namespace) -> None: +def inspect_sections(opts: UkifyConfig) -> None: indent = 4 if opts.json == 'pretty' else None for file in opts.files: @@ -1935,7 +1985,7 @@ def parse_args(args: Optional[list[str]] = None) -> argparse.Namespace: def main() -> None: - opts = parse_args() + opts = UkifyConfig.from_namespace(parse_args()) if opts.summary: # TODO: replace pprint() with some fancy formatting. pprint.pprint(vars(opts)) -- 2.47.3