import sys
import tempfile
import textwrap
-from collections.abc import Sequence
+from collections.abc import Iterable, Iterator, Sequence
from hashlib import sha256
from pathlib import Path
+from types import ModuleType
from typing import (
IO,
Any,
Callable,
Optional,
+ TypeVar,
Union,
+ cast,
)
import pefile # type: ignore
reset = '\033[0m' if sys.stderr.isatty() else ''
-def guess_efi_arch():
+def guess_efi_arch() -> str:
arch = os.uname().machine
for glob, mapping in EFI_ARCH_MAP.items():
print(text)
-def shell_join(cmd):
+def shell_join(cmd: list[Union[str, Path]]) -> str:
# TODO: drop in favour of shlex.join once shlex.join supports Path.
return ' '.join(shlex.quote(str(x)) for x in cmd)
-def round_up(x, blocksize=4096):
+def round_up(x: int, blocksize: int = 4096) -> int:
return (x + blocksize - 1) // blocksize * blocksize
-def try_import(modname, name=None):
+def try_import(modname: str, name: Optional[str] = None) -> ModuleType:
try:
return __import__(modname)
except ImportError as e:
raise ValueError(f'Kernel is compressed with {name or modname}, but module unavailable') from e
-def get_zboot_kernel(f):
+def get_zboot_kernel(f: IO[bytes]) -> bytes:
"""Decompress zboot efistub kernel if compressed. Return contents."""
# See linux/drivers/firmware/efi/libstub/Makefile.zboot
# and linux/drivers/firmware/efi/libstub/zboot-header.S
f.seek(start)
if comp_type.startswith(b'gzip'):
gzip = try_import('gzip')
- return gzip.open(f).read(size)
+ return cast(bytes, gzip.open(f).read(size))
elif comp_type.startswith(b'lz4'):
lz4 = try_import('lz4.frame', 'lz4')
- return lz4.frame.decompress(f.read(size))
+ return cast(bytes, lz4.frame.decompress(f.read(size)))
elif comp_type.startswith(b'lzma'):
lzma = try_import('lzma')
- return lzma.open(f).read(size)
+ return cast(bytes, lzma.open(f).read(size))
elif comp_type.startswith(b'lzo'):
raise NotImplementedError('lzo decompression not implemented')
elif comp_type.startswith(b'xzkern'):
raise NotImplementedError('xzkern decompression not implemented')
elif comp_type.startswith(b'zstd22'):
zstd = try_import('zstd')
- return zstd.uncompress(f.read(size))
- else:
- raise NotImplementedError(f'unknown compressed type: {comp_type}')
+ return cast(bytes, zstd.uncompress(f.read(size)))
+
+ raise NotImplementedError(f'unknown compressed type: {comp_type!r}')
-def maybe_decompress(filename):
+def maybe_decompress(filename: Union[str, Path]) -> bytes:
"""Decompress file if compressed. Return contents."""
f = open(filename, 'rb')
start = f.read(4)
if start.startswith(b'\x1f\x8b'):
gzip = try_import('gzip')
- return gzip.open(f).read()
+ return cast(bytes, gzip.open(f).read())
if start.startswith(b'\x28\xb5\x2f\xfd'):
zstd = try_import('zstd')
- return zstd.uncompress(f.read())
+ return cast(bytes, zstd.uncompress(f.read()))
if start.startswith(b'\x02\x21\x4c\x18'):
lz4 = try_import('lz4.frame', 'lz4')
- return lz4.frame.decompress(f.read())
+ return cast(bytes, lz4.frame.decompress(f.read()))
if start.startswith(b'\x04\x22\x4d\x18'):
print('Newer lz4 stream format detected! This may not boot!')
lz4 = try_import('lz4.frame', 'lz4')
- return lz4.frame.decompress(f.read())
+ return cast(bytes, lz4.frame.decompress(f.read()))
if start.startswith(b'\x89LZO'):
# python3-lzo is not packaged for Fedora
if start.startswith(b'BZh'):
bz2 = try_import('bz2', 'bzip2')
- return bz2.open(f).read()
+ return cast(bytes, bz2.open(f).read())
if start.startswith(b'\x5d\x00\x00'):
lzma = try_import('lzma')
- return lzma.open(f).read()
+ return cast(bytes, lzma.open(f).read())
- raise NotImplementedError(f'unknown file format (starts with {start})')
+ raise NotImplementedError(f'unknown file format (starts with {start!r})')
class Uname:
TEXT_PATTERN = rb'Linux version (?P<version>\d\.\S+) \('
@classmethod
- def scrape_x86(cls, filename, opts=None):
+ def scrape_x86(cls, filename: Path, opts: Optional[argparse.Namespace] = None) -> str:
# Based on https://gitlab.archlinux.org/archlinux/mkinitcpio/mkinitcpio/-/blob/master/functions#L136
# and https://docs.kernel.org/arch/x86/boot.html#the-real-mode-kernel-header
with open(filename, 'rb') as f:
f.seek(0x200 + offset)
text = f.read(128)
text = text.split(b'\0', maxsplit=1)[0]
- text = text.decode()
+ decoded = text.decode()
- if not (m := re.match(cls.VERSION_PATTERN, text)):
+ if not (m := re.match(cls.VERSION_PATTERN, decoded)):
raise ValueError(f'Cannot parse version-host-release uname string: {text!r}')
return m.group('version')
@classmethod
- def scrape_elf(cls, filename, opts=None):
+ def scrape_elf(cls, filename: Path, opts: Optional[argparse.Namespace] = None) -> str:
readelf = find_tool('readelf', opts=opts)
+ if not readelf:
+ raise ValueError('FIXME')
cmd = [
readelf,
return text.rstrip('\0')
@classmethod
- def scrape_generic(cls, filename, opts=None):
+ def scrape_generic(cls, filename: Path, opts: Optional[argparse.Namespace] = None) -> str:
# import libarchive
# libarchive-c fails with
# ArchiveError: Unrecognized archive format (errno=84, retcode=-30, archive_p=94705420454656)
return m.group('version').decode()
@classmethod
- def scrape(cls, filename, opts=None):
+ def scrape(cls, filename: Path, opts: Optional[argparse.Namespace] = None) -> Optional[str]:
for func in (cls.scrape_x86, cls.scrape_elf, cls.scrape_generic):
try:
version = func(filename, opts=opts)
class Section:
name: str
content: Optional[Path]
- tmpfile: Optional[IO] = None
+ tmpfile: Optional[IO[Any]] = None
measure: bool = False
output_mode: Optional[str] = None
virtual_size: Optional[int] = None
@classmethod
- def create(cls, name, contents, **kwargs):
+ def create(cls, name: str, contents: Union[str, bytes, Path, None], **kwargs: Any) -> 'Section':
if isinstance(contents, (str, bytes)):
mode = 'wt' if isinstance(contents, str) else 'wb'
tmp = tempfile.NamedTemporaryFile(mode=mode, prefix=f'tmp{name}')
return cls(name, contents, tmpfile=tmp, **kwargs)
@classmethod
- def parse_input(cls, s):
+ def parse_input(cls, s: str) -> 'Section':
try:
name, contents, *rest = s.split(':')
except ValueError as e:
raise ValueError(f'Cannot parse section spec (extraneous parameters): {s!r}')
if contents.startswith('@'):
- contents = Path(contents[1:])
+ sec = cls.create(name, Path(contents[1:]))
+ else:
+ sec = cls.create(name, contents)
- sec = cls.create(name, contents)
sec.check_name()
return sec
@classmethod
- def parse_output(cls, s):
+ def parse_output(cls, s: str) -> 'Section':
if not (m := re.match(r'([a-zA-Z0-9_.]+):(text|binary)(?:@(.+))?', s)):
raise ValueError(f'Cannot parse section spec: {s!r}')
return cls.create(name, out, output_mode=ttype)
- def check_name(self):
+ def check_name(self) -> None:
# PE section names with more than 8 characters are legal, but our stub does
# not support them.
if not self.name.isascii() or not self.name.isprintable():
executable: list[Union[Path, str]]
sections: list[Section] = dataclasses.field(default_factory=list, init=False)
- def add_section(self, section):
+ def add_section(self, section: Section) -> None:
start = 0
# Start search at last .profile section, if there is one
self.sections += [section]
-def parse_banks(s):
+def parse_banks(s: str) -> list[str]:
banks = re.split(r',|\s+', s)
# TODO: do some sanity checking here
return banks
)
-def parse_phase_paths(s):
+def parse_phase_paths(s: str) -> list[str]:
# Split on commas or whitespace here. Commas might be hard to parse visually.
paths = re.split(r',|\s+', s)
return paths
-def check_splash(filename):
+def check_splash(filename: Optional[str]) -> None:
if filename is None:
return
print(f'Splash image {filename} is {img.width}×{img.height} pixels')
-def check_inputs(opts):
+def check_inputs(opts: argparse.Namespace) -> None:
for name, value in vars(opts).items():
if name in {'output', 'tools'}:
continue
check_splash(opts.splash)
-def check_cert_and_keys_nonexistent(opts):
+def check_cert_and_keys_nonexistent(opts: argparse.Namespace) -> None:
# Raise if any of the keys and certs are found on disk
paths = itertools.chain(
(opts.sb_key, opts.sb_cert),
raise ValueError(f'{path} is present')
-def find_tool(name, fallback=None, opts=None):
+def find_tool(
+ name: str,
+ fallback: Optional[str] = None,
+ opts: Optional[argparse.Namespace] = None,
+) -> Union[str, Path, None]:
if opts and opts.tools:
for d in opts.tools:
tool = d / name
if tool.exists():
- return tool
+ return cast(Path, tool)
if shutil.which(name) is not None:
return name
return fallback
-def combine_signatures(pcrsigs):
- combined = collections.defaultdict(list)
+def combine_signatures(pcrsigs: list[dict[str, str]]) -> str:
+ combined: collections.defaultdict[str, list[str]] = collections.defaultdict(list)
for pcrsig in pcrsigs:
for bank, sigs in pcrsig.items():
for sig in sigs:
return json.dumps(combined)
-def key_path_groups(opts):
+def key_path_groups(opts: argparse.Namespace) -> Iterator:
if not opts.pcr_private_keys:
return
)
-def pe_strip_section_name(name):
+def pe_strip_section_name(name: bytes) -> str:
return name.rstrip(b'\x00').decode()
-def call_systemd_measure(uki, opts, profile_start=0):
+def call_systemd_measure(uki: UKI, opts: argparse.Namespace, profile_start: int = 0) -> None:
measure_tool = find_tool(
'systemd-measure',
'/usr/lib/systemd/systemd-measure',
opts=opts,
)
+ if not measure_tool:
+ raise ValueError('FIXME')
banks = opts.pcr_banks or ()
extra += [f'--public-key={pub_key}']
extra += [f'--phase={phase_path}' for phase_path in group or ()]
- print('+', shell_join(cmd + extra))
- pcrsig = subprocess.check_output(cmd + extra, text=True)
+ print('+', shell_join(cmd + extra)) # type: ignore
+ pcrsig = subprocess.check_output(cmd + extra, text=True) # type: ignore
pcrsig = json.loads(pcrsig)
pcrsigs += [pcrsig]
uki.add_section(Section.create('.pcrsig', combined))
-def join_initrds(initrds):
+def join_initrds(initrds: list[Path]) -> Union[Path, bytes, None]:
if not initrds:
return None
if len(initrds) == 1:
return b''.join(seq)
-def pairwise(iterable):
+T = TypeVar('T')
+
+
+def pairwise(iterable: Iterable[T]) -> Iterator[tuple[T, Optional[T]]]:
a, b = itertools.tee(iterable)
next(b, None)
return zip(a, b)
pass
-def pe_add_sections(uki: UKI, output: str):
+def pe_add_sections(uki: UKI, output: str) -> None:
pe = pefile.PE(uki.executable, fast_load=True)
# Old stubs do not have the symbol/string table stripped, even though image files should not have one.
pe.write(output)
-def merge_sbat(input_pe: [Path], input_text: [str]) -> str:
+def merge_sbat(input_pe: list[Path], input_text: list[str]) -> str:
sbat = []
for f in input_pe:
)
-def signer_sign(cmd):
+def signer_sign(cmd: list[Union[str, Path]]) -> None:
print('+', shell_join(cmd))
subprocess.check_call(cmd)
-def find_sbsign(opts=None):
+def find_sbsign(opts: Optional[argparse.Namespace] = None) -> Union[str, Path, None]:
return find_tool('sbsign', opts=opts)
-def sbsign_sign(sbsign_tool, input_f, output_f, opts=None):
+def sbsign_sign(
+ sbsign_tool: Union[str, Path],
+ input_f: str,
+ output_f: str,
+ opts: argparse.Namespace,
+) -> None:
sign_invocation = [
sbsign_tool,
'--key', opts.sb_key,
signer_sign(sign_invocation)
-def find_pesign(opts=None):
+def find_pesign(opts: Optional[argparse.Namespace] = None) -> Union[str, Path, None]:
return find_tool('pesign', opts=opts)
-def pesign_sign(pesign_tool, input_f, output_f, opts=None):
+def pesign_sign(
+ pesign_tool: Union[str, Path],
+ input_f: str,
+ output_f: str,
+ opts: argparse.Namespace,
+) -> None:
sign_invocation = [
pesign_tool,
'-s',
}
-def verify(tool, opts):
+def verify(tool: dict[str, str], opts: argparse.Namespace) -> bool:
verify_tool = find_tool(tool['name'], opts=opts)
cmd = [
verify_tool,
return tool['output'] in info
-def make_uki(opts):
+def make_uki(opts: argparse.Namespace) -> None:
# kernel payload signing
sign_tool = None
sign_args_present = opts.sb_key or opts.sb_cert_name
sign_kernel = opts.sign_kernel
- sign = None
+ sign: Optional[Callable[[Union[str, Path], str, str, argparse.Namespace], None]] = None
linux = opts.linux
if sign_args_present:
sign_kernel = verify(verify_tool, opts)
if sign_kernel:
+ assert sign is not None
+ assert sign_tool is not None
linux_signed = tempfile.NamedTemporaryFile(prefix='linux-signed')
linux = Path(linux_signed.name)
sign(sign_tool, opts.linux, linux, opts=opts)
# UKI signing
if sign_args_present:
- assert sign
- sign(sign_tool, unsigned_output, opts.output, opts=opts)
+ assert sign is not None
+ assert sign_tool is not None
+ sign(sign_tool, unsigned_output, opts.output, opts)
# We end up with no executable bits, let's reapply them
os.umask(umask := os.umask(0))
@contextlib.contextmanager
-def temporary_umask(mask: int):
+def temporary_umask(mask: int) -> Iterator[None]:
# Drop <mask> bits from umask
old = os.umask(0)
os.umask(old | mask)
common_name: str,
valid_days: int,
keylength: int = 2048,
-) -> tuple[bytes]:
+) -> tuple[bytes, bytes]:
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
return key_pem, cert_pem
-def generate_priv_pub_key_pair(keylength: int = 2048) -> tuple[bytes]:
+def generate_priv_pub_key_pair(keylength: int = 2048) -> tuple[bytes, bytes]:
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
return priv_key_pem, pub_key_pem
-def generate_keys(opts):
+def generate_keys(opts: argparse.Namespace) -> None:
work = False
# This will generate keys and certificates and write them to the paths that
)
-def inspect_section(opts, section):
+def inspect_section(
+ opts: argparse.Namespace,
+ section: pefile.SectionStructure,
+) -> tuple[str, Optional[dict[str, Union[int, str]]]]:
name = pe_strip_section_name(section.Name)
# find the config for this section in opts and whether to show it
return name, struct
-def inspect_sections(opts):
+def inspect_sections(opts: argparse.Namespace) -> None:
indent = 4 if opts.json == 'pretty' else None
for file in opts.files:
return self.dest
return self._names()[0].lstrip('-').replace('-', '_')
- def add_to(self, parser: argparse.ArgumentParser):
+ def add_to(self, parser: argparse.ArgumentParser) -> None:
kwargs = {
key: val
for key in dataclasses.asdict(self)
args = self._names()
parser.add_argument(*args, **kwargs)
- def apply_config(self, namespace, section, group, key, value) -> None:
+ def apply_config(
+ self,
+ namespace: argparse.Namespace,
+ section: str,
+ group: Optional[str],
+ key: str,
+ value: Any,
+ ) -> None:
assert f'{section}/{key}' == self.config_key
dest = self.argparse_dest()
CONFIGFILE_ITEMS = {item.config_key: item for item in CONFIG_ITEMS if item.config_key}
-def apply_config(namespace, filename=None):
+def apply_config(namespace: argparse.Namespace, filename: Union[str, Path, None] = None) -> None:
if filename is None:
if namespace.config:
# Config set by the user, use that.
strict=False,
)
# Do not make keys lowercase
- cp.optionxform = lambda option: option
+ cp.optionxform = lambda option: option # type: ignore
# The API is not great.
read = cp.read(filename)
print(f'Unknown config setting [{section_name}] {key}=')
-def config_example():
- prev_section = None
+def config_example() -> Iterator[str]:
+ prev_section: Optional[str] = None
for item in CONFIG_ITEMS:
section, key, value = item.config_example()
if section:
parser.exit()
-def create_parser():
+def create_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
description='Build and sign Unified Kernel Images',
usage='\n '
item.add_to(p)
# Suppress printing of usage synopsis on errors
- p.error = lambda message: p.exit(2, f'{p.prog}: error: {message}\n')
+ p.error = lambda message: p.exit(2, f'{p.prog}: error: {message}\n') # type: ignore
# Make --help paged
p.add_argument(
return p
-def finalize_options(opts):
+def finalize_options(opts: argparse.Namespace) -> None:
# Figure out which syntax is being used, one of:
# ukify verb --arg --arg --arg
# ukify linux initrd…
sys.exit()
-def parse_args(args=None):
+def parse_args(args: Optional[list[str]] = None) -> argparse.Namespace:
opts = create_parser().parse_args(args)
apply_config(opts)
finalize_options(opts)
return opts
-def main():
+def main() -> None:
opts = parse_args()
if opts.verb == 'build':
check_inputs(opts)