]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
utils: modelling: default fields assigned before everything else, conversion function...
authorVasek Sraier <git@vakabus.cz>
Tue, 21 Sep 2021 07:59:13 +0000 (09:59 +0200)
committerAleš Mrázek <ales.mrazek@nic.cz>
Fri, 8 Apr 2022 14:17:53 +0000 (16:17 +0200)
manager/knot_resolver_manager/utils/modelling.py

index a72def25c70a3d25eeccd8e6f782398289df485d..f1b93dbd29e7c93e2e844d2bea5e1ff362836528 100644 (file)
@@ -1,5 +1,5 @@
 import inspect
-from typing import Any, Dict, Optional, Set, Tuple, Type, Union
+from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
 
 from knot_resolver_manager.exceptions import DataException, SchemaException
 from knot_resolver_manager.utils.custom_types import CustomValueType
@@ -174,56 +174,94 @@ def _validated_object_type(
         )
 
 
-TSource = Union[NoneType, ParsedTree, "SchemaNode"]
+TSource = Union[NoneType, ParsedTree, "SchemaNode", Dict[str, Any]]
 
 
 class SchemaNode:
     _PREVIOUS_SCHEMA: Optional[Type["SchemaNode"]] = None
 
-    def __init__(self, source: TSource = None, object_path: str = "/"):
-        # construct lower level schema node first if configured to do so
-        if self._PREVIOUS_SCHEMA is not None:
-            source = self._PREVIOUS_SCHEMA(source, object_path=object_path)  # pylint: disable=not-callable
+    def _assign_default_fields(self) -> Set[str]:
+        cls = self.__class__
+        annot = cls.__dict__.get("__annotations__", {})
 
-        # make sure that all raw data checks passed on the source object
-        if isinstance(source, dict):
-            source = ParsedTree(source)
+        used_keys: Set[str] = set()
+        for name in annot:
+            val = getattr(cls, name, ...)
+            if val is not ...:
+                setattr(self, name, val)
+                used_keys.add(name)
 
+        return used_keys
+
+    def _assign_field(self, name: str, python_type: Any, value: Any, object_path: str):
+        cls = self.__class__
+        use_default = hasattr(cls, name)
+        default = getattr(cls, name, ...)
+        value = _validated_object_type(python_type, value, default, use_default, object_path=f"{object_path}/{name}")
+        setattr(self, name, value)
+
+    def _assign_fields(self, source: Union[ParsedTree, "SchemaNode", NoneType], object_path: str) -> Set[str]:
+        """
+        Order of assignment:
+          1. all direct assignments
+          2. assignments with conversion method
+        """
         cls = self.__class__
         annot = cls.__dict__.get("__annotations__", {})
 
         used_keys: Set[str] = set()
+        deffered: List[Tuple[str, Any]] = []
         for name, python_type in annot.items():
             if is_internal_field_name(name):
                 continue
 
             # populate field
             if not source:
-                val = None
+                self._assign_field(name, python_type, None, object_path)
+
             # we have a way how to create the value
             elif hasattr(self, f"_{name}"):
-                val = self._get_converted_value(name, source, object_path)
-                used_keys.add(name)  # the field might not exist, but that won't break anything
+                deffered.append((name, python_type))
+
             # source just contains the value
             elif name in source:
                 val = source[name]
                 used_keys.add(name)
+                self._assign_field(name, python_type, val, object_path)
+
             # there is a default value and in the source, the value is missing
             elif getattr(self, name, ...) is not ...:
-                val = None
+                self._assign_field(name, python_type, None, object_path)
+
             # the value is optional and there is nothing
             elif is_optional(python_type):
-                val = None
+                self._assign_field(name, python_type, None, object_path)
+
             # we expected a value but it was not there
             else:
                 raise SchemaException(f"Missing attribute '{name}'.", object_path)
 
-            use_default = hasattr(cls, name)
-            default = getattr(cls, name, ...)
-            value = _validated_object_type(python_type, val, default, use_default, object_path=f"{object_path}/{name}")
-            setattr(self, name, value)
+        for name, python_type in deffered:
+            val = self._get_converted_value(name, source, object_path)
+            used_keys.add(name)  # the field might not exist, but that won't break anything
+            self._assign_field(name, python_type, val, object_path)
+
+        return used_keys
+
+    def __init__(self, source: TSource = None, object_path: str = "/"):
+        # construct lower level schema node first if configured to do so
+        if self._PREVIOUS_SCHEMA is not None:
+            source = self._PREVIOUS_SCHEMA(source, object_path=object_path)  # pylint: disable=not-callable
+
+        # make sure that all raw data checks passed on the source object
+        if isinstance(source, dict):
+            source = ParsedTree(source)
+
+        # assign fields
+        used_keys = self._assign_default_fields()
+        used_keys.update(self._assign_fields(source, object_path))
 
-        # check for unused keys in case the
+        # check for unused keys in the source object
         if source and not isinstance(source, SchemaNode):
             unused = source.keys() - used_keys
             if len(unused) > 0:
@@ -254,15 +292,5 @@ class SchemaNode:
     def __contains__(self, item: Any) -> bool:
         return hasattr(self, item)
 
-    def validate(self) -> None:
-        for field_name in dir(self):
-            if is_internal_field_name(field_name):
-                continue
-
-            field = getattr(self, field_name)
-            if isinstance(field, SchemaNode):
-                field.validate()
-        self._validate()
-
     def _validate(self) -> None:
         pass