network:
interfaces:
- listen: 127.0.0.1@5353
- edns_buffer_size:
- downstream: 4K
-options:
- prediction: true
-cache:
- storage: .
- size_max: 100M
-logging:
- level: 5
server:
- instances: 1
+ workers: 1
print(response.text)
def set_num_workers(self, n: int):
- response = requests.post(self._create_url("/config/server/instances"), data=str(n))
+ response = requests.post(self._create_url("/config/server/workers"), data=str(n))
print(response.text)
def wait_for_initialization(self, timeout_sec: float = 5, time_step: float = 0.4):
-from .config import KresConfig
+from .config import KresConfig, KresConfigStrict
-__all__ = [
- "KresConfig",
-]
+__all__ = ["KresConfig", "KresConfigStrict"]
+++ /dev/null
-from typing import Optional
-
-from knot_resolver_manager.compat.dataclasses import dataclass
-from knot_resolver_manager.datamodel.types import SizeUnits
-from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin
-
-
-@dataclass
-class CacheConfig(DataclassParserValidatorMixin):
- storage: str = "/var/cache/knot-resolver"
- size_max: Optional[str] = None
- _size_max_bytes: int = 100 * SizeUnits.mebibyte
-
- def __post_init__(self):
- if self.size_max:
- self._size_max_bytes = SizeUnits.parse(self.size_max)
-
- def get_size_max(self) -> int:
- return self._size_max_bytes
-
- def _validate(self):
- pass
import pkgutil
-from typing import Optional, Text, Union
+from typing import Optional, Text
from jinja2 import Environment, Template
-from knot_resolver_manager.compat.dataclasses import dataclass
-from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin
-
-from .cache_config import CacheConfig
-from .dns64_config import Dns64Config
-from .logging_config import LoggingConfig
-from .lua_config import LuaConfig
-from .network_config import NetworkConfig
-from .options_config import OptionsConfig
-from .server_config import ServerConfig
+from knot_resolver_manager.datamodel.lua_config import Lua, LuaStrict
+from knot_resolver_manager.datamodel.network_config import Network, NetworkStrict
+from knot_resolver_manager.datamodel.server_config import Server, ServerStrict
+from knot_resolver_manager.utils import DataParser, DataValidator
def _import_lua_template() -> Template:
_LUA_TEMPLATE = _import_lua_template()
-@dataclass
-class KresConfig(DataclassParserValidatorMixin):
- # pylint: disable=too-many-instance-attributes
- server: ServerConfig = ServerConfig()
- network: NetworkConfig = NetworkConfig()
- options: OptionsConfig = OptionsConfig()
- cache: CacheConfig = CacheConfig()
- # DNS64 is disabled by default
- dns64: Union[bool, Dns64Config] = False
- logging: LoggingConfig = LoggingConfig()
- lua: Optional[LuaConfig] = None
-
- def __post_init__(self):
- # if DNS64 is enabled with defaults
- if self.dns64 is True:
- self.dns64 = Dns64Config()
-
- def _validate(self):
- pass
+class KresConfig(DataParser):
+ server: Server = Server()
+ network: Network = Network()
+ lua: Optional[Lua] = None
+
+
+class KresConfigStrict(DataValidator):
+ server: ServerStrict
+ network: NetworkStrict
+ lua: Optional[LuaStrict]
def render_lua(self) -> Text:
return _LUA_TEMPLATE.render(cfg=self)
+
+ def _validate(self) -> None:
+ pass
+++ /dev/null
-from knot_resolver_manager.compat.dataclasses import dataclass
-from knot_resolver_manager.exceptions import DataValidationException
-from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin
-
-from .types import RE_IPV6_PREFIX_96
-
-
-@dataclass
-class Dns64Config(DataclassParserValidatorMixin):
- prefix: str = "64:ff9b::"
-
- def _validate(self):
- if not bool(RE_IPV6_PREFIX_96.match(self.prefix)):
- raise DataValidationException("'dns64.prefix' must be valid IPv6 /96 prefix")
+++ /dev/null
-from knot_resolver_manager.compat.dataclasses import dataclass
-from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin
-
-
-@dataclass
-class DnssecConfig(DataclassParserValidatorMixin):
- def _validate(self):
- pass
+++ /dev/null
-from knot_resolver_manager.compat.dataclasses import dataclass
-from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin
-
-
-@dataclass
-class StaticHintsConfig(DataclassParserValidatorMixin):
- def _validate(self):
- pass
+++ /dev/null
-from knot_resolver_manager.compat.dataclasses import dataclass
-from knot_resolver_manager.exceptions import DataValidationException
-from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin
-
-
-@dataclass
-class LoggingConfig(DataclassParserValidatorMixin):
- level: int = 3
-
- def _validate(self):
- if not 0 <= self.level <= 7:
- raise DataValidationException("logging 'level' must be in range 0..7")
-from typing import List, Optional
+from typing import List, Optional, Union
-from knot_resolver_manager.compat.dataclasses import dataclass
-from knot_resolver_manager.exceptions import DataValidationException
-from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin
+from knot_resolver_manager.utils import DataParser, DataValidator
-@dataclass
-class LuaConfig(DataclassParserValidatorMixin):
- script_list: Optional[List[str]] = None
- script: Optional[str] = None
+class Lua(DataParser):
+ script: Optional[Union[List[str], str]] = None
+ script_file: Optional[str] = None
- def __post_init__(self):
- # Concatenate array to single string
- if self.script_list is not None:
- self.script = "\n".join(self.script_list)
- def _validate(self):
- if self.script is None:
- raise DataValidationException("Lua script not specified")
+class LuaStrict(DataValidator):
+ script: Optional[str]
+ script_file: Optional[str]
+
+ def _script(self, lua: Lua) -> Optional[str]:
+ if isinstance(lua.script, List):
+ return "\n".join(lua.script)
+ return lua.script
+
+ def _validate(self) -> None:
+ pass
-{% if cfg.server.hostname %}
--- server.hostname
-hostname('{{ cfg.server.hostname }}')
-{% endif %}
-
-- network.interfaces
{% for item in cfg.network.interfaces %}
-net.listen('{{ item.get_address() }}', {{ item.get_port() if item.get_port() else 'nil' }}, {
+net.listen('{{ item.address }}', {{ item.port }}, {
kind = '{{ item.kind if item.kind != 'dot' else 'tls' }}',
freebind = {{ 'true' if item.freebind else 'false'}}
})
{% endfor %}
--- network.edns-buffer-size
-net.bufsize({{ cfg.network.edns_buffer_size.get_downstream() }}, {{ cfg.network.edns_buffer_size.get_upstream() }})
-
--- modules
-modules = {
- 'hints > iterate', -- Load /etc/hosts and allow custom root hints",
- 'stats', -- Track internal statistics",
-{% if cfg.options.prediction %}
- predict = { -- Prefetch expiring/frequent records"
- window = {{ cfg.options.prediction.get_window() }},
- period = {{ cfg.options.prediction.period }}
- },
-{% endif %}
-{% if cfg.dns64 %}
- dns64 = '{{ cfg.dns64.prefix }}', -- dns64
+{% if cfg.lua %}
+-- lua section
+{% if cfg.lua.script_file %}
+{% import cfg.lua.script_file as script_file %}
+-- lua.script-file
+{{ script_file }}
{% endif %}
-}
-
--- cache
-cache.open({{ cfg.cache.get_size_max() }}, 'lmdb://{{ cfg.cache.storage }}')
-
--- logging level
-verbose({{ 'true' if cfg.logging.level > 3 else 'false'}})
{% if cfg.lua.script %}
--- lua
+-- lua.script
{{ cfg.lua.script }}
+{% endif %}
{% endif %}
\ No newline at end of file
-from typing import List, Optional, Union
+from typing import List
-from knot_resolver_manager.compat.dataclasses import dataclass, field
-from knot_resolver_manager.datamodel.types import SizeUnits
-from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin
+from knot_resolver_manager.utils import DataParser, DataValidator
+from knot_resolver_manager.utils.types import LiteralEnum
+KindEnum = LiteralEnum["dns", "xdp", "dot", "doh"]
-@dataclass
-class InterfacesConfig(DataclassParserValidatorMixin):
+
+class Interface(DataParser):
listen: str
- kind: str = "dns"
+ kind: KindEnum = "dns"
freebind: bool = False
- _address: Optional[str] = None
- _port: Optional[int] = None
- _kind_port_map = {"dns": 53, "xdp": 53, "dot": 853, "doh": 443}
-
- def __post_init__(self):
- # split 'address@port'
- if "@" in self.listen:
- address, port = self.listen.split("@", maxsplit=1)
- self._address = address
- self._port = int(port)
- else:
- # if port number not specified
- self._address = self.listen
- # set port number based on 'kind'
- self._port = self._kind_port_map.get(self.kind)
-
- def get_address(self) -> Optional[str]:
- return self._address
- def get_port(self) -> Optional[int]:
- return self._port
-
- def _validate(self):
- pass
+class InterfaceStrict(DataValidator):
+ address: str
+ port: int
+ kind: str
+ freebind: bool
-@dataclass
-class EdnsBufferSizeConfig(DataclassParserValidatorMixin):
- downstream: Optional[str] = None
- upstream: Optional[str] = None
- _downstream_bytes: int = 1232
- _upstream_bytes: int = 1232
+ def _address(self, obj: Interface) -> str:
+ if "@" in obj.listen:
+ address = obj.listen.split("@", maxsplit=1)[0]
+ return address
+ return obj.listen
- def __post_init__(self):
- if self.downstream:
- self._downstream_bytes = SizeUnits.parse(self.downstream)
- if self.upstream:
- self._upstream_bytes = SizeUnits.parse(self.upstream)
+ def _port(self, obj: Interface) -> int:
+ port_map = {"dns": 53, "xdp": 53, "dot": 853, "doh": 443}
+ if "@" in obj.listen:
+ port = obj.listen.split("@", maxsplit=1)[1]
+ return int(port)
+ return port_map.get(obj.kind, 0)
- def _validate(self):
+ def _validate(self) -> None:
pass
- def get_downstream(self) -> int:
- return self._downstream_bytes
-
- def get_upstream(self) -> int:
- return self._upstream_bytes
+class Network(DataParser):
+ interfaces: List[Interface] = [Interface({"listen": "127.0.0.1"}), Interface({"listen": "::1", "freebind": True})]
-@dataclass
-class NetworkConfig(DataclassParserValidatorMixin):
- interfaces: List[InterfacesConfig] = field(
- default_factory=lambda: [InterfacesConfig(listen="127.0.0.1"), InterfacesConfig(listen="::1", freebind=True)]
- )
- edns_buffer_size: Union[str, EdnsBufferSizeConfig] = EdnsBufferSizeConfig()
- def __post_init__(self):
- if isinstance(self.edns_buffer_size, str):
- bufsize = self.edns_buffer_size
- self.edns_buffer_size = EdnsBufferSizeConfig(downstream=bufsize, upstream=bufsize)
+class NetworkStrict(DataValidator):
+ interfaces: List[InterfaceStrict]
- def _validate(self):
+ def _validate(self) -> None:
pass
+++ /dev/null
-from typing import Optional, Union
-
-from knot_resolver_manager.compat.dataclasses import dataclass
-from knot_resolver_manager.datamodel.types import TimeUnits
-from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin
-
-
-@dataclass
-class PredictionConfig(DataclassParserValidatorMixin):
- window: Optional[str] = None
- _window_seconds: int = 15 * TimeUnits.minute
- period: int = 24
-
- def __post_init__(self):
- if self.window:
- self._window_seconds = TimeUnits.parse(self.window)
-
- def get_window(self) -> int:
- return self._window_seconds
-
- def _validate(self):
- pass
-
-
-@dataclass
-class OptionsConfig(DataclassParserValidatorMixin):
- prediction: Union[bool, PredictionConfig] = False
-
- def __post_init__(self):
- if self.prediction is True:
- self.prediction = PredictionConfig()
-
- def _validate(self):
- pass
import logging
import os
-from typing import Optional, Union
+from typing import Union
from typing_extensions import Literal
-from knot_resolver_manager.compat.dataclasses import dataclass
-from knot_resolver_manager.exceptions import DataValidationException
-from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin
+from knot_resolver_manager.utils import DataParser, DataValidationException, DataValidator
logger = logging.getLogger(__name__)
return cpus
-@dataclass
-class ServerConfig(DataclassParserValidatorMixin):
- hostname: Optional[str] = None
- instances: Union[Literal["auto"], int, None] = None
- _instances: int = 1
+class Server(DataParser):
+ workers: Union[Literal["auto"], int] = 1
use_cache_gc: bool = True
- def __post_init__(self):
- if isinstance(self.instances, int):
- self._instances = self.instances
- elif self.instances == "auto":
- self._instances = _cpu_count()
-
- def get_instances(self) -> int:
- # FIXME: this is a hack to make the partial updates working without a second data structure
- # this will be unnecessary in near future
- if isinstance(self.instances, int):
- return self.instances
- elif self.instances == "auto":
- cpu_count = os.cpu_count()
- if cpu_count is not None:
- return cpu_count
- else:
- raise RuntimeError("cannot find number of system available CPUs")
- else:
- return 0
-
- def _validate(self):
- if not 0 < self._instances <= 256:
- raise DataValidationException("number of kresd instances must be in range 1..256")
+
+class ServerStrict(DataValidator):
+ workers: int
+ use_cache_gc: bool
+
+ def _workers(self, obj: Server) -> int:
+ if isinstance(obj.workers, int):
+ return obj.workers
+ elif obj.workers == "auto":
+ return _cpu_count()
+ raise DataValidationException(f"Unexpected value: {obj.workers}")
+
+ def _validate(self) -> None:
+ if self.workers < 0:
+ raise DataValidationException("Number of workers must be non-negative")
import re
+from typing import Any, Dict, Optional, Pattern, Union
+
+from knot_resolver_manager.utils import CustomValueType, DataValidationException
+
+
+class Unit(CustomValueType):
+ _re: Pattern[str]
+ _units: Dict[Optional[str], int]
+
+ def __init__(self, source_value: Any) -> None:
+ super().__init__(source_value)
+ self._value: int
+ self._value_orig: Union[str, int]
+ if isinstance(source_value, str) and type(self)._re.match(source_value):
+ self._value_orig = source_value
+ grouped = type(self)._re.search(source_value)
+ if grouped:
+ val, unit = grouped.groups()
+ if unit not in type(self)._units:
+ raise DataValidationException(f"Used unexpected unit '{unit}' for {type(self).__name__}...")
+ self._value = int(val) * type(self)._units[unit]
+ else:
+ raise DataValidationException(f"{type(self._value)} Failed to convert: {self}")
+ elif isinstance(source_value, int):
+ if source_value < 0:
+ raise DataValidationException(f"Input value '{source_value}' is not non-negative.")
+ self._value_orig = source_value
+ self._value = source_value
+ else:
+ raise DataValidationException(
+ f"Unexpected input type for Unit type - {type(source_value)}."
+ " Cause might be invalid format or invalid type."
+ )
+
+ def __int__(self) -> int:
+ return self._value
+
+ def __str__(self) -> str:
+ """
+ Used by Jinja2. Must return only a number.
+ """
+ return str(self._value)
+
+ def __eq__(self, o: object) -> bool:
+ """
+ Two instances are equal when they represent the same size
+ regardless of their string representation.
+ """
+ return isinstance(o, Unit) and o._value == self._value
+
+
+class SizeUnit(Unit):
+ _re = re.compile(r"^([0-9]+)\s{0,1}([BKMG]){0,1}$")
+ _units = {None: 1, "B": 1, "K": 1024, "M": 1024 ** 2, "G": 1024 ** 3}
-from knot_resolver_manager.exceptions import DataValidationException
-
-RE_IPV6_PREFIX_96 = re.compile(r"^([0-9A-Fa-f]{1,4}:){2}:$")
-
-
-class TimeUnits:
- second = 1
- minute = 60
- hour = 3600
- day = 24 * 3600
+class TimeUnit(Unit):
_re = re.compile(r"^(\d+)\s{0,1}([smhd]){0,1}$")
- _map = {"s": second, "m": minute, "h": hour, "d": day}
-
- @staticmethod
- def parse(time_str: str) -> int:
- searched = TimeUnits._re.search(time_str)
- if searched:
- value, unit = searched.groups()
- return int(value) * TimeUnits._map.get(unit, 1)
- raise DataValidationException(f"failed to parse: {time_str}")
-
-
-class SizeUnits:
- byte = 1
- kibibyte = 1024
- mebibyte = 1024 ** 2
- gibibyte = 1024 ** 3
-
- _re = re.compile(r"^([0-9]+)\s{0,1}([BKMG]){0,1}$")
- _map = {"B": byte, "K": kibibyte, "M": mebibyte, "G": gibibyte}
-
- @staticmethod
- def parse(size_str: str) -> int:
- searched = SizeUnits._re.search(size_str)
- if searched:
- value, unit = searched.groups()
- return int(value) * SizeUnits._map.get(unit, 1)
- raise DataValidationException(f"failed to parse: {size_str}")
+ _units = {None: 1, "s": 1, "m": 60, "h": 3600, "d": 24 * 3600}
class SubprocessControllerException(Exception):
pass
-
-
-class ValidationException(Exception):
- pass
-
-
-class SchemaValidationException(ValidationException):
- pass
-
-
-class DataValidationException(ValidationException):
- pass
from knot_resolver_manager import kres_id
from knot_resolver_manager.compat.asyncio import create_task
from knot_resolver_manager.constants import KRESD_CONFIG_FILE, WATCHDOG_INTERVAL
-from knot_resolver_manager.exceptions import ValidationException
from knot_resolver_manager.kresd_controller.interface import (
Subprocess,
SubprocessController,
SubprocessStatus,
SubprocessType,
)
+from knot_resolver_manager.utils import DataValidationException
from knot_resolver_manager.utils.async_utils import writefile
-from .datamodel import KresConfig
+from .datamodel import KresConfig, KresConfigStrict
logger = logging.getLogger(__name__)
self._manager_lock = asyncio.Lock()
self._controller: SubprocessController
self._last_used_config: Optional[KresConfig] = None
+ self._last_used_config_strict: Optional[KresConfigStrict] = None
self._watchdog_task: Optional["Future[None]"] = None
async def load_system_state(self):
await self._gc.stop()
self._gc = None
- async def _write_config(self, config: KresConfig):
- lua_config = config.render_lua()
+ async def _write_config(self, config_strict: KresConfigStrict):
+ lua_config = config_strict.render_lua()
await writefile(KRESD_CONFIG_FILE, lua_config)
async def apply_config(self, config: KresConfig):
async with self._manager_lock:
+ logger.debug("Validating configuration...")
+ config_strict = KresConfigStrict(config)
+
logger.debug("Writing new config to file...")
- await self._write_config(config)
+ await self._write_config(config_strict)
logger.debug("Testing the new config with a canary process")
try:
await self._spawn_new_worker()
except SubprocessError:
logger.error("kresd with the new config failed to start, rejecting config")
- last = self.get_last_used_config()
+ last = self.get_last_used_config_strict()
if last is not None:
await self._write_config(last)
- raise ValidationException("Canary kresd instance failed. Config is invalid.")
+ raise DataValidationException("Canary kresd instance failed. Config is invalid.")
logger.debug("Canary process test passed, Applying new config to all workers")
self._last_used_config = config
- await self._ensure_number_of_children(config.server.get_instances())
+ self._last_used_config_strict = config_strict
+ await self._ensure_number_of_children(config_strict.server.workers)
await self._rolling_restart()
if self._is_gc_running() != config.server.use_cache_gc:
def get_last_used_config(self) -> Optional[KresConfig]:
return self._last_used_config
+ def get_last_used_config_strict(self) -> Optional[KresConfigStrict]:
+ return self._last_used_config_strict
+
async def _instability_handler(self) -> None:
logger.error("Instability callback invoked. No idea how to react, performing suicide. See you later!")
sys.exit(1)
from aiohttp.web_response import json_response
from knot_resolver_manager.constants import MANAGER_CONFIG_FILE
-from knot_resolver_manager.exceptions import ValidationException
from knot_resolver_manager.kresd_controller import get_controller_by_name
from knot_resolver_manager.kresd_controller.interface import SubprocessController
+from knot_resolver_manager.utils import DataValidationException, Format
from knot_resolver_manager.utils.async_utils import readfile
-from knot_resolver_manager.utils.dataclasses_parservalidator import Format
from .datamodel import KresConfig
from .kres_manager import KresManager
try:
return await handler(request)
- except ValidationException as e:
+ except DataValidationException as e:
logger.error("Failed to parse given data in API request", exc_info=True)
return web.Response(text=f"Data validation failed: {e}", status=HTTPStatus.BAD_REQUEST)
from typing import Any, Callable, Iterable, Optional, Type, TypeVar
-from .dataclasses_parservalidator import DataclassParserValidatorMixin
+from .custom_types import CustomValueType
+from .data_parser_validator import DataParser, DataValidator, Format
+from .exceptions import DataParsingException, DataValidationException
from .overload import Overloaded
T = TypeVar("T")
__all__ = [
"ignore_exceptions_optional",
"ignore_exceptions",
- "DataclassParserValidatorMixin",
+ "Format",
+ "CustomValueType",
+ "DataParser",
+ "DataValidator",
+ "DataParsingException",
+ "DataValidationException",
"Overloaded",
]
--- /dev/null
+from typing import Any
+
+
+class CustomValueType:
+ """
+ Subclasses of this class can be used as type annotations in 'DataParser'. When a value
+ is being parsed from a serialized format (e.g. JSON/YAML), an object will be created by
+ calling the constructor of the appropriate type on the field value. The only limitation
+ is that the value MUST NOT be `None`.
+
+ Example:
+ ```
+ class A(DataParser):
+ field: MyCustomValueType
+
+ A.from_json('{"field": "value"}') == A(field=MyCustomValueType("value"))
+ ```
+
+ There is no validation done on the wrapped value. The only condition is that
+ it can't be `None`. If you want to perform any validation during creation,
+ raise a `DataValidationException` in case of errors.
+ """
+
+ def __init__(self, source_value: Any) -> None:
+ pass
+
+ def __int__(self) -> int:
+ raise NotImplementedError("CustomValueType return 'int()' value is not implemented.")
+
+ def __str__(self) -> str:
+ raise NotImplementedError("CustomValueType return 'str()' value is not implemented.")
import copy
+import inspect
import json
import re
from enum import Enum, auto
from yaml.constructor import ConstructorError
from yaml.nodes import MappingNode
-from knot_resolver_manager.exceptions import SchemaValidationException
+from knot_resolver_manager.utils.custom_types import CustomValueType
+from knot_resolver_manager.utils.exceptions import DataParsingException
from knot_resolver_manager.utils.types import (
get_attr_type,
get_generic_type_argument,
is_union,
)
-from ..compat.dataclasses import is_dataclass
+def is_internal_field(field_name: str) -> bool:
+ return field_name.startswith("_")
+
+
+def is_obj_type(obj: Any, types: Union[type, Tuple[Any, ...], Tuple[type, ...]]) -> bool:
+ # To check specific type we are using 'type()' instead of 'isinstance()'
+ # because for example 'bool' is instance of 'int', 'isinstance(False, int)' returns True.
+ # pylint: disable=unidiomatic-typecheck
+ if isinstance(types, Tuple):
+ return type(obj) in types
+ return type(obj) == types
+
+
+def _to_primitive(obj: Any) -> Any:
+ """
+ Convert our custom values into primitive variants for dumping.
+ """
+
+ # CustomValueType instances
+ if isinstance(obj, CustomValueType):
+ return str(obj)
+
+ # nested DataParser class instances
+ elif isinstance(obj, DataParser):
+ return obj.to_dict()
+
+ # otherwise just return, what we were given
+ else:
+ return obj
+
+
+def _validated_object_type(cls: Type[Any], obj: Any, default: Any = ..., use_default: bool = False) -> Any:
+ """
+ Given an expected type `cls` and a value object `obj`, validate the type of `obj` and return it
+ """
-def _from_dictlike_obj(cls: Any, obj: Any, default: Any, use_default: bool) -> Any:
# Disabling these checks, because I think it's much more readable as a single function
- # and it's not that large at this point. If it got larger, then we should definitely split
- # it
+ # and it's not that large at this point. If it got larger, then we should definitely split it
# pylint: disable=too-many-branches,too-many-locals,too-many-statements
# default values
if obj is None:
return None
else:
- raise SchemaValidationException(f"Expected None, found {obj}")
+ raise DataParsingException(f"Expected None, found '{obj}'.")
# Union[*variants] (handles Optional[T] due to the way the typing system works)
elif is_union(cls):
variants = get_generic_type_arguments(cls)
for v in variants:
try:
- return _from_dictlike_obj(v, obj, ..., False)
- except SchemaValidationException:
+ return _validated_object_type(v, obj)
+ except DataParsingException:
pass
- raise SchemaValidationException(f"Union {cls} could not be parsed - parsing of all variants failed")
+ raise DataParsingException(f"Union {cls} could not be parsed - parsing of all variants failed.")
# after this, there is no place for a None object
elif obj is None:
- raise SchemaValidationException(f"Unexpected None value for type {cls}")
+ raise DataParsingException(f"Unexpected None value for type {cls}")
# int
elif cls == int:
# we don't want to make an int out of anything else than other int
- if isinstance(obj, int):
+ # except for CustomValueType class instances
+ if is_obj_type(obj, int) or isinstance(obj, CustomValueType):
return int(obj)
- else:
- raise SchemaValidationException(f"Expected int, found {type(obj)}")
+ raise DataParsingException(f"Expected int, found {type(obj)}")
# str
elif cls == str:
# we are willing to cast any primitive value to string, but no compound values are allowed
- if isinstance(obj, (str, float, int)):
+ if is_obj_type(obj, (str, float, int)) or isinstance(obj, CustomValueType):
return str(obj)
- elif isinstance(obj, bool):
- raise SchemaValidationException(
+ elif is_obj_type(obj, bool):
+ raise DataParsingException(
"Expected str, found bool. Be careful, that YAML parsers consider even"
' "no" and "yes" as a bool. Search for the Norway Problem for more'
" details. And please use quotes explicitly."
)
else:
- raise SchemaValidationException(
+ raise DataParsingException(
f"Expected str (or number that would be cast to string), but found type {type(obj)}"
)
# bool
elif cls == bool:
- if isinstance(obj, bool):
+ if is_obj_type(obj, bool):
return obj
else:
- raise SchemaValidationException(f"Expected bool, found {type(obj)}")
+ raise DataParsingException(f"Expected bool, found {type(obj)}")
# float
elif cls == float:
raise NotImplementedError(
- "Floating point values are not supported in the parser validator."
+ "Floating point values are not supported in the parser."
" Please implement them and be careful with type coercions"
)
if obj == expected:
return obj
else:
- raise SchemaValidationException(f"Literal {cls} is not matched with the value {obj}")
+ raise DataParsingException(f"Literal {cls} is not matched with the value {obj}")
# Dict[K,V]
elif is_dict(cls):
key_type, val_type = get_generic_type_arguments(cls)
try:
return {
- _from_dictlike_obj(key_type, key, ..., False): _from_dictlike_obj(val_type, val, ..., False)
- for key, val in obj.items()
+ _validated_object_type(key_type, key): _validated_object_type(val_type, val) for key, val in obj.items()
}
except AttributeError as e:
- raise SchemaValidationException(
+ raise DataParsingException(
f"Expected dict-like object, but failed to access its .items() method. Value was {obj}", e
)
# List[T]
elif is_list(cls):
inner_type = get_generic_type_argument(cls)
- return [_from_dictlike_obj(inner_type, val, ..., False) for val in obj]
+ return [_validated_object_type(inner_type, val) for val in obj]
# Tuple[A,B,C,D,...]
elif is_tuple(cls):
types = get_generic_type_arguments(cls)
- return tuple(_from_dictlike_obj(typ, val, ..., False) for typ, val in zip(types, obj))
-
- # nested dataclass
- elif is_dataclass(cls):
- anot = cls.__dict__.get("__annotations__", {})
- kwargs = {}
- for name, python_type in anot.items():
- # skip internal fields
- if name.startswith("_"):
- continue
-
- value = obj[name] if name in obj else None
- use_default = hasattr(cls, name)
- default = getattr(cls, name, ...)
- kwargs[name] = _from_dictlike_obj(python_type, value, default, use_default)
- return cls(**kwargs)
+ return tuple(_validated_object_type(typ, val) for typ, val in zip(types, obj))
+
+ # CustomValueType subclasses
+ elif inspect.isclass(cls) and issubclass(cls, CustomValueType):
+ # no validation performed, the implementation does it in the constuctor
+ return cls(obj)
+
+ # nested DataParser subclasses
+ elif inspect.isclass(cls) and issubclass(cls, DataParser):
+ # we should return DataParser, we expect to be given a dict,
+ # because we can construct a DataParser from it
+ if isinstance(obj, dict):
+ return cls(obj) # type: ignore
+ raise DataParsingException(f"Expected '{dict}' object, found '{type(obj)}'")
+
+ # nested DataValidator subclasses
+ elif inspect.isclass(cls) and issubclass(cls, DataValidator):
+ # we should return DataValidator, we expect to be given a DataParser,
+ # because we can construct a DataValidator from it
+ if isinstance(obj, DataParser):
+ return cls(obj)
+ raise DataParsingException(f"Expected instance of '{DataParser}' class, found '{type(obj)}'")
# default error handler
else:
- raise SchemaValidationException(
+ raise DataParsingException(
f"Type {cls} cannot be parsed. This is a implementation error. "
- "Please fix your types in the dataclass or improve the parser/validator."
+ "Please fix your types in the class or improve the parser/validator."
)
dict_out: Dict[Any, Any] = {}
for key, val in pairs:
if key in dict_out:
- raise SchemaValidationException(f"duplicate key detected: {key}")
+ raise DataParsingException(f"Duplicate attribute key detected: {key}")
dict_out[key] = val
return dict_out
# check for duplicate keys
if key in mapping:
- raise SchemaValidationException(f"duplicate key detected: {key_node.start_mark}")
+ raise DataParsingException(f"duplicate key detected: {key_node.start_mark}")
value = self.construct_object(value_node, deep=deep) # type: ignore
mapping[key] = value
return mapping
-_T = TypeVar("_T", bound="DataclassParserValidatorMixin")
-
-
-_SUBTREE_MUTATION_PATH_PATTERN = re.compile(r"^(/[^/]+)*/?$")
-
-
class Format(Enum):
YAML = auto()
JSON = auto()
else:
raise NotImplementedError(f"Parsing of format '{self}' is not implemented")
+ def dict_dump(self, data: Dict[str, Any]) -> str:
+ if self is Format.YAML:
+ return yaml.safe_dump(data) # type: ignore
+ elif self is Format.JSON:
+ return json.dumps(data)
+ else:
+ raise NotImplementedError(f"Exporting to '{self}' format is not implemented")
+
@staticmethod
def from_mime_type(mime_type: str) -> "Format":
formats = {
"text/vnd.yaml": Format.YAML,
}
if mime_type not in formats:
- raise SchemaValidationException("Unsupported MIME type")
+ raise DataParsingException("Unsupported MIME type")
return formats[mime_type]
-class DataclassParserValidatorMixin:
- def __init__(self, *args: Any, **kwargs: Any):
- """
- This constructor is useless except for typechecking. It makes sure that the dataclasses can be created with
- any arguments whatsoever.
- """
+_T = TypeVar("_T", bound="DataParser")
- def validate(self) -> None:
- for field_name in dir(self):
- # skip internal fields
- if field_name.startswith("_"):
+
+_SUBTREE_MUTATION_PATH_PATTERN = re.compile(r"^(/[^/]+)*/?$")
+
+
+class DataParser:
+ def __init__(self, obj: Optional[Dict[str, Any]] = None):
+ cls = self.__class__
+ annot = cls.__dict__.get("__annotations__", {})
+
+ used_keys: List[str] = []
+ for name, python_type in annot.items():
+ if is_internal_field(name):
continue
- field = getattr(self, field_name)
- if is_dataclass(field):
- if not isinstance(field, DataclassParserValidatorMixin):
- raise SchemaValidationException(
- f"Nested dataclass in the field {field_name} does not include the ParserValidatorMixin"
- )
- field.validate()
+ val = None
+ dash_name = name.replace("_", "-")
+ if obj and dash_name in obj:
+ val = obj[dash_name]
+ used_keys.append(dash_name)
- self._validate()
+ use_default = hasattr(cls, name)
+ default = getattr(cls, name, ...)
+ value = _validated_object_type(python_type, val, default, use_default)
+ setattr(self, name, value)
- def _validate(self) -> None:
- raise NotImplementedError(f"Validation function is not implemented in class {type(self).__name__}")
+ # check for unused keys
+ if obj:
+ for key in obj:
+ if key not in used_keys:
+ raise DataParsingException(f"Unknown attribute key '{key}'.")
@classmethod
def parse_from(cls: Type[_T], fmt: Format, text: str):
- data = fmt.parse_to_dict(text)
- config: _T = _from_dictlike_obj(cls, data, ..., False)
- config.validate()
+ data_dict = fmt.parse_to_dict(text)
+ config: _T = cls(data_dict)
return config
@classmethod
def from_json(cls: Type[_T], text: str) -> _T:
return cls.parse_from(Format.JSON, text)
+ def to_dict(self) -> Dict[str, Any]:
+ cls = self.__class__
+ anot = cls.__dict__.get("__annotations__", {})
+ dict_obj: Dict[str, Any] = {}
+ for name in anot:
+ if is_internal_field(name):
+ continue
+
+ value = getattr(self, name)
+ dash_name = str(name).replace("_", "-")
+ dict_obj[dash_name] = _to_primitive(value)
+ return dict_obj
+
+ def dump(self, fmt: Format) -> str:
+ dict_data = self.to_dict()
+ return fmt.dict_dump(dict_data)
+
+ def dump_to_yaml(self) -> str:
+ return self.dump(Format.YAML)
+
+ def dump_to_json(self) -> str:
+ return self.dump(Format.JSON)
+
def copy_with_changed_subtree(self: _T, fmt: Format, path: str, text: str) -> _T:
cls = self.__class__
# prepare and validate the path object
path = path[:-1] if path.endswith("/") else path
if re.match(_SUBTREE_MUTATION_PATH_PATTERN, path) is None:
- raise SchemaValidationException("Provided object path for mutation is invalid.")
+ raise DataParsingException("Provided object path for mutation is invalid.")
path = path[1:] if path.startswith("/") else path
# now, the path variable should contain '/' separated field names
to_mutate = copy.deepcopy(self)
obj = to_mutate
parent = None
- for segment in path.split("/"):
+
+ for dash_segment in path.split("/"):
+ segment = dash_segment.replace("-", "_")
+
if segment == "":
- raise SchemaValidationException(f"Unexpectedly empty segment in path '{path}'")
- elif segment.startswith("_"):
- raise SchemaValidationException(
- "No, changing internal fields (starting with _) is not allowed. Nice try."
- )
+ raise DataParsingException(f"Unexpectedly empty segment in path '{path}'")
+ elif is_internal_field(segment):
+ raise DataParsingException("No, changing internal fields (starting with _) is not allowed. Nice try.")
elif hasattr(obj, segment):
parent = obj
obj = getattr(parent, segment)
else:
- raise SchemaValidationException(
- f"Path segment '{segment}' does not match any field on the provided parent object"
+ raise DataParsingException(
+ f"Path segment '{dash_segment}' does not match any field on the provided parent object"
)
assert parent is not None
# assign the subtree
- last_name = path.split("/")[-1]
+ last_name = path.split("/")[-1].replace("-", "_")
data = fmt.parse_to_dict(text)
tp = get_attr_type(parent, last_name)
- parsed_value = _from_dictlike_obj(tp, data, ..., False)
+ parsed_value = _validated_object_type(tp, data)
setattr(parent, last_name, parsed_value)
- to_mutate.validate()
-
return to_mutate
+
+
+class DataValidator:
+ def __init__(self, obj: DataParser):
+ cls = self.__class__
+ anot = cls.__dict__.get("__annotations__", {})
+
+ for attr_name, attr_type in anot.items():
+ if is_internal_field(attr_name):
+ continue
+
+ # use transformation function if available
+ if hasattr(self, f"_{attr_name}"):
+ value = getattr(self, f"_{attr_name}")(obj)
+ elif hasattr(obj, attr_name):
+ value = getattr(obj, attr_name)
+ else:
+ raise DataParsingException(f"DataParser object {obj} is missing '{attr_name}' attribute.")
+
+ setattr(self, attr_name, _validated_object_type(attr_type, value))
+
+ self._validate()
+
+ def validate(self) -> None:
+ for field_name in dir(self):
+ if is_internal_field(field_name):
+ continue
+
+ field = getattr(self, field_name)
+ if isinstance(field, DataValidator):
+ field.validate()
+ self._validate()
+
+ def _validate(self) -> None:
+ raise NotImplementedError(f"Validation function is not implemented in class {type(self).__name__}")
--- /dev/null
+class DataParsingException(Exception):
+ pass
+
+
+class DataValidationException(Exception):
+ pass
def is_literal(tp: Any) -> bool:
- return getattr(tp, "__origin__", None) == Literal
+ return isinstance(tp, type(Literal))
def get_generic_type_arguments(tp: Any) -> List[Any]:
default: List[Any] = []
- return getattr(tp, "__args__", default)
+ if is_literal(tp):
+ return getattr(tp, "__values__")
+ else:
+ return getattr(tp, "__args__", default)
def get_generic_type_argument(tp: Any) -> Any:
--- /dev/null
+from pytest import raises
+
+from knot_resolver_manager.datamodel.types import SizeUnit, TimeUnit
+from knot_resolver_manager.utils import DataParser, DataValidationException, DataValidator
+
+
+def test_size_unit():
+ assert (
+ SizeUnit(5368709120)
+ == SizeUnit("5368709120")
+ == SizeUnit("5368709120B")
+ == SizeUnit("5242880K")
+ == SizeUnit("5120M")
+ == SizeUnit("5G")
+ )
+
+ with raises(DataValidationException):
+ SizeUnit("-5368709120")
+ with raises(DataValidationException):
+ SizeUnit(-5368709120)
+ with raises(DataValidationException):
+ SizeUnit("5120MM")
+
+
+def test_time_unit():
+ assert TimeUnit("1d") == TimeUnit("24h") == TimeUnit("1440m") == TimeUnit("86400s") == TimeUnit(86400)
+
+ with raises(DataValidationException):
+ TimeUnit("-1")
+ with raises(DataValidationException):
+ TimeUnit(-24)
+ with raises(DataValidationException):
+ TimeUnit("1440mm")
+
+
+def test_parsing_units():
+ class TestClass(DataParser):
+ size: SizeUnit
+ time: TimeUnit
+
+ class TestClassStrict(DataValidator):
+ size: int
+ time: int
+
+ def _validate(self) -> None:
+ pass
+
+ yaml = """
+size: 3K
+time: 10m
+"""
+
+ obj = TestClass.from_yaml(yaml)
+ assert obj.size == SizeUnit(3 * 1024)
+ assert obj.time == TimeUnit(10 * 60)
+
+ strict = TestClassStrict(obj)
+ assert strict.size == 3 * 1024
+ assert strict.time == 10 * 60
+
+ y = obj.dump_to_yaml()
+ j = obj.dump_to_json()
+ a = TestClass.from_yaml(y)
+ b = TestClass.from_json(j)
+ assert a.size == b.size == obj.size
+ assert a.time == b.time == obj.time
+++ /dev/null
-from knot_resolver_manager.datamodel import KresConfig
-
-
-def test_simple():
- json = """
- {
- "server": {
- "instances": 1
- },
- "lua": {
- "script_list": [
- "-- SPDX-License-Identifier: CC0-1.0",
- "-- vim:syntax=lua:set ts=4 sw=4:",
- "-- Refer to manual: https://knot-resolver.readthedocs.org/en/stable/",
- "-- Network interface configuration","net.listen('127.0.0.1', 53, { kind = 'dns' })",
- "net.listen('127.0.0.1', 853, { kind = 'tls' })",
- "--net.listen('127.0.0.1', 443, { kind = 'doh2' })",
- "net.listen('::1', 53, { kind = 'dns', freebind = true })",
- "net.listen('::1', 853, { kind = 'tls', freebind = true })",
- "--net.listen('::1', 443, { kind = 'doh2' })",
- "-- Load useful modules","modules = {",
- "'hints > iterate', -- Load /etc/hosts and allow custom root hints",
- "'stats', -- Track internal statistics",
- "'predict', -- Prefetch expiring/frequent records",
- "}",
- "-- Cache size",
- "cache.size = 100 * MB"
- ]
- }
- }
- """
-
- config = KresConfig.from_json(json)
-
- assert config.server.instances == 1
- assert config.lua.script is not None
--- /dev/null
+from typing import Dict, List, Optional, Tuple, Union
+
+from pytest import raises
+from typing_extensions import Literal
+
+from knot_resolver_manager.utils import DataParser, DataValidationException, DataValidator, Format
+from knot_resolver_manager.utils.exceptions import DataParsingException
+
+
+def test_primitive():
+ class TestClass(DataParser):
+ i: int
+ s: str
+ b: bool
+
+ class TestClassStrict(DataValidator):
+ i: int
+ s: str
+ b: bool
+
+ def _validate(self) -> None:
+ pass
+
+ yaml = """
+i: 5
+s: test
+b: false
+"""
+
+ obj = TestClass.from_yaml(yaml)
+ assert obj.i == 5
+ assert obj.s == "test"
+ assert obj.b == False
+
+ strict = TestClassStrict(obj)
+ assert strict.i == 5
+ assert strict.s == "test"
+ assert strict.b == False
+
+ y = obj.dump_to_yaml()
+ j = obj.dump_to_json()
+ a = TestClass.from_yaml(y)
+ b = TestClass.from_json(j)
+ assert a.i == b.i == obj.i
+ assert a.s == b.s == obj.s
+ assert a.b == b.b == obj.b
+
+
+def test_parsing_primitive_exceptions():
+ class TestStr(DataParser):
+ s: str
+
+ # int and float are allowed inputs for string
+ with raises(DataParsingException):
+ TestStr.from_yaml("s: false") # bool
+
+ class TestInt(DataParser):
+ i: int
+
+ with raises(DataParsingException):
+ TestInt.from_yaml("i: false") # bool
+ with raises(DataParsingException):
+ TestInt.from_yaml('i: "5"') # str
+ with raises(DataParsingException):
+ TestInt.from_yaml("i: 5.5") # float
+
+ class TestBool(DataParser):
+ b: bool
+
+ with raises(DataParsingException):
+ TestBool.from_yaml("b: 5") # int
+ with raises(DataParsingException):
+ TestBool.from_yaml('b: "5"') # str
+ with raises(DataParsingException):
+ TestBool.from_yaml("b: 5.5") # float
+
+
+def test_nested():
+ class Lower(DataParser):
+ i: int
+
+ class Upper(DataParser):
+ l: Lower
+
+ class LowerStrict(DataValidator):
+ i: int
+
+ def _validate(self) -> None:
+ pass
+
+ class UpperStrict(DataValidator):
+ l: LowerStrict
+
+ def _validate(self) -> None:
+ pass
+
+ yaml = """
+l:
+ i: 5
+"""
+
+ obj = Upper.from_yaml(yaml)
+ assert obj.l.i == 5
+
+ strict = UpperStrict(obj)
+ assert strict.l.i == 5
+
+ y = obj.dump_to_yaml()
+ j = obj.dump_to_json()
+ a = Upper.from_yaml(y)
+ b = Upper.from_json(j)
+ assert a.l.i == b.l.i == obj.l.i
+
+
+def test_simple_compount_types():
+ class TestClass(DataParser):
+ l: List[int]
+ d: Dict[str, str]
+ t: Tuple[str, int]
+ o: Optional[int]
+
+ class TestClassStrict(DataValidator):
+ l: List[int]
+ d: Dict[str, str]
+ t: Tuple[str, int]
+ o: Optional[int]
+
+ def _validate(self) -> None:
+ pass
+
+ yaml = """
+l:
+ - 1
+ - 2
+ - 3
+ - 4
+ - 5
+d:
+ something: else
+ w: all
+t:
+ - test
+ - 5
+"""
+
+ obj = TestClass.from_yaml(yaml)
+ assert obj.l == [1, 2, 3, 4, 5]
+ assert obj.d == {"something": "else", "w": "all"}
+ assert obj.t == ("test", 5)
+ assert obj.o is None
+
+ strict = TestClassStrict(obj)
+ assert strict.l == [1, 2, 3, 4, 5]
+ assert strict.d == {"something": "else", "w": "all"}
+ assert strict.t == ("test", 5)
+ assert strict.o is None
+
+ y = obj.dump_to_yaml()
+ j = obj.dump_to_json()
+ a = TestClass.from_yaml(y)
+ b = TestClass.from_json(j)
+ assert a.l == b.l == obj.l
+ assert a.d == b.d == obj.d
+ assert a.t == b.t == obj.t
+ assert a.o == b.o == obj.o
+
+
+def test_nested_compound_types():
+ class TestClass(DataParser):
+ o: Optional[Dict[str, str]]
+
+ class TestClassStrict(DataValidator):
+ o: Optional[Dict[str, str]]
+
+ def _validate(self) -> None:
+ pass
+
+ yaml = """
+o:
+ key: val
+"""
+
+ obj = TestClass.from_yaml(yaml)
+ assert obj.o == {"key": "val"}
+
+ strict = TestClassStrict(obj)
+ assert strict.o == {"key": "val"}
+
+ y = obj.dump_to_yaml()
+ j = obj.dump_to_json()
+ a = TestClass.from_yaml(y)
+ b = TestClass.from_json(j)
+ assert a.o == b.o == obj.o
+
+
+def test_nested_compount_types2():
+ class TestClass(DataParser):
+ i: int
+ o: Optional[Dict[str, str]]
+
+ class TestClassStrict(DataValidator):
+ i: int
+ o: Optional[Dict[str, str]]
+
+ def _validate(self) -> None:
+ pass
+
+ yaml = "i: 5"
+
+ obj = TestClass.from_yaml(yaml)
+ assert obj.i == 5
+ assert obj.o is None
+
+ strict = TestClassStrict(obj)
+ assert strict.i == 5
+ assert strict.o is None
+
+ y = obj.dump_to_yaml()
+ j = obj.dump_to_json()
+ a = TestClass.from_yaml(y)
+ b = TestClass.from_json(j)
+ assert a.i == b.i == obj.i
+ assert a.o == b.o == obj.o
+
+
+def test_partial_mutations():
+ class Inner(DataParser):
+ size: int = 5
+
+ class ConfData(DataParser):
+ workers: Union[Literal["auto"], int] = 1
+ lua_config: Optional[str] = None
+ inner: Inner = Inner()
+
+ class InnerStrict(DataValidator):
+ size: int
+
+ def _validate(self) -> None:
+ pass
+
+ class ConfDataStrict(DataValidator):
+ workers: int
+ lua_config: Optional[str]
+ inner: InnerStrict
+
+ def _workers(self, data: ConfData) -> int:
+ if data.workers == "auto":
+ return 8
+ else:
+ return data.workers
+
+ def _validate(self) -> None:
+ if self.workers < 0:
+ raise DataValidationException("Number of workers must be non-negative")
+
+ yaml = """
+ workers: auto
+ lua-config: something
+ """
+
+ conf = ConfData.from_yaml(yaml)
+
+ x = ConfDataStrict(conf)
+ assert x.lua_config == "something"
+ assert x.inner.size == 5
+ assert x.workers == 8
+
+ y = conf.dump_to_yaml()
+ j = conf.dump_to_json()
+ a = ConfData.from_yaml(y)
+ b = ConfData.from_json(j)
+ assert a.workers == b.workers == conf.workers
+ assert a.lua_config == b.lua_config == conf.lua_config
+ assert a.inner.size == b.inner.size == conf.inner.size
+
+ # replacement of 'lua-config' attribute
+ x = ConfDataStrict(conf.copy_with_changed_subtree(Format.JSON, "/lua-config", '"new_value"'))
+ assert x.lua_config == "new_value"
+ assert x.inner.size == 5
+ assert x.workers == 8
+
+ # replacement of the whole tree
+ x = ConfDataStrict(conf.copy_with_changed_subtree(Format.JSON, "/", '{"inner": {"size": 55}}'))
+ assert x.lua_config is None
+ assert x.workers == 1
+ assert x.inner.size == 55
+
+ # replacement of 'inner' subtree
+ x = ConfDataStrict(conf.copy_with_changed_subtree(Format.JSON, "/inner", '{"size": 33}'))
+ assert x.lua_config == "something"
+ assert x.workers == 8
+ assert x.inner.size == 33
+++ /dev/null
-from dataclasses import dataclass
-from typing import Dict, List, Optional, Tuple
-
-from knot_resolver_manager.utils.dataclasses_parservalidator import DataclassParserValidatorMixin, Format
-
-
-def test_parsing_primitive():
- @dataclass
- class TestClass(DataclassParserValidatorMixin):
- i: int
- s: str
-
- def _validate(self):
- pass
-
- yaml = """i: 5
-s: "test"
-"""
-
- obj = TestClass.from_yaml(yaml)
-
- assert obj.i == 5
- assert obj.s == "test"
-
-
-def test_parsing_nested():
- @dataclass
- class Lower(DataclassParserValidatorMixin):
- i: int
-
- def _validate(self):
- pass
-
- @dataclass
- class Upper(DataclassParserValidatorMixin):
- l: Lower
-
- def _validate(self):
- pass
-
- yaml = """l:
- i: 5"""
-
- obj = Upper.from_yaml(yaml)
- assert obj.l.i == 5
-
-
-def test_simple_compount_types():
- @dataclass
- class TestClass(DataclassParserValidatorMixin):
- l: List[int]
- d: Dict[str, str]
- t: Tuple[str, int]
- o: Optional[int]
-
- def _validate(self):
- pass
-
- yaml = """l:
- - 1
- - 2
- - 3
- - 4
- - 5
-d:
- something: else
- w: all
-t:
- - test
- - 5"""
-
- obj = TestClass.from_yaml(yaml)
-
- assert obj.l == [1, 2, 3, 4, 5]
- assert obj.d == {"something": "else", "w": "all"}
- assert obj.t == ("test", 5)
- assert obj.o is None
-
-
-def test_nested_compount_types():
- @dataclass
- class TestClass(DataclassParserValidatorMixin):
- o: Optional[Dict[str, str]]
-
- def _validate(self):
- pass
-
- yaml = """o:
- key: val"""
-
- obj = TestClass.from_yaml(yaml)
-
- assert obj.o == {"key": "val"}
-
-
-def test_nested_compount_types2():
- @dataclass
- class TestClass(DataclassParserValidatorMixin):
- i: int
- o: Optional[Dict[str, str]]
-
- def _validate(self):
- pass
-
- yaml = "i: 5"
-
- obj = TestClass.from_yaml(yaml)
-
- assert obj.o is None
-
-
-def test_real_failing_dummy_confdata():
- @dataclass
- class ConfData(DataclassParserValidatorMixin):
- num_workers: int = 1
- lua_config: Optional[str] = None
-
- def _validate(self):
- if self.num_workers < 0:
- raise Exception("Number of workers must be non-negative")
-
- # prepare the payload
- lua_config = "dummy"
- config = f"""
-num_workers: 4
-lua_config: |
- { lua_config }"""
-
- data = ConfData.from_yaml(config)
-
- assert type(data.num_workers) == int
- assert data.num_workers == 4
- assert type(data.lua_config) == str
- assert data.lua_config == "dummy"
-
-
-def test_partial_mutations():
- @dataclass
- class Inner(DataclassParserValidatorMixin):
- number: int
-
- def _validate(self):
- pass
-
- @dataclass
- class ConfData(DataclassParserValidatorMixin):
- num_workers: int = 1
- lua_config: Optional[str] = None
- inner: Inner = Inner(5)
-
- def _validate(self):
- if self.num_workers < 0:
- raise Exception("Number of workers must be non-negative")
-
- data = ConfData(5, "something", Inner(10))
-
- x = data.copy_with_changed_subtree(Format.JSON, "/lua_config", '"new_value"')
- assert x.lua_config == "new_value"
- assert x.num_workers == 5
- assert x.inner.number == 10
-
- x = data.copy_with_changed_subtree(Format.JSON, "/inner", '{"number": 55}')
- assert x.lua_config == "something"
- assert x.num_workers == 5
- assert x.inner.number == 55
-
- x = data.copy_with_changed_subtree(Format.JSON, "/", '{"inner": {"number": 55}}')
- assert x.lua_config is None
- assert x.num_workers == 1
- assert x.inner.number == 55
from typing_extensions import Literal
-from knot_resolver_manager.utils.types import LiteralEnum, is_list
+from knot_resolver_manager.utils.types import LiteralEnum, is_list, is_literal
def test_is_list():
assert is_list(List[int])
+def test_is_literal():
+ assert is_literal(Literal[5])
+ assert is_literal(Literal["test"])
+
+
def test_literal_enum():
assert LiteralEnum[5, "test"] == Union[Literal[5], Literal["test"]]
assert LiteralEnum["str", 5] == Union[Literal["str"], Literal[5]]