]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
dataclasses_yaml: full dataclasses parser
authorVasek Sraier <git@vakabus.cz>
Sat, 20 Mar 2021 11:38:34 +0000 (12:38 +0100)
committerAleš Mrázek <ales.mrazek@nic.cz>
Fri, 8 Apr 2022 14:17:51 +0000 (16:17 +0200)
manager/knot_resolver_manager/utils/__init__.py
manager/knot_resolver_manager/utils/dataclasses_yaml.py
manager/tests/utils/test_dataclasses_yaml.py

index fde1331752fccdfedc45523fad3969d4d8ca71a7..75315953c4aef9c999a39925c27dc277b272a199 100644 (file)
@@ -1,4 +1,4 @@
-from .dataclasses_yaml import dataclass_strictyaml_schema
+from .dataclasses_yaml import dataclass_strictyaml_schema, dataclass_strictyaml
 
 
-__all__ = ["dataclass_strictyaml_schema"]
+__all__ = ["dataclass_strictyaml_schema", "dataclass_strictyaml"]
index 4769ca0cb09a2ca4405c1eb5e9475799af602954..3a837fbf20a60955ccab732d3ab1eeba30d9a27f 100644 (file)
@@ -1,5 +1,16 @@
 from typing import List, Dict, Tuple, Union
-from strictyaml import Map, Str, EmptyDict, Int, Float, Seq, MapPattern, FixedSeq
+from strictyaml import (
+    Map,
+    Str,
+    EmptyDict,
+    Int,
+    Float,
+    Seq,
+    MapPattern,
+    FixedSeq,
+    load,
+    YAML,
+)
 import strictyaml
 
 
@@ -27,6 +38,10 @@ class StrictYAMLSchemaGenerationError(Exception):
     pass
 
 
+class StrictYAMLValueMappingError(Exception):
+    pass
+
+
 def _get_strictyaml_type(python_type):
     # another already processed class
     if hasattr(python_type, _SCHEMA_FIELD_NAME):
@@ -43,7 +58,8 @@ def _get_strictyaml_type(python_type):
 
         # special case for Optional[T]
         if origin == Union and len(args) == 2 and args[1] == NoneType:
-            return strictyaml.Optional(_get_strictyaml_type(args[0]))
+            # for some weird reason, the optional wrapper is on the key, not on the value type
+            return _get_strictyaml_type(args[0])
 
         type_constructor = _TYPE_MAP[origin]
         type_arguments = [_get_strictyaml_type(a) for a in args]
@@ -75,9 +91,108 @@ def dataclass_strictyaml_schema(cls):
     else:
         fields = {}
         for name, python_type in anot.items():
+            # special case for Optional[T], because it's weird
+            # https://hitchdev.com/strictyaml/using/alpha/compound/optional-keys-with-defaults/
+            if (
+                hasattr(python_type, "__origin__")
+                and hasattr(python_type, "__args__")
+                and getattr(python_type, "__origin__") == Union
+                and len(getattr(python_type, "__args__")) == 2
+                and getattr(python_type, "__args__")[1] == NoneType
+            ):
+                name = strictyaml.Optional(name)
             fields[name] = _get_strictyaml_type(python_type)
         schema = Map(fields)
 
     setattr(cls, _SCHEMA_FIELD_NAME, schema)
 
     return cls
+
+
+def _yamlobj_to_dataclass(cls, obj: YAML):
+    # primitive values recursion helper
+    if cls in (str, int, float):
+        return cls(obj)
+
+    # assert that no other weird class gets here
+    assert hasattr(cls, _SCHEMA_FIELD_NAME)
+
+    anot = cls.__dict__.get("__annotations__", {})
+
+    kwargs = {}
+    for name, python_type in anot.items():
+        # another dataclass
+        if hasattr(python_type, _SCHEMA_FIELD_NAME):
+            kwargs[name] = _yamlobj_to_dataclass(python_type, obj[name])
+
+        # string
+        elif python_type == str:
+            kwargs[name] = obj[name].text
+
+        # numbers
+        elif python_type in (int, float):
+            kwargs[name] = obj[name]
+
+        # compound generic types
+        elif (
+            hasattr(python_type, "__origin__")
+            and hasattr(python_type, "__args__")
+            and getattr(python_type, "__origin__") in (Union, Dict, List, Tuple)
+        ):
+            origin = getattr(python_type, "__origin__")
+            args = getattr(python_type, "__args__")
+
+            # Optional[T]
+            if origin == Union and len(args) == 2 and args[1] == NoneType:
+                kwargs[name] = obj[name] if name in obj else None
+
+            # Dict[K, V]
+            elif origin == Dict and len(args) == 2:
+                kwargs[name] = {
+                    _yamlobj_to_dataclass(args[0], key): _yamlobj_to_dataclass(
+                        args[1], val
+                    )
+                    for key, val in obj[name].items()
+                }
+
+            # List[T]
+            elif origin == List and len(args) == 1:
+                kwargs[name] = [
+                    _yamlobj_to_dataclass(args[0], val)
+                    for val in obj[name]
+                    if print(args[0], val) is None
+                ]
+
+            # Tuple
+            elif origin == Tuple:
+                kwargs[name] = tuple(
+                    _yamlobj_to_dataclass(typ, val) for typ, val in zip(args, obj[name])
+                )
+
+            # unsupported compound type
+            else:
+                raise StrictYAMLValueMappingError(
+                    f"Failed to map compound map field {name} <{python_type}> into {cls}"
+                )
+
+        # unsupported type
+        else:
+            raise StrictYAMLValueMappingError(
+                f"Failed to map field {name} <{python_type}> into {cls}"
+            )
+
+    return cls(**kwargs)
+
+
+def dataclass_strictyaml(cls):
+    if not hasattr(cls, _SCHEMA_FIELD_NAME):
+        cls = dataclass_strictyaml_schema(cls)
+
+    def from_yaml(text: str) -> cls:
+        schema = getattr(cls, _SCHEMA_FIELD_NAME)
+
+        yamlobj = load(text, schema)
+        return _yamlobj_to_dataclass(cls, yamlobj)
+
+    setattr(cls, "from_yaml", from_yaml)
+    return cls
index fff6213c8c38e2b981bac73b867f3863e5845b5c..5e9a566bca3b52cbad9360158e72e32b3e1e069a 100644 (file)
@@ -1,8 +1,10 @@
+from knot_resolver_manager.utils.dataclasses_yaml import dataclass_strictyaml
 from knot_resolver_manager.utils import dataclass_strictyaml_schema
 from typing import List, Dict, Optional, Tuple
 from strictyaml import Map, Str, EmptyDict, Int, Float, Seq, MapPattern, FixedSeq
 import strictyaml
 import pytest
+from dataclasses import dataclass
 
 
 def _schema_eq(schema1, schema2) -> bool:
@@ -81,7 +83,7 @@ def test_optional_field():
         field: Optional[int]
 
     assert _schema_eq(
-        TestClass.STRICTYAML_SCHEMA, Map({"field": strictyaml.Optional(Int())})
+        TestClass.STRICTYAML_SCHEMA, Map({strictyaml.Optional("field"): Int()})
     )
 
 
@@ -175,3 +177,98 @@ def test_combined_with_dataclass2():
         field: int
 
     assert _schema_eq(TestClass.STRICTYAML_SCHEMA, Map({"field": Int()}))
+
+
+def test_parsing_primitive():
+    @dataclass
+    @dataclass_strictyaml
+    class TestClass:
+        i: int
+        s: str
+        f: float
+
+    yaml = """i: 5
+s: "test"
+f: 3.14"""
+
+    obj = TestClass.from_yaml(yaml)
+
+    assert obj.i == 5
+    assert obj.s == "test"
+    assert obj.f == 3.14
+
+
+def test_parsing_nested():
+    @dataclass
+    @dataclass_strictyaml
+    class Lower:
+        i: int
+
+    @dataclass
+    @dataclass_strictyaml
+    class Upper:
+        l: Lower
+
+    yaml = """l:
+  i: 5"""
+
+    obj = Upper.from_yaml(yaml)
+    assert obj.l.i == 5
+
+
+def test_simple_compount_types():
+    @dataclass
+    @dataclass_strictyaml
+    class TestClass:
+        l: List[int]
+        d: Dict[str, str]
+        t: Tuple[str, int]
+        o: Optional[int]
+
+    yaml = """l:
+  - 1
+  - 2
+  - 3
+  - 4
+  - 5
+d:
+  something: else
+  w: all
+t:
+  - test
+  - 5"""
+
+    obj = TestClass.from_yaml(yaml)
+
+    assert obj.l == [1, 2, 3, 4, 5]
+    assert obj.d == {"something": "else", "w": "all"}
+    assert obj.t == ("test", 5)
+    assert obj.o is None
+
+
+def test_nested_compount_types():
+    @dataclass
+    @dataclass_strictyaml
+    class TestClass:
+        o: Optional[Dict[str, str]]
+
+    yaml = """o:
+  key: val"""
+
+    obj = TestClass.from_yaml(yaml)
+
+    assert obj.o == {"key": "val"}
+
+
+def test_nested_compount_types2():
+    @dataclass
+    @dataclass_strictyaml
+    class TestClass:
+        i: int
+        o: Optional[Dict[str, str]]
+
+    yaml = "i: 5"
+
+    obj = TestClass.from_yaml(yaml)
+
+    assert obj.o is None