From: Aleš Mrázek Date: Fri, 3 Sep 2021 12:11:51 +0000 (+0200) Subject: utils: data parser and validator reimplementation X-Git-Tag: v6.0.0a1~128^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=07f157915a75d720dc00f6b3d7cfd89ded237a36;p=thirdparty%2Fknot-resolver.git utils: data parser and validator reimplementation --- diff --git a/manager/etc/knot-resolver/config.yml b/manager/etc/knot-resolver/config.yml index 89990e6aa..04fdb5569 100644 --- a/manager/etc/knot-resolver/config.yml +++ b/manager/etc/knot-resolver/config.yml @@ -1,14 +1,5 @@ network: interfaces: - listen: 127.0.0.1@5353 - edns_buffer_size: - downstream: 4K -options: - prediction: true -cache: - storage: . - size_max: 100M -logging: - level: 5 server: - instances: 1 + workers: 1 diff --git a/manager/knot_resolver_manager/client/__init__.py b/manager/knot_resolver_manager/client/__init__.py index c89a68640..63a96c553 100644 --- a/manager/knot_resolver_manager/client/__init__.py +++ b/manager/knot_resolver_manager/client/__init__.py @@ -24,7 +24,7 @@ class KnotManagerClient: print(response.text) def set_num_workers(self, n: int): - response = requests.post(self._create_url("/config/server/instances"), data=str(n)) + response = requests.post(self._create_url("/config/server/workers"), data=str(n)) print(response.text) def wait_for_initialization(self, timeout_sec: float = 5, time_step: float = 0.4): diff --git a/manager/knot_resolver_manager/datamodel/__init__.py b/manager/knot_resolver_manager/datamodel/__init__.py index 12061c0b7..f1d14c637 100644 --- a/manager/knot_resolver_manager/datamodel/__init__.py +++ b/manager/knot_resolver_manager/datamodel/__init__.py @@ -1,5 +1,3 @@ -from .config import KresConfig +from .config import KresConfig, KresConfigStrict -__all__ = [ - "KresConfig", -] +__all__ = ["KresConfig", "KresConfigStrict"] diff --git a/manager/knot_resolver_manager/datamodel/cache_config.py b/manager/knot_resolver_manager/datamodel/cache_config.py deleted file mode 100644 index 11620c8ed..000000000 --- a/manager/knot_resolver_manager/datamodel/cache_config.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import Optional - -from knot_resolver_manager.compat.dataclasses import dataclass -from knot_resolver_manager.datamodel.types import SizeUnits -from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin - - -@dataclass -class CacheConfig(DataclassParserValidatorMixin): - storage: str = "/var/cache/knot-resolver" - size_max: Optional[str] = None - _size_max_bytes: int = 100 * SizeUnits.mebibyte - - def __post_init__(self): - if self.size_max: - self._size_max_bytes = SizeUnits.parse(self.size_max) - - def get_size_max(self) -> int: - return self._size_max_bytes - - def _validate(self): - pass diff --git a/manager/knot_resolver_manager/datamodel/config.py b/manager/knot_resolver_manager/datamodel/config.py index c606d8694..e73e2f5a8 100644 --- a/manager/knot_resolver_manager/datamodel/config.py +++ b/manager/knot_resolver_manager/datamodel/config.py @@ -1,18 +1,12 @@ import pkgutil -from typing import Optional, Text, Union +from typing import Optional, Text from jinja2 import Environment, Template -from knot_resolver_manager.compat.dataclasses import dataclass -from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin - -from .cache_config import CacheConfig -from .dns64_config import Dns64Config -from .logging_config import LoggingConfig -from .lua_config import LuaConfig -from .network_config import NetworkConfig -from .options_config import OptionsConfig -from .server_config import ServerConfig +from knot_resolver_manager.datamodel.lua_config import Lua, LuaStrict +from knot_resolver_manager.datamodel.network_config import Network, NetworkStrict +from knot_resolver_manager.datamodel.server_config import Server, ServerStrict +from knot_resolver_manager.utils import DataParser, DataValidator def _import_lua_template() -> Template: @@ -26,25 +20,19 @@ def _import_lua_template() -> Template: _LUA_TEMPLATE = _import_lua_template() -@dataclass -class KresConfig(DataclassParserValidatorMixin): - # pylint: disable=too-many-instance-attributes - server: ServerConfig = ServerConfig() - network: NetworkConfig = NetworkConfig() - options: OptionsConfig = OptionsConfig() - cache: CacheConfig = CacheConfig() - # DNS64 is disabled by default - dns64: Union[bool, Dns64Config] = False - logging: LoggingConfig = LoggingConfig() - lua: Optional[LuaConfig] = None - - def __post_init__(self): - # if DNS64 is enabled with defaults - if self.dns64 is True: - self.dns64 = Dns64Config() - - def _validate(self): - pass +class KresConfig(DataParser): + server: Server = Server() + network: Network = Network() + lua: Optional[Lua] = None + + +class KresConfigStrict(DataValidator): + server: ServerStrict + network: NetworkStrict + lua: Optional[LuaStrict] def render_lua(self) -> Text: return _LUA_TEMPLATE.render(cfg=self) + + def _validate(self) -> None: + pass diff --git a/manager/knot_resolver_manager/datamodel/dns64_config.py b/manager/knot_resolver_manager/datamodel/dns64_config.py deleted file mode 100644 index f6eeae85b..000000000 --- a/manager/knot_resolver_manager/datamodel/dns64_config.py +++ /dev/null @@ -1,14 +0,0 @@ -from knot_resolver_manager.compat.dataclasses import dataclass -from knot_resolver_manager.exceptions import DataValidationException -from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin - -from .types import RE_IPV6_PREFIX_96 - - -@dataclass -class Dns64Config(DataclassParserValidatorMixin): - prefix: str = "64:ff9b::" - - def _validate(self): - if not bool(RE_IPV6_PREFIX_96.match(self.prefix)): - raise DataValidationException("'dns64.prefix' must be valid IPv6 /96 prefix") diff --git a/manager/knot_resolver_manager/datamodel/dnssec_config.py b/manager/knot_resolver_manager/datamodel/dnssec_config.py deleted file mode 100644 index 8b308f024..000000000 --- a/manager/knot_resolver_manager/datamodel/dnssec_config.py +++ /dev/null @@ -1,8 +0,0 @@ -from knot_resolver_manager.compat.dataclasses import dataclass -from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin - - -@dataclass -class DnssecConfig(DataclassParserValidatorMixin): - def _validate(self): - pass diff --git a/manager/knot_resolver_manager/datamodel/hints_config.py b/manager/knot_resolver_manager/datamodel/hints_config.py deleted file mode 100644 index c60b6345e..000000000 --- a/manager/knot_resolver_manager/datamodel/hints_config.py +++ /dev/null @@ -1,8 +0,0 @@ -from knot_resolver_manager.compat.dataclasses import dataclass -from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin - - -@dataclass -class StaticHintsConfig(DataclassParserValidatorMixin): - def _validate(self): - pass diff --git a/manager/knot_resolver_manager/datamodel/logging_config.py b/manager/knot_resolver_manager/datamodel/logging_config.py deleted file mode 100644 index 0a7418479..000000000 --- a/manager/knot_resolver_manager/datamodel/logging_config.py +++ /dev/null @@ -1,12 +0,0 @@ -from knot_resolver_manager.compat.dataclasses import dataclass -from knot_resolver_manager.exceptions import DataValidationException -from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin - - -@dataclass -class LoggingConfig(DataclassParserValidatorMixin): - level: int = 3 - - def _validate(self): - if not 0 <= self.level <= 7: - raise DataValidationException("logging 'level' must be in range 0..7") diff --git a/manager/knot_resolver_manager/datamodel/lua_config.py b/manager/knot_resolver_manager/datamodel/lua_config.py index 40ef9c9e9..0eeaeb342 100644 --- a/manager/knot_resolver_manager/datamodel/lua_config.py +++ b/manager/knot_resolver_manager/datamodel/lua_config.py @@ -1,20 +1,21 @@ -from typing import List, Optional +from typing import List, Optional, Union -from knot_resolver_manager.compat.dataclasses import dataclass -from knot_resolver_manager.exceptions import DataValidationException -from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin +from knot_resolver_manager.utils import DataParser, DataValidator -@dataclass -class LuaConfig(DataclassParserValidatorMixin): - script_list: Optional[List[str]] = None - script: Optional[str] = None +class Lua(DataParser): + script: Optional[Union[List[str], str]] = None + script_file: Optional[str] = None - def __post_init__(self): - # Concatenate array to single string - if self.script_list is not None: - self.script = "\n".join(self.script_list) - def _validate(self): - if self.script is None: - raise DataValidationException("Lua script not specified") +class LuaStrict(DataValidator): + script: Optional[str] + script_file: Optional[str] + + def _script(self, lua: Lua) -> Optional[str]: + if isinstance(lua.script, List): + return "\n".join(lua.script) + return lua.script + + def _validate(self) -> None: + pass diff --git a/manager/knot_resolver_manager/datamodel/lua_template.j2 b/manager/knot_resolver_manager/datamodel/lua_template.j2 index b414df0cc..9e4c95813 100644 --- a/manager/knot_resolver_manager/datamodel/lua_template.j2 +++ b/manager/knot_resolver_manager/datamodel/lua_template.j2 @@ -1,41 +1,21 @@ -{% if cfg.server.hostname %} --- server.hostname -hostname('{{ cfg.server.hostname }}') -{% endif %} - -- network.interfaces {% for item in cfg.network.interfaces %} -net.listen('{{ item.get_address() }}', {{ item.get_port() if item.get_port() else 'nil' }}, { +net.listen('{{ item.address }}', {{ item.port }}, { kind = '{{ item.kind if item.kind != 'dot' else 'tls' }}', freebind = {{ 'true' if item.freebind else 'false'}} }) {% endfor %} --- network.edns-buffer-size -net.bufsize({{ cfg.network.edns_buffer_size.get_downstream() }}, {{ cfg.network.edns_buffer_size.get_upstream() }}) - --- modules -modules = { - 'hints > iterate', -- Load /etc/hosts and allow custom root hints", - 'stats', -- Track internal statistics", -{% if cfg.options.prediction %} - predict = { -- Prefetch expiring/frequent records" - window = {{ cfg.options.prediction.get_window() }}, - period = {{ cfg.options.prediction.period }} - }, -{% endif %} -{% if cfg.dns64 %} - dns64 = '{{ cfg.dns64.prefix }}', -- dns64 +{% if cfg.lua %} +-- lua section +{% if cfg.lua.script_file %} +{% import cfg.lua.script_file as script_file %} +-- lua.script-file +{{ script_file }} {% endif %} -} - --- cache -cache.open({{ cfg.cache.get_size_max() }}, 'lmdb://{{ cfg.cache.storage }}') - --- logging level -verbose({{ 'true' if cfg.logging.level > 3 else 'false'}}) {% if cfg.lua.script %} --- lua +-- lua.script {{ cfg.lua.script }} +{% endif %} {% endif %} \ No newline at end of file diff --git a/manager/knot_resolver_manager/datamodel/network_config.py b/manager/knot_resolver_manager/datamodel/network_config.py index 5e0f5f703..0a0f63434 100644 --- a/manager/knot_resolver_manager/datamodel/network_config.py +++ b/manager/knot_resolver_manager/datamodel/network_config.py @@ -1,75 +1,46 @@ -from typing import List, Optional, Union +from typing import List -from knot_resolver_manager.compat.dataclasses import dataclass, field -from knot_resolver_manager.datamodel.types import SizeUnits -from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin +from knot_resolver_manager.utils import DataParser, DataValidator +from knot_resolver_manager.utils.types import LiteralEnum +KindEnum = LiteralEnum["dns", "xdp", "dot", "doh"] -@dataclass -class InterfacesConfig(DataclassParserValidatorMixin): + +class Interface(DataParser): listen: str - kind: str = "dns" + kind: KindEnum = "dns" freebind: bool = False - _address: Optional[str] = None - _port: Optional[int] = None - _kind_port_map = {"dns": 53, "xdp": 53, "dot": 853, "doh": 443} - - def __post_init__(self): - # split 'address@port' - if "@" in self.listen: - address, port = self.listen.split("@", maxsplit=1) - self._address = address - self._port = int(port) - else: - # if port number not specified - self._address = self.listen - # set port number based on 'kind' - self._port = self._kind_port_map.get(self.kind) - - def get_address(self) -> Optional[str]: - return self._address - def get_port(self) -> Optional[int]: - return self._port - - def _validate(self): - pass +class InterfaceStrict(DataValidator): + address: str + port: int + kind: str + freebind: bool -@dataclass -class EdnsBufferSizeConfig(DataclassParserValidatorMixin): - downstream: Optional[str] = None - upstream: Optional[str] = None - _downstream_bytes: int = 1232 - _upstream_bytes: int = 1232 + def _address(self, obj: Interface) -> str: + if "@" in obj.listen: + address = obj.listen.split("@", maxsplit=1)[0] + return address + return obj.listen - def __post_init__(self): - if self.downstream: - self._downstream_bytes = SizeUnits.parse(self.downstream) - if self.upstream: - self._upstream_bytes = SizeUnits.parse(self.upstream) + def _port(self, obj: Interface) -> int: + port_map = {"dns": 53, "xdp": 53, "dot": 853, "doh": 443} + if "@" in obj.listen: + port = obj.listen.split("@", maxsplit=1)[1] + return int(port) + return port_map.get(obj.kind, 0) - def _validate(self): + def _validate(self) -> None: pass - def get_downstream(self) -> int: - return self._downstream_bytes - - def get_upstream(self) -> int: - return self._upstream_bytes +class Network(DataParser): + interfaces: List[Interface] = [Interface({"listen": "127.0.0.1"}), Interface({"listen": "::1", "freebind": True})] -@dataclass -class NetworkConfig(DataclassParserValidatorMixin): - interfaces: List[InterfacesConfig] = field( - default_factory=lambda: [InterfacesConfig(listen="127.0.0.1"), InterfacesConfig(listen="::1", freebind=True)] - ) - edns_buffer_size: Union[str, EdnsBufferSizeConfig] = EdnsBufferSizeConfig() - def __post_init__(self): - if isinstance(self.edns_buffer_size, str): - bufsize = self.edns_buffer_size - self.edns_buffer_size = EdnsBufferSizeConfig(downstream=bufsize, upstream=bufsize) +class NetworkStrict(DataValidator): + interfaces: List[InterfaceStrict] - def _validate(self): + def _validate(self) -> None: pass diff --git a/manager/knot_resolver_manager/datamodel/options_config.py b/manager/knot_resolver_manager/datamodel/options_config.py deleted file mode 100644 index 866b5b07b..000000000 --- a/manager/knot_resolver_manager/datamodel/options_config.py +++ /dev/null @@ -1,34 +0,0 @@ -from typing import Optional, Union - -from knot_resolver_manager.compat.dataclasses import dataclass -from knot_resolver_manager.datamodel.types import TimeUnits -from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin - - -@dataclass -class PredictionConfig(DataclassParserValidatorMixin): - window: Optional[str] = None - _window_seconds: int = 15 * TimeUnits.minute - period: int = 24 - - def __post_init__(self): - if self.window: - self._window_seconds = TimeUnits.parse(self.window) - - def get_window(self) -> int: - return self._window_seconds - - def _validate(self): - pass - - -@dataclass -class OptionsConfig(DataclassParserValidatorMixin): - prediction: Union[bool, PredictionConfig] = False - - def __post_init__(self): - if self.prediction is True: - self.prediction = PredictionConfig() - - def _validate(self): - pass diff --git a/manager/knot_resolver_manager/datamodel/server_config.py b/manager/knot_resolver_manager/datamodel/server_config.py index d703cd3b7..fd70f74e2 100644 --- a/manager/knot_resolver_manager/datamodel/server_config.py +++ b/manager/knot_resolver_manager/datamodel/server_config.py @@ -1,12 +1,10 @@ import logging import os -from typing import Optional, Union +from typing import Union from typing_extensions import Literal -from knot_resolver_manager.compat.dataclasses import dataclass -from knot_resolver_manager.exceptions import DataValidationException -from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin +from knot_resolver_manager.utils import DataParser, DataValidationException, DataValidator logger = logging.getLogger(__name__) @@ -29,33 +27,22 @@ def _cpu_count() -> int: return cpus -@dataclass -class ServerConfig(DataclassParserValidatorMixin): - hostname: Optional[str] = None - instances: Union[Literal["auto"], int, None] = None - _instances: int = 1 +class Server(DataParser): + workers: Union[Literal["auto"], int] = 1 use_cache_gc: bool = True - def __post_init__(self): - if isinstance(self.instances, int): - self._instances = self.instances - elif self.instances == "auto": - self._instances = _cpu_count() - - def get_instances(self) -> int: - # FIXME: this is a hack to make the partial updates working without a second data structure - # this will be unnecessary in near future - if isinstance(self.instances, int): - return self.instances - elif self.instances == "auto": - cpu_count = os.cpu_count() - if cpu_count is not None: - return cpu_count - else: - raise RuntimeError("cannot find number of system available CPUs") - else: - return 0 - - def _validate(self): - if not 0 < self._instances <= 256: - raise DataValidationException("number of kresd instances must be in range 1..256") + +class ServerStrict(DataValidator): + workers: int + use_cache_gc: bool + + def _workers(self, obj: Server) -> int: + if isinstance(obj.workers, int): + return obj.workers + elif obj.workers == "auto": + return _cpu_count() + raise DataValidationException(f"Unexpected value: {obj.workers}") + + def _validate(self) -> None: + if self.workers < 0: + raise DataValidationException("Number of workers must be non-negative") diff --git a/manager/knot_resolver_manager/datamodel/types.py b/manager/knot_resolver_manager/datamodel/types.py index 3ea96df13..954ccb5ec 100644 --- a/manager/knot_resolver_manager/datamodel/types.py +++ b/manager/knot_resolver_manager/datamodel/types.py @@ -1,41 +1,60 @@ import re +from typing import Any, Dict, Optional, Pattern, Union + +from knot_resolver_manager.utils import CustomValueType, DataValidationException + + +class Unit(CustomValueType): + _re: Pattern[str] + _units: Dict[Optional[str], int] + + def __init__(self, source_value: Any) -> None: + super().__init__(source_value) + self._value: int + self._value_orig: Union[str, int] + if isinstance(source_value, str) and type(self)._re.match(source_value): + self._value_orig = source_value + grouped = type(self)._re.search(source_value) + if grouped: + val, unit = grouped.groups() + if unit not in type(self)._units: + raise DataValidationException(f"Used unexpected unit '{unit}' for {type(self).__name__}...") + self._value = int(val) * type(self)._units[unit] + else: + raise DataValidationException(f"{type(self._value)} Failed to convert: {self}") + elif isinstance(source_value, int): + if source_value < 0: + raise DataValidationException(f"Input value '{source_value}' is not non-negative.") + self._value_orig = source_value + self._value = source_value + else: + raise DataValidationException( + f"Unexpected input type for Unit type - {type(source_value)}." + " Cause might be invalid format or invalid type." + ) + + def __int__(self) -> int: + return self._value + + def __str__(self) -> str: + """ + Used by Jinja2. Must return only a number. + """ + return str(self._value) + + def __eq__(self, o: object) -> bool: + """ + Two instances are equal when they represent the same size + regardless of their string representation. + """ + return isinstance(o, Unit) and o._value == self._value + + +class SizeUnit(Unit): + _re = re.compile(r"^([0-9]+)\s{0,1}([BKMG]){0,1}$") + _units = {None: 1, "B": 1, "K": 1024, "M": 1024 ** 2, "G": 1024 ** 3} -from knot_resolver_manager.exceptions import DataValidationException - -RE_IPV6_PREFIX_96 = re.compile(r"^([0-9A-Fa-f]{1,4}:){2}:$") - - -class TimeUnits: - second = 1 - minute = 60 - hour = 3600 - day = 24 * 3600 +class TimeUnit(Unit): _re = re.compile(r"^(\d+)\s{0,1}([smhd]){0,1}$") - _map = {"s": second, "m": minute, "h": hour, "d": day} - - @staticmethod - def parse(time_str: str) -> int: - searched = TimeUnits._re.search(time_str) - if searched: - value, unit = searched.groups() - return int(value) * TimeUnits._map.get(unit, 1) - raise DataValidationException(f"failed to parse: {time_str}") - - -class SizeUnits: - byte = 1 - kibibyte = 1024 - mebibyte = 1024 ** 2 - gibibyte = 1024 ** 3 - - _re = re.compile(r"^([0-9]+)\s{0,1}([BKMG]){0,1}$") - _map = {"B": byte, "K": kibibyte, "M": mebibyte, "G": gibibyte} - - @staticmethod - def parse(size_str: str) -> int: - searched = SizeUnits._re.search(size_str) - if searched: - value, unit = searched.groups() - return int(value) * SizeUnits._map.get(unit, 1) - raise DataValidationException(f"failed to parse: {size_str}") + _units = {None: 1, "s": 1, "m": 60, "h": 3600, "d": 24 * 3600} diff --git a/manager/knot_resolver_manager/exceptions.py b/manager/knot_resolver_manager/exceptions.py index 5b1789681..192cd7739 100644 --- a/manager/knot_resolver_manager/exceptions.py +++ b/manager/knot_resolver_manager/exceptions.py @@ -1,14 +1,2 @@ class SubprocessControllerException(Exception): pass - - -class ValidationException(Exception): - pass - - -class SchemaValidationException(ValidationException): - pass - - -class DataValidationException(ValidationException): - pass diff --git a/manager/knot_resolver_manager/kres_manager.py b/manager/knot_resolver_manager/kres_manager.py index ef661e3cf..5acddc959 100644 --- a/manager/knot_resolver_manager/kres_manager.py +++ b/manager/knot_resolver_manager/kres_manager.py @@ -9,16 +9,16 @@ import knot_resolver_manager.kresd_controller from knot_resolver_manager import kres_id from knot_resolver_manager.compat.asyncio import create_task from knot_resolver_manager.constants import KRESD_CONFIG_FILE, WATCHDOG_INTERVAL -from knot_resolver_manager.exceptions import ValidationException from knot_resolver_manager.kresd_controller.interface import ( Subprocess, SubprocessController, SubprocessStatus, SubprocessType, ) +from knot_resolver_manager.utils import DataValidationException from knot_resolver_manager.utils.async_utils import writefile -from .datamodel import KresConfig +from .datamodel import KresConfig, KresConfigStrict logger = logging.getLogger(__name__) @@ -85,6 +85,7 @@ class KresManager: self._manager_lock = asyncio.Lock() self._controller: SubprocessController self._last_used_config: Optional[KresConfig] = None + self._last_used_config_strict: Optional[KresConfigStrict] = None self._watchdog_task: Optional["Future[None]"] = None async def load_system_state(self): @@ -140,28 +141,32 @@ class KresManager: await self._gc.stop() self._gc = None - async def _write_config(self, config: KresConfig): - lua_config = config.render_lua() + async def _write_config(self, config_strict: KresConfigStrict): + lua_config = config_strict.render_lua() await writefile(KRESD_CONFIG_FILE, lua_config) async def apply_config(self, config: KresConfig): async with self._manager_lock: + logger.debug("Validating configuration...") + config_strict = KresConfigStrict(config) + logger.debug("Writing new config to file...") - await self._write_config(config) + await self._write_config(config_strict) logger.debug("Testing the new config with a canary process") try: await self._spawn_new_worker() except SubprocessError: logger.error("kresd with the new config failed to start, rejecting config") - last = self.get_last_used_config() + last = self.get_last_used_config_strict() if last is not None: await self._write_config(last) - raise ValidationException("Canary kresd instance failed. Config is invalid.") + raise DataValidationException("Canary kresd instance failed. Config is invalid.") logger.debug("Canary process test passed, Applying new config to all workers") self._last_used_config = config - await self._ensure_number_of_children(config.server.get_instances()) + self._last_used_config_strict = config_strict + await self._ensure_number_of_children(config_strict.server.workers) await self._rolling_restart() if self._is_gc_running() != config.server.use_cache_gc: @@ -183,6 +188,9 @@ class KresManager: def get_last_used_config(self) -> Optional[KresConfig]: return self._last_used_config + def get_last_used_config_strict(self) -> Optional[KresConfigStrict]: + return self._last_used_config_strict + async def _instability_handler(self) -> None: logger.error("Instability callback invoked. No idea how to react, performing suicide. See you later!") sys.exit(1) diff --git a/manager/knot_resolver_manager/server.py b/manager/knot_resolver_manager/server.py index 5f6699867..8603fe67e 100644 --- a/manager/knot_resolver_manager/server.py +++ b/manager/knot_resolver_manager/server.py @@ -11,11 +11,10 @@ from aiohttp.web import middleware from aiohttp.web_response import json_response from knot_resolver_manager.constants import MANAGER_CONFIG_FILE -from knot_resolver_manager.exceptions import ValidationException from knot_resolver_manager.kresd_controller import get_controller_by_name from knot_resolver_manager.kresd_controller.interface import SubprocessController +from knot_resolver_manager.utils import DataValidationException, Format from knot_resolver_manager.utils.async_utils import readfile -from knot_resolver_manager.utils.dataclasses_parservalidator import Format from .datamodel import KresConfig from .kres_manager import KresManager @@ -83,7 +82,7 @@ async def error_handler(request: web.Request, handler: Any): try: return await handler(request) - except ValidationException as e: + except DataValidationException as e: logger.error("Failed to parse given data in API request", exc_info=True) return web.Response(text=f"Data validation failed: {e}", status=HTTPStatus.BAD_REQUEST) diff --git a/manager/knot_resolver_manager/utils/__init__.py b/manager/knot_resolver_manager/utils/__init__.py index c9efbff74..1149f96bc 100644 --- a/manager/knot_resolver_manager/utils/__init__.py +++ b/manager/knot_resolver_manager/utils/__init__.py @@ -1,6 +1,8 @@ from typing import Any, Callable, Iterable, Optional, Type, TypeVar -from .dataclasses_parservalidator import DataclassParserValidatorMixin +from .custom_types import CustomValueType +from .data_parser_validator import DataParser, DataValidator, Format +from .exceptions import DataParsingException, DataValidationException from .overload import Overloaded T = TypeVar("T") @@ -55,6 +57,11 @@ def contains_element_matching(cond: Callable[[T], bool], arr: Iterable[T]) -> bo __all__ = [ "ignore_exceptions_optional", "ignore_exceptions", - "DataclassParserValidatorMixin", + "Format", + "CustomValueType", + "DataParser", + "DataValidator", + "DataParsingException", + "DataValidationException", "Overloaded", ] diff --git a/manager/knot_resolver_manager/utils/custom_types.py b/manager/knot_resolver_manager/utils/custom_types.py new file mode 100644 index 000000000..f8836bdb8 --- /dev/null +++ b/manager/knot_resolver_manager/utils/custom_types.py @@ -0,0 +1,31 @@ +from typing import Any + + +class CustomValueType: + """ + Subclasses of this class can be used as type annotations in 'DataParser'. When a value + is being parsed from a serialized format (e.g. JSON/YAML), an object will be created by + calling the constructor of the appropriate type on the field value. The only limitation + is that the value MUST NOT be `None`. + + Example: + ``` + class A(DataParser): + field: MyCustomValueType + + A.from_json('{"field": "value"}') == A(field=MyCustomValueType("value")) + ``` + + There is no validation done on the wrapped value. The only condition is that + it can't be `None`. If you want to perform any validation during creation, + raise a `DataValidationException` in case of errors. + """ + + def __init__(self, source_value: Any) -> None: + pass + + def __int__(self) -> int: + raise NotImplementedError("CustomValueType return 'int()' value is not implemented.") + + def __str__(self) -> str: + raise NotImplementedError("CustomValueType return 'str()' value is not implemented.") diff --git a/manager/knot_resolver_manager/utils/dataclasses_parservalidator.py b/manager/knot_resolver_manager/utils/data_parser_validator.py similarity index 50% rename from manager/knot_resolver_manager/utils/dataclasses_parservalidator.py rename to manager/knot_resolver_manager/utils/data_parser_validator.py index bb02918ca..72213eff8 100644 --- a/manager/knot_resolver_manager/utils/dataclasses_parservalidator.py +++ b/manager/knot_resolver_manager/utils/data_parser_validator.py @@ -1,4 +1,5 @@ import copy +import inspect import json import re from enum import Enum, auto @@ -8,7 +9,8 @@ import yaml from yaml.constructor import ConstructorError from yaml.nodes import MappingNode -from knot_resolver_manager.exceptions import SchemaValidationException +from knot_resolver_manager.utils.custom_types import CustomValueType +from knot_resolver_manager.utils.exceptions import DataParsingException from knot_resolver_manager.utils.types import ( get_attr_type, get_generic_type_argument, @@ -21,13 +23,45 @@ from knot_resolver_manager.utils.types import ( is_union, ) -from ..compat.dataclasses import is_dataclass +def is_internal_field(field_name: str) -> bool: + return field_name.startswith("_") + + +def is_obj_type(obj: Any, types: Union[type, Tuple[Any, ...], Tuple[type, ...]]) -> bool: + # To check specific type we are using 'type()' instead of 'isinstance()' + # because for example 'bool' is instance of 'int', 'isinstance(False, int)' returns True. + # pylint: disable=unidiomatic-typecheck + if isinstance(types, Tuple): + return type(obj) in types + return type(obj) == types + + +def _to_primitive(obj: Any) -> Any: + """ + Convert our custom values into primitive variants for dumping. + """ + + # CustomValueType instances + if isinstance(obj, CustomValueType): + return str(obj) + + # nested DataParser class instances + elif isinstance(obj, DataParser): + return obj.to_dict() + + # otherwise just return, what we were given + else: + return obj + + +def _validated_object_type(cls: Type[Any], obj: Any, default: Any = ..., use_default: bool = False) -> Any: + """ + Given an expected type `cls` and a value object `obj`, validate the type of `obj` and return it + """ -def _from_dictlike_obj(cls: Any, obj: Any, default: Any, use_default: bool) -> Any: # Disabling these checks, because I think it's much more readable as a single function - # and it's not that large at this point. If it got larger, then we should definitely split - # it + # and it's not that large at this point. If it got larger, then we should definitely split it # pylint: disable=too-many-branches,too-many-locals,too-many-statements # default values @@ -39,57 +73,57 @@ def _from_dictlike_obj(cls: Any, obj: Any, default: Any, use_default: bool) -> A if obj is None: return None else: - raise SchemaValidationException(f"Expected None, found {obj}") + raise DataParsingException(f"Expected None, found '{obj}'.") # Union[*variants] (handles Optional[T] due to the way the typing system works) elif is_union(cls): variants = get_generic_type_arguments(cls) for v in variants: try: - return _from_dictlike_obj(v, obj, ..., False) - except SchemaValidationException: + return _validated_object_type(v, obj) + except DataParsingException: pass - raise SchemaValidationException(f"Union {cls} could not be parsed - parsing of all variants failed") + raise DataParsingException(f"Union {cls} could not be parsed - parsing of all variants failed.") # after this, there is no place for a None object elif obj is None: - raise SchemaValidationException(f"Unexpected None value for type {cls}") + raise DataParsingException(f"Unexpected None value for type {cls}") # int elif cls == int: # we don't want to make an int out of anything else than other int - if isinstance(obj, int): + # except for CustomValueType class instances + if is_obj_type(obj, int) or isinstance(obj, CustomValueType): return int(obj) - else: - raise SchemaValidationException(f"Expected int, found {type(obj)}") + raise DataParsingException(f"Expected int, found {type(obj)}") # str elif cls == str: # we are willing to cast any primitive value to string, but no compound values are allowed - if isinstance(obj, (str, float, int)): + if is_obj_type(obj, (str, float, int)) or isinstance(obj, CustomValueType): return str(obj) - elif isinstance(obj, bool): - raise SchemaValidationException( + elif is_obj_type(obj, bool): + raise DataParsingException( "Expected str, found bool. Be careful, that YAML parsers consider even" ' "no" and "yes" as a bool. Search for the Norway Problem for more' " details. And please use quotes explicitly." ) else: - raise SchemaValidationException( + raise DataParsingException( f"Expected str (or number that would be cast to string), but found type {type(obj)}" ) # bool elif cls == bool: - if isinstance(obj, bool): + if is_obj_type(obj, bool): return obj else: - raise SchemaValidationException(f"Expected bool, found {type(obj)}") + raise DataParsingException(f"Expected bool, found {type(obj)}") # float elif cls == float: raise NotImplementedError( - "Floating point values are not supported in the parser validator." + "Floating point values are not supported in the parser." " Please implement them and be careful with type coercions" ) @@ -99,51 +133,56 @@ def _from_dictlike_obj(cls: Any, obj: Any, default: Any, use_default: bool) -> A if obj == expected: return obj else: - raise SchemaValidationException(f"Literal {cls} is not matched with the value {obj}") + raise DataParsingException(f"Literal {cls} is not matched with the value {obj}") # Dict[K,V] elif is_dict(cls): key_type, val_type = get_generic_type_arguments(cls) try: return { - _from_dictlike_obj(key_type, key, ..., False): _from_dictlike_obj(val_type, val, ..., False) - for key, val in obj.items() + _validated_object_type(key_type, key): _validated_object_type(val_type, val) for key, val in obj.items() } except AttributeError as e: - raise SchemaValidationException( + raise DataParsingException( f"Expected dict-like object, but failed to access its .items() method. Value was {obj}", e ) # List[T] elif is_list(cls): inner_type = get_generic_type_argument(cls) - return [_from_dictlike_obj(inner_type, val, ..., False) for val in obj] + return [_validated_object_type(inner_type, val) for val in obj] # Tuple[A,B,C,D,...] elif is_tuple(cls): types = get_generic_type_arguments(cls) - return tuple(_from_dictlike_obj(typ, val, ..., False) for typ, val in zip(types, obj)) - - # nested dataclass - elif is_dataclass(cls): - anot = cls.__dict__.get("__annotations__", {}) - kwargs = {} - for name, python_type in anot.items(): - # skip internal fields - if name.startswith("_"): - continue - - value = obj[name] if name in obj else None - use_default = hasattr(cls, name) - default = getattr(cls, name, ...) - kwargs[name] = _from_dictlike_obj(python_type, value, default, use_default) - return cls(**kwargs) + return tuple(_validated_object_type(typ, val) for typ, val in zip(types, obj)) + + # CustomValueType subclasses + elif inspect.isclass(cls) and issubclass(cls, CustomValueType): + # no validation performed, the implementation does it in the constuctor + return cls(obj) + + # nested DataParser subclasses + elif inspect.isclass(cls) and issubclass(cls, DataParser): + # we should return DataParser, we expect to be given a dict, + # because we can construct a DataParser from it + if isinstance(obj, dict): + return cls(obj) # type: ignore + raise DataParsingException(f"Expected '{dict}' object, found '{type(obj)}'") + + # nested DataValidator subclasses + elif inspect.isclass(cls) and issubclass(cls, DataValidator): + # we should return DataValidator, we expect to be given a DataParser, + # because we can construct a DataValidator from it + if isinstance(obj, DataParser): + return cls(obj) + raise DataParsingException(f"Expected instance of '{DataParser}' class, found '{type(obj)}'") # default error handler else: - raise SchemaValidationException( + raise DataParsingException( f"Type {cls} cannot be parsed. This is a implementation error. " - "Please fix your types in the dataclass or improve the parser/validator." + "Please fix your types in the class or improve the parser/validator." ) @@ -153,7 +192,7 @@ def json_raise_duplicates(pairs: List[Tuple[Any, Any]]) -> Optional[Any]: dict_out: Dict[Any, Any] = {} for key, val in pairs: if key in dict_out: - raise SchemaValidationException(f"duplicate key detected: {key}") + raise DataParsingException(f"Duplicate attribute key detected: {key}") dict_out[key] = val return dict_out @@ -180,18 +219,12 @@ class RaiseDuplicatesLoader(yaml.SafeLoader): # check for duplicate keys if key in mapping: - raise SchemaValidationException(f"duplicate key detected: {key_node.start_mark}") + raise DataParsingException(f"duplicate key detected: {key_node.start_mark}") value = self.construct_object(value_node, deep=deep) # type: ignore mapping[key] = value return mapping -_T = TypeVar("_T", bound="DataclassParserValidatorMixin") - - -_SUBTREE_MUTATION_PATH_PATTERN = re.compile(r"^(/[^/]+)*/?$") - - class Format(Enum): YAML = auto() JSON = auto() @@ -206,6 +239,14 @@ class Format(Enum): else: raise NotImplementedError(f"Parsing of format '{self}' is not implemented") + def dict_dump(self, data: Dict[str, Any]) -> str: + if self is Format.YAML: + return yaml.safe_dump(data) # type: ignore + elif self is Format.JSON: + return json.dumps(data) + else: + raise NotImplementedError(f"Exporting to '{self}' format is not implemented") + @staticmethod def from_mime_type(mime_type: str) -> "Format": formats = { @@ -214,41 +255,47 @@ class Format(Enum): "text/vnd.yaml": Format.YAML, } if mime_type not in formats: - raise SchemaValidationException("Unsupported MIME type") + raise DataParsingException("Unsupported MIME type") return formats[mime_type] -class DataclassParserValidatorMixin: - def __init__(self, *args: Any, **kwargs: Any): - """ - This constructor is useless except for typechecking. It makes sure that the dataclasses can be created with - any arguments whatsoever. - """ +_T = TypeVar("_T", bound="DataParser") - def validate(self) -> None: - for field_name in dir(self): - # skip internal fields - if field_name.startswith("_"): + +_SUBTREE_MUTATION_PATH_PATTERN = re.compile(r"^(/[^/]+)*/?$") + + +class DataParser: + def __init__(self, obj: Optional[Dict[str, Any]] = None): + cls = self.__class__ + annot = cls.__dict__.get("__annotations__", {}) + + used_keys: List[str] = [] + for name, python_type in annot.items(): + if is_internal_field(name): continue - field = getattr(self, field_name) - if is_dataclass(field): - if not isinstance(field, DataclassParserValidatorMixin): - raise SchemaValidationException( - f"Nested dataclass in the field {field_name} does not include the ParserValidatorMixin" - ) - field.validate() + val = None + dash_name = name.replace("_", "-") + if obj and dash_name in obj: + val = obj[dash_name] + used_keys.append(dash_name) - self._validate() + use_default = hasattr(cls, name) + default = getattr(cls, name, ...) + value = _validated_object_type(python_type, val, default, use_default) + setattr(self, name, value) - def _validate(self) -> None: - raise NotImplementedError(f"Validation function is not implemented in class {type(self).__name__}") + # check for unused keys + if obj: + for key in obj: + if key not in used_keys: + raise DataParsingException(f"Unknown attribute key '{key}'.") @classmethod def parse_from(cls: Type[_T], fmt: Format, text: str): - data = fmt.parse_to_dict(text) - config: _T = _from_dictlike_obj(cls, data, ..., False) - config.validate() + data_dict = fmt.parse_to_dict(text) + config: _T = cls(data_dict) return config @classmethod @@ -259,13 +306,36 @@ class DataclassParserValidatorMixin: def from_json(cls: Type[_T], text: str) -> _T: return cls.parse_from(Format.JSON, text) + def to_dict(self) -> Dict[str, Any]: + cls = self.__class__ + anot = cls.__dict__.get("__annotations__", {}) + dict_obj: Dict[str, Any] = {} + for name in anot: + if is_internal_field(name): + continue + + value = getattr(self, name) + dash_name = str(name).replace("_", "-") + dict_obj[dash_name] = _to_primitive(value) + return dict_obj + + def dump(self, fmt: Format) -> str: + dict_data = self.to_dict() + return fmt.dict_dump(dict_data) + + def dump_to_yaml(self) -> str: + return self.dump(Format.YAML) + + def dump_to_json(self) -> str: + return self.dump(Format.JSON) + def copy_with_changed_subtree(self: _T, fmt: Format, path: str, text: str) -> _T: cls = self.__class__ # prepare and validate the path object path = path[:-1] if path.endswith("/") else path if re.match(_SUBTREE_MUTATION_PATH_PATTERN, path) is None: - raise SchemaValidationException("Provided object path for mutation is invalid.") + raise DataParsingException("Provided object path for mutation is invalid.") path = path[1:] if path.startswith("/") else path # now, the path variable should contain '/' separated field names @@ -278,29 +348,63 @@ class DataclassParserValidatorMixin: to_mutate = copy.deepcopy(self) obj = to_mutate parent = None - for segment in path.split("/"): + + for dash_segment in path.split("/"): + segment = dash_segment.replace("-", "_") + if segment == "": - raise SchemaValidationException(f"Unexpectedly empty segment in path '{path}'") - elif segment.startswith("_"): - raise SchemaValidationException( - "No, changing internal fields (starting with _) is not allowed. Nice try." - ) + raise DataParsingException(f"Unexpectedly empty segment in path '{path}'") + elif is_internal_field(segment): + raise DataParsingException("No, changing internal fields (starting with _) is not allowed. Nice try.") elif hasattr(obj, segment): parent = obj obj = getattr(parent, segment) else: - raise SchemaValidationException( - f"Path segment '{segment}' does not match any field on the provided parent object" + raise DataParsingException( + f"Path segment '{dash_segment}' does not match any field on the provided parent object" ) assert parent is not None # assign the subtree - last_name = path.split("/")[-1] + last_name = path.split("/")[-1].replace("-", "_") data = fmt.parse_to_dict(text) tp = get_attr_type(parent, last_name) - parsed_value = _from_dictlike_obj(tp, data, ..., False) + parsed_value = _validated_object_type(tp, data) setattr(parent, last_name, parsed_value) - to_mutate.validate() - return to_mutate + + +class DataValidator: + def __init__(self, obj: DataParser): + cls = self.__class__ + anot = cls.__dict__.get("__annotations__", {}) + + for attr_name, attr_type in anot.items(): + if is_internal_field(attr_name): + continue + + # use transformation function if available + if hasattr(self, f"_{attr_name}"): + value = getattr(self, f"_{attr_name}")(obj) + elif hasattr(obj, attr_name): + value = getattr(obj, attr_name) + else: + raise DataParsingException(f"DataParser object {obj} is missing '{attr_name}' attribute.") + + setattr(self, attr_name, _validated_object_type(attr_type, value)) + + self._validate() + + def validate(self) -> None: + for field_name in dir(self): + if is_internal_field(field_name): + continue + + field = getattr(self, field_name) + if isinstance(field, DataValidator): + field.validate() + self._validate() + + def _validate(self) -> None: + raise NotImplementedError(f"Validation function is not implemented in class {type(self).__name__}") diff --git a/manager/knot_resolver_manager/utils/exceptions.py b/manager/knot_resolver_manager/utils/exceptions.py new file mode 100644 index 000000000..f6e81509e --- /dev/null +++ b/manager/knot_resolver_manager/utils/exceptions.py @@ -0,0 +1,6 @@ +class DataParsingException(Exception): + pass + + +class DataValidationException(Exception): + pass diff --git a/manager/knot_resolver_manager/utils/types.py b/manager/knot_resolver_manager/utils/types.py index fb2246708..ca19ac241 100644 --- a/manager/knot_resolver_manager/utils/types.py +++ b/manager/knot_resolver_manager/utils/types.py @@ -30,12 +30,15 @@ def is_union(tp: Any) -> bool: def is_literal(tp: Any) -> bool: - return getattr(tp, "__origin__", None) == Literal + return isinstance(tp, type(Literal)) def get_generic_type_arguments(tp: Any) -> List[Any]: default: List[Any] = [] - return getattr(tp, "__args__", default) + if is_literal(tp): + return getattr(tp, "__values__") + else: + return getattr(tp, "__args__", default) def get_generic_type_argument(tp: Any) -> Any: diff --git a/manager/tests/datamodel/test_datamodel_types.py b/manager/tests/datamodel/test_datamodel_types.py new file mode 100644 index 000000000..e5321f4f5 --- /dev/null +++ b/manager/tests/datamodel/test_datamodel_types.py @@ -0,0 +1,66 @@ +from pytest import raises + +from knot_resolver_manager.datamodel.types import SizeUnit, TimeUnit +from knot_resolver_manager.utils import DataParser, DataValidationException, DataValidator + + +def test_size_unit(): + assert ( + SizeUnit(5368709120) + == SizeUnit("5368709120") + == SizeUnit("5368709120B") + == SizeUnit("5242880K") + == SizeUnit("5120M") + == SizeUnit("5G") + ) + + with raises(DataValidationException): + SizeUnit("-5368709120") + with raises(DataValidationException): + SizeUnit(-5368709120) + with raises(DataValidationException): + SizeUnit("5120MM") + + +def test_time_unit(): + assert TimeUnit("1d") == TimeUnit("24h") == TimeUnit("1440m") == TimeUnit("86400s") == TimeUnit(86400) + + with raises(DataValidationException): + TimeUnit("-1") + with raises(DataValidationException): + TimeUnit(-24) + with raises(DataValidationException): + TimeUnit("1440mm") + + +def test_parsing_units(): + class TestClass(DataParser): + size: SizeUnit + time: TimeUnit + + class TestClassStrict(DataValidator): + size: int + time: int + + def _validate(self) -> None: + pass + + yaml = """ +size: 3K +time: 10m +""" + + obj = TestClass.from_yaml(yaml) + assert obj.size == SizeUnit(3 * 1024) + assert obj.time == TimeUnit(10 * 60) + + strict = TestClassStrict(obj) + assert strict.size == 3 * 1024 + assert strict.time == 10 * 60 + + y = obj.dump_to_yaml() + j = obj.dump_to_json() + a = TestClass.from_yaml(y) + b = TestClass.from_json(j) + assert a.size == b.size == obj.size + assert a.time == b.time == obj.time diff --git a/manager/tests/test_datamodel.py b/manager/tests/test_datamodel.py deleted file mode 100644 index a1af7dc08..000000000 --- a/manager/tests/test_datamodel.py +++ /dev/null @@ -1,36 +0,0 @@ -from knot_resolver_manager.datamodel import KresConfig - - -def test_simple(): - json = """ - { - "server": { - "instances": 1 - }, - "lua": { - "script_list": [ - "-- SPDX-License-Identifier: CC0-1.0", - "-- vim:syntax=lua:set ts=4 sw=4:", - "-- Refer to manual: https://knot-resolver.readthedocs.org/en/stable/", - "-- Network interface configuration","net.listen('127.0.0.1', 53, { kind = 'dns' })", - "net.listen('127.0.0.1', 853, { kind = 'tls' })", - "--net.listen('127.0.0.1', 443, { kind = 'doh2' })", - "net.listen('::1', 53, { kind = 'dns', freebind = true })", - "net.listen('::1', 853, { kind = 'tls', freebind = true })", - "--net.listen('::1', 443, { kind = 'doh2' })", - "-- Load useful modules","modules = {", - "'hints > iterate', -- Load /etc/hosts and allow custom root hints", - "'stats', -- Track internal statistics", - "'predict', -- Prefetch expiring/frequent records", - "}", - "-- Cache size", - "cache.size = 100 * MB" - ] - } - } - """ - - config = KresConfig.from_json(json) - - assert config.server.instances == 1 - assert config.lua.script is not None diff --git a/manager/tests/utils/test_data_parser_validator.py b/manager/tests/utils/test_data_parser_validator.py new file mode 100644 index 000000000..5431f3db7 --- /dev/null +++ b/manager/tests/utils/test_data_parser_validator.py @@ -0,0 +1,292 @@ +from typing import Dict, List, Optional, Tuple, Union + +from pytest import raises +from typing_extensions import Literal + +from knot_resolver_manager.utils import DataParser, DataValidationException, DataValidator, Format +from knot_resolver_manager.utils.exceptions import DataParsingException + + +def test_primitive(): + class TestClass(DataParser): + i: int + s: str + b: bool + + class TestClassStrict(DataValidator): + i: int + s: str + b: bool + + def _validate(self) -> None: + pass + + yaml = """ +i: 5 +s: test +b: false +""" + + obj = TestClass.from_yaml(yaml) + assert obj.i == 5 + assert obj.s == "test" + assert obj.b == False + + strict = TestClassStrict(obj) + assert strict.i == 5 + assert strict.s == "test" + assert strict.b == False + + y = obj.dump_to_yaml() + j = obj.dump_to_json() + a = TestClass.from_yaml(y) + b = TestClass.from_json(j) + assert a.i == b.i == obj.i + assert a.s == b.s == obj.s + assert a.b == b.b == obj.b + + +def test_parsing_primitive_exceptions(): + class TestStr(DataParser): + s: str + + # int and float are allowed inputs for string + with raises(DataParsingException): + TestStr.from_yaml("s: false") # bool + + class TestInt(DataParser): + i: int + + with raises(DataParsingException): + TestInt.from_yaml("i: false") # bool + with raises(DataParsingException): + TestInt.from_yaml('i: "5"') # str + with raises(DataParsingException): + TestInt.from_yaml("i: 5.5") # float + + class TestBool(DataParser): + b: bool + + with raises(DataParsingException): + TestBool.from_yaml("b: 5") # int + with raises(DataParsingException): + TestBool.from_yaml('b: "5"') # str + with raises(DataParsingException): + TestBool.from_yaml("b: 5.5") # float + + +def test_nested(): + class Lower(DataParser): + i: int + + class Upper(DataParser): + l: Lower + + class LowerStrict(DataValidator): + i: int + + def _validate(self) -> None: + pass + + class UpperStrict(DataValidator): + l: LowerStrict + + def _validate(self) -> None: + pass + + yaml = """ +l: + i: 5 +""" + + obj = Upper.from_yaml(yaml) + assert obj.l.i == 5 + + strict = UpperStrict(obj) + assert strict.l.i == 5 + + y = obj.dump_to_yaml() + j = obj.dump_to_json() + a = Upper.from_yaml(y) + b = Upper.from_json(j) + assert a.l.i == b.l.i == obj.l.i + + +def test_simple_compount_types(): + class TestClass(DataParser): + l: List[int] + d: Dict[str, str] + t: Tuple[str, int] + o: Optional[int] + + class TestClassStrict(DataValidator): + l: List[int] + d: Dict[str, str] + t: Tuple[str, int] + o: Optional[int] + + def _validate(self) -> None: + pass + + yaml = """ +l: + - 1 + - 2 + - 3 + - 4 + - 5 +d: + something: else + w: all +t: + - test + - 5 +""" + + obj = TestClass.from_yaml(yaml) + assert obj.l == [1, 2, 3, 4, 5] + assert obj.d == {"something": "else", "w": "all"} + assert obj.t == ("test", 5) + assert obj.o is None + + strict = TestClassStrict(obj) + assert strict.l == [1, 2, 3, 4, 5] + assert strict.d == {"something": "else", "w": "all"} + assert strict.t == ("test", 5) + assert strict.o is None + + y = obj.dump_to_yaml() + j = obj.dump_to_json() + a = TestClass.from_yaml(y) + b = TestClass.from_json(j) + assert a.l == b.l == obj.l + assert a.d == b.d == obj.d + assert a.t == b.t == obj.t + assert a.o == b.o == obj.o + + +def test_nested_compound_types(): + class TestClass(DataParser): + o: Optional[Dict[str, str]] + + class TestClassStrict(DataValidator): + o: Optional[Dict[str, str]] + + def _validate(self) -> None: + pass + + yaml = """ +o: + key: val +""" + + obj = TestClass.from_yaml(yaml) + assert obj.o == {"key": "val"} + + strict = TestClassStrict(obj) + assert strict.o == {"key": "val"} + + y = obj.dump_to_yaml() + j = obj.dump_to_json() + a = TestClass.from_yaml(y) + b = TestClass.from_json(j) + assert a.o == b.o == obj.o + + +def test_nested_compount_types2(): + class TestClass(DataParser): + i: int + o: Optional[Dict[str, str]] + + class TestClassStrict(DataValidator): + i: int + o: Optional[Dict[str, str]] + + def _validate(self) -> None: + pass + + yaml = "i: 5" + + obj = TestClass.from_yaml(yaml) + assert obj.i == 5 + assert obj.o is None + + strict = TestClassStrict(obj) + assert strict.i == 5 + assert strict.o is None + + y = obj.dump_to_yaml() + j = obj.dump_to_json() + a = TestClass.from_yaml(y) + b = TestClass.from_json(j) + assert a.i == b.i == obj.i + assert a.o == b.o == obj.o + + +def test_partial_mutations(): + class Inner(DataParser): + size: int = 5 + + class ConfData(DataParser): + workers: Union[Literal["auto"], int] = 1 + lua_config: Optional[str] = None + inner: Inner = Inner() + + class InnerStrict(DataValidator): + size: int + + def _validate(self) -> None: + pass + + class ConfDataStrict(DataValidator): + workers: int + lua_config: Optional[str] + inner: InnerStrict + + def _workers(self, data: ConfData) -> int: + if data.workers == "auto": + return 8 + else: + return data.workers + + def _validate(self) -> None: + if self.workers < 0: + raise DataValidationException("Number of workers must be non-negative") + + yaml = """ + workers: auto + lua-config: something + """ + + conf = ConfData.from_yaml(yaml) + + x = ConfDataStrict(conf) + assert x.lua_config == "something" + assert x.inner.size == 5 + assert x.workers == 8 + + y = conf.dump_to_yaml() + j = conf.dump_to_json() + a = ConfData.from_yaml(y) + b = ConfData.from_json(j) + assert a.workers == b.workers == conf.workers + assert a.lua_config == b.lua_config == conf.lua_config + assert a.inner.size == b.inner.size == conf.inner.size + + # replacement of 'lua-config' attribute + x = ConfDataStrict(conf.copy_with_changed_subtree(Format.JSON, "/lua-config", '"new_value"')) + assert x.lua_config == "new_value" + assert x.inner.size == 5 + assert x.workers == 8 + + # replacement of the whole tree + x = ConfDataStrict(conf.copy_with_changed_subtree(Format.JSON, "/", '{"inner": {"size": 55}}')) + assert x.lua_config is None + assert x.workers == 1 + assert x.inner.size == 55 + + # replacement of 'inner' subtree + x = ConfDataStrict(conf.copy_with_changed_subtree(Format.JSON, "/inner", '{"size": 33}')) + assert x.lua_config == "something" + assert x.workers == 8 + assert x.inner.size == 33 diff --git a/manager/tests/utils/test_dataclasses_parservalidator.py b/manager/tests/utils/test_dataclasses_parservalidator.py deleted file mode 100644 index 277133087..000000000 --- a/manager/tests/utils/test_dataclasses_parservalidator.py +++ /dev/null @@ -1,170 +0,0 @@ -from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple - -from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin, Format - - -def test_parsing_primitive(): - @dataclass - class TestClass(DataclassParserValidatorMixin): - i: int - s: str - - def _validate(self): - pass - - yaml = """i: 5 -s: "test" -""" - - obj = TestClass.from_yaml(yaml) - - assert obj.i == 5 - assert obj.s == "test" - - -def test_parsing_nested(): - @dataclass - class Lower(DataclassParserValidatorMixin): - i: int - - def _validate(self): - pass - - @dataclass - class Upper(DataclassParserValidatorMixin): - l: Lower - - def _validate(self): - pass - - yaml = """l: - i: 5""" - - obj = Upper.from_yaml(yaml) - assert obj.l.i == 5 - - -def test_simple_compount_types(): - @dataclass - class TestClass(DataclassParserValidatorMixin): - l: List[int] - d: Dict[str, str] - t: Tuple[str, int] - o: Optional[int] - - def _validate(self): - pass - - yaml = """l: - - 1 - - 2 - - 3 - - 4 - - 5 -d: - something: else - w: all -t: - - test - - 5""" - - obj = TestClass.from_yaml(yaml) - - assert obj.l == [1, 2, 3, 4, 5] - assert obj.d == {"something": "else", "w": "all"} - assert obj.t == ("test", 5) - assert obj.o is None - - -def test_nested_compount_types(): - @dataclass - class TestClass(DataclassParserValidatorMixin): - o: Optional[Dict[str, str]] - - def _validate(self): - pass - - yaml = """o: - key: val""" - - obj = TestClass.from_yaml(yaml) - - assert obj.o == {"key": "val"} - - -def test_nested_compount_types2(): - @dataclass - class TestClass(DataclassParserValidatorMixin): - i: int - o: Optional[Dict[str, str]] - - def _validate(self): - pass - - yaml = "i: 5" - - obj = TestClass.from_yaml(yaml) - - assert obj.o is None - - -def test_real_failing_dummy_confdata(): - @dataclass - class ConfData(DataclassParserValidatorMixin): - num_workers: int = 1 - lua_config: Optional[str] = None - - def _validate(self): - if self.num_workers < 0: - raise Exception("Number of workers must be non-negative") - - # prepare the payload - lua_config = "dummy" - config = f""" -num_workers: 4 -lua_config: | - { lua_config }""" - - data = ConfData.from_yaml(config) - - assert type(data.num_workers) == int - assert data.num_workers == 4 - assert type(data.lua_config) == str - assert data.lua_config == "dummy" - - -def test_partial_mutations(): - @dataclass - class Inner(DataclassParserValidatorMixin): - number: int - - def _validate(self): - pass - - @dataclass - class ConfData(DataclassParserValidatorMixin): - num_workers: int = 1 - lua_config: Optional[str] = None - inner: Inner = Inner(5) - - def _validate(self): - if self.num_workers < 0: - raise Exception("Number of workers must be non-negative") - - data = ConfData(5, "something", Inner(10)) - - x = data.copy_with_changed_subtree(Format.JSON, "/lua_config", '"new_value"') - assert x.lua_config == "new_value" - assert x.num_workers == 5 - assert x.inner.number == 10 - - x = data.copy_with_changed_subtree(Format.JSON, "/inner", '{"number": 55}') - assert x.lua_config == "something" - assert x.num_workers == 5 - assert x.inner.number == 55 - - x = data.copy_with_changed_subtree(Format.JSON, "/", '{"inner": {"number": 55}}') - assert x.lua_config is None - assert x.num_workers == 1 - assert x.inner.number == 55 diff --git a/manager/tests/utils/test_types.py b/manager/tests/utils/test_types.py index 824d4ea5b..580154a25 100644 --- a/manager/tests/utils/test_types.py +++ b/manager/tests/utils/test_types.py @@ -2,7 +2,7 @@ from typing import List, Union from typing_extensions import Literal -from knot_resolver_manager.utils.types import LiteralEnum, is_list +from knot_resolver_manager.utils.types import LiteralEnum, is_list, is_literal def test_is_list(): @@ -10,6 +10,11 @@ def test_is_list(): assert is_list(List[int]) +def test_is_literal(): + assert is_literal(Literal[5]) + assert is_literal(Literal["test"]) + + def test_literal_enum(): assert LiteralEnum[5, "test"] == Union[Literal[5], Literal["test"]] assert LiteralEnum["str", 5] == Union[Literal["str"], Literal[5]]