]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
utils/modeling/parsing.py: use global validation context docs-parsing-yaml-jpfrdq/deployments/7283 parsing-yaml-include
authorAleš Mrázek <ales.mrazek@nic.cz>
Fri, 25 Jul 2025 23:25:52 +0000 (01:25 +0200)
committerAleš Mrázek <ales.mrazek@nic.cz>
Fri, 25 Jul 2025 23:26:20 +0000 (01:26 +0200)
The global validation context helps resolve relative paths.

python/knot_resolver/utils/modeling/parsing.py

index 7f66fbd76f467dddbbff2308efecde656b334857..99419c45de94c4180845581bec2ff87ab650feca 100644 (file)
@@ -1,6 +1,7 @@
 import json
 import os
 from enum import Enum, auto
+from pathlib import Path
 from typing import Any, Dict, List, Optional, Tuple, Union
 
 import yaml
@@ -9,6 +10,7 @@ from yaml.nodes import MappingNode
 
 from .exceptions import DataParsingError, DataValidationError
 from .renaming import Renamed, renamed
+from .validation_context import get_global_validation_context
 
 _include_key = "!include"
 
@@ -65,14 +67,13 @@ class _RaiseDuplicatesIncludeLoader(yaml.SafeLoader):
 # custom constructor for to detect '!include' keys in the data
 # source: https://gist.github.com/joshbode/569627ced3076931b02f
 def construct_include(loader: _RaiseDuplicatesIncludeLoader, node: Any) -> Any:
-    try:
-        root = os.path.split(loader.stream.name)[0]  # type: ignore
-    except AttributeError:
-        root = os.path.curdir
-
-    file_path = os.path.abspath(os.path.join(root, loader.construct_scalar(node)))
+    file_path = Path(loader.construct_scalar(node))
     extension = os.path.splitext(file_path)[1].lstrip(".")
 
+    context = get_global_validation_context()
+    if not file_path.is_absolute() and context.resolve_root:
+        file_path = context.resolve_root / file_path
+
     with open(file_path, "r") as file:
         if extension in ("yaml", "yml"):
             return yaml.load(file, Loader=_RaiseDuplicatesIncludeLoader)
@@ -85,7 +86,12 @@ def include_root(text: str) -> str:
     ntext = ""
     for line in iter(text.splitlines()):
         if line.startswith(_include_key):
-            file_path = line[len(_include_key) + 1 :]
+            file_path = Path(line[len(_include_key) + 1 :])
+
+            context = get_global_validation_context()
+            if not file_path.is_absolute() and context.resolve_root:
+                file_path = context.resolve_root / file_path
+
             with open(file_path, "r") as file:
                 include_text = file.read()
             ntext += include_root(include_text)