]> git.ipfire.org Git - thirdparty/mkosi.git/commitdiff
Make ConfigSetting generic
authorSeptatrix <24257556+Septatrix@users.noreply.github.com>
Wed, 20 Nov 2024 20:48:16 +0000 (21:48 +0100)
committerDaan De Meyer <daan.j.demeyer@gmail.com>
Fri, 22 Nov 2024 13:02:06 +0000 (14:02 +0100)
mkosi/config.py

index 0cb6cb336676f8f0aabf2fd64f2a305805d3fe3a..6ed20c4eaded5e8655d83fb89de88090ec7bb5ed 100644 (file)
@@ -28,7 +28,7 @@ import uuid
 from collections.abc import Collection, Iterable, Iterator, Sequence
 from contextlib import AbstractContextManager
 from pathlib import Path
-from typing import Any, Callable, Optional, TypeVar, Union, cast
+from typing import Any, Callable, Generic, Optional, TypeVar, Union, cast
 
 from mkosi.distributions import Distribution, detect_distribution
 from mkosi.log import ARG_DEBUG, ARG_DEBUG_SANDBOX, ARG_DEBUG_SHELL, Style, die
@@ -48,11 +48,11 @@ from mkosi.util import (
 from mkosi.versioncomp import GenericVersion
 
 T = TypeVar("T")
+SE = TypeVar("SE", bound=StrEnum)
 
-ConfigParseCallback = Callable[[Optional[str], Optional[Any]], Any]
-ConfigMatchCallback = Callable[[str, Any], bool]
-ConfigDefaultCallback = Callable[[argparse.Namespace], Any]
-
+ConfigParseCallback = Callable[[Optional[str], Optional[T]], Optional[T]]
+ConfigMatchCallback = Callable[[str, T], bool]
+ConfigDefaultCallback = Callable[[argparse.Namespace], T]
 
 BUILTIN_CONFIGS = ("mkosi-tools", "mkosi-initrd", "mkosi-vm")
 
@@ -676,7 +676,7 @@ def config_match_build_sources(match: str, value: list[ConfigTree]) -> bool:
     return Path(match.lstrip("/")) in [tree.target for tree in value if tree.target]
 
 
-def config_make_list_matcher(parse: Callable[[str], T]) -> ConfigMatchCallback:
+def config_make_list_matcher(parse: Callable[[str], T]) -> ConfigMatchCallback[list[T]]:
     def config_match_list(match: str, value: list[T]) -> bool:
         return parse(match) in value
 
@@ -687,7 +687,7 @@ def config_parse_string(value: Optional[str], old: Optional[str]) -> Optional[st
     return value or None
 
 
-def config_make_string_matcher(allow_globs: bool = False) -> ConfigMatchCallback:
+def config_make_string_matcher(allow_globs: bool = False) -> ConfigMatchCallback[str]:
     def config_match_string(match: str, value: str) -> bool:
         if allow_globs:
             return fnmatch.fnmatchcase(value, match)
@@ -906,8 +906,8 @@ def config_default_proxy_url(namespace: argparse.Namespace) -> Optional[str]:
     return None
 
 
-def make_enum_parser(type: type[StrEnum]) -> Callable[[str], StrEnum]:
-    def parse_enum(value: str) -> StrEnum:
+def make_enum_parser(type: type[SE]) -> Callable[[str], SE]:
+    def parse_enum(value: str) -> SE:
         try:
             return type(value)
         except ValueError:
@@ -916,17 +916,15 @@ def make_enum_parser(type: type[StrEnum]) -> Callable[[str], StrEnum]:
     return parse_enum
 
 
-def config_make_enum_parser(type: type[StrEnum]) -> ConfigParseCallback:
-    def config_parse_enum(value: Optional[str], old: Optional[StrEnum]) -> Optional[StrEnum]:
+def config_make_enum_parser(type: type[SE]) -> ConfigParseCallback[SE]:
+    def config_parse_enum(value: Optional[str], old: Optional[SE]) -> Optional[SE]:
         return make_enum_parser(type)(value) if value else None
 
     return config_parse_enum
 
 
-def config_make_enum_parser_with_boolean(
-    type: type[StrEnum], *, yes: StrEnum, no: StrEnum
-) -> ConfigParseCallback:
-    def config_parse_enum(value: Optional[str], old: Optional[StrEnum]) -> Optional[StrEnum]:
+def config_make_enum_parser_with_boolean(type: type[SE], *, yes: SE, no: SE) -> ConfigParseCallback[SE]:
+    def config_parse_enum(value: Optional[str], old: Optional[SE]) -> Optional[SE]:
         if not value:
             return None
 
@@ -938,8 +936,8 @@ def config_make_enum_parser_with_boolean(
     return config_parse_enum
 
 
-def config_make_enum_matcher(type: type[StrEnum]) -> ConfigMatchCallback:
-    def config_match_enum(match: str, value: StrEnum) -> bool:
+def config_make_enum_matcher(type: type[SE]) -> ConfigMatchCallback[SE]:
+    def config_match_enum(match: str, value: SE) -> bool:
         return make_enum_parser(type)(match) == value
 
     return config_match_enum
@@ -948,11 +946,11 @@ def config_make_enum_matcher(type: type[StrEnum]) -> ConfigMatchCallback:
 def config_make_list_parser(
     *,
     delimiter: Optional[str] = None,
-    parse: Callable[[str], Any] = str,
+    parse: Callable[[str], T] = str,  # type: ignore # see mypy#3737
     unescape: bool = False,
     reset: bool = True,
-) -> ConfigParseCallback:
-    def config_parse_list(value: Optional[str], old: Optional[list[Any]]) -> Optional[list[Any]]:
+) -> ConfigParseCallback[list[T]]:
+    def config_parse_list(value: Optional[str], old: Optional[list[T]]) -> Optional[list[T]]:
         new = old.copy() if old else []
 
         if value is None:
@@ -1010,12 +1008,12 @@ def config_match_version(match: str, value: str) -> bool:
 def config_make_dict_parser(
     *,
     delimiter: Optional[str] = None,
-    parse: Callable[[str], tuple[str, Any]],
+    parse: Callable[[str], tuple[str, str]],
     unescape: bool = False,
     allow_paths: bool = False,
     reset: bool = True,
-) -> ConfigParseCallback:
-    def config_parse_dict(value: Optional[str], old: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]:
+) -> ConfigParseCallback[dict[str, str]]:
+    def config_parse_dict(value: Optional[str], old: Optional[dict[str, str]]) -> Optional[dict[str, str]]:
         new = old.copy() if old else {}
 
         if value is None:
@@ -1104,7 +1102,7 @@ def config_make_path_parser(
     expandvars: bool = True,
     secret: bool = False,
     constants: Sequence[str] = (),
-) -> ConfigParseCallback:
+) -> ConfigParseCallback[Path]:
     def config_parse_path(value: Optional[str], old: Optional[Path]) -> Optional[Path]:
         if not value:
             return None
@@ -1127,7 +1125,7 @@ def is_valid_filename(s: str) -> bool:
     return not (s == "." or s == ".." or "/" in s)
 
 
-def config_make_filename_parser(hint: str) -> ConfigParseCallback:
+def config_make_filename_parser(hint: str) -> ConfigParseCallback[str]:
     def config_parse_filename(value: Optional[str], old: Optional[str]) -> Optional[str]:
         if not value:
             return None
@@ -1389,7 +1387,8 @@ def config_parse_artifact_output_list(
     if boolean_value is not None:
         return ArtifactOutput.compat_yes() if boolean_value else ArtifactOutput.compat_no()
 
-    list_value = config_make_list_parser(delimiter=",", parse=make_enum_parser(ArtifactOutput))(value, old)
+    list_parser = config_make_list_parser(delimiter=",", parse=make_enum_parser(ArtifactOutput))
+    list_value = list_parser(value, old)
     return cast(list[ArtifactOutput], list_value)
 
 
@@ -1403,14 +1402,14 @@ class SettingScope(StrEnum):
 
 
 @dataclasses.dataclass(frozen=True)
-class ConfigSetting:
+class ConfigSetting(Generic[T]):
     dest: str
     section: str
-    parse: ConfigParseCallback = config_parse_string
-    match: Optional[ConfigMatchCallback] = None
+    parse: ConfigParseCallback[T] = config_parse_string  # type: ignore # see mypy#3737
+    match: Optional[ConfigMatchCallback[T]] = None
     name: str = ""
-    default: Any = None
-    default_factory: Optional[ConfigDefaultCallback] = None
+    default: Optional[T] = None
+    default_factory: Optional[ConfigDefaultCallback[T]] = None
     default_factory_depends: tuple[str, ...] = tuple()
     paths: tuple[str, ...] = ()
     recursive_paths: tuple[str, ...] = ()
@@ -1422,7 +1421,7 @@ class ConfigSetting:
     # settings for argparse
     short: Optional[str] = None
     long: str = ""
-    choices: Optional[Any] = None
+    choices: Optional[list[str]] = None
     metavar: Optional[str] = None
     nargs: Optional[str] = None
     const: Optional[Any] = None
@@ -1529,7 +1528,7 @@ class PagerHelpAction(argparse._HelpAction):
         parser.exit()
 
 
-def dict_with_capitalised_keys_factory(pairs: Any) -> dict[str, Any]:
+def dict_with_capitalised_keys_factory(pairs: list[tuple[str, T]]) -> dict[str, T]:
     def key_transformer(k: str) -> str:
         if (s := SETTINGS_LOOKUP_BY_DEST.get(k)) is not None:
             return s.name
@@ -1634,11 +1633,14 @@ class UKIProfile:
     cmdline: list[str]
 
 
-def make_simple_config_parser(settings: Sequence[ConfigSetting], type: type[Any]) -> Callable[[str], Any]:
+def make_simple_config_parser(
+    settings: Sequence[ConfigSetting[object]],
+    valtype: type[T],
+) -> Callable[[str], T]:
     lookup_by_name = {s.name: s for s in settings}
     lookup_by_dest = {s.dest: s for s in settings}
 
-    def finalize_value(config: argparse.Namespace, setting: ConfigSetting) -> None:
+    def finalize_value(config: argparse.Namespace, setting: ConfigSetting[object]) -> None:
         if hasattr(config, setting.dest):
             return
 
@@ -1654,7 +1656,7 @@ def make_simple_config_parser(settings: Sequence[ConfigSetting], type: type[Any]
 
         setattr(config, setting.dest, default)
 
-    def parse_simple_config(value: str) -> Any:
+    def parse_simple_config(value: str) -> T:
         path = parse_path(value)
         config = argparse.Namespace()
 
@@ -1681,7 +1683,9 @@ def make_simple_config_parser(settings: Sequence[ConfigSetting], type: type[Any]
         for setting in settings:
             finalize_value(config, setting)
 
-        return type(**{k: v for k, v in vars(config).items() if k in inspect.signature(type).parameters})
+        return valtype(
+            **{k: v for k, v in vars(config).items() if k in inspect.signature(valtype).parameters}
+        )
 
     return parse_simple_config
 
@@ -2160,7 +2164,7 @@ def parse_ini(path: Path, only_sections: Collection[str] = ()) -> Iterator[tuple
         yield section, "", ""
 
 
-PE_ADDON_SETTINGS = (
+PE_ADDON_SETTINGS: list[ConfigSetting[Any]] = [
     ConfigSetting(
         dest="output",
         section="PEAddon",
@@ -2172,10 +2176,10 @@ PE_ADDON_SETTINGS = (
         section="PEAddon",
         parse=config_make_list_parser(delimiter=" "),
     ),
-)
+]
 
 
-UKI_PROFILE_SETTINGS = (
+UKI_PROFILE_SETTINGS: list[ConfigSetting[Any]] = [
     ConfigSetting(
         dest="profile",
         section="UKIProfile",
@@ -2186,10 +2190,10 @@ UKI_PROFILE_SETTINGS = (
         section="UKIProfile",
         parse=config_make_list_parser(delimiter=" "),
     ),
-)
+]
 
 
-SETTINGS = (
+SETTINGS: list[ConfigSetting[Any]] = [
     # Include section
     ConfigSetting(
         dest="include",
@@ -3665,7 +3669,7 @@ SETTINGS = (
         # arguments.
         help=argparse.SUPPRESS,
     ),
-)
+]
 SETTINGS_LOOKUP_BY_NAME = {name: s for s in SETTINGS for name in [s.name, *s.compat_names]}
 SETTINGS_LOOKUP_BY_DEST = {s.dest: s for s in SETTINGS}
 SETTINGS_LOOKUP_BY_SPECIFIER = {s.specifier: s for s in SETTINGS if s.specifier}
@@ -4065,19 +4069,30 @@ class ParseContext:
             with chdir(path if path.is_dir() else Path.cwd()):
                 self.parse_config_one(path if path.is_file() else Path("."))
 
-    def finalize_value(self, setting: ConfigSetting) -> Optional[Any]:
+    def finalize_value(self, setting: ConfigSetting[T]) -> Optional[T]:
         # If a value was specified on the CLI, it always takes priority. If the setting is a collection of
         # values, we merge the value from the CLI with the value from the configuration, making sure that the
         # value from the CLI always takes priority.
-        if (v := getattr(self.cli, setting.dest, None)) is not None:
-            if getattr(self.cli, f"{setting.dest}_was_none", False):
+        if (v := cast(Optional[T], getattr(self.cli, setting.dest, None))) is not None:
+            cfg_value = getattr(self.config, setting.dest, None)
+            # We either have no corresponding value in the config files
+            # or the values was assigned the empty string on the CLI
+            # and should thus be treated as a reset and override of the value from the config file.
+            if cfg_value is None or getattr(self.cli, f"{setting.dest}_was_none", False):
                 return v
-            elif isinstance(v, list):
-                return (getattr(self.config, setting.dest, None) or []) + v
+
+            # The instance asserts are pushed down to help mypy/pylance narrow the types.
+            # Mypy still cannot properly infer that the merged collections conform to T
+            # so we ignore the return-value error for it.
+            if isinstance(v, list):
+                assert isinstance(cfg_value, type(v))
+                return cfg_value + v  # type: ignore[return-value]
             elif isinstance(v, dict):
-                return (getattr(self.config, setting.dest, None) or {}) | v
+                assert isinstance(cfg_value, type(v))
+                return cfg_value | v  # type: ignore[return-value]
             elif isinstance(v, set):
-                return (getattr(self.config, setting.dest, None) or set()) | v
+                assert isinstance(cfg_value, type(v))
+                return cfg_value | v  # type: ignore[return-value]
             else:
                 return v
 
@@ -4088,7 +4103,7 @@ class ParseContext:
         if (
             not hasattr(self.cli, setting.dest)
             and hasattr(self.config, setting.dest)
-            and (v := getattr(self.config, setting.dest)) is not None
+            and (v := cast(Optional[T], getattr(self.config, setting.dest))) is not None
         ):
             return v
 
@@ -4195,7 +4210,7 @@ class ParseContext:
         return match_triggered is not False
 
     def parse_config_one(self, path: Path, parse_profiles: bool = False, parse_local: bool = False) -> bool:
-        s: Optional[ConfigSetting]  # Make mypy happy
+        s: Optional[ConfigSetting[object]]  # Hint to mypy that we might assign None
         extras = path.is_dir()
 
         if path.is_dir():