]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
parsing: yaml: support for !include in custom loader
authorAleš Mrázek <ales.mrazek@nic.cz>
Mon, 21 Jul 2025 12:37:03 +0000 (14:37 +0200)
committerAleš Mrázek <ales.mrazek@nic.cz>
Fri, 25 Jul 2025 21:00:35 +0000 (23:00 +0200)
python/knot_resolver/utils/modeling/parsing.py

index dc6cec6cc74effa3052bb7acede0dc258e8f6ce9..b9c76378a277c8dbf01d6af8354083ed97caf2a5 100644 (file)
@@ -1,4 +1,5 @@
 import json
+import os
 from enum import Enum, auto
 from typing import Any, Dict, List, Optional, Tuple, Union
 
@@ -9,6 +10,8 @@ from yaml.nodes import MappingNode
 from .exceptions import DataParsingError, DataValidationError
 from .renaming import Renamed, renamed
 
+_include_key = "!include"
+
 
 # custom hook for 'json.loads()' to detect duplicate keys in data
 # source: https://stackoverflow.com/q/14902299/12858520
@@ -21,9 +24,19 @@ def _json_raise_duplicates(pairs: List[Tuple[Any, Any]]) -> Optional[Any]:
     return dict_out
 
 
-# custom loader for 'yaml.load()' to detect duplicate keys in data
-# source: https://gist.github.com/pypt/94d747fe5180851196eb
-class _RaiseDuplicatesLoader(yaml.SafeLoader):
+class _RaiseDuplicatesIncludeLoader(yaml.SafeLoader):
+    """
+    Custom YAML Loader for 'yaml.load()'.
+    - detects duplicate keys in the data
+    - detects '!include' keys in the data
+    """
+
+    def __init__(self, stream: Any) -> None:
+        self.add_constructor(_include_key, construct_include)
+        super().__init__(stream)
+
+    # custom constructor to detect duplicate keys in data
+    # source: https://gist.github.com/pypt/94d747fe5180851196eb
     def construct_mapping(self, node: Union[MappingNode, Any], deep: bool = False) -> Dict[Any, Any]:
         if not isinstance(node, MappingNode):
             raise ConstructorError(None, None, f"expected a mapping node, but found {node.id}", node.start_mark)
@@ -49,15 +62,34 @@ class _RaiseDuplicatesLoader(yaml.SafeLoader):
         return mapping
 
 
+# custom constructor for to detect '!include' keys in the data
+# source: https://gist.github.com/joshbode/569627ced3076931b02f
+def construct_include(loader: _RaiseDuplicatesIncludeLoader, node: Any) -> Any:
+    try:
+        root = os.path.split(loader.stream.name)[0]  # type: ignore
+    except AttributeError:
+        root = os.path.curdir
+
+    file_path = os.path.abspath(os.path.join(root, loader.construct_scalar(node)))
+    extension = os.path.splitext(file_path)[1].lstrip(".")
+
+    with open(file_path, "r") as file:
+        if extension in ("yaml", "yml"):
+            return yaml.load(file, Loader=_RaiseDuplicatesIncludeLoader)
+        if extension in ("json",):
+            return json.load(file)
+        return "".join(file.readlines())
+
+
 class DataFormat(Enum):
     YAML = auto()
     JSON = auto()
 
     def parse_to_dict(self, text: str) -> Any:
         if self is DataFormat.YAML:
-            # RaiseDuplicatesLoader extends yaml.SafeLoader, so this should be safe
+            # _RaiseDuplicatesIncludeLoader extends yaml.SafeLoader, so this should be safe
             # https://python.land/data-processing/python-yaml#PyYAML_safe_load_vs_load
-            return renamed(yaml.load(text, Loader=_RaiseDuplicatesLoader))  # type: ignore
+            return renamed(yaml.load(text, Loader=_RaiseDuplicatesIncludeLoader))  # type: ignore
         if self is DataFormat.JSON:
             return renamed(json.loads(text, object_pairs_hook=_json_raise_duplicates))
         raise NotImplementedError(f"Parsing of format '{self}' is not implemented")