]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
manager: dataclasses strictyaml schema support for Optional type
authorVasek Sraier <git@vakabus.cz>
Fri, 19 Mar 2021 21:13:33 +0000 (22:13 +0100)
committerAleš Mrázek <ales.mrazek@nic.cz>
Fri, 8 Apr 2022 14:17:51 +0000 (16:17 +0200)
manager/knot_resolver_manager/utils/__init__.py
manager/knot_resolver_manager/utils/dataclasses_yaml.py
manager/tests/utils/test_dataclasses_yaml.py

index df31c58244fb4a34f7198eae07a60f33b0250321..fde1331752fccdfedc45523fad3969d4d8ca71a7 100644 (file)
@@ -1,4 +1,4 @@
-from .dataclasses_yaml import dataclasses_strictyaml_schema
+from .dataclasses_yaml import dataclass_strictyaml_schema
 
 
-__all__ = ["dataclasses_strictyaml_schema"]
+__all__ = ["dataclass_strictyaml_schema"]
index 4f308e886ad9bdcb09473aa338a8a60667b6ff54..4769ca0cb09a2ca4405c1eb5e9475799af602954 100644 (file)
@@ -1,5 +1,14 @@
-from typing import List, Dict, Tuple
+from typing import List, Dict, Tuple, Union
 from strictyaml import Map, Str, EmptyDict, Int, Float, Seq, MapPattern, FixedSeq
+import strictyaml
+
+
+class _DummyType:
+    pass
+
+
+NoneType = type(None)
+
 
 _TYPE_MAP = {
     int: Int,
@@ -8,9 +17,10 @@ _TYPE_MAP = {
     List: Seq,
     Dict: MapPattern,
     Tuple: FixedSeq,
+    Union: _DummyType,
 }
 
-_FIELD_NAME = "STRICTYAML_SCHEMA"
+_SCHEMA_FIELD_NAME = "STRICTYAML_SCHEMA"
 
 
 class StrictYAMLSchemaGenerationError(Exception):
@@ -18,9 +28,11 @@ class StrictYAMLSchemaGenerationError(Exception):
 
 
 def _get_strictyaml_type(python_type):
-    if hasattr(python_type, _FIELD_NAME):
-        return getattr(python_type, _FIELD_NAME)
+    # another already processed class
+    if hasattr(python_type, _SCHEMA_FIELD_NAME):
+        return getattr(python_type, _SCHEMA_FIELD_NAME)
 
+    # compount types like List
     elif (
         hasattr(python_type, "__origin__")
         and hasattr(python_type, "__args__")
@@ -29,24 +41,33 @@ def _get_strictyaml_type(python_type):
         origin = getattr(python_type, "__origin__")
         args = getattr(python_type, "__args__")
 
+        # special case for Optional[T]
+        if origin == Union and len(args) == 2 and args[1] == NoneType:
+            return strictyaml.Optional(_get_strictyaml_type(args[0]))
+
         type_constructor = _TYPE_MAP[origin]
         type_arguments = [_get_strictyaml_type(a) for a in args]
         print(type_constructor, type_arguments)
+
+        # special case for Tuple
         if origin == Tuple:
             return type_constructor(type_arguments)
-        else:
-            return type_constructor(*type_arguments)
 
+        # default behaviour
+        return type_constructor(*type_arguments)
+
+    # error handlers for non existent primitive types
     elif python_type not in _TYPE_MAP:
         raise StrictYAMLSchemaGenerationError(
             f"Type {python_type} is not supported for YAML schema generation"
         )
 
+    # remaining primitive and untyped types
     else:
         return _TYPE_MAP[python_type]()
 
 
-def dataclasses_strictyaml_schema(cls):
+def dataclass_strictyaml_schema(cls):
     anot = cls.__dict__.get("__annotations__", {})
 
     if len(anot) == 0:
@@ -57,6 +78,6 @@ def dataclasses_strictyaml_schema(cls):
             fields[name] = _get_strictyaml_type(python_type)
         schema = Map(fields)
 
-    setattr(cls, _FIELD_NAME, schema)
+    setattr(cls, _SCHEMA_FIELD_NAME, schema)
 
     return cls
index cbad3c203769d1c39665c419aae801ed021e0e1b..fff6213c8c38e2b981bac73b867f3863e5845b5c 100644 (file)
@@ -1,6 +1,7 @@
-from knot_resolver_manager.utils import dataclasses_strictyaml_schema
-from typing import List, Dict, Tuple
+from knot_resolver_manager.utils import dataclass_strictyaml_schema
+from typing import List, Dict, Optional, Tuple
 from strictyaml import Map, Str, EmptyDict, Int, Float, Seq, MapPattern, FixedSeq
+import strictyaml
 import pytest
 
 
@@ -12,7 +13,7 @@ def _schema_eq(schema1, schema2) -> bool:
 
 
 def test_empty_class():
-    @dataclasses_strictyaml_schema
+    @dataclass_strictyaml_schema
     class TestClass:
         pass
 
@@ -20,7 +21,7 @@ def test_empty_class():
 
 
 def test_int_field():
-    @dataclasses_strictyaml_schema
+    @dataclass_strictyaml_schema
     class TestClass:
         field: int
 
@@ -28,7 +29,7 @@ def test_int_field():
 
 
 def test_string_field():
-    @dataclasses_strictyaml_schema
+    @dataclass_strictyaml_schema
     class TestClass:
         field: str
 
@@ -36,7 +37,7 @@ def test_string_field():
 
 
 def test_float_field():
-    @dataclasses_strictyaml_schema
+    @dataclass_strictyaml_schema
     class TestClass:
         field: float
 
@@ -44,7 +45,7 @@ def test_float_field():
 
 
 def test_multiple_fields():
-    @dataclasses_strictyaml_schema
+    @dataclass_strictyaml_schema
     class TestClass:
         field1: str
         field2: int
@@ -57,7 +58,7 @@ def test_multiple_fields():
 
 
 def test_list_field():
-    @dataclasses_strictyaml_schema
+    @dataclass_strictyaml_schema
     class TestClass:
         field: List[str]
 
@@ -65,7 +66,7 @@ def test_list_field():
 
 
 def test_dict_field():
-    @dataclasses_strictyaml_schema
+    @dataclass_strictyaml_schema
     class TestClass:
         field: Dict[str, int]
 
@@ -74,8 +75,18 @@ def test_dict_field():
     )
 
 
+def test_optional_field():
+    @dataclass_strictyaml_schema
+    class TestClass:
+        field: Optional[int]
+
+    assert _schema_eq(
+        TestClass.STRICTYAML_SCHEMA, Map({"field": strictyaml.Optional(Int())})
+    )
+
+
 def test_nested_dict_list():
-    @dataclasses_strictyaml_schema
+    @dataclass_strictyaml_schema
     class TestClass:
         field: Dict[str, List[int]]
 
@@ -90,7 +101,7 @@ def test_nested_dict_key_list():
     List can't be a dict key, so this should fail
     """
 
-    @dataclasses_strictyaml_schema
+    @dataclass_strictyaml_schema
     class TestClass:
         field: Dict[List[int], List[int]]
 
@@ -100,7 +111,7 @@ def test_nested_dict_key_list():
 
 
 def test_nested_list():
-    @dataclasses_strictyaml_schema
+    @dataclass_strictyaml_schema
     class TestClass:
         field: List[List[List[List[int]]]]
 
@@ -110,7 +121,7 @@ def test_nested_list():
 
 
 def test_tuple_field():
-    @dataclasses_strictyaml_schema
+    @dataclass_strictyaml_schema
     class TestClass:
         field: Tuple[str, int]
 
@@ -120,7 +131,7 @@ def test_tuple_field():
 
 
 def test_nested_tuple():
-    @dataclasses_strictyaml_schema
+    @dataclass_strictyaml_schema
     class TestClass:
         field: Tuple[str, Dict[str, int], List[List[int]]]
 
@@ -131,11 +142,11 @@ def test_nested_tuple():
 
 
 def test_chained_classes():
-    @dataclasses_strictyaml_schema
+    @dataclass_strictyaml_schema
     class TestClass:
         field: int
 
-    @dataclasses_strictyaml_schema
+    @dataclass_strictyaml_schema
     class CompoundClass:
         c: TestClass
 
@@ -148,7 +159,7 @@ def test_combined_with_dataclass():
     from dataclasses import dataclass
 
     @dataclass
-    @dataclasses_strictyaml_schema
+    @dataclass_strictyaml_schema
     class TestClass:
         field: int
 
@@ -158,7 +169,7 @@ def test_combined_with_dataclass():
 def test_combined_with_dataclass2():
     from dataclasses import dataclass
 
-    @dataclasses_strictyaml_schema
+    @dataclass_strictyaml_schema
     @dataclass
     class TestClass:
         field: int