]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
utils: data parser and validator reimplementation
authorAleš Mrázek <ales.mrazek@nic.cz>
Fri, 3 Sep 2021 12:11:51 +0000 (14:11 +0200)
committerAleš Mrázek <ales.mrazek@nic.cz>
Fri, 8 Apr 2022 14:17:52 +0000 (16:17 +0200)
28 files changed:
manager/etc/knot-resolver/config.yml
manager/knot_resolver_manager/client/__init__.py
manager/knot_resolver_manager/datamodel/__init__.py
manager/knot_resolver_manager/datamodel/cache_config.py [deleted file]
manager/knot_resolver_manager/datamodel/config.py
manager/knot_resolver_manager/datamodel/dns64_config.py [deleted file]
manager/knot_resolver_manager/datamodel/dnssec_config.py [deleted file]
manager/knot_resolver_manager/datamodel/hints_config.py [deleted file]
manager/knot_resolver_manager/datamodel/logging_config.py [deleted file]
manager/knot_resolver_manager/datamodel/lua_config.py
manager/knot_resolver_manager/datamodel/lua_template.j2
manager/knot_resolver_manager/datamodel/network_config.py
manager/knot_resolver_manager/datamodel/options_config.py [deleted file]
manager/knot_resolver_manager/datamodel/server_config.py
manager/knot_resolver_manager/datamodel/types.py
manager/knot_resolver_manager/exceptions.py
manager/knot_resolver_manager/kres_manager.py
manager/knot_resolver_manager/server.py
manager/knot_resolver_manager/utils/__init__.py
manager/knot_resolver_manager/utils/custom_types.py [new file with mode: 0644]
manager/knot_resolver_manager/utils/data_parser_validator.py [moved from manager/knot_resolver_manager/utils/dataclasses_parservalidator.py with 50% similarity]
manager/knot_resolver_manager/utils/exceptions.py [new file with mode: 0644]
manager/knot_resolver_manager/utils/types.py
manager/tests/datamodel/test_datamodel_types.py [new file with mode: 0644]
manager/tests/test_datamodel.py [deleted file]
manager/tests/utils/test_data_parser_validator.py [new file with mode: 0644]
manager/tests/utils/test_dataclasses_parservalidator.py [deleted file]
manager/tests/utils/test_types.py

index 89990e6aada913cbd4a2d414643261d29dbd245f..04fdb5569e65eda0128d751bd86886888ccb9471 100644 (file)
@@ -1,14 +1,5 @@
 network:
     interfaces:
       - listen: 127.0.0.1@5353
-    edns_buffer_size:
-        downstream: 4K
-options:
-    prediction: true
-cache:
-    storage: .
-    size_max: 100M
-logging:
-    level: 5
 server:
-    instances: 1
+    workers: 1
index c89a68640906f0dce80ac309a14b07ba7b2cb88f..63a96c55375ae5e4371fb04b8b3df3e7903bdd0d 100644 (file)
@@ -24,7 +24,7 @@ class KnotManagerClient:
         print(response.text)
 
     def set_num_workers(self, n: int):
-        response = requests.post(self._create_url("/config/server/instances"), data=str(n))
+        response = requests.post(self._create_url("/config/server/workers"), data=str(n))
         print(response.text)
 
     def wait_for_initialization(self, timeout_sec: float = 5, time_step: float = 0.4):
index 12061c0b74e9da788a8272764947e89a2058cd6e..f1d14c637a422b4fb75189ec3f7d0cb24ac00c8a 100644 (file)
@@ -1,5 +1,3 @@
-from .config import KresConfig
+from .config import KresConfig, KresConfigStrict
 
-__all__ = [
-    "KresConfig",
-]
+__all__ = ["KresConfig", "KresConfigStrict"]
diff --git a/manager/knot_resolver_manager/datamodel/cache_config.py b/manager/knot_resolver_manager/datamodel/cache_config.py
deleted file mode 100644 (file)
index 11620c8..0000000
+++ /dev/null
@@ -1,22 +0,0 @@
-from typing import Optional
-
-from knot_resolver_manager.compat.dataclasses import dataclass
-from knot_resolver_manager.datamodel.types import SizeUnits
-from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin
-
-
-@dataclass
-class CacheConfig(DataclassParserValidatorMixin):
-    storage: str = "/var/cache/knot-resolver"
-    size_max: Optional[str] = None
-    _size_max_bytes: int = 100 * SizeUnits.mebibyte
-
-    def __post_init__(self):
-        if self.size_max:
-            self._size_max_bytes = SizeUnits.parse(self.size_max)
-
-    def get_size_max(self) -> int:
-        return self._size_max_bytes
-
-    def _validate(self):
-        pass
index c606d86948289da7ae058339cebe3caa65677de0..e73e2f5a8f973b92b0a1fd0d2eb59d5469a46791 100644 (file)
@@ -1,18 +1,12 @@
 import pkgutil
-from typing import Optional, Text, Union
+from typing import Optional, Text
 
 from jinja2 import Environment, Template
 
-from knot_resolver_manager.compat.dataclasses import dataclass
-from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin
-
-from .cache_config import CacheConfig
-from .dns64_config import Dns64Config
-from .logging_config import LoggingConfig
-from .lua_config import LuaConfig
-from .network_config import NetworkConfig
-from .options_config import OptionsConfig
-from .server_config import ServerConfig
+from knot_resolver_manager.datamodel.lua_config import Lua, LuaStrict
+from knot_resolver_manager.datamodel.network_config import Network, NetworkStrict
+from knot_resolver_manager.datamodel.server_config import Server, ServerStrict
+from knot_resolver_manager.utils import DataParser, DataValidator
 
 
 def _import_lua_template() -> Template:
@@ -26,25 +20,19 @@ def _import_lua_template() -> Template:
 _LUA_TEMPLATE = _import_lua_template()
 
 
-@dataclass
-class KresConfig(DataclassParserValidatorMixin):
-    # pylint: disable=too-many-instance-attributes
-    server: ServerConfig = ServerConfig()
-    network: NetworkConfig = NetworkConfig()
-    options: OptionsConfig = OptionsConfig()
-    cache: CacheConfig = CacheConfig()
-    # DNS64 is disabled by default
-    dns64: Union[bool, Dns64Config] = False
-    logging: LoggingConfig = LoggingConfig()
-    lua: Optional[LuaConfig] = None
-
-    def __post_init__(self):
-        # if DNS64 is enabled with defaults
-        if self.dns64 is True:
-            self.dns64 = Dns64Config()
-
-    def _validate(self):
-        pass
+class KresConfig(DataParser):
+    server: Server = Server()
+    network: Network = Network()
+    lua: Optional[Lua] = None
+
+
+class KresConfigStrict(DataValidator):
+    server: ServerStrict
+    network: NetworkStrict
+    lua: Optional[LuaStrict]
 
     def render_lua(self) -> Text:
         return _LUA_TEMPLATE.render(cfg=self)
+
+    def _validate(self) -> None:
+        pass
diff --git a/manager/knot_resolver_manager/datamodel/dns64_config.py b/manager/knot_resolver_manager/datamodel/dns64_config.py
deleted file mode 100644 (file)
index f6eeae8..0000000
+++ /dev/null
@@ -1,14 +0,0 @@
-from knot_resolver_manager.compat.dataclasses import dataclass
-from knot_resolver_manager.exceptions import DataValidationException
-from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin
-
-from .types import RE_IPV6_PREFIX_96
-
-
-@dataclass
-class Dns64Config(DataclassParserValidatorMixin):
-    prefix: str = "64:ff9b::"
-
-    def _validate(self):
-        if not bool(RE_IPV6_PREFIX_96.match(self.prefix)):
-            raise DataValidationException("'dns64.prefix' must be valid IPv6 /96 prefix")
diff --git a/manager/knot_resolver_manager/datamodel/dnssec_config.py b/manager/knot_resolver_manager/datamodel/dnssec_config.py
deleted file mode 100644 (file)
index 8b308f0..0000000
+++ /dev/null
@@ -1,8 +0,0 @@
-from knot_resolver_manager.compat.dataclasses import dataclass
-from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin
-
-
-@dataclass
-class DnssecConfig(DataclassParserValidatorMixin):
-    def _validate(self):
-        pass
diff --git a/manager/knot_resolver_manager/datamodel/hints_config.py b/manager/knot_resolver_manager/datamodel/hints_config.py
deleted file mode 100644 (file)
index c60b634..0000000
+++ /dev/null
@@ -1,8 +0,0 @@
-from knot_resolver_manager.compat.dataclasses import dataclass
-from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin
-
-
-@dataclass
-class StaticHintsConfig(DataclassParserValidatorMixin):
-    def _validate(self):
-        pass
diff --git a/manager/knot_resolver_manager/datamodel/logging_config.py b/manager/knot_resolver_manager/datamodel/logging_config.py
deleted file mode 100644 (file)
index 0a74184..0000000
+++ /dev/null
@@ -1,12 +0,0 @@
-from knot_resolver_manager.compat.dataclasses import dataclass
-from knot_resolver_manager.exceptions import DataValidationException
-from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin
-
-
-@dataclass
-class LoggingConfig(DataclassParserValidatorMixin):
-    level: int = 3
-
-    def _validate(self):
-        if not 0 <= self.level <= 7:
-            raise DataValidationException("logging 'level' must be in range 0..7")
index 40ef9c9e9e5970cedbb1fc35c9ba882733fdca3c..0eeaeb3424b2689ed30d82b474d63dd26ab62829 100644 (file)
@@ -1,20 +1,21 @@
-from typing import List, Optional
+from typing import List, Optional, Union
 
-from knot_resolver_manager.compat.dataclasses import dataclass
-from knot_resolver_manager.exceptions import DataValidationException
-from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin
+from knot_resolver_manager.utils import DataParser, DataValidator
 
 
-@dataclass
-class LuaConfig(DataclassParserValidatorMixin):
-    script_list: Optional[List[str]] = None
-    script: Optional[str] = None
+class Lua(DataParser):
+    script: Optional[Union[List[str], str]] = None
+    script_file: Optional[str] = None
 
-    def __post_init__(self):
-        # Concatenate array to single string
-        if self.script_list is not None:
-            self.script = "\n".join(self.script_list)
 
-    def _validate(self):
-        if self.script is None:
-            raise DataValidationException("Lua script not specified")
+class LuaStrict(DataValidator):
+    script: Optional[str]
+    script_file: Optional[str]
+
+    def _script(self, lua: Lua) -> Optional[str]:
+        if isinstance(lua.script, List):
+            return "\n".join(lua.script)
+        return lua.script
+
+    def _validate(self) -> None:
+        pass
index b414df0cca65d26132b5585280a81249e1a8f726..9e4c958136ccdf90340ab8782e622d6453d61dfa 100644 (file)
@@ -1,41 +1,21 @@
-{% if cfg.server.hostname %}
--- server.hostname
-hostname('{{ cfg.server.hostname }}')
-{% endif %}
-
 -- network.interfaces
 {% for item in cfg.network.interfaces %}
-net.listen('{{ item.get_address() }}', {{ item.get_port() if item.get_port() else 'nil' }}, {
+net.listen('{{ item.address }}', {{ item.port }}, {
     kind = '{{ item.kind if item.kind != 'dot' else 'tls' }}',
     freebind = {{ 'true' if item.freebind else 'false'}}
 })
 {% endfor %}
 
--- network.edns-buffer-size
-net.bufsize({{ cfg.network.edns_buffer_size.get_downstream() }}, {{ cfg.network.edns_buffer_size.get_upstream() }})
-
--- modules
-modules = {
-    'hints > iterate',   -- Load /etc/hosts and allow custom root hints",
-    'stats',             -- Track internal statistics",
-{% if cfg.options.prediction %}
-    predict = {          -- Prefetch expiring/frequent records"
-        window = {{ cfg.options.prediction.get_window() }},
-        period = {{ cfg.options.prediction.period }}
-    },
-{% endif %}
-{% if cfg.dns64 %}
-    dns64 = '{{ cfg.dns64.prefix }}', -- dns64
+{% if cfg.lua %}
+-- lua section
+{% if cfg.lua.script_file %}
+{% import cfg.lua.script_file as script_file %}
+-- lua.script-file
+{{ script_file }}
 {% endif %}
-}
-
--- cache
-cache.open({{ cfg.cache.get_size_max() }}, 'lmdb://{{ cfg.cache.storage }}')
-
--- logging level
-verbose({{ 'true' if cfg.logging.level > 3 else 'false'}})
 
 {% if cfg.lua.script %}
--- lua
+-- lua.script
 {{ cfg.lua.script }}
+{% endif %}
 {% endif %}
\ No newline at end of file
index 5e0f5f70318c18a4a6a55c82086159ca302cb30a..0a0f634341c2e95f8a07c7df1b43dd43acd14095 100644 (file)
@@ -1,75 +1,46 @@
-from typing import List, Optional, Union
+from typing import List
 
-from knot_resolver_manager.compat.dataclasses import dataclass, field
-from knot_resolver_manager.datamodel.types import SizeUnits
-from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin
+from knot_resolver_manager.utils import DataParser, DataValidator
+from knot_resolver_manager.utils.types import LiteralEnum
 
+KindEnum = LiteralEnum["dns", "xdp", "dot", "doh"]
 
-@dataclass
-class InterfacesConfig(DataclassParserValidatorMixin):
+
+class Interface(DataParser):
     listen: str
-    kind: str = "dns"
+    kind: KindEnum = "dns"
     freebind: bool = False
-    _address: Optional[str] = None
-    _port: Optional[int] = None
-    _kind_port_map = {"dns": 53, "xdp": 53, "dot": 853, "doh": 443}
-
-    def __post_init__(self):
-        # split 'address@port'
-        if "@" in self.listen:
-            address, port = self.listen.split("@", maxsplit=1)
-            self._address = address
-            self._port = int(port)
-        else:
-            # if port number not specified
-            self._address = self.listen
-            # set port number based on 'kind'
-            self._port = self._kind_port_map.get(self.kind)
-
-    def get_address(self) -> Optional[str]:
-        return self._address
 
-    def get_port(self) -> Optional[int]:
-        return self._port
-
-    def _validate(self):
-        pass
 
+class InterfaceStrict(DataValidator):
+    address: str
+    port: int
+    kind: str
+    freebind: bool
 
-@dataclass
-class EdnsBufferSizeConfig(DataclassParserValidatorMixin):
-    downstream: Optional[str] = None
-    upstream: Optional[str] = None
-    _downstream_bytes: int = 1232
-    _upstream_bytes: int = 1232
+    def _address(self, obj: Interface) -> str:
+        if "@" in obj.listen:
+            address = obj.listen.split("@", maxsplit=1)[0]
+            return address
+        return obj.listen
 
-    def __post_init__(self):
-        if self.downstream:
-            self._downstream_bytes = SizeUnits.parse(self.downstream)
-        if self.upstream:
-            self._upstream_bytes = SizeUnits.parse(self.upstream)
+    def _port(self, obj: Interface) -> int:
+        port_map = {"dns": 53, "xdp": 53, "dot": 853, "doh": 443}
+        if "@" in obj.listen:
+            port = obj.listen.split("@", maxsplit=1)[1]
+            return int(port)
+        return port_map.get(obj.kind, 0)
 
-    def _validate(self):
+    def _validate(self) -> None:
         pass
 
-    def get_downstream(self) -> int:
-        return self._downstream_bytes
-
-    def get_upstream(self) -> int:
-        return self._upstream_bytes
 
+class Network(DataParser):
+    interfaces: List[Interface] = [Interface({"listen": "127.0.0.1"}), Interface({"listen": "::1", "freebind": True})]
 
-@dataclass
-class NetworkConfig(DataclassParserValidatorMixin):
-    interfaces: List[InterfacesConfig] = field(
-        default_factory=lambda: [InterfacesConfig(listen="127.0.0.1"), InterfacesConfig(listen="::1", freebind=True)]
-    )
-    edns_buffer_size: Union[str, EdnsBufferSizeConfig] = EdnsBufferSizeConfig()
 
-    def __post_init__(self):
-        if isinstance(self.edns_buffer_size, str):
-            bufsize = self.edns_buffer_size
-            self.edns_buffer_size = EdnsBufferSizeConfig(downstream=bufsize, upstream=bufsize)
+class NetworkStrict(DataValidator):
+    interfaces: List[InterfaceStrict]
 
-    def _validate(self):
+    def _validate(self) -> None:
         pass
diff --git a/manager/knot_resolver_manager/datamodel/options_config.py b/manager/knot_resolver_manager/datamodel/options_config.py
deleted file mode 100644 (file)
index 866b5b0..0000000
+++ /dev/null
@@ -1,34 +0,0 @@
-from typing import Optional, Union
-
-from knot_resolver_manager.compat.dataclasses import dataclass
-from knot_resolver_manager.datamodel.types import TimeUnits
-from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin
-
-
-@dataclass
-class PredictionConfig(DataclassParserValidatorMixin):
-    window: Optional[str] = None
-    _window_seconds: int = 15 * TimeUnits.minute
-    period: int = 24
-
-    def __post_init__(self):
-        if self.window:
-            self._window_seconds = TimeUnits.parse(self.window)
-
-    def get_window(self) -> int:
-        return self._window_seconds
-
-    def _validate(self):
-        pass
-
-
-@dataclass
-class OptionsConfig(DataclassParserValidatorMixin):
-    prediction: Union[bool, PredictionConfig] = False
-
-    def __post_init__(self):
-        if self.prediction is True:
-            self.prediction = PredictionConfig()
-
-    def _validate(self):
-        pass
index d703cd3b77cfca7396dd3a149d705d976962d1cd..fd70f74e20a0530d02e9adf7f7e0b5def369cd87 100644 (file)
@@ -1,12 +1,10 @@
 import logging
 import os
-from typing import Optional, Union
+from typing import Union
 
 from typing_extensions import Literal
 
-from knot_resolver_manager.compat.dataclasses import dataclass
-from knot_resolver_manager.exceptions import DataValidationException
-from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin
+from knot_resolver_manager.utils import DataParser, DataValidationException, DataValidator
 
 logger = logging.getLogger(__name__)
 
@@ -29,33 +27,22 @@ def _cpu_count() -> int:
         return cpus
 
 
-@dataclass
-class ServerConfig(DataclassParserValidatorMixin):
-    hostname: Optional[str] = None
-    instances: Union[Literal["auto"], int, None] = None
-    _instances: int = 1
+class Server(DataParser):
+    workers: Union[Literal["auto"], int] = 1
     use_cache_gc: bool = True
 
-    def __post_init__(self):
-        if isinstance(self.instances, int):
-            self._instances = self.instances
-        elif self.instances == "auto":
-            self._instances = _cpu_count()
-
-    def get_instances(self) -> int:
-        # FIXME: this is a hack to make the partial updates working without a second data structure
-        # this will be unnecessary in near future
-        if isinstance(self.instances, int):
-            return self.instances
-        elif self.instances == "auto":
-            cpu_count = os.cpu_count()
-            if cpu_count is not None:
-                return cpu_count
-            else:
-                raise RuntimeError("cannot find number of system available CPUs")
-        else:
-            return 0
-
-    def _validate(self):
-        if not 0 < self._instances <= 256:
-            raise DataValidationException("number of kresd instances must be in range 1..256")
+
+class ServerStrict(DataValidator):
+    workers: int
+    use_cache_gc: bool
+
+    def _workers(self, obj: Server) -> int:
+        if isinstance(obj.workers, int):
+            return obj.workers
+        elif obj.workers == "auto":
+            return _cpu_count()
+        raise DataValidationException(f"Unexpected value: {obj.workers}")
+
+    def _validate(self) -> None:
+        if self.workers < 0:
+            raise DataValidationException("Number of workers must be non-negative")
index 3ea96df13ca7b1e7d95c37225f80456edcf25647..954ccb5ece47214dc9489568158123d4a8fc9b46 100644 (file)
@@ -1,41 +1,60 @@
 import re
+from typing import Any, Dict, Optional, Pattern, Union
+
+from knot_resolver_manager.utils import CustomValueType, DataValidationException
+
+
+class Unit(CustomValueType):
+    _re: Pattern[str]
+    _units: Dict[Optional[str], int]
+
+    def __init__(self, source_value: Any) -> None:
+        super().__init__(source_value)
+        self._value: int
+        self._value_orig: Union[str, int]
+        if isinstance(source_value, str) and type(self)._re.match(source_value):
+            self._value_orig = source_value
+            grouped = type(self)._re.search(source_value)
+            if grouped:
+                val, unit = grouped.groups()
+                if unit not in type(self)._units:
+                    raise DataValidationException(f"Used unexpected unit '{unit}' for {type(self).__name__}...")
+                self._value = int(val) * type(self)._units[unit]
+            else:
+                raise DataValidationException(f"{type(self._value)} Failed to convert: {self}")
+        elif isinstance(source_value, int):
+            if source_value < 0:
+                raise DataValidationException(f"Input value '{source_value}' is not non-negative.")
+            self._value_orig = source_value
+            self._value = source_value
+        else:
+            raise DataValidationException(
+                f"Unexpected input type for Unit type - {type(source_value)}."
+                " Cause might be invalid format or invalid type."
+            )
+
+    def __int__(self) -> int:
+        return self._value
+
+    def __str__(self) -> str:
+        """
+        Used by Jinja2. Must return only a number.
+        """
+        return str(self._value)
+
+    def __eq__(self, o: object) -> bool:
+        """
+        Two instances are equal when they represent the same size
+        regardless of their string representation.
+        """
+        return isinstance(o, Unit) and o._value == self._value
+
+
+class SizeUnit(Unit):
+    _re = re.compile(r"^([0-9]+)\s{0,1}([BKMG]){0,1}$")
+    _units = {None: 1, "B": 1, "K": 1024, "M": 1024 ** 2, "G": 1024 ** 3}
 
-from knot_resolver_manager.exceptions import DataValidationException
-
-RE_IPV6_PREFIX_96 = re.compile(r"^([0-9A-Fa-f]{1,4}:){2}:$")
-
-
-class TimeUnits:
-    second = 1
-    minute = 60
-    hour = 3600
-    day = 24 * 3600
 
+class TimeUnit(Unit):
     _re = re.compile(r"^(\d+)\s{0,1}([smhd]){0,1}$")
-    _map = {"s": second, "m": minute, "h": hour, "d": day}
-
-    @staticmethod
-    def parse(time_str: str) -> int:
-        searched = TimeUnits._re.search(time_str)
-        if searched:
-            value, unit = searched.groups()
-            return int(value) * TimeUnits._map.get(unit, 1)
-        raise DataValidationException(f"failed to parse: {time_str}")
-
-
-class SizeUnits:
-    byte = 1
-    kibibyte = 1024
-    mebibyte = 1024 ** 2
-    gibibyte = 1024 ** 3
-
-    _re = re.compile(r"^([0-9]+)\s{0,1}([BKMG]){0,1}$")
-    _map = {"B": byte, "K": kibibyte, "M": mebibyte, "G": gibibyte}
-
-    @staticmethod
-    def parse(size_str: str) -> int:
-        searched = SizeUnits._re.search(size_str)
-        if searched:
-            value, unit = searched.groups()
-            return int(value) * SizeUnits._map.get(unit, 1)
-        raise DataValidationException(f"failed to parse: {size_str}")
+    _units = {None: 1, "s": 1, "m": 60, "h": 3600, "d": 24 * 3600}
index 5b17896811fa8e402336e89a50052ef287e5d233..192cd7739f5051104dd77e4745c3bdf80587cb6c 100644 (file)
@@ -1,14 +1,2 @@
 class SubprocessControllerException(Exception):
     pass
-
-
-class ValidationException(Exception):
-    pass
-
-
-class SchemaValidationException(ValidationException):
-    pass
-
-
-class DataValidationException(ValidationException):
-    pass
index ef661e3cfba04f39363b7abedaac97a58b65f289..5acddc959e0ab1ad9a81de54bf4aedc58e47447c 100644 (file)
@@ -9,16 +9,16 @@ import knot_resolver_manager.kresd_controller
 from knot_resolver_manager import kres_id
 from knot_resolver_manager.compat.asyncio import create_task
 from knot_resolver_manager.constants import KRESD_CONFIG_FILE, WATCHDOG_INTERVAL
-from knot_resolver_manager.exceptions import ValidationException
 from knot_resolver_manager.kresd_controller.interface import (
     Subprocess,
     SubprocessController,
     SubprocessStatus,
     SubprocessType,
 )
+from knot_resolver_manager.utils import DataValidationException
 from knot_resolver_manager.utils.async_utils import writefile
 
-from .datamodel import KresConfig
+from .datamodel import KresConfig, KresConfigStrict
 
 logger = logging.getLogger(__name__)
 
@@ -85,6 +85,7 @@ class KresManager:
         self._manager_lock = asyncio.Lock()
         self._controller: SubprocessController
         self._last_used_config: Optional[KresConfig] = None
+        self._last_used_config_strict: Optional[KresConfigStrict] = None
         self._watchdog_task: Optional["Future[None]"] = None
 
     async def load_system_state(self):
@@ -140,28 +141,32 @@ class KresManager:
         await self._gc.stop()
         self._gc = None
 
-    async def _write_config(self, config: KresConfig):
-        lua_config = config.render_lua()
+    async def _write_config(self, config_strict: KresConfigStrict):
+        lua_config = config_strict.render_lua()
         await writefile(KRESD_CONFIG_FILE, lua_config)
 
     async def apply_config(self, config: KresConfig):
         async with self._manager_lock:
+            logger.debug("Validating configuration...")
+            config_strict = KresConfigStrict(config)
+
             logger.debug("Writing new config to file...")
-            await self._write_config(config)
+            await self._write_config(config_strict)
 
             logger.debug("Testing the new config with a canary process")
             try:
                 await self._spawn_new_worker()
             except SubprocessError:
                 logger.error("kresd with the new config failed to start, rejecting config")
-                last = self.get_last_used_config()
+                last = self.get_last_used_config_strict()
                 if last is not None:
                     await self._write_config(last)
-                raise ValidationException("Canary kresd instance failed. Config is invalid.")
+                raise DataValidationException("Canary kresd instance failed. Config is invalid.")
 
             logger.debug("Canary process test passed, Applying new config to all workers")
             self._last_used_config = config
-            await self._ensure_number_of_children(config.server.get_instances())
+            self._last_used_config_strict = config_strict
+            await self._ensure_number_of_children(config_strict.server.workers)
             await self._rolling_restart()
 
             if self._is_gc_running() != config.server.use_cache_gc:
@@ -183,6 +188,9 @@ class KresManager:
     def get_last_used_config(self) -> Optional[KresConfig]:
         return self._last_used_config
 
+    def get_last_used_config_strict(self) -> Optional[KresConfigStrict]:
+        return self._last_used_config_strict
+
     async def _instability_handler(self) -> None:
         logger.error("Instability callback invoked. No idea how to react, performing suicide. See you later!")
         sys.exit(1)
index 5f669986779cc439ed3d866d319b26bd9f73dfe1..8603fe67e294f68b06d3203e1527dc332ec6123c 100644 (file)
@@ -11,11 +11,10 @@ from aiohttp.web import middleware
 from aiohttp.web_response import json_response
 
 from knot_resolver_manager.constants import MANAGER_CONFIG_FILE
-from knot_resolver_manager.exceptions import ValidationException
 from knot_resolver_manager.kresd_controller import get_controller_by_name
 from knot_resolver_manager.kresd_controller.interface import SubprocessController
+from knot_resolver_manager.utils import DataValidationException, Format
 from knot_resolver_manager.utils.async_utils import readfile
-from knot_resolver_manager.utils.dataclasses_parservalidator import Format
 
 from .datamodel import KresConfig
 from .kres_manager import KresManager
@@ -83,7 +82,7 @@ async def error_handler(request: web.Request, handler: Any):
 
     try:
         return await handler(request)
-    except ValidationException as e:
+    except DataValidationException as e:
         logger.error("Failed to parse given data in API request", exc_info=True)
         return web.Response(text=f"Data validation failed: {e}", status=HTTPStatus.BAD_REQUEST)
 
index c9efbff7495c6cf6fb457100ce0b268e8cd96860..1149f96bc48f7a5bb21bb84a558344a7ceb69818 100644 (file)
@@ -1,6 +1,8 @@
 from typing import Any, Callable, Iterable, Optional, Type, TypeVar
 
-from .dataclasses_parservalidator import DataclassParserValidatorMixin
+from .custom_types import CustomValueType
+from .data_parser_validator import DataParser, DataValidator, Format
+from .exceptions import DataParsingException, DataValidationException
 from .overload import Overloaded
 
 T = TypeVar("T")
@@ -55,6 +57,11 @@ def contains_element_matching(cond: Callable[[T], bool], arr: Iterable[T]) -> bo
 __all__ = [
     "ignore_exceptions_optional",
     "ignore_exceptions",
-    "DataclassParserValidatorMixin",
+    "Format",
+    "CustomValueType",
+    "DataParser",
+    "DataValidator",
+    "DataParsingException",
+    "DataValidationException",
     "Overloaded",
 ]
diff --git a/manager/knot_resolver_manager/utils/custom_types.py b/manager/knot_resolver_manager/utils/custom_types.py
new file mode 100644 (file)
index 0000000..f8836bd
--- /dev/null
@@ -0,0 +1,31 @@
+from typing import Any
+
+
+class CustomValueType:
+    """
+    Subclasses of this class can be used as type annotations in 'DataParser'. When a value
+    is being parsed from a serialized format (e.g. JSON/YAML), an object will be created by
+    calling the constructor of the appropriate type on the field value. The only limitation
+    is that the value MUST NOT be `None`.
+
+    Example:
+    ```
+    class A(DataParser):
+        field: MyCustomValueType
+
+    A.from_json('{"field": "value"}') == A(field=MyCustomValueType("value"))
+    ```
+
+    There is no validation done on the wrapped value. The only condition is that
+    it can't be `None`. If you want to perform any validation during creation,
+    raise a `DataValidationException` in case of errors.
+    """
+
+    def __init__(self, source_value: Any) -> None:
+        pass
+
+    def __int__(self) -> int:
+        raise NotImplementedError("CustomValueType return 'int()' value is not implemented.")
+
+    def __str__(self) -> str:
+        raise NotImplementedError("CustomValueType return 'str()' value is not implemented.")
similarity index 50%
rename from manager/knot_resolver_manager/utils/dataclasses_parservalidator.py
rename to manager/knot_resolver_manager/utils/data_parser_validator.py
index bb02918ca0cb5a9d64f7b405189c74e6efdec4a5..72213eff81d7c4c1e512935828c58eaab970b0bd 100644 (file)
@@ -1,4 +1,5 @@
 import copy
+import inspect
 import json
 import re
 from enum import Enum, auto
@@ -8,7 +9,8 @@ import yaml
 from yaml.constructor import ConstructorError
 from yaml.nodes import MappingNode
 
-from knot_resolver_manager.exceptions import SchemaValidationException
+from knot_resolver_manager.utils.custom_types import CustomValueType
+from knot_resolver_manager.utils.exceptions import DataParsingException
 from knot_resolver_manager.utils.types import (
     get_attr_type,
     get_generic_type_argument,
@@ -21,13 +23,45 @@ from knot_resolver_manager.utils.types import (
     is_union,
 )
 
-from ..compat.dataclasses import is_dataclass
 
+def is_internal_field(field_name: str) -> bool:
+    return field_name.startswith("_")
+
+
+def is_obj_type(obj: Any, types: Union[type, Tuple[Any, ...], Tuple[type, ...]]) -> bool:
+    # To check specific type we are using 'type()' instead of 'isinstance()'
+    # because for example 'bool' is instance of 'int', 'isinstance(False, int)' returns True.
+    # pylint: disable=unidiomatic-typecheck
+    if isinstance(types, Tuple):
+        return type(obj) in types
+    return type(obj) == types
+
+
+def _to_primitive(obj: Any) -> Any:
+    """
+    Convert our custom values into primitive variants for dumping.
+    """
+
+    # CustomValueType instances
+    if isinstance(obj, CustomValueType):
+        return str(obj)
+
+    # nested DataParser class instances
+    elif isinstance(obj, DataParser):
+        return obj.to_dict()
+
+    # otherwise just return, what we were given
+    else:
+        return obj
+
+
+def _validated_object_type(cls: Type[Any], obj: Any, default: Any = ..., use_default: bool = False) -> Any:
+    """
+    Given an expected type `cls` and a value object `obj`, validate the type of `obj` and return it
+    """
 
-def _from_dictlike_obj(cls: Any, obj: Any, default: Any, use_default: bool) -> Any:
     # Disabling these checks, because I think it's much more readable as a single function
-    # and it's not that large at this point. If it got larger, then we should definitely split
-    # it
+    # and it's not that large at this point. If it got larger, then we should definitely split it
     # pylint: disable=too-many-branches,too-many-locals,too-many-statements
 
     # default values
@@ -39,57 +73,57 @@ def _from_dictlike_obj(cls: Any, obj: Any, default: Any, use_default: bool) -> A
         if obj is None:
             return None
         else:
-            raise SchemaValidationException(f"Expected None, found {obj}")
+            raise DataParsingException(f"Expected None, found '{obj}'.")
 
     # Union[*variants] (handles Optional[T] due to the way the typing system works)
     elif is_union(cls):
         variants = get_generic_type_arguments(cls)
         for v in variants:
             try:
-                return _from_dictlike_obj(v, obj, ..., False)
-            except SchemaValidationException:
+                return _validated_object_type(v, obj)
+            except DataParsingException:
                 pass
-        raise SchemaValidationException(f"Union {cls} could not be parsed - parsing of all variants failed")
+        raise DataParsingException(f"Union {cls} could not be parsed - parsing of all variants failed.")
 
     # after this, there is no place for a None object
     elif obj is None:
-        raise SchemaValidationException(f"Unexpected None value for type {cls}")
+        raise DataParsingException(f"Unexpected None value for type {cls}")
 
     # int
     elif cls == int:
         # we don't want to make an int out of anything else than other int
-        if isinstance(obj, int):
+        # except for CustomValueType class instances
+        if is_obj_type(obj, int) or isinstance(obj, CustomValueType):
             return int(obj)
-        else:
-            raise SchemaValidationException(f"Expected int, found {type(obj)}")
+        raise DataParsingException(f"Expected int, found {type(obj)}")
 
     # str
     elif cls == str:
         # we are willing to cast any primitive value to string, but no compound values are allowed
-        if isinstance(obj, (str, float, int)):
+        if is_obj_type(obj, (str, float, int)) or isinstance(obj, CustomValueType):
             return str(obj)
-        elif isinstance(obj, bool):
-            raise SchemaValidationException(
+        elif is_obj_type(obj, bool):
+            raise DataParsingException(
                 "Expected str, found bool. Be careful, that YAML parsers consider even"
                 ' "no" and "yes" as a bool. Search for the Norway Problem for more'
                 " details. And please use quotes explicitly."
             )
         else:
-            raise SchemaValidationException(
+            raise DataParsingException(
                 f"Expected str (or number that would be cast to string), but found type {type(obj)}"
             )
 
     # bool
     elif cls == bool:
-        if isinstance(obj, bool):
+        if is_obj_type(obj, bool):
             return obj
         else:
-            raise SchemaValidationException(f"Expected bool, found {type(obj)}")
+            raise DataParsingException(f"Expected bool, found {type(obj)}")
 
     # float
     elif cls == float:
         raise NotImplementedError(
-            "Floating point values are not supported in the parser validator."
+            "Floating point values are not supported in the parser."
             " Please implement them and be careful with type coercions"
         )
 
@@ -99,51 +133,56 @@ def _from_dictlike_obj(cls: Any, obj: Any, default: Any, use_default: bool) -> A
         if obj == expected:
             return obj
         else:
-            raise SchemaValidationException(f"Literal {cls} is not matched with the value {obj}")
+            raise DataParsingException(f"Literal {cls} is not matched with the value {obj}")
 
     # Dict[K,V]
     elif is_dict(cls):
         key_type, val_type = get_generic_type_arguments(cls)
         try:
             return {
-                _from_dictlike_obj(key_type, key, ..., False): _from_dictlike_obj(val_type, val, ..., False)
-                for key, val in obj.items()
+                _validated_object_type(key_type, key): _validated_object_type(val_type, val) for key, val in obj.items()
             }
         except AttributeError as e:
-            raise SchemaValidationException(
+            raise DataParsingException(
                 f"Expected dict-like object, but failed to access its .items() method. Value was {obj}", e
             )
 
     # List[T]
     elif is_list(cls):
         inner_type = get_generic_type_argument(cls)
-        return [_from_dictlike_obj(inner_type, val, ..., False) for val in obj]
+        return [_validated_object_type(inner_type, val) for val in obj]
 
     # Tuple[A,B,C,D,...]
     elif is_tuple(cls):
         types = get_generic_type_arguments(cls)
-        return tuple(_from_dictlike_obj(typ, val, ..., False) for typ, val in zip(types, obj))
-
-    # nested dataclass
-    elif is_dataclass(cls):
-        anot = cls.__dict__.get("__annotations__", {})
-        kwargs = {}
-        for name, python_type in anot.items():
-            # skip internal fields
-            if name.startswith("_"):
-                continue
-
-            value = obj[name] if name in obj else None
-            use_default = hasattr(cls, name)
-            default = getattr(cls, name, ...)
-            kwargs[name] = _from_dictlike_obj(python_type, value, default, use_default)
-        return cls(**kwargs)
+        return tuple(_validated_object_type(typ, val) for typ, val in zip(types, obj))
+
+    # CustomValueType subclasses
+    elif inspect.isclass(cls) and issubclass(cls, CustomValueType):
+        # no validation performed, the implementation does it in the constuctor
+        return cls(obj)
+
+    # nested DataParser subclasses
+    elif inspect.isclass(cls) and issubclass(cls, DataParser):
+        # we should return DataParser, we expect to be given a dict,
+        # because we can construct a DataParser from it
+        if isinstance(obj, dict):
+            return cls(obj)  # type: ignore
+        raise DataParsingException(f"Expected '{dict}' object, found '{type(obj)}'")
+
+    # nested DataValidator subclasses
+    elif inspect.isclass(cls) and issubclass(cls, DataValidator):
+        # we should return DataValidator, we expect to be given a DataParser,
+        # because we can construct a DataValidator from it
+        if isinstance(obj, DataParser):
+            return cls(obj)
+        raise DataParsingException(f"Expected instance of '{DataParser}' class, found '{type(obj)}'")
 
     # default error handler
     else:
-        raise SchemaValidationException(
+        raise DataParsingException(
             f"Type {cls} cannot be parsed. This is a implementation error. "
-            "Please fix your types in the dataclass or improve the parser/validator."
+            "Please fix your types in the class or improve the parser/validator."
         )
 
 
@@ -153,7 +192,7 @@ def json_raise_duplicates(pairs: List[Tuple[Any, Any]]) -> Optional[Any]:
     dict_out: Dict[Any, Any] = {}
     for key, val in pairs:
         if key in dict_out:
-            raise SchemaValidationException(f"duplicate key detected: {key}")
+            raise DataParsingException(f"Duplicate attribute key detected: {key}")
         dict_out[key] = val
     return dict_out
 
@@ -180,18 +219,12 @@ class RaiseDuplicatesLoader(yaml.SafeLoader):
 
             # check for duplicate keys
             if key in mapping:
-                raise SchemaValidationException(f"duplicate key detected: {key_node.start_mark}")
+                raise DataParsingException(f"duplicate key detected: {key_node.start_mark}")
             value = self.construct_object(value_node, deep=deep)  # type: ignore
             mapping[key] = value
         return mapping
 
 
-_T = TypeVar("_T", bound="DataclassParserValidatorMixin")
-
-
-_SUBTREE_MUTATION_PATH_PATTERN = re.compile(r"^(/[^/]+)*/?$")
-
-
 class Format(Enum):
     YAML = auto()
     JSON = auto()
@@ -206,6 +239,14 @@ class Format(Enum):
         else:
             raise NotImplementedError(f"Parsing of format '{self}' is not implemented")
 
+    def dict_dump(self, data: Dict[str, Any]) -> str:
+        if self is Format.YAML:
+            return yaml.safe_dump(data)  # type: ignore
+        elif self is Format.JSON:
+            return json.dumps(data)
+        else:
+            raise NotImplementedError(f"Exporting to '{self}' format is not implemented")
+
     @staticmethod
     def from_mime_type(mime_type: str) -> "Format":
         formats = {
@@ -214,41 +255,47 @@ class Format(Enum):
             "text/vnd.yaml": Format.YAML,
         }
         if mime_type not in formats:
-            raise SchemaValidationException("Unsupported MIME type")
+            raise DataParsingException("Unsupported MIME type")
         return formats[mime_type]
 
 
-class DataclassParserValidatorMixin:
-    def __init__(self, *args: Any, **kwargs: Any):
-        """
-        This constructor is useless except for typechecking. It makes sure that the dataclasses can be created with
-        any arguments whatsoever.
-        """
+_T = TypeVar("_T", bound="DataParser")
 
-    def validate(self) -> None:
-        for field_name in dir(self):
-            # skip internal fields
-            if field_name.startswith("_"):
+
+_SUBTREE_MUTATION_PATH_PATTERN = re.compile(r"^(/[^/]+)*/?$")
+
+
+class DataParser:
+    def __init__(self, obj: Optional[Dict[str, Any]] = None):
+        cls = self.__class__
+        annot = cls.__dict__.get("__annotations__", {})
+
+        used_keys: List[str] = []
+        for name, python_type in annot.items():
+            if is_internal_field(name):
                 continue
 
-            field = getattr(self, field_name)
-            if is_dataclass(field):
-                if not isinstance(field, DataclassParserValidatorMixin):
-                    raise SchemaValidationException(
-                        f"Nested dataclass in the field {field_name} does not include the ParserValidatorMixin"
-                    )
-                field.validate()
+            val = None
+            dash_name = name.replace("_", "-")
+            if obj and dash_name in obj:
+                val = obj[dash_name]
+                used_keys.append(dash_name)
 
-        self._validate()
+            use_default = hasattr(cls, name)
+            default = getattr(cls, name, ...)
+            value = _validated_object_type(python_type, val, default, use_default)
+            setattr(self, name, value)
 
-    def _validate(self) -> None:
-        raise NotImplementedError(f"Validation function is not implemented in class {type(self).__name__}")
+        # check for unused keys
+        if obj:
+            for key in obj:
+                if key not in used_keys:
+                    raise DataParsingException(f"Unknown attribute key '{key}'.")
 
     @classmethod
     def parse_from(cls: Type[_T], fmt: Format, text: str):
-        data = fmt.parse_to_dict(text)
-        config: _T = _from_dictlike_obj(cls, data, ..., False)
-        config.validate()
+        data_dict = fmt.parse_to_dict(text)
+        config: _T = cls(data_dict)
         return config
 
     @classmethod
@@ -259,13 +306,36 @@ class DataclassParserValidatorMixin:
     def from_json(cls: Type[_T], text: str) -> _T:
         return cls.parse_from(Format.JSON, text)
 
+    def to_dict(self) -> Dict[str, Any]:
+        cls = self.__class__
+        anot = cls.__dict__.get("__annotations__", {})
+        dict_obj: Dict[str, Any] = {}
+        for name in anot:
+            if is_internal_field(name):
+                continue
+
+            value = getattr(self, name)
+            dash_name = str(name).replace("_", "-")
+            dict_obj[dash_name] = _to_primitive(value)
+        return dict_obj
+
+    def dump(self, fmt: Format) -> str:
+        dict_data = self.to_dict()
+        return fmt.dict_dump(dict_data)
+
+    def dump_to_yaml(self) -> str:
+        return self.dump(Format.YAML)
+
+    def dump_to_json(self) -> str:
+        return self.dump(Format.JSON)
+
     def copy_with_changed_subtree(self: _T, fmt: Format, path: str, text: str) -> _T:
         cls = self.__class__
 
         # prepare and validate the path object
         path = path[:-1] if path.endswith("/") else path
         if re.match(_SUBTREE_MUTATION_PATH_PATTERN, path) is None:
-            raise SchemaValidationException("Provided object path for mutation is invalid.")
+            raise DataParsingException("Provided object path for mutation is invalid.")
         path = path[1:] if path.startswith("/") else path
 
         # now, the path variable should contain '/' separated field names
@@ -278,29 +348,63 @@ class DataclassParserValidatorMixin:
         to_mutate = copy.deepcopy(self)
         obj = to_mutate
         parent = None
-        for segment in path.split("/"):
+
+        for dash_segment in path.split("/"):
+            segment = dash_segment.replace("-", "_")
+
             if segment == "":
-                raise SchemaValidationException(f"Unexpectedly empty segment in path '{path}'")
-            elif segment.startswith("_"):
-                raise SchemaValidationException(
-                    "No, changing internal fields (starting with _) is not allowed. Nice try."
-                )
+                raise DataParsingException(f"Unexpectedly empty segment in path '{path}'")
+            elif is_internal_field(segment):
+                raise DataParsingException("No, changing internal fields (starting with _) is not allowed. Nice try.")
             elif hasattr(obj, segment):
                 parent = obj
                 obj = getattr(parent, segment)
             else:
-                raise SchemaValidationException(
-                    f"Path segment '{segment}' does not match any field on the provided parent object"
+                raise DataParsingException(
+                    f"Path segment '{dash_segment}' does not match any field on the provided parent object"
                 )
         assert parent is not None
 
         # assign the subtree
-        last_name = path.split("/")[-1]
+        last_name = path.split("/")[-1].replace("-", "_")
         data = fmt.parse_to_dict(text)
         tp = get_attr_type(parent, last_name)
-        parsed_value = _from_dictlike_obj(tp, data, ..., False)
+        parsed_value = _validated_object_type(tp, data)
         setattr(parent, last_name, parsed_value)
 
-        to_mutate.validate()
-
         return to_mutate
+
+
+class DataValidator:
+    def __init__(self, obj: DataParser):
+        cls = self.__class__
+        anot = cls.__dict__.get("__annotations__", {})
+
+        for attr_name, attr_type in anot.items():
+            if is_internal_field(attr_name):
+                continue
+
+            # use transformation function if available
+            if hasattr(self, f"_{attr_name}"):
+                value = getattr(self, f"_{attr_name}")(obj)
+            elif hasattr(obj, attr_name):
+                value = getattr(obj, attr_name)
+            else:
+                raise DataParsingException(f"DataParser object {obj} is missing '{attr_name}' attribute.")
+
+            setattr(self, attr_name, _validated_object_type(attr_type, value))
+
+        self._validate()
+
+    def validate(self) -> None:
+        for field_name in dir(self):
+            if is_internal_field(field_name):
+                continue
+
+            field = getattr(self, field_name)
+            if isinstance(field, DataValidator):
+                field.validate()
+        self._validate()
+
+    def _validate(self) -> None:
+        raise NotImplementedError(f"Validation function is not implemented in class {type(self).__name__}")
diff --git a/manager/knot_resolver_manager/utils/exceptions.py b/manager/knot_resolver_manager/utils/exceptions.py
new file mode 100644 (file)
index 0000000..f6e8150
--- /dev/null
@@ -0,0 +1,6 @@
+class DataParsingException(Exception):
+    pass
+
+
+class DataValidationException(Exception):
+    pass
index fb2246708222671ae1bcd16758906e7ab6db3123..ca19ac241690628e97a449b9796219683940df0a 100644 (file)
@@ -30,12 +30,15 @@ def is_union(tp: Any) -> bool:
 
 
 def is_literal(tp: Any) -> bool:
-    return getattr(tp, "__origin__", None) == Literal
+    return isinstance(tp, type(Literal))
 
 
 def get_generic_type_arguments(tp: Any) -> List[Any]:
     default: List[Any] = []
-    return getattr(tp, "__args__", default)
+    if is_literal(tp):
+        return getattr(tp, "__values__")
+    else:
+        return getattr(tp, "__args__", default)
 
 
 def get_generic_type_argument(tp: Any) -> Any:
diff --git a/manager/tests/datamodel/test_datamodel_types.py b/manager/tests/datamodel/test_datamodel_types.py
new file mode 100644 (file)
index 0000000..e5321f4
--- /dev/null
@@ -0,0 +1,66 @@
+from pytest import raises
+
+from knot_resolver_manager.datamodel.types import SizeUnit, TimeUnit
+from knot_resolver_manager.utils import DataParser, DataValidationException, DataValidator
+
+
+def test_size_unit():
+    assert (
+        SizeUnit(5368709120)
+        == SizeUnit("5368709120")
+        == SizeUnit("5368709120B")
+        == SizeUnit("5242880K")
+        == SizeUnit("5120M")
+        == SizeUnit("5G")
+    )
+
+    with raises(DataValidationException):
+        SizeUnit("-5368709120")
+    with raises(DataValidationException):
+        SizeUnit(-5368709120)
+    with raises(DataValidationException):
+        SizeUnit("5120MM")
+
+
+def test_time_unit():
+    assert TimeUnit("1d") == TimeUnit("24h") == TimeUnit("1440m") == TimeUnit("86400s") == TimeUnit(86400)
+
+    with raises(DataValidationException):
+        TimeUnit("-1")
+    with raises(DataValidationException):
+        TimeUnit(-24)
+    with raises(DataValidationException):
+        TimeUnit("1440mm")
+
+
+def test_parsing_units():
+    class TestClass(DataParser):
+        size: SizeUnit
+        time: TimeUnit
+
+    class TestClassStrict(DataValidator):
+        size: int
+        time: int
+
+        def _validate(self) -> None:
+            pass
+
+    yaml = """
+size: 3K
+time: 10m
+"""
+
+    obj = TestClass.from_yaml(yaml)
+    assert obj.size == SizeUnit(3 * 1024)
+    assert obj.time == TimeUnit(10 * 60)
+
+    strict = TestClassStrict(obj)
+    assert strict.size == 3 * 1024
+    assert strict.time == 10 * 60
+
+    y = obj.dump_to_yaml()
+    j = obj.dump_to_json()
+    a = TestClass.from_yaml(y)
+    b = TestClass.from_json(j)
+    assert a.size == b.size == obj.size
+    assert a.time == b.time == obj.time
diff --git a/manager/tests/test_datamodel.py b/manager/tests/test_datamodel.py
deleted file mode 100644 (file)
index a1af7dc..0000000
+++ /dev/null
@@ -1,36 +0,0 @@
-from knot_resolver_manager.datamodel import KresConfig
-
-
-def test_simple():
-    json = """
-    {
-    "server": {
-        "instances": 1
-    },
-    "lua": {
-        "script_list": [
-        "-- SPDX-License-Identifier: CC0-1.0",
-        "-- vim:syntax=lua:set ts=4 sw=4:",
-        "-- Refer to manual: https://knot-resolver.readthedocs.org/en/stable/",
-        "-- Network interface configuration","net.listen('127.0.0.1', 53, { kind = 'dns' })",
-        "net.listen('127.0.0.1', 853, { kind = 'tls' })",
-        "--net.listen('127.0.0.1', 443, { kind = 'doh2' })",
-        "net.listen('::1', 53, { kind = 'dns', freebind = true })",
-        "net.listen('::1', 853, { kind = 'tls', freebind = true })",
-        "--net.listen('::1', 443, { kind = 'doh2' })",
-        "-- Load useful modules","modules = {",
-        "'hints > iterate',  -- Load /etc/hosts and allow custom root hints",
-        "'stats',            -- Track internal statistics",
-        "'predict',          -- Prefetch expiring/frequent records",
-        "}",
-        "-- Cache size",
-        "cache.size = 100 * MB"
-        ]
-    }
-    }
-    """
-
-    config = KresConfig.from_json(json)
-
-    assert config.server.instances == 1
-    assert config.lua.script is not None
diff --git a/manager/tests/utils/test_data_parser_validator.py b/manager/tests/utils/test_data_parser_validator.py
new file mode 100644 (file)
index 0000000..5431f3d
--- /dev/null
@@ -0,0 +1,292 @@
+from typing import Dict, List, Optional, Tuple, Union
+
+from pytest import raises
+from typing_extensions import Literal
+
+from knot_resolver_manager.utils import DataParser, DataValidationException, DataValidator, Format
+from knot_resolver_manager.utils.exceptions import DataParsingException
+
+
+def test_primitive():
+    class TestClass(DataParser):
+        i: int
+        s: str
+        b: bool
+
+    class TestClassStrict(DataValidator):
+        i: int
+        s: str
+        b: bool
+
+        def _validate(self) -> None:
+            pass
+
+    yaml = """
+i: 5
+s: test
+b: false
+"""
+
+    obj = TestClass.from_yaml(yaml)
+    assert obj.i == 5
+    assert obj.s == "test"
+    assert obj.b == False
+
+    strict = TestClassStrict(obj)
+    assert strict.i == 5
+    assert strict.s == "test"
+    assert strict.b == False
+
+    y = obj.dump_to_yaml()
+    j = obj.dump_to_json()
+    a = TestClass.from_yaml(y)
+    b = TestClass.from_json(j)
+    assert a.i == b.i == obj.i
+    assert a.s == b.s == obj.s
+    assert a.b == b.b == obj.b
+
+
+def test_parsing_primitive_exceptions():
+    class TestStr(DataParser):
+        s: str
+
+    # int and float are allowed inputs for string
+    with raises(DataParsingException):
+        TestStr.from_yaml("s: false")  # bool
+
+    class TestInt(DataParser):
+        i: int
+
+    with raises(DataParsingException):
+        TestInt.from_yaml("i: false")  # bool
+    with raises(DataParsingException):
+        TestInt.from_yaml('i: "5"')  # str
+    with raises(DataParsingException):
+        TestInt.from_yaml("i: 5.5")  # float
+
+    class TestBool(DataParser):
+        b: bool
+
+    with raises(DataParsingException):
+        TestBool.from_yaml("b: 5")  # int
+    with raises(DataParsingException):
+        TestBool.from_yaml('b: "5"')  # str
+    with raises(DataParsingException):
+        TestBool.from_yaml("b: 5.5")  # float
+
+
+def test_nested():
+    class Lower(DataParser):
+        i: int
+
+    class Upper(DataParser):
+        l: Lower
+
+    class LowerStrict(DataValidator):
+        i: int
+
+        def _validate(self) -> None:
+            pass
+
+    class UpperStrict(DataValidator):
+        l: LowerStrict
+
+        def _validate(self) -> None:
+            pass
+
+    yaml = """
+l:
+  i: 5
+"""
+
+    obj = Upper.from_yaml(yaml)
+    assert obj.l.i == 5
+
+    strict = UpperStrict(obj)
+    assert strict.l.i == 5
+
+    y = obj.dump_to_yaml()
+    j = obj.dump_to_json()
+    a = Upper.from_yaml(y)
+    b = Upper.from_json(j)
+    assert a.l.i == b.l.i == obj.l.i
+
+
+def test_simple_compount_types():
+    class TestClass(DataParser):
+        l: List[int]
+        d: Dict[str, str]
+        t: Tuple[str, int]
+        o: Optional[int]
+
+    class TestClassStrict(DataValidator):
+        l: List[int]
+        d: Dict[str, str]
+        t: Tuple[str, int]
+        o: Optional[int]
+
+        def _validate(self) -> None:
+            pass
+
+    yaml = """
+l:
+  - 1
+  - 2
+  - 3
+  - 4
+  - 5
+d:
+  something: else
+  w: all
+t:
+  - test
+  - 5
+"""
+
+    obj = TestClass.from_yaml(yaml)
+    assert obj.l == [1, 2, 3, 4, 5]
+    assert obj.d == {"something": "else", "w": "all"}
+    assert obj.t == ("test", 5)
+    assert obj.o is None
+
+    strict = TestClassStrict(obj)
+    assert strict.l == [1, 2, 3, 4, 5]
+    assert strict.d == {"something": "else", "w": "all"}
+    assert strict.t == ("test", 5)
+    assert strict.o is None
+
+    y = obj.dump_to_yaml()
+    j = obj.dump_to_json()
+    a = TestClass.from_yaml(y)
+    b = TestClass.from_json(j)
+    assert a.l == b.l == obj.l
+    assert a.d == b.d == obj.d
+    assert a.t == b.t == obj.t
+    assert a.o == b.o == obj.o
+
+
+def test_nested_compound_types():
+    class TestClass(DataParser):
+        o: Optional[Dict[str, str]]
+
+    class TestClassStrict(DataValidator):
+        o: Optional[Dict[str, str]]
+
+        def _validate(self) -> None:
+            pass
+
+    yaml = """
+o:
+  key: val
+"""
+
+    obj = TestClass.from_yaml(yaml)
+    assert obj.o == {"key": "val"}
+
+    strict = TestClassStrict(obj)
+    assert strict.o == {"key": "val"}
+
+    y = obj.dump_to_yaml()
+    j = obj.dump_to_json()
+    a = TestClass.from_yaml(y)
+    b = TestClass.from_json(j)
+    assert a.o == b.o == obj.o
+
+
+def test_nested_compount_types2():
+    class TestClass(DataParser):
+        i: int
+        o: Optional[Dict[str, str]]
+
+    class TestClassStrict(DataValidator):
+        i: int
+        o: Optional[Dict[str, str]]
+
+        def _validate(self) -> None:
+            pass
+
+    yaml = "i: 5"
+
+    obj = TestClass.from_yaml(yaml)
+    assert obj.i == 5
+    assert obj.o is None
+
+    strict = TestClassStrict(obj)
+    assert strict.i == 5
+    assert strict.o is None
+
+    y = obj.dump_to_yaml()
+    j = obj.dump_to_json()
+    a = TestClass.from_yaml(y)
+    b = TestClass.from_json(j)
+    assert a.i == b.i == obj.i
+    assert a.o == b.o == obj.o
+
+
+def test_partial_mutations():
+    class Inner(DataParser):
+        size: int = 5
+
+    class ConfData(DataParser):
+        workers: Union[Literal["auto"], int] = 1
+        lua_config: Optional[str] = None
+        inner: Inner = Inner()
+
+    class InnerStrict(DataValidator):
+        size: int
+
+        def _validate(self) -> None:
+            pass
+
+    class ConfDataStrict(DataValidator):
+        workers: int
+        lua_config: Optional[str]
+        inner: InnerStrict
+
+        def _workers(self, data: ConfData) -> int:
+            if data.workers == "auto":
+                return 8
+            else:
+                return data.workers
+
+        def _validate(self) -> None:
+            if self.workers < 0:
+                raise DataValidationException("Number of workers must be non-negative")
+
+    yaml = """
+    workers: auto
+    lua-config: something
+    """
+
+    conf = ConfData.from_yaml(yaml)
+
+    x = ConfDataStrict(conf)
+    assert x.lua_config == "something"
+    assert x.inner.size == 5
+    assert x.workers == 8
+
+    y = conf.dump_to_yaml()
+    j = conf.dump_to_json()
+    a = ConfData.from_yaml(y)
+    b = ConfData.from_json(j)
+    assert a.workers == b.workers == conf.workers
+    assert a.lua_config == b.lua_config == conf.lua_config
+    assert a.inner.size == b.inner.size == conf.inner.size
+
+    # replacement of 'lua-config' attribute
+    x = ConfDataStrict(conf.copy_with_changed_subtree(Format.JSON, "/lua-config", '"new_value"'))
+    assert x.lua_config == "new_value"
+    assert x.inner.size == 5
+    assert x.workers == 8
+
+    # replacement of the whole tree
+    x = ConfDataStrict(conf.copy_with_changed_subtree(Format.JSON, "/", '{"inner": {"size": 55}}'))
+    assert x.lua_config is None
+    assert x.workers == 1
+    assert x.inner.size == 55
+
+    # replacement of 'inner' subtree
+    x = ConfDataStrict(conf.copy_with_changed_subtree(Format.JSON, "/inner", '{"size": 33}'))
+    assert x.lua_config == "something"
+    assert x.workers == 8
+    assert x.inner.size == 33
diff --git a/manager/tests/utils/test_dataclasses_parservalidator.py b/manager/tests/utils/test_dataclasses_parservalidator.py
deleted file mode 100644 (file)
index 2771330..0000000
+++ /dev/null
@@ -1,170 +0,0 @@
-from dataclasses import dataclass
-from typing import Dict, List, Optional, Tuple
-
-from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin, Format
-
-
-def test_parsing_primitive():
-    @dataclass
-    class TestClass(DataclassParserValidatorMixin):
-        i: int
-        s: str
-
-        def _validate(self):
-            pass
-
-    yaml = """i: 5
-s: "test"
-"""
-
-    obj = TestClass.from_yaml(yaml)
-
-    assert obj.i == 5
-    assert obj.s == "test"
-
-
-def test_parsing_nested():
-    @dataclass
-    class Lower(DataclassParserValidatorMixin):
-        i: int
-
-        def _validate(self):
-            pass
-
-    @dataclass
-    class Upper(DataclassParserValidatorMixin):
-        l: Lower
-
-        def _validate(self):
-            pass
-
-    yaml = """l:
-  i: 5"""
-
-    obj = Upper.from_yaml(yaml)
-    assert obj.l.i == 5
-
-
-def test_simple_compount_types():
-    @dataclass
-    class TestClass(DataclassParserValidatorMixin):
-        l: List[int]
-        d: Dict[str, str]
-        t: Tuple[str, int]
-        o: Optional[int]
-
-        def _validate(self):
-            pass
-
-    yaml = """l:
-  - 1
-  - 2
-  - 3
-  - 4
-  - 5
-d:
-  something: else
-  w: all
-t:
-  - test
-  - 5"""
-
-    obj = TestClass.from_yaml(yaml)
-
-    assert obj.l == [1, 2, 3, 4, 5]
-    assert obj.d == {"something": "else", "w": "all"}
-    assert obj.t == ("test", 5)
-    assert obj.o is None
-
-
-def test_nested_compount_types():
-    @dataclass
-    class TestClass(DataclassParserValidatorMixin):
-        o: Optional[Dict[str, str]]
-
-        def _validate(self):
-            pass
-
-    yaml = """o:
-  key: val"""
-
-    obj = TestClass.from_yaml(yaml)
-
-    assert obj.o == {"key": "val"}
-
-
-def test_nested_compount_types2():
-    @dataclass
-    class TestClass(DataclassParserValidatorMixin):
-        i: int
-        o: Optional[Dict[str, str]]
-
-        def _validate(self):
-            pass
-
-    yaml = "i: 5"
-
-    obj = TestClass.from_yaml(yaml)
-
-    assert obj.o is None
-
-
-def test_real_failing_dummy_confdata():
-    @dataclass
-    class ConfData(DataclassParserValidatorMixin):
-        num_workers: int = 1
-        lua_config: Optional[str] = None
-
-        def _validate(self):
-            if self.num_workers < 0:
-                raise Exception("Number of workers must be non-negative")
-
-    # prepare the payload
-    lua_config = "dummy"
-    config = f"""
-num_workers: 4
-lua_config: |
-  { lua_config }"""
-
-    data = ConfData.from_yaml(config)
-
-    assert type(data.num_workers) == int
-    assert data.num_workers == 4
-    assert type(data.lua_config) == str
-    assert data.lua_config == "dummy"
-
-
-def test_partial_mutations():
-    @dataclass
-    class Inner(DataclassParserValidatorMixin):
-        number: int
-
-        def _validate(self):
-            pass
-
-    @dataclass
-    class ConfData(DataclassParserValidatorMixin):
-        num_workers: int = 1
-        lua_config: Optional[str] = None
-        inner: Inner = Inner(5)
-
-        def _validate(self):
-            if self.num_workers < 0:
-                raise Exception("Number of workers must be non-negative")
-
-    data = ConfData(5, "something", Inner(10))
-
-    x = data.copy_with_changed_subtree(Format.JSON, "/lua_config", '"new_value"')
-    assert x.lua_config == "new_value"
-    assert x.num_workers == 5
-    assert x.inner.number == 10
-
-    x = data.copy_with_changed_subtree(Format.JSON, "/inner", '{"number": 55}')
-    assert x.lua_config == "something"
-    assert x.num_workers == 5
-    assert x.inner.number == 55
-
-    x = data.copy_with_changed_subtree(Format.JSON, "/", '{"inner": {"number": 55}}')
-    assert x.lua_config is None
-    assert x.num_workers == 1
-    assert x.inner.number == 55
index 824d4ea5be5290232d09ee2bec98ec5af46cf78c..580154a252d1efbab23da7e7e7aaac6ec815a246 100644 (file)
@@ -2,7 +2,7 @@ from typing import List, Union
 
 from typing_extensions import Literal
 
-from knot_resolver_manager.utils.types import LiteralEnum, is_list
+from knot_resolver_manager.utils.types import LiteralEnum, is_list, is_literal
 
 
 def test_is_list():
@@ -10,6 +10,11 @@ def test_is_list():
     assert is_list(List[int])
 
 
+def test_is_literal():
+    assert is_literal(Literal[5])
+    assert is_literal(Literal["test"])
+
+
 def test_literal_enum():
     assert LiteralEnum[5, "test"] == Union[Literal[5], Literal["test"]]
     assert LiteralEnum["str", 5] == Union[Literal["str"], Literal[5]]