From: Vasek Sraier Date: Fri, 19 Mar 2021 21:13:33 +0000 (+0100) Subject: manager: dataclasses strictyaml schema support for Optional type X-Git-Tag: v6.0.0a1~210 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7a3b2badb14730d27f703bfe09ed0c4334ffdeca;p=thirdparty%2Fknot-resolver.git manager: dataclasses strictyaml schema support for Optional type --- diff --git a/manager/knot_resolver_manager/utils/__init__.py b/manager/knot_resolver_manager/utils/__init__.py index df31c5824..fde133175 100644 --- a/manager/knot_resolver_manager/utils/__init__.py +++ b/manager/knot_resolver_manager/utils/__init__.py @@ -1,4 +1,4 @@ -from .dataclasses_yaml import dataclasses_strictyaml_schema +from .dataclasses_yaml import dataclass_strictyaml_schema -__all__ = ["dataclasses_strictyaml_schema"] +__all__ = ["dataclass_strictyaml_schema"] diff --git a/manager/knot_resolver_manager/utils/dataclasses_yaml.py b/manager/knot_resolver_manager/utils/dataclasses_yaml.py index 4f308e886..4769ca0cb 100644 --- a/manager/knot_resolver_manager/utils/dataclasses_yaml.py +++ b/manager/knot_resolver_manager/utils/dataclasses_yaml.py @@ -1,5 +1,14 @@ -from typing import List, Dict, Tuple +from typing import List, Dict, Tuple, Union from strictyaml import Map, Str, EmptyDict, Int, Float, Seq, MapPattern, FixedSeq +import strictyaml + + +class _DummyType: + pass + + +NoneType = type(None) + _TYPE_MAP = { int: Int, @@ -8,9 +17,10 @@ _TYPE_MAP = { List: Seq, Dict: MapPattern, Tuple: FixedSeq, + Union: _DummyType, } -_FIELD_NAME = "STRICTYAML_SCHEMA" +_SCHEMA_FIELD_NAME = "STRICTYAML_SCHEMA" class StrictYAMLSchemaGenerationError(Exception): @@ -18,9 +28,11 @@ class StrictYAMLSchemaGenerationError(Exception): def _get_strictyaml_type(python_type): - if hasattr(python_type, _FIELD_NAME): - return getattr(python_type, _FIELD_NAME) + # another already processed class + if hasattr(python_type, _SCHEMA_FIELD_NAME): + return getattr(python_type, _SCHEMA_FIELD_NAME) + # compount types like List elif ( hasattr(python_type, "__origin__") and hasattr(python_type, "__args__") @@ -29,24 +41,33 @@ def _get_strictyaml_type(python_type): origin = getattr(python_type, "__origin__") args = getattr(python_type, "__args__") + # special case for Optional[T] + if origin == Union and len(args) == 2 and args[1] == NoneType: + return strictyaml.Optional(_get_strictyaml_type(args[0])) + type_constructor = _TYPE_MAP[origin] type_arguments = [_get_strictyaml_type(a) for a in args] print(type_constructor, type_arguments) + + # special case for Tuple if origin == Tuple: return type_constructor(type_arguments) - else: - return type_constructor(*type_arguments) + # default behaviour + return type_constructor(*type_arguments) + + # error handlers for non existent primitive types elif python_type not in _TYPE_MAP: raise StrictYAMLSchemaGenerationError( f"Type {python_type} is not supported for YAML schema generation" ) + # remaining primitive and untyped types else: return _TYPE_MAP[python_type]() -def dataclasses_strictyaml_schema(cls): +def dataclass_strictyaml_schema(cls): anot = cls.__dict__.get("__annotations__", {}) if len(anot) == 0: @@ -57,6 +78,6 @@ def dataclasses_strictyaml_schema(cls): fields[name] = _get_strictyaml_type(python_type) schema = Map(fields) - setattr(cls, _FIELD_NAME, schema) + setattr(cls, _SCHEMA_FIELD_NAME, schema) return cls diff --git a/manager/tests/utils/test_dataclasses_yaml.py b/manager/tests/utils/test_dataclasses_yaml.py index cbad3c203..fff6213c8 100644 --- a/manager/tests/utils/test_dataclasses_yaml.py +++ b/manager/tests/utils/test_dataclasses_yaml.py @@ -1,6 +1,7 @@ -from knot_resolver_manager.utils import dataclasses_strictyaml_schema -from typing import List, Dict, Tuple +from knot_resolver_manager.utils import dataclass_strictyaml_schema +from typing import List, Dict, Optional, Tuple from strictyaml import Map, Str, EmptyDict, Int, Float, Seq, MapPattern, FixedSeq +import strictyaml import pytest @@ -12,7 +13,7 @@ def _schema_eq(schema1, schema2) -> bool: def test_empty_class(): - @dataclasses_strictyaml_schema + @dataclass_strictyaml_schema class TestClass: pass @@ -20,7 +21,7 @@ def test_empty_class(): def test_int_field(): - @dataclasses_strictyaml_schema + @dataclass_strictyaml_schema class TestClass: field: int @@ -28,7 +29,7 @@ def test_int_field(): def test_string_field(): - @dataclasses_strictyaml_schema + @dataclass_strictyaml_schema class TestClass: field: str @@ -36,7 +37,7 @@ def test_string_field(): def test_float_field(): - @dataclasses_strictyaml_schema + @dataclass_strictyaml_schema class TestClass: field: float @@ -44,7 +45,7 @@ def test_float_field(): def test_multiple_fields(): - @dataclasses_strictyaml_schema + @dataclass_strictyaml_schema class TestClass: field1: str field2: int @@ -57,7 +58,7 @@ def test_multiple_fields(): def test_list_field(): - @dataclasses_strictyaml_schema + @dataclass_strictyaml_schema class TestClass: field: List[str] @@ -65,7 +66,7 @@ def test_list_field(): def test_dict_field(): - @dataclasses_strictyaml_schema + @dataclass_strictyaml_schema class TestClass: field: Dict[str, int] @@ -74,8 +75,18 @@ def test_dict_field(): ) +def test_optional_field(): + @dataclass_strictyaml_schema + class TestClass: + field: Optional[int] + + assert _schema_eq( + TestClass.STRICTYAML_SCHEMA, Map({"field": strictyaml.Optional(Int())}) + ) + + def test_nested_dict_list(): - @dataclasses_strictyaml_schema + @dataclass_strictyaml_schema class TestClass: field: Dict[str, List[int]] @@ -90,7 +101,7 @@ def test_nested_dict_key_list(): List can't be a dict key, so this should fail """ - @dataclasses_strictyaml_schema + @dataclass_strictyaml_schema class TestClass: field: Dict[List[int], List[int]] @@ -100,7 +111,7 @@ def test_nested_dict_key_list(): def test_nested_list(): - @dataclasses_strictyaml_schema + @dataclass_strictyaml_schema class TestClass: field: List[List[List[List[int]]]] @@ -110,7 +121,7 @@ def test_nested_list(): def test_tuple_field(): - @dataclasses_strictyaml_schema + @dataclass_strictyaml_schema class TestClass: field: Tuple[str, int] @@ -120,7 +131,7 @@ def test_tuple_field(): def test_nested_tuple(): - @dataclasses_strictyaml_schema + @dataclass_strictyaml_schema class TestClass: field: Tuple[str, Dict[str, int], List[List[int]]] @@ -131,11 +142,11 @@ def test_nested_tuple(): def test_chained_classes(): - @dataclasses_strictyaml_schema + @dataclass_strictyaml_schema class TestClass: field: int - @dataclasses_strictyaml_schema + @dataclass_strictyaml_schema class CompoundClass: c: TestClass @@ -148,7 +159,7 @@ def test_combined_with_dataclass(): from dataclasses import dataclass @dataclass - @dataclasses_strictyaml_schema + @dataclass_strictyaml_schema class TestClass: field: int @@ -158,7 +169,7 @@ def test_combined_with_dataclass(): def test_combined_with_dataclass2(): from dataclasses import dataclass - @dataclasses_strictyaml_schema + @dataclass_strictyaml_schema @dataclass class TestClass: field: int