]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
dataclasses_yaml: added an option to use inheritance instead of a decorator
authorVasek Sraier <git@vakabus.cz>
Mon, 22 Mar 2021 17:50:16 +0000 (18:50 +0100)
committerAleš Mrázek <ales.mrazek@nic.cz>
Fri, 8 Apr 2022 14:17:51 +0000 (16:17 +0200)
manager/knot_resolver_manager/configuration.py
manager/knot_resolver_manager/datamodel.py
manager/knot_resolver_manager/kresd_manager.py
manager/knot_resolver_manager/utils/__init__.py
manager/knot_resolver_manager/utils/dataclasses_yaml.py
manager/tests/utils/test_dataclasses_yaml_inheritance.py [new file with mode: 0644]

index 995bc30e89a4cd6ffe7dc1d670d1e4186726cdb8..563fc3942c18b7954e8e9ed26310963af4e5d2ce 100644 (file)
@@ -1,45 +1,11 @@
-from strictyaml import Map, Str, Int
-from strictyaml.parser import load
-from strictyaml.representation import YAML
-
 from .datamodel import ConfData
 
 
-_CONFIG_SCHEMA = Map({"lua_config": Str(), "num_workers": Int()})
-
-
-def _get_config_schema():
-    """
-    Returns a schema defined using the strictyaml library, that the manager
-    should accept at it's input.
-
-    If this function does something, that can be cached, it should cache it by
-    itself. For example, loading the schema from a file is OK, the loaded
-    parsed schema object should then however be cached in memory. The function
-    is on purpose non-async and it's expected to return very fast.
-    """
-    return _CONFIG_SCHEMA
-
-
 class ConfigValidationException(Exception):
     pass
 
 
-async def _validate_config(config):
-    """
-    Perform runtime value validation of the provided configuration object which
-    is guaranteed to follow the configuration schema returned by the
-    `get_config_schema` function.
-
-    Throws a ConfigValidationException in case any errors are found. The error
-    message should be in the error message of the exception.
-    """
-
-    if config["num_workers"] < 0:
-        raise ConfigValidationException("Number of workers must be non-negative")
-
-
 async def parse(yaml: str) -> ConfData:
     conf = ConfData.from_yaml(yaml)
     await conf.validate()
-    return conf
\ No newline at end of file
+    return conf
index 6012b8db639e31b4789d4e137f63c71fdab52f9a..32cc06f0d11ade91cd3a13af088ac8cb443297da 100644 (file)
@@ -1,6 +1,7 @@
 from dataclasses import dataclass
+from typing import Optional
 
-from .utils import dataclass_strictyaml
+from .utils import StrictyamlParser
 
 
 class ConfDataValidationException(Exception):
@@ -8,11 +9,12 @@ class ConfDataValidationException(Exception):
 
 
 @dataclass
-@dataclass_strictyaml
-class ConfData:
+class ConfData(StrictyamlParser):
     num_workers: int = 1
-    lua_config: str = None
+    lua_config: Optional[str] = None
 
     async def validate(self) -> bool:
         if self.num_workers < 0:
             raise ConfDataValidationException("Number of workers must be non-negative")
+
+        return True
index 6ac26430fbb5c1e094d6e267bcb351f563e60567..d4c5f9b3a4fd3ea97d57b3cdcd24be7973649aa2 100644 (file)
@@ -1,7 +1,6 @@
 import asyncio
 from uuid import uuid4
 from typing import List, Optional
-from strictyaml.representation import YAML
 
 from . import compat
 from . import systemd
@@ -79,8 +78,9 @@ class KresdManager:
 
     async def _write_config(self, config: ConfData):
         # FIXME: this code is blocking!!!
-        with open("/etc/knot-resolver/kresd.conf", "w") as f:
-            f.write(config.lua_config)
+        if config.lua_config is not None:
+            with open("/etc/knot-resolver/kresd.conf", "w") as f:
+                f.write(config.lua_config)
 
     async def apply_config(self, config: ConfData):
         async with self._children_lock:
index 75315953c4aef9c999a39925c27dc277b272a199..cf8fa8974a17a15ea5314073d10ef32a55a80558 100644 (file)
@@ -1,4 +1,8 @@
-from .dataclasses_yaml import dataclass_strictyaml_schema, dataclass_strictyaml
+from .dataclasses_yaml import (
+    dataclass_strictyaml_schema,
+    dataclass_strictyaml,
+    StrictyamlParser,
+)
 
 
-__all__ = ["dataclass_strictyaml_schema", "dataclass_strictyaml"]
+__all__ = ["dataclass_strictyaml_schema", "dataclass_strictyaml", "StrictyamlParser"]
index 939bafd572a922e1518dc53dc6d630eba4319347..3d8ce65e07d8ae2762c2119e50f963aa79c97cbd 100644 (file)
@@ -1,4 +1,4 @@
-from typing import List, Dict, Tuple, Union
+from typing import List, Dict, Tuple, Type, TypeVar, Union
 from strictyaml import (
     Map,
     Str,
@@ -158,8 +158,7 @@ def _yamlobj_to_dataclass(cls, obj: YAML):
             # List[T]
             elif origin == List and len(args) == 1:
                 kwargs[name] = [
-                    _yamlobj_to_dataclass(args[0], val)
-                    for val in obj[name]
+                    _yamlobj_to_dataclass(args[0], val) for val in obj[name]
                 ]
 
             # Tuple
@@ -183,15 +182,28 @@ def _yamlobj_to_dataclass(cls, obj: YAML):
     return cls(**kwargs)
 
 
+def _from_yaml(cls, text: str):
+    schema = getattr(cls, _SCHEMA_FIELD_NAME)
+
+    yamlobj = load(text, schema)
+    return _yamlobj_to_dataclass(cls, yamlobj)
+
+
 def dataclass_strictyaml(cls):
     if not hasattr(cls, _SCHEMA_FIELD_NAME):
         cls = dataclass_strictyaml_schema(cls)
 
-    def from_yaml(text: str) -> cls:
-        schema = getattr(cls, _SCHEMA_FIELD_NAME)
+    setattr(cls, "from_yaml", classmethod(_from_yaml))
+    return cls
+
 
-        yamlobj = load(text, schema)
-        return _yamlobj_to_dataclass(cls, yamlobj)
+_T = TypeVar("_T", bound="StrictyamlParser")
 
-    setattr(cls, "from_yaml", from_yaml)
-    return cls
+
+class StrictyamlParser:
+    @classmethod
+    def from_yaml(cls: Type[_T], text: str) -> _T:
+        if not hasattr(cls, _SCHEMA_FIELD_NAME):
+            dataclass_strictyaml_schema(cls)
+
+        return _from_yaml(cls, text)
diff --git a/manager/tests/utils/test_dataclasses_yaml_inheritance.py b/manager/tests/utils/test_dataclasses_yaml_inheritance.py
new file mode 100644 (file)
index 0000000..67a247f
--- /dev/null
@@ -0,0 +1,101 @@
+from knot_resolver_manager.utils.dataclasses_yaml import (
+    StrictyamlParser,
+    dataclass_strictyaml,
+)
+from knot_resolver_manager.utils import dataclass_strictyaml_schema
+from typing import List, Dict, Optional, Tuple
+from strictyaml import Map, Str, EmptyDict, Int, Float, Seq, MapPattern, FixedSeq
+import strictyaml
+import pytest
+from dataclasses import dataclass
+
+
+def test_parsing_primitive():
+    @dataclass
+    @dataclass_strictyaml
+    class TestClass:
+        i: int
+        s: str
+        f: float
+
+    yaml = """i: 5
+s: "test"
+f: 3.14"""
+
+    obj = TestClass.from_yaml(yaml)
+
+    assert obj.i == 5
+    assert obj.s == "test"
+    assert obj.f == 3.14
+
+
+def test_parsing_nested():
+    @dataclass
+    @dataclass_strictyaml_schema
+    class Lower:
+        i: int
+
+    @dataclass
+    class Upper(StrictyamlParser):
+        l: Lower
+
+    yaml = """l:
+  i: 5"""
+
+    obj = Upper.from_yaml(yaml)
+    assert obj.l.i == 5
+
+
+def test_simple_compount_types():
+    @dataclass
+    class TestClass(StrictyamlParser):
+        l: List[int]
+        d: Dict[str, str]
+        t: Tuple[str, int]
+        o: Optional[int]
+
+    yaml = """l:
+  - 1
+  - 2
+  - 3
+  - 4
+  - 5
+d:
+  something: else
+  w: all
+t:
+  - test
+  - 5"""
+
+    obj = TestClass.from_yaml(yaml)
+
+    assert obj.l == [1, 2, 3, 4, 5]
+    assert obj.d == {"something": "else", "w": "all"}
+    assert obj.t == ("test", 5)
+    assert obj.o is None
+
+
+def test_nested_compount_types():
+    @dataclass
+    class TestClass(StrictyamlParser):
+        o: Optional[Dict[str, str]]
+
+    yaml = """o:
+  key: val"""
+
+    obj = TestClass.from_yaml(yaml)
+
+    assert obj.o == {"key": "val"}
+
+
+def test_nested_compount_types2():
+    @dataclass
+    class TestClass(StrictyamlParser):
+        i: int
+        o: Optional[Dict[str, str]]
+
+    yaml = "i: 5"
+
+    obj = TestClass.from_yaml(yaml)
+
+    assert obj.o is None