]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
utils: modelling: move conversion from between '-' and '_' to _ parsing step (from...
authorVasek Sraier <git@vakabus.cz>
Mon, 20 Sep 2021 12:29:53 +0000 (14:29 +0200)
committerAleš Mrázek <ales.mrazek@nic.cz>
Fri, 8 Apr 2022 14:17:53 +0000 (16:17 +0200)
manager/knot_resolver_manager/utils/modelling.py
manager/knot_resolver_manager/utils/parsing.py
manager/tests/utils/test_modeling.py

index 5edd8b836b9461f284655f94644e5da0447f1334..a72def25c70a3d25eeccd8e6f782398289df485d 100644 (file)
@@ -174,7 +174,7 @@ def _validated_object_type(
         )
 
 
-TSource = Union[NoneType, Dict[Any, Any], ParsedTree, "SchemaNode"]
+TSource = Union[NoneType, ParsedTree, "SchemaNode"]
 
 
 class SchemaNode:
@@ -185,6 +185,10 @@ class SchemaNode:
         if self._PREVIOUS_SCHEMA is not None:
             source = self._PREVIOUS_SCHEMA(source, object_path=object_path)  # pylint: disable=not-callable
 
+        # make sure that all raw data checks passed on the source object
+        if isinstance(source, dict):
+            source = ParsedTree(source)
+
         cls = self.__class__
         annot = cls.__dict__.get("__annotations__", {})
 
@@ -193,20 +197,17 @@ class SchemaNode:
             if is_internal_field_name(name):
                 continue
 
-            # convert naming (used when converting from json/yaml)
-            source_name = name.replace("_", "-") if isinstance(source, dict) else name
-
             # populate field
             if not source:
                 val = None
             # we have a way how to create the value
             elif hasattr(self, f"_{name}"):
                 val = self._get_converted_value(name, source, object_path)
-                used_keys.add(source_name)  # the field might not exist, but that won't break anything
+                used_keys.add(name)  # the field might not exist, but that won't break anything
             # source just contains the value
-            elif source_name in source:
-                val = source[source_name]
-                used_keys.add(source_name)
+            elif name in source:
+                val = source[name]
+                used_keys.add(name)
             # there is a default value and in the source, the value is missing
             elif getattr(self, name, ...) is not ...:
                 val = None
@@ -215,7 +216,7 @@ class SchemaNode:
                 val = None
             # we expected a value but it was not there
             else:
-                raise SchemaException(f"Missing attribute '{source_name}'.", object_path)
+                raise SchemaException(f"Missing attribute '{name}'.", object_path)
 
             use_default = hasattr(cls, name)
             default = getattr(cls, name, ...)
@@ -223,7 +224,7 @@ class SchemaNode:
             setattr(self, name, value)
 
         # check for unused keys in case the
-        if source and isinstance(source, dict):
+        if source and not isinstance(source, SchemaNode):
             unused = source.keys() - used_keys
             if len(unused) > 0:
                 raise SchemaException(
index 897f0b5fca4e973955392c052dcfbb33407360a0..4c9fe29fb9ade936bbd8202c36612ad758592e84 100644 (file)
@@ -2,7 +2,7 @@ import copy
 import json
 import re
 from enum import Enum, auto
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, KeysView, List, Optional, Tuple, Union, cast
 
 import yaml
 from yaml.constructor import ConstructorError
@@ -20,26 +20,55 @@ class ParsedTree:
     IMMUTABLE, DO NOT MODIFY
     """
 
-    def __init__(self, dct: Dict[str, Any]):
-        self.data = dct
-
-    def to_dict(self) -> Dict[str, Any]:
+    @staticmethod
+    def _convert_to_underscores(dct: Dict[Any, Any]) -> Dict[Any, Any]:
+        assert isinstance(dct, dict)
+        res: Dict[Any, Any] = {}
+        for key in dct:
+            assert isinstance(key, str)
+
+            # rename & convert recursively
+            obj = dct[key]
+            if isinstance(obj, dict):
+                obj = ParsedTree._convert_to_underscores(cast(Dict[Any, Any], obj))
+            res[key.replace("-", "_")] = obj
+
+        return res
+
+    def __init__(self, data: Union[Dict[str, Any], str, int, bool]):
+        if isinstance(data, dict):
+            data = ParsedTree._convert_to_underscores(data)
+        self.data = data
+
+    def to_raw(self) -> Union[Dict[str, Any], str, int, bool]:
         return self.data
 
     def __getitem__(self, key: str):
+        assert isinstance(self.data, dict)
         return self.data[key]
 
     def __contains__(self, key: str):
+        assert isinstance(self.data, dict)
         return key in self.data
 
+    def __str__(self) -> str:
+        return json.dumps(self.data, sort_keys=False, indent=2)
+
+    def keys(self) -> KeysView[Any]:
+        assert isinstance(self.data, dict)
+        return self.data.keys()
+
     _SUBTREE_MUTATION_PATH_PATTERN = re.compile(r"^(/[^/]+)*/?$")
 
-    def update(self, document_path: str, data: "ParsedTree") -> "ParsedTree":
+    def update(self, path: str, data: "ParsedTree") -> "ParsedTree":
 
         # prepare and validate the path object
-        path = document_path[:-1] if document_path.endswith("/") else document_path
+        path = path[:-1] if path.endswith("/") else path
         if re.match(ParsedTree._SUBTREE_MUTATION_PATH_PATTERN, path) is None:
             raise ParsingException("Provided object path for mutation is invalid.")
+        if "_" in path:
+            raise ParsingException("Provided object path contains character '_', which is illegal")
+        path = path.replace("-", "_")
         path = path[1:] if path.startswith("/") else path
 
         # now, the path variable should contain '/' separated field names
@@ -49,12 +78,12 @@ class ParsedTree:
             return data
 
         # find the subtree we will replace in a copy of the original object
-        to_mutate = copy.deepcopy(self.to_dict())
+        to_mutate = copy.deepcopy(self.to_raw())
         obj = to_mutate
         parent = None
 
-        for dash_segment in path.split("/"):
-            segment = dash_segment.replace("-", "_")
+        for segment in path.split("/"):
+            assert isinstance(obj, dict)
 
             if segment == "":
                 raise ParsingException(f"Unexpectedly empty segment in path '{path}'")
@@ -64,7 +93,7 @@ class ParsedTree:
                 )
             elif segment in obj:
                 parent = obj
-                obj = getattr(parent, segment)
+                obj = obj[segment]
             elif segment not in obj:
                 parent = obj
                 obj = {}
@@ -72,8 +101,8 @@ class ParsedTree:
         assert parent is not None
 
         # assign the subtree
-        last_name = path.split("/")[-1].replace("-", "_")
-        parent[last_name] = data.to_dict()
+        last_name = path.split("/")[-1]
+        parent[last_name] = data.to_raw()
 
         return ParsedTree(to_mutate)
 
index 5305e6c555f1c0ceb34868f04f4aaaa06cb3d203..29ff8f70a472ff328afa96a7d10ba116410a7909 100644 (file)
@@ -162,7 +162,8 @@ def test_partial_mutations():
     assert o.workers == 8
 
     # replacement of 'lua-config' attribute
-    o = ConfSchema(d.update("/lua-config", parse_json('"new_value"')))
+    upd = d.update("/lua-config", parse_json('"new_value"'))
+    o = ConfSchema(upd)
     assert o.lua_config == "new_value"
     assert o.inner.size == 5
     assert o.workers == 8
@@ -175,7 +176,7 @@ def test_partial_mutations():
 
     # replacement of 'inner' subtree
     o = ConfSchema(d.update("/inner", parse_json('{"size": 33}')))
-    assert o.lua_config == None
+    assert o.lua_config == "something"
     assert o.workers == 8
     assert o.inner.size == 33