]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
utils: unified data modeling tools into one SchemaNode class
authorVasek Sraier <git@vakabus.cz>
Thu, 16 Sep 2021 13:29:30 +0000 (15:29 +0200)
committerAleš Mrázek <ales.mrazek@nic.cz>
Fri, 8 Apr 2022 14:17:53 +0000 (16:17 +0200)
12 files changed:
manager/knot_resolver_manager/datamodel/config.py
manager/knot_resolver_manager/datamodel/dns64_config.py
manager/knot_resolver_manager/datamodel/dnssec_config.py
manager/knot_resolver_manager/datamodel/lua_config.py
manager/knot_resolver_manager/datamodel/network_config.py
manager/knot_resolver_manager/datamodel/options_config.py
manager/knot_resolver_manager/datamodel/server_config.py
manager/knot_resolver_manager/datamodel/types.py
manager/knot_resolver_manager/utils/__init__.py
manager/knot_resolver_manager/utils/data_parser_validator.py
manager/tests/datamodel/test_datamodel_types.py
manager/tests/utils/test_data_parser_validator.py

index 52b7c862f789ada986580f208450b66623cca1ff..96d65376d499a70fa5e4ea77bd27a0a8b5a4cc98 100644 (file)
@@ -9,7 +9,7 @@ from knot_resolver_manager.datamodel.lua_config import Lua, LuaStrict
 from knot_resolver_manager.datamodel.network_config import Network, NetworkStrict
 from knot_resolver_manager.datamodel.options_config import Options, OptionsStrict
 from knot_resolver_manager.datamodel.server_config import Server, ServerStrict
-from knot_resolver_manager.utils import DataParser, DataValidator
+from knot_resolver_manager.utils import SchemaNode
 
 
 def _import_lua_template() -> Template:
@@ -23,7 +23,7 @@ def _import_lua_template() -> Template:
 _LUA_TEMPLATE = _import_lua_template()
 
 
-class KresConfig(DataParser):
+class KresConfig(SchemaNode):
     server: Server = Server()
     options: Options = Options()
     network: Network = Network()
@@ -32,7 +32,7 @@ class KresConfig(DataParser):
     lua: Lua = Lua()
 
 
-class KresConfigStrict(DataValidator):
+class KresConfigStrict(SchemaNode):
     server: ServerStrict
     options: OptionsStrict
     network: NetworkStrict
index f59970ee20ad8bbc907cd7bdfe41576c803474cc..b89063960de5e0accb7a17d7cd66c000a3fcc949 100644 (file)
@@ -1,10 +1,10 @@
 from knot_resolver_manager.datamodel.types import IPv6Network96
-from knot_resolver_manager.utils import DataParser, DataValidator
+from knot_resolver_manager.utils import SchemaNode
 
 
-class Dns64(DataParser):
+class Dns64(SchemaNode):
     prefix: IPv6Network96 = IPv6Network96("64:ff9b::/96")
 
 
-class Dns64Strict(DataValidator):
+class Dns64Strict(SchemaNode):
     prefix: IPv6Network96
index 2590b5ddbeaa7db4ee32d64181b5e076d88280d3..3a689736bdaebb851bda2611725c1ed2696b0fdb 100644 (file)
@@ -1,15 +1,15 @@
 from typing import List, Optional
 
 from knot_resolver_manager.datamodel.types import TimeUnit
-from knot_resolver_manager.utils import DataParser, DataValidator
+from knot_resolver_manager.utils import SchemaNode
 
 
-class TrustAnchorFile(DataParser):
+class TrustAnchorFile(SchemaNode):
     file: str
     read_only: bool = False
 
 
-class Dnssec(DataParser):
+class Dnssec(SchemaNode):
     trust_anchor_sentinel: bool = True
     trust_anchor_signal_query: bool = True
     time_skew_detection: bool = True
@@ -22,12 +22,12 @@ class Dnssec(DataParser):
     trust_anchors_files: Optional[List[TrustAnchorFile]] = None
 
 
-class TrustAnchorFileStrict(DataValidator):
+class TrustAnchorFileStrict(SchemaNode):
     file: str
     read_only: bool
 
 
-class DnssecStrict(DataValidator):
+class DnssecStrict(SchemaNode):
     trust_anchor_sentinel: bool
     trust_anchor_signal_query: bool
     time_skew_detection: bool
index 6a924fa20d83758508285d4a7b1bcd6b10f2273c..bea95620e56ba37597fbc7fb78778fe85ff95238 100644 (file)
@@ -1,16 +1,16 @@
 from typing import Optional
 
 from knot_resolver_manager.exceptions import ValidationException
-from knot_resolver_manager.utils import DataParser, DataValidator
+from knot_resolver_manager.utils import SchemaNode
 
 
-class Lua(DataParser):
+class Lua(SchemaNode):
     script_only: bool = False
     script: Optional[str] = None
     script_file: Optional[str] = None
 
 
-class LuaStrict(DataValidator):
+class LuaStrict(SchemaNode):
     script_only: bool
     script: Optional[str]
     script_file: Optional[str]
index ab54d5429497421559f39867fc9c2a994550362a..78d01d8bc2431bdb4d35a499eae96f776d0c992b 100644 (file)
@@ -1,18 +1,18 @@
 from typing import List
 
-from knot_resolver_manager.utils import DataParser, DataValidator
+from knot_resolver_manager.utils import SchemaNode
 from knot_resolver_manager.utils.types import LiteralEnum
 
 KindEnum = LiteralEnum["dns", "xdp", "dot", "doh"]
 
 
-class Interface(DataParser):
+class Interface(SchemaNode):
     listen: str
     kind: KindEnum = "dns"
     freebind: bool = False
 
 
-class InterfaceStrict(DataValidator):
+class InterfaceStrict(SchemaNode):
     address: str
     port: int
     kind: str
@@ -32,9 +32,9 @@ class InterfaceStrict(DataValidator):
         return port_map.get(obj.kind, 0)
 
 
-class Network(DataParser):
+class Network(SchemaNode):
     interfaces: List[Interface] = [Interface({"listen": "127.0.0.1"}), Interface({"listen": "::1", "freebind": True})]
 
 
-class NetworkStrict(DataValidator):
+class NetworkStrict(SchemaNode):
     interfaces: List[InterfaceStrict]
index f3c7a31cf01ab54bbee51ca3bf635447f5c8fb85..0f03efab7d1062472bae66d86d9ed8d4a519f9c5 100644 (file)
@@ -1,6 +1,6 @@
 from typing import Union
 
-from knot_resolver_manager.utils import DataParser, DataValidator
+from knot_resolver_manager.utils import SchemaNode
 from knot_resolver_manager.utils.types import LiteralEnum
 
 from .types import TimeUnit
@@ -8,12 +8,12 @@ from .types import TimeUnit
 GlueCheckingEnum = LiteralEnum["normal", "strict", "permissive"]
 
 
-class Prediction(DataParser):
+class Prediction(SchemaNode):
     window: TimeUnit = TimeUnit("15m")
     period: int = 24
 
 
-class Options(DataParser):
+class Options(SchemaNode):
     glue_checking: GlueCheckingEnum = "normal"
     qname_minimisation: bool = True
     query_loopback: bool = False
@@ -29,12 +29,12 @@ class Options(DataParser):
     prediction: Union[bool, Prediction] = False
 
 
-class PredictionStrict(DataValidator):
+class PredictionStrict(SchemaNode):
     window: TimeUnit
     period: int
 
 
-class OptionsStrict(DataValidator):
+class OptionsStrict(SchemaNode):
     glue_checking: GlueCheckingEnum
     qname_minimisation: bool
     query_loopback: bool
index cf8a5e6c03c85cf07e580b27eab75a2d84b48d7d..69bd70f5e37349df476ad3d875516a2b82266e75 100644 (file)
@@ -7,7 +7,7 @@ from typing_extensions import Literal
 
 from knot_resolver_manager.datamodel.types import AnyPath, Listen, ListenStrict
 from knot_resolver_manager.exceptions import ValidationException
-from knot_resolver_manager.utils import DataParser, DataValidator
+from knot_resolver_manager.utils import SchemaNode
 from knot_resolver_manager.utils.types import LiteralEnum
 
 logger = logging.getLogger(__name__)
@@ -34,33 +34,33 @@ def _cpu_count() -> int:
 BackendEnum = LiteralEnum["auto", "systemd", "supervisord"]
 
 
-class Management(DataParser):
+class Management(SchemaNode):
     listen: Listen = Listen({"unix-socket": "/tmp/manager.sock"})
     backend: BackendEnum = "auto"
     rundir: AnyPath = AnyPath(".")
 
 
-class ManagementStrict(DataValidator):
+class ManagementStrict(SchemaNode):
     listen: ListenStrict
     backend: BackendEnum
     rundir: AnyPath
 
 
-class Webmgmt(DataParser):
+class Webmgmt(SchemaNode):
     listen: Listen
     tls: bool = False
     cert_file: Optional[AnyPath] = None
     key_file: Optional[AnyPath] = None
 
 
-class WebmgmtStrict(DataValidator):
+class WebmgmtStrict(SchemaNode):
     listen: ListenStrict
     tls: bool
     cert_file: Optional[AnyPath]
     key_file: Optional[AnyPath]
 
 
-class Server(DataParser):
+class Server(SchemaNode):
     hostname: Optional[str] = None
     groupid: Optional[str] = None
     nsid: Optional[str]
@@ -71,7 +71,7 @@ class Server(DataParser):
     webmgmt: Optional[Webmgmt] = None
 
 
-class ServerStrict(DataValidator):
+class ServerStrict(SchemaNode):
     hostname: str
     groupid: Optional[str]
     nsid: Optional[str]
index 8f41ef553d3924b748b5d7733e336ad5324cbb47..e0e0ae0f242b8d3fbb12624f43c136a794b10456 100644 (file)
@@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Pattern, Union
 
 from knot_resolver_manager.exceptions import DataValidationException
 from knot_resolver_manager.utils import CustomValueType
-from knot_resolver_manager.utils.data_parser_validator import DataParser, DataValidator
+from knot_resolver_manager.utils.data_parser_validator import SchemaNode
 
 logger = logging.getLogger(__name__)
 
@@ -123,7 +123,7 @@ class AnyPath(CustomValueType):
         return str(self._value)
 
 
-class Listen(DataParser):
+class Listen(SchemaNode):
     ip: Optional[str] = None
     port: Optional[int] = None
     unix_socket: Optional[AnyPath] = None
@@ -136,7 +136,7 @@ class ListenType(Enum):
     INTERFACE_AND_PORT = auto()
 
 
-class ListenStrict(DataValidator):
+class ListenStrict(SchemaNode):
     typ: ListenType
     ip: Optional[Union[ipaddress.IPv4Address, ipaddress.IPv6Address]] = None
     port: Optional[int] = None
index 19ba814ad4baf218e563e116ef228c1cbe65a4d2..18efa7de1a75f7e3fe69292869b7917e878cfd62 100644 (file)
@@ -1,7 +1,7 @@
 from typing import Any, Callable, Iterable, Optional, Type, TypeVar
 
 from .custom_types import CustomValueType
-from .data_parser_validator import DataParser, DataValidator, Format
+from .data_parser_validator import Format, SchemaNode
 
 T = TypeVar("T")
 
@@ -55,6 +55,5 @@ def contains_element_matching(cond: Callable[[T], bool], arr: Iterable[T]) -> bo
 __all__ = [
     "Format",
     "CustomValueType",
-    "DataParser",
-    "DataValidator",
+    "SchemaNode",
 ]
index 30610ed4013549dbc73977ebf7d559d88128b966..0970107e2d0260475d07de6496a540eaaab796a4 100644 (file)
@@ -3,7 +3,7 @@ import inspect
 import json
 import re
 from enum import Enum, auto
-from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
+from typing import Any, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
 
 import yaml
 from yaml.constructor import ConstructorError
@@ -17,6 +17,7 @@ from knot_resolver_manager.exceptions import (
 )
 from knot_resolver_manager.utils.custom_types import CustomValueType
 from knot_resolver_manager.utils.types import (
+    NoneType,
     get_attr_type,
     get_generic_type_argument,
     get_generic_type_arguments,
@@ -25,6 +26,7 @@ from knot_resolver_manager.utils.types import (
     is_list,
     is_literal,
     is_none_type,
+    is_optional,
     is_tuple,
     is_union,
 )
@@ -53,7 +55,7 @@ def _to_primitive(obj: Any) -> Any:
         return obj.serialize()
 
     # nested DataParser class instances
-    elif isinstance(obj, DataParser):
+    elif isinstance(obj, SchemaNode):
         return obj.to_dict()
 
     # otherwise just return, what we were given
@@ -185,21 +187,13 @@ def _validated_object_type(
             # no validation performed, the implementation does it in the constuctor
             return cls(obj, object_path=object_path)
 
-    # nested DataParser subclasses
-    elif inspect.isclass(cls) and issubclass(cls, DataParser):
+    # nested SchemaNode subclasses
+    elif inspect.isclass(cls) and issubclass(cls, SchemaNode):
         # we should return DataParser, we expect to be given a dict,
         # because we can construct a DataParser from it
-        if isinstance(obj, dict):
+        if isinstance(obj, (dict, SchemaNode)):
             return cls(obj, object_path=object_path)  # type: ignore
-        raise DataParsingException(f"Expected '{dict}' object, found '{type(obj)}'", object_path)
-
-    # nested DataValidator subclasses
-    elif inspect.isclass(cls) and issubclass(cls, DataValidator):
-        # we should return DataValidator, we expect to be given a DataParser,
-        # because we can construct a DataValidator from it
-        if isinstance(obj, DataParser):
-            return cls(obj, object_path=object_path)
-        raise DataParsingException(f"Expected instance of '{DataParser}' class, found '{type(obj)}'", object_path)
+        raise DataParsingException(f"Expected 'dict' or 'SchemaNode' object, found '{type(obj)}'", object_path)
 
     # if the object matches, just pass it through
     elif inspect.isclass(cls) and isinstance(obj, cls):
@@ -287,45 +281,97 @@ class Format(Enum):
         return formats[mime_type]
 
 
-_T = TypeVar("_T", bound="DataParser")
+_T = TypeVar("_T", bound="SchemaNode")
 
 
 _SUBTREE_MUTATION_PATH_PATTERN = re.compile(r"^(/[^/]+)*/?$")
 
 
-class DataParser:
-    def __init__(self, obj: Optional[Dict[Any, Any]] = None, object_path: str = "/"):
+TSource = Union[NoneType, Dict[Any, Any], "SchemaNode"]
+
+
+class SchemaNode:
+    def __init__(self, source: TSource = None, object_path: str = "/"):
         cls = self.__class__
         annot = cls.__dict__.get("__annotations__", {})
 
-        used_keys: List[str] = []
+        used_keys: Set[str] = set()
         for name, python_type in annot.items():
             if is_internal_field(name):
                 continue
 
-            val = None
-            dash_name = name.replace("_", "-")
-            if obj and dash_name in obj:
-                val = obj[dash_name]
-                used_keys.append(dash_name)
+            # convert naming (used when converting from json/yaml)
+            source_name = name.replace("_", "-") if isinstance(source, dict) else name
+
+            # populate field
+            if not source:
+                val = None
+            # we have a way how to create the value
+            elif hasattr(self, f"_{name}"):
+                val = self._get_converted_value(name, source, object_path)
+                used_keys.add(source_name)  # the field might not exist, but that won't break anything
+            # source just contains the value
+            elif source_name in source:
+                val = source[source_name]
+                used_keys.add(source_name)
+            # there is a default value and in the source, the value is missing
+            elif getattr(self, name, ...) is not ...:
+                val = None
+            # the value is optional and there is nothing
+            elif is_optional(python_type):
+                val = None
+            # we expected a value but it was not there
+            else:
+                raise DataValidationException(f"Missing attribute '{source_name}'.", object_path)
 
             use_default = hasattr(cls, name)
             default = getattr(cls, name, ...)
             value = _validated_object_type(python_type, val, default, use_default, object_path=f"{object_path}/{name}")
             setattr(self, name, value)
 
-        # check for unused keys
-        if obj:
-            for key in obj:
-                if key not in used_keys:
-                    additional_info = ""
-                    if "_" in key:
-                        additional_info = (
-                            " The problem might be that you are using '_', but you should be using '-' instead."
-                        )
-                    raise DataParsingException(
-                        f"Attribute '{key}' was not provided with any value." + additional_info, object_path
-                    )
+        # check for unused keys in case the
+        if source and isinstance(source, dict):
+            unused = source.keys() - used_keys
+            if len(unused) > 0:
+                raise DataParsingException(
+                    f"Keys {unused} in your configuration object are not part of the configuration schema."
+                    " Are you using '-' instead of '_'?",
+                    object_path,
+                )
+
+        # validate the constructed value
+        self._validate()
+
+    def _get_converted_value(self, key: str, source: TSource, object_path: str) -> Any:
+        try:
+            return getattr(self, f"_{key}")(source)
+        except (ValueError, ValidationException) as e:
+            if len(e.args) > 0 and isinstance(e.args[0], str):
+                msg = e.args[0]
+            else:
+                msg = "Failed to validate value type"
+            raise DataValidationException(msg, object_path) from e
+
+    def __getitem__(self, key: str) -> Any:
+        if not hasattr(self, key):
+            raise RuntimeError(f"Object '{self}' of type '{type(self)}' does not have field named '{key}'")
+        return getattr(self, key)
+
+    def __contains__(self, item: Any) -> bool:
+        return hasattr(self, item)
+
+    def validate(self) -> None:
+        for field_name in dir(self):
+            if is_internal_field(field_name):
+                continue
+
+            field = getattr(self, field_name)
+            if isinstance(field, SchemaNode):
+                field.validate()
+        self._validate()
+
+    def _validate(self) -> None:
+        pass
 
     @classmethod
     def parse_from(cls: Type[_T], fmt: Format, text: str):
@@ -410,47 +456,3 @@ class DataParser:
         setattr(parent, last_name, parsed_value)
 
         return to_mutate
-
-
-class DataValidator:
-    def __init__(self, obj: DataParser, object_path: str = ""):
-        cls = self.__class__
-        anot = cls.__dict__.get("__annotations__", {})
-
-        for attr_name, attr_type in anot.items():
-            if is_internal_field(attr_name):
-                continue
-
-            # use transformation function if available
-            if hasattr(self, f"_{attr_name}"):
-                try:
-                    value = getattr(self, f"_{attr_name}")(obj)
-                except (ValueError, ValidationException) as e:
-                    if len(e.args) > 0 and isinstance(e.args[0], str):
-                        msg = e.args[0]
-                    else:
-                        msg = "Failed to validate value type"
-                    raise DataValidationException(msg, object_path) from e
-            elif hasattr(obj, attr_name):
-                value = getattr(obj, attr_name)
-            else:
-                raise DataValidationException(
-                    f"DataParser object {obj} is missing '{attr_name}' attribute.", object_path
-                )
-
-            setattr(self, attr_name, _validated_object_type(attr_type, value))
-
-        self._validate()
-
-    def validate(self) -> None:
-        for field_name in dir(self):
-            if is_internal_field(field_name):
-                continue
-
-            field = getattr(self, field_name)
-            if isinstance(field, DataValidator):
-                field.validate()
-        self._validate()
-
-    def _validate(self) -> None:
-        pass
index 7cc086c000a796ca4b5fc8057978307118735aef..e2af9d25f081394ddfab20f6e62c2b94f667dd62 100644 (file)
@@ -13,7 +13,7 @@ from knot_resolver_manager.datamodel.types import (
     TimeUnit,
 )
 from knot_resolver_manager.exceptions import KresdManagerException
-from knot_resolver_manager.utils import DataParser, DataValidator
+from knot_resolver_manager.utils import SchemaNode
 
 
 def test_size_unit():
@@ -41,11 +41,11 @@ def test_time_unit():
 
 
 def test_parsing_units():
-    class TestClass(DataParser):
+    class TestClass(SchemaNode):
         size: SizeUnit
         time: TimeUnit
 
-    class TestClassStrict(DataValidator):
+    class TestClassStrict(SchemaNode):
         size: int
         time: int
 
@@ -74,7 +74,7 @@ time: 10m
 
 
 def test_anypath():
-    class Data(DataParser):
+    class Data(SchemaNode):
         p: AnyPath
 
     assert str(Data.from_yaml('p: "/tmp"').p) == "/tmp"
index 090177abb5fbe2cdc64863497b285e051e27b78e..fa5dd8a262cd4522450e2c844fdcb31623ab35c4 100644 (file)
@@ -4,16 +4,16 @@ from pytest import raises
 from typing_extensions import Literal
 
 from knot_resolver_manager.exceptions import DataParsingException
-from knot_resolver_manager.utils import DataParser, DataValidator, Format
+from knot_resolver_manager.utils import Format, SchemaNode
 
 
 def test_primitive():
-    class TestClass(DataParser):
+    class TestClass(SchemaNode):
         i: int
         s: str
         b: bool
 
-    class TestClassStrict(DataValidator):
+    class TestClassStrict(SchemaNode):
         i: int
         s: str
         b: bool
@@ -47,14 +47,14 @@ b: false
 
 
 def test_parsing_primitive_exceptions():
-    class TestStr(DataParser):
+    class TestStr(SchemaNode):
         s: str
 
     # int and float are allowed inputs for string
     with raises(DataParsingException):
         TestStr.from_yaml("s: false")  # bool
 
-    class TestInt(DataParser):
+    class TestInt(SchemaNode):
         i: int
 
     with raises(DataParsingException):
@@ -64,7 +64,7 @@ def test_parsing_primitive_exceptions():
     with raises(DataParsingException):
         TestInt.from_yaml("i: 5.5")  # float
 
-    class TestBool(DataParser):
+    class TestBool(SchemaNode):
         b: bool
 
     with raises(DataParsingException):
@@ -76,19 +76,19 @@ def test_parsing_primitive_exceptions():
 
 
 def test_nested():
-    class Lower(DataParser):
+    class Lower(SchemaNode):
         i: int
 
-    class Upper(DataParser):
+    class Upper(SchemaNode):
         l: Lower
 
-    class LowerStrict(DataValidator):
+    class LowerStrict(SchemaNode):
         i: int
 
         def _validate(self) -> None:
             pass
 
-    class UpperStrict(DataValidator):
+    class UpperStrict(SchemaNode):
         l: LowerStrict
 
         def _validate(self) -> None:
@@ -113,13 +113,13 @@ l:
 
 
 def test_simple_compount_types():
-    class TestClass(DataParser):
+    class TestClass(SchemaNode):
         l: List[int]
         d: Dict[str, str]
         t: Tuple[str, int]
         o: Optional[int]
 
-    class TestClassStrict(DataValidator):
+    class TestClassStrict(SchemaNode):
         l: List[int]
         d: Dict[str, str]
         t: Tuple[str, int]
@@ -166,10 +166,10 @@ t:
 
 
 def test_nested_compound_types():
-    class TestClass(DataParser):
+    class TestClass(SchemaNode):
         o: Optional[Dict[str, str]]
 
-    class TestClassStrict(DataValidator):
+    class TestClassStrict(SchemaNode):
         o: Optional[Dict[str, str]]
 
         def _validate(self) -> None:
@@ -194,11 +194,11 @@ o:
 
 
 def test_nested_compount_types2():
-    class TestClass(DataParser):
+    class TestClass(SchemaNode):
         i: int
         o: Optional[Dict[str, str]]
 
-    class TestClassStrict(DataValidator):
+    class TestClassStrict(SchemaNode):
         i: int
         o: Optional[Dict[str, str]]
 
@@ -224,21 +224,21 @@ def test_nested_compount_types2():
 
 
 def test_partial_mutations():
-    class Inner(DataParser):
+    class Inner(SchemaNode):
         size: int = 5
 
-    class ConfData(DataParser):
+    class ConfData(SchemaNode):
         workers: Union[Literal["auto"], int] = 1
         lua_config: Optional[str] = None
         inner: Inner = Inner()
 
-    class InnerStrict(DataValidator):
+    class InnerStrict(SchemaNode):
         size: int
 
         def _validate(self) -> None:
             pass
 
-    class ConfDataStrict(DataValidator):
+    class ConfDataStrict(SchemaNode):
         workers: int
         lua_config: Optional[str]
         inner: InnerStrict