]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
dataclasses_yaml: fixed compound types handling (fixes integration tests broken in...
authorVasek Sraier <git@vakabus.cz>
Mon, 22 Mar 2021 19:54:15 +0000 (20:54 +0100)
committerAleš Mrázek <ales.mrazek@nic.cz>
Fri, 8 Apr 2022 14:17:51 +0000 (16:17 +0200)
manager/knot_resolver_manager/utils/dataclasses_yaml.py
manager/pyproject.toml
manager/tests/utils/test_dataclasses_yaml_inheritance.py

index 3d8ce65e07d8ae2762c2119e50f963aa79c97cbd..115933e7b1de456349269776dcb6cb00baa4abac 100644 (file)
@@ -1,4 +1,4 @@
-from typing import List, Dict, Tuple, Type, TypeVar, Union
+from typing import Any, List, Dict, Tuple, Type, TypeVar, Union
 from strictyaml import (
     Map,
     Str,
@@ -109,76 +109,56 @@ def dataclass_strictyaml_schema(cls):
     return cls
 
 
-def _yamlobj_to_dataclass(cls, obj: YAML):
-    # primitive values recursion helper
-    if cls in (str, int, float):
+def _yamlobj_to_dataclass(cls, obj: YAML) -> Any:
+    # native values recursion helper
+    if cls in (int, float):
         return cls(obj)
+    if cls == str:
+        return str(obj.text)
+    # compount types
+    if (
+        hasattr(cls, "__origin__")
+        and hasattr(cls, "__args__")
+        and getattr(cls, "__origin__") in (Union, Dict, List, Tuple)
+    ):
+        origin = getattr(cls, "__origin__")
+        args = getattr(cls, "__args__")
 
-    # assert that no other weird class gets here
-    assert hasattr(cls, _SCHEMA_FIELD_NAME)
+        # Optional[T]
+        if origin == Union and len(args) == 2 and args[1] == NoneType:
+            return _yamlobj_to_dataclass(args[0], obj) if obj is not None else None
 
-    anot = cls.__dict__.get("__annotations__", {})
+        # Dict[K, V]
+        elif origin == Dict and len(args) == 2:
+            return {
+                _yamlobj_to_dataclass(args[0], key): _yamlobj_to_dataclass(args[1], val)
+                for key, val in obj.items()
+            }
+
+        # List[T]
+        elif origin == List and len(args) == 1:
+            return [_yamlobj_to_dataclass(args[0], val) for val in obj]
+
+        # Tuple
+        elif origin == Tuple:
+            return tuple(_yamlobj_to_dataclass(typ, val) for typ, val in zip(args, obj))
+
+    # ^ that's full list of native types
+    # the remaining code handles cases when cls is a dataclasses
 
+    # assert that no weird class without schema gets here
+    if not hasattr(cls, _SCHEMA_FIELD_NAME):
+        raise Exception(
+            f"{str(cls)} does not have a schema field and is not primitive - don't know how to parse. "
+            + "Did you forget to add @dataclass_strictyaml_schema to nested dataclass?"
+        )
+
+    anot = cls.__dict__.get("__annotations__", {})
     kwargs = {}
     for name, python_type in anot.items():
-        # another dataclass
-        if hasattr(python_type, _SCHEMA_FIELD_NAME):
-            kwargs[name] = _yamlobj_to_dataclass(python_type, obj[name])
-
-        # string
-        elif python_type == str:
-            kwargs[name] = obj[name].text
-
-        # numbers
-        elif python_type in (int, float):
-            kwargs[name] = obj[name]
-
-        # compound generic types
-        elif (
-            hasattr(python_type, "__origin__")
-            and hasattr(python_type, "__args__")
-            and getattr(python_type, "__origin__") in (Union, Dict, List, Tuple)
-        ):
-            origin = getattr(python_type, "__origin__")
-            args = getattr(python_type, "__args__")
-
-            # Optional[T]
-            if origin == Union and len(args) == 2 and args[1] == NoneType:
-                kwargs[name] = obj[name] if name in obj else None
-
-            # Dict[K, V]
-            elif origin == Dict and len(args) == 2:
-                kwargs[name] = {
-                    _yamlobj_to_dataclass(args[0], key): _yamlobj_to_dataclass(
-                        args[1], val
-                    )
-                    for key, val in obj[name].items()
-                }
-
-            # List[T]
-            elif origin == List and len(args) == 1:
-                kwargs[name] = [
-                    _yamlobj_to_dataclass(args[0], val) for val in obj[name]
-                ]
-
-            # Tuple
-            elif origin == Tuple:
-                kwargs[name] = tuple(
-                    _yamlobj_to_dataclass(typ, val) for typ, val in zip(args, obj[name])
-                )
-
-            # unsupported compound type
-            else:
-                raise StrictYAMLValueMappingError(
-                    f"Failed to map compound map field {name} <{python_type}> into {cls}"
-                )
-
-        # unsupported type
-        else:
-            raise StrictYAMLValueMappingError(
-                f"Failed to map field {name} <{python_type}> into {cls}"
-            )
-
+        kwargs[name] = _yamlobj_to_dataclass(
+            python_type, obj[name] if name in obj else None
+        )
     return cls(**kwargs)
 
 
index dd0dbb033fcfc5d89bffe17732c746015ff68e87..216051061ae4273a73903d30a75ed48f2eb7146e 100644 (file)
@@ -30,7 +30,7 @@ click = "^7.1.2"
 run = { cmd = "python -m knot_resolver_manager", help = "Run the manager" }
 test = { cmd = "pytest --cov=knot_resolver_manager --show-capture=all tests/", help = "Run tests" }
 check = { cmd = "scripts/codecheck", help = "Run static code analysis" }
-format = { cmd = "poetry run black knot_resolver_manager", help = "Run 'Black' code formater" }
+format = { cmd = "poetry run black knot_resolver_manager/ tests/", help = "Run 'Black' code formater" }
 fixdeps = { shell = "poetry install; yarn install", help = "Install/update dependencies according to configuration files"}
 clean = """
   rm -rf .coverage
@@ -78,7 +78,8 @@ disable= [
     "too-few-public-methods",
     "unused-import",  # checked by flake8,
     "bad-continuation", # conflicts with black
-    "consider-using-in", # pyright can't see through in expressions
+    "consider-using-in", # pyright can't see through in expressions,
+    "too-many-return-statements", # would prevent us from using recursive tree traversals
 ]
 
 [tool.pylint.SIMILARITIES]
index 67a247f3cb2de3c5e7e9d48422d4907715d9f8a0..3951a746165c3e3a69bede6ccfebec1291b080f1 100644 (file)
@@ -99,3 +99,30 @@ def test_nested_compount_types2():
     obj = TestClass.from_yaml(yaml)
 
     assert obj.o is None
+
+
+def test_real_failing_dummy_confdata():
+    @dataclass
+    class ConfData(StrictyamlParser):
+        num_workers: int = 1
+        lua_config: Optional[str] = None
+
+        async def validate(self) -> bool:
+            if self.num_workers < 0:
+                raise Exception("Number of workers must be non-negative")
+
+            return True
+
+    # prepare the payload
+    lua_config = "dummy"
+    config = f"""
+num_workers: 4
+lua_config: |
+  { lua_config }"""
+
+    data = ConfData.from_yaml(config)
+
+    assert type(data.num_workers) == int
+    assert data.num_workers == 4
+    assert type(data.lua_config) == str
+    assert data.lua_config == "dummy"