From 27fcbb55cd2fe34f444271847392bacf9feaf972 Mon Sep 17 00:00:00 2001 From: Vasek Sraier Date: Sat, 20 Mar 2021 12:38:34 +0100 Subject: [PATCH] dataclasses_yaml: full dataclasses parser --- .../knot_resolver_manager/utils/__init__.py | 4 +- .../utils/dataclasses_yaml.py | 119 +++++++++++++++++- manager/tests/utils/test_dataclasses_yaml.py | 99 ++++++++++++++- 3 files changed, 217 insertions(+), 5 deletions(-) diff --git a/manager/knot_resolver_manager/utils/__init__.py b/manager/knot_resolver_manager/utils/__init__.py index fde133175..75315953c 100644 --- a/manager/knot_resolver_manager/utils/__init__.py +++ b/manager/knot_resolver_manager/utils/__init__.py @@ -1,4 +1,4 @@ -from .dataclasses_yaml import dataclass_strictyaml_schema +from .dataclasses_yaml import dataclass_strictyaml_schema, dataclass_strictyaml -__all__ = ["dataclass_strictyaml_schema"] +__all__ = ["dataclass_strictyaml_schema", "dataclass_strictyaml"] diff --git a/manager/knot_resolver_manager/utils/dataclasses_yaml.py b/manager/knot_resolver_manager/utils/dataclasses_yaml.py index 4769ca0cb..3a837fbf2 100644 --- a/manager/knot_resolver_manager/utils/dataclasses_yaml.py +++ b/manager/knot_resolver_manager/utils/dataclasses_yaml.py @@ -1,5 +1,16 @@ from typing import List, Dict, Tuple, Union -from strictyaml import Map, Str, EmptyDict, Int, Float, Seq, MapPattern, FixedSeq +from strictyaml import ( + Map, + Str, + EmptyDict, + Int, + Float, + Seq, + MapPattern, + FixedSeq, + load, + YAML, +) import strictyaml @@ -27,6 +38,10 @@ class StrictYAMLSchemaGenerationError(Exception): pass +class StrictYAMLValueMappingError(Exception): + pass + + def _get_strictyaml_type(python_type): # another already processed class if hasattr(python_type, _SCHEMA_FIELD_NAME): @@ -43,7 +58,8 @@ def _get_strictyaml_type(python_type): # special case for Optional[T] if origin == Union and len(args) == 2 and args[1] == NoneType: - return strictyaml.Optional(_get_strictyaml_type(args[0])) + # for some weird reason, the optional wrapper is on the key, not on the value type + return _get_strictyaml_type(args[0]) type_constructor = _TYPE_MAP[origin] type_arguments = [_get_strictyaml_type(a) for a in args] @@ -75,9 +91,108 @@ def dataclass_strictyaml_schema(cls): else: fields = {} for name, python_type in anot.items(): + # special case for Optional[T], because it's weird + # https://hitchdev.com/strictyaml/using/alpha/compound/optional-keys-with-defaults/ + if ( + hasattr(python_type, "__origin__") + and hasattr(python_type, "__args__") + and getattr(python_type, "__origin__") == Union + and len(getattr(python_type, "__args__")) == 2 + and getattr(python_type, "__args__")[1] == NoneType + ): + name = strictyaml.Optional(name) fields[name] = _get_strictyaml_type(python_type) schema = Map(fields) setattr(cls, _SCHEMA_FIELD_NAME, schema) return cls + + +def _yamlobj_to_dataclass(cls, obj: YAML): + # primitive values recursion helper + if cls in (str, int, float): + return cls(obj) + + # assert that no other weird class gets here + assert hasattr(cls, _SCHEMA_FIELD_NAME) + + anot = cls.__dict__.get("__annotations__", {}) + + kwargs = {} + for name, python_type in anot.items(): + # another dataclass + if hasattr(python_type, _SCHEMA_FIELD_NAME): + kwargs[name] = _yamlobj_to_dataclass(python_type, obj[name]) + + # string + elif python_type == str: + kwargs[name] = obj[name].text + + # numbers + elif python_type in (int, float): + kwargs[name] = obj[name] + + # compound generic types + elif ( + hasattr(python_type, "__origin__") + and hasattr(python_type, "__args__") + and getattr(python_type, "__origin__") in (Union, Dict, List, Tuple) + ): + origin = getattr(python_type, "__origin__") + args = getattr(python_type, "__args__") + + # Optional[T] + if origin == Union and len(args) == 2 and args[1] == NoneType: + kwargs[name] = obj[name] if name in obj else None + + # Dict[K, V] + elif origin == Dict and len(args) == 2: + kwargs[name] = { + _yamlobj_to_dataclass(args[0], key): _yamlobj_to_dataclass( + args[1], val + ) + for key, val in obj[name].items() + } + + # List[T] + elif origin == List and len(args) == 1: + kwargs[name] = [ + _yamlobj_to_dataclass(args[0], val) + for val in obj[name] + if print(args[0], val) is None + ] + + # Tuple + elif origin == Tuple: + kwargs[name] = tuple( + _yamlobj_to_dataclass(typ, val) for typ, val in zip(args, obj[name]) + ) + + # unsupported compound type + else: + raise StrictYAMLValueMappingError( + f"Failed to map compound map field {name} <{python_type}> into {cls}" + ) + + # unsupported type + else: + raise StrictYAMLValueMappingError( + f"Failed to map field {name} <{python_type}> into {cls}" + ) + + return cls(**kwargs) + + +def dataclass_strictyaml(cls): + if not hasattr(cls, _SCHEMA_FIELD_NAME): + cls = dataclass_strictyaml_schema(cls) + + def from_yaml(text: str) -> cls: + schema = getattr(cls, _SCHEMA_FIELD_NAME) + + yamlobj = load(text, schema) + return _yamlobj_to_dataclass(cls, yamlobj) + + setattr(cls, "from_yaml", from_yaml) + return cls diff --git a/manager/tests/utils/test_dataclasses_yaml.py b/manager/tests/utils/test_dataclasses_yaml.py index fff6213c8..5e9a566bc 100644 --- a/manager/tests/utils/test_dataclasses_yaml.py +++ b/manager/tests/utils/test_dataclasses_yaml.py @@ -1,8 +1,10 @@ +from knot_resolver_manager.utils.dataclasses_yaml import dataclass_strictyaml from knot_resolver_manager.utils import dataclass_strictyaml_schema from typing import List, Dict, Optional, Tuple from strictyaml import Map, Str, EmptyDict, Int, Float, Seq, MapPattern, FixedSeq import strictyaml import pytest +from dataclasses import dataclass def _schema_eq(schema1, schema2) -> bool: @@ -81,7 +83,7 @@ def test_optional_field(): field: Optional[int] assert _schema_eq( - TestClass.STRICTYAML_SCHEMA, Map({"field": strictyaml.Optional(Int())}) + TestClass.STRICTYAML_SCHEMA, Map({strictyaml.Optional("field"): Int()}) ) @@ -175,3 +177,98 @@ def test_combined_with_dataclass2(): field: int assert _schema_eq(TestClass.STRICTYAML_SCHEMA, Map({"field": Int()})) + + +def test_parsing_primitive(): + @dataclass + @dataclass_strictyaml + class TestClass: + i: int + s: str + f: float + + yaml = """i: 5 +s: "test" +f: 3.14""" + + obj = TestClass.from_yaml(yaml) + + assert obj.i == 5 + assert obj.s == "test" + assert obj.f == 3.14 + + +def test_parsing_nested(): + @dataclass + @dataclass_strictyaml + class Lower: + i: int + + @dataclass + @dataclass_strictyaml + class Upper: + l: Lower + + yaml = """l: + i: 5""" + + obj = Upper.from_yaml(yaml) + assert obj.l.i == 5 + + +def test_simple_compount_types(): + @dataclass + @dataclass_strictyaml + class TestClass: + l: List[int] + d: Dict[str, str] + t: Tuple[str, int] + o: Optional[int] + + yaml = """l: + - 1 + - 2 + - 3 + - 4 + - 5 +d: + something: else + w: all +t: + - test + - 5""" + + obj = TestClass.from_yaml(yaml) + + assert obj.l == [1, 2, 3, 4, 5] + assert obj.d == {"something": "else", "w": "all"} + assert obj.t == ("test", 5) + assert obj.o is None + + +def test_nested_compount_types(): + @dataclass + @dataclass_strictyaml + class TestClass: + o: Optional[Dict[str, str]] + + yaml = """o: + key: val""" + + obj = TestClass.from_yaml(yaml) + + assert obj.o == {"key": "val"} + + +def test_nested_compount_types2(): + @dataclass + @dataclass_strictyaml + class TestClass: + i: int + o: Optional[Dict[str, str]] + + yaml = "i: 5" + + obj = TestClass.from_yaml(yaml) + + assert obj.o is None -- 2.47.3