From: Aleš Mrázek Date: Fri, 25 Jul 2025 23:25:52 +0000 (+0200) Subject: utils/modeling/parsing.py: use global validation context X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=refs%2Fheads%2Fparsing-yaml-include;p=thirdparty%2Fknot-resolver.git utils/modeling/parsing.py: use global validation context The global validation context helps resolve relative paths. --- diff --git a/python/knot_resolver/utils/modeling/parsing.py b/python/knot_resolver/utils/modeling/parsing.py index 7f66fbd76..99419c45d 100644 --- a/python/knot_resolver/utils/modeling/parsing.py +++ b/python/knot_resolver/utils/modeling/parsing.py @@ -1,6 +1,7 @@ import json import os from enum import Enum, auto +from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union import yaml @@ -9,6 +10,7 @@ from yaml.nodes import MappingNode from .exceptions import DataParsingError, DataValidationError from .renaming import Renamed, renamed +from .validation_context import get_global_validation_context _include_key = "!include" @@ -65,14 +67,13 @@ class _RaiseDuplicatesIncludeLoader(yaml.SafeLoader): # custom constructor for to detect '!include' keys in the data # source: https://gist.github.com/joshbode/569627ced3076931b02f def construct_include(loader: _RaiseDuplicatesIncludeLoader, node: Any) -> Any: - try: - root = os.path.split(loader.stream.name)[0] # type: ignore - except AttributeError: - root = os.path.curdir - - file_path = os.path.abspath(os.path.join(root, loader.construct_scalar(node))) + file_path = Path(loader.construct_scalar(node)) extension = os.path.splitext(file_path)[1].lstrip(".") + context = get_global_validation_context() + if not file_path.is_absolute() and context.resolve_root: + file_path = context.resolve_root / file_path + with open(file_path, "r") as file: if extension in ("yaml", "yml"): return yaml.load(file, Loader=_RaiseDuplicatesIncludeLoader) @@ -85,7 +86,12 @@ def include_root(text: str) -> str: ntext = "" for line in iter(text.splitlines()): if line.startswith(_include_key): - file_path = line[len(_include_key) + 1 :] + file_path = Path(line[len(_include_key) + 1 :]) + + context = get_global_validation_context() + if not file_path.is_absolute() and context.resolve_root: + file_path = context.resolve_root / file_path + with open(file_path, "r") as file: include_text = file.read() ntext += include_root(include_text)