]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
manager: utils: modeling: support for custom generic types
authorAleš Mrázek <ales.mrazek@nic.cz>
Fri, 28 Apr 2023 15:17:23 +0000 (17:17 +0200)
committerAleš Mrázek <ales.mrazek@nic.cz>
Fri, 9 Jun 2023 11:54:07 +0000 (11:54 +0000)
manager/knot_resolver_manager/utils/modeling/__init__.py
manager/knot_resolver_manager/utils/modeling/base_generic_type_wrapper.py [new file with mode: 0644]
manager/knot_resolver_manager/utils/modeling/base_schema.py
manager/knot_resolver_manager/utils/modeling/base_value_type.py
manager/knot_resolver_manager/utils/modeling/types.py

index c72c60c734f197b239bfceaf888c92404f2c91bb..d16f6c12d6a4a51a633aedabf22d4f84aab15ef8 100644 (file)
@@ -1,8 +1,10 @@
+from .base_generic_type_wrapper import BaseGenericTypeWrapper
 from .base_schema import BaseSchema, ConfigSchema
 from .base_value_type import BaseValueType
 from .parsing import parse_json, parse_yaml, try_to_parse
 
 __all__ = [
+    "BaseGenericTypeWrapper",
     "BaseValueType",
     "BaseSchema",
     "ConfigSchema",
diff --git a/manager/knot_resolver_manager/utils/modeling/base_generic_type_wrapper.py b/manager/knot_resolver_manager/utils/modeling/base_generic_type_wrapper.py
new file mode 100644 (file)
index 0000000..3aee3c1
--- /dev/null
@@ -0,0 +1,9 @@
+from typing import Generic, TypeVar
+
+from .base_value_type import BaseTypeABC
+
+T = TypeVar("T")
+
+
+class BaseGenericTypeWrapper(Generic[T], BaseTypeABC):
+    pass
index 31cea7cc40d08735e5f4e3a3a0b5d91e6c557b89..32388816e5284108f0d1e6a0f978a4bb2c777caa 100644 (file)
@@ -7,15 +7,18 @@ import yaml
 
 from knot_resolver_manager.utils.functional import all_matches
 
+from .base_generic_type_wrapper import BaseGenericTypeWrapper
 from .base_value_type import BaseValueType
 from .exceptions import AggregateDataValidationError, DataDescriptionError, DataValidationError
 from .renaming import Renamed, renamed
 from .types import (
     get_generic_type_argument,
     get_generic_type_arguments,
+    get_generic_type_wrapper_argument,
     get_optional_inner_type,
     is_dict,
     is_enum,
+    is_generic_type_wrapper,
     is_internal_field_name,
     is_list,
     is_literal,
@@ -54,6 +57,7 @@ class Serializable(ABC):
             or is_literal(typ)
             or is_dict(typ)
             or is_list(typ)
+            or is_generic_type_wrapper(typ)
             or (inspect.isclass(typ) and issubclass(typ, Serializable))
             or (inspect.isclass(typ) and issubclass(typ, BaseValueType))
             or (inspect.isclass(typ) and issubclass(typ, BaseSchema))
@@ -66,8 +70,11 @@ class Serializable(ABC):
         if isinstance(obj, Serializable):
             return obj.to_dict()
 
-        elif isinstance(obj, BaseValueType):
-            return obj.serialize()
+        elif isinstance(obj, (BaseValueType, BaseGenericTypeWrapper)):
+            o = obj.serialize()
+            # if Serializable.is_serializable(o):
+            return Serializable.serialize(o)
+            # return o
 
         elif isinstance(obj, list):
             res: List[Any] = [Serializable.serialize(i) for i in cast(List[Any], obj)]
@@ -171,6 +178,10 @@ def _describe_type(typ: Type[Any]) -> Dict[Any, Any]:
     elif inspect.isclass(typ) and issubclass(typ, BaseValueType):
         return typ.json_schema()
 
+    elif is_generic_type_wrapper(typ):
+        wrapped = get_generic_type_wrapper_argument(typ)
+        return _describe_type(wrapped)
+
     elif is_none_type(typ):
         return {"type": "null"}
 
@@ -279,11 +290,15 @@ class ObjectMapper:
         inner_type = get_generic_type_argument(tp)
         errs: List[DataValidationError] = []
         res: List[Any] = []
-        for i, val in enumerate(obj):
-            try:
+
+        try:
+            for i, val in enumerate(obj):
                 res.append(self.map_object(inner_type, val, object_path=f"{object_path}[{i}]"))
-            except DataValidationError as e:
-                errs.append(e)
+        except DataValidationError as e:
+            errs.append(e)
+        except TypeError as e:
+            errs.append(DataValidationError(str(e), object_path))
+
         if len(errs) == 1:
             raise errs[0]
         elif len(errs) > 1:
@@ -465,6 +480,12 @@ class ObjectMapper:
         elif inspect.isclass(tp) and issubclass(tp, BaseValueType):
             return self.create_value_type_object(tp, obj, object_path)
 
+        # BaseGenericTypeWrapper subclasses
+        elif is_generic_type_wrapper(tp):
+            inner_type = get_generic_type_wrapper_argument(tp)
+            obj_valid = self.map_object(inner_type, obj, object_path)
+            return tp(obj_valid, object_path=object_path)  # type: ignore
+
         # nested BaseSchema subclasses
         elif inspect.isclass(tp) and issubclass(tp, BaseSchema):
             return self._create_base_schema_object(tp, obj, object_path)
index 1b07e3d4712f775f667c4f2e3b2dcb84517c7688..dff4a3fe6b9970451b935df2ebd4b253b98b48e8 100644 (file)
@@ -2,18 +2,7 @@ from abc import ABC, abstractmethod  # pylint: disable=[no-name-in-module]
 from typing import Any, Dict, Type
 
 
-class BaseValueType(ABC):
-    """
-    Subclasses of this class can be used as type annotations in 'DataParser'. When a value
-    is being parsed from a serialized format (e.g. JSON/YAML), an object will be created by
-    calling the constructor of the appropriate type on the field value. The only limitation
-    is that the value MUST NOT be `None`.
-
-    There is no validation done on the wrapped value. The only condition is that
-    it can't be `None`. If you want to perform any validation during creation,
-    raise a `ValueError` in case of errors.
-    """
-
+class BaseTypeABC(ABC):
     @abstractmethod
     def __init__(self, source_value: Any, object_path: str = "/") -> None:
         pass
@@ -37,6 +26,19 @@ class BaseValueType(ABC):
         """
         raise NotImplementedError(f"{type(self).__name__}'s' 'serialize()' not implemented.")
 
+
+class BaseValueType(BaseTypeABC):
+    """
+    Subclasses of this class can be used as type annotations in 'DataParser'. When a value
+    is being parsed from a serialized format (e.g. JSON/YAML), an object will be created by
+    calling the constructor of the appropriate type on the field value. The only limitation
+    is that the value MUST NOT be `None`.
+
+    There is no validation done on the wrapped value. The only condition is that
+    it can't be `None`. If you want to perform any validation during creation,
+    raise a `ValueError` in case of errors.
+    """
+
     @classmethod
     @abstractmethod
     def json_schema(cls: Type["BaseValueType"]) -> Dict[Any, Any]:
index aaeded9e91c8c65ad3f7f2062988891ab6ba7380..4ce9aecca9cda45bcae252627a712c645b6451f2 100644 (file)
@@ -8,6 +8,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
 
 from typing_extensions import Literal
 
+from .base_generic_type_wrapper import BaseGenericTypeWrapper
+
 NoneType = type(None)
 
 
@@ -46,6 +48,11 @@ def is_literal(tp: Any) -> bool:
         return getattr(tp, "__origin__", None) == Literal
 
 
+def is_generic_type_wrapper(tp: Any) -> bool:
+    orig = getattr(tp, "__origin__", None)
+    return inspect.isclass(orig) and issubclass(orig, BaseGenericTypeWrapper)
+
+
 def get_generic_type_arguments(tp: Any) -> List[Any]:
     default: List[Any] = []
     if sys.version_info.minor == 6 and is_literal(tp):
@@ -62,6 +69,17 @@ def get_generic_type_argument(tp: Any) -> Any:
     return args[0]
 
 
+def get_generic_type_wrapper_argument(tp: Type["BaseGenericTypeWrapper[Any]"]) -> Any:
+    assert hasattr(tp, "__origin__")
+    origin = getattr(tp, "__origin__")
+
+    assert hasattr(origin, "__orig_bases__")
+    orig_base: List[Any] = getattr(origin, "__orig_bases__", [])[0]
+
+    arg = get_generic_type_argument(tp)
+    return get_generic_type_argument(orig_base[arg])
+
+
 def is_none_type(tp: Any) -> bool:
     return tp is None or tp == NoneType