+++ /dev/null
-from .constants import VERSION
-from .exceptions import KresBaseError
-
-__version__ = VERSION
-
-__all__ = ["KresBaseError"]
+++ /dev/null
-from pathlib import Path
-
-from knot_resolver.datamodel.globals import Context, set_global_validation_context
-
-set_global_validation_context(Context(Path("."), False))
+++ /dev/null
-from knot_resolver.client.main import main
-
-if __name__ == "__main__":
- main()
+++ /dev/null
-import argparse
-
-from knot_resolver.client.command import CommandArgs
-
-KRES_CLIENT_NAME = "kresctl"
-
-
-class KresClient:
- def __init__(
- self,
- namespace: argparse.Namespace,
- parser: argparse.ArgumentParser,
- prompt: str = KRES_CLIENT_NAME,
- ) -> None:
- self.path = None
- self.prompt = prompt
- self.namespace = namespace
- self.parser = parser
-
- def execute(self) -> None:
- if hasattr(self.namespace, "command"):
- args = CommandArgs(self.namespace, self.parser)
- command = args.command(self.namespace)
- command.run(args)
- else:
- self.parser.print_help()
-
- def _prompt_format(self) -> str:
- bolt = "\033[1m"
- white = "\033[38;5;255m"
- reset = "\033[0;0m"
-
- prompt = f"{bolt}[{self.prompt} {white}{self.path}{reset}{bolt}]" if self.path else f"{bolt}{self.prompt}"
- return f"{prompt}> {reset}"
-
- def interactive(self) -> None:
- try:
- while True:
- pass
- # TODO: not working yet
- # cmd = input(f"{self._prompt_format()}")
- # namespace = self.parser.parse_args(cmd.split(" "))
- # namespace.interactive = True
- # namespace.socket = self.namespace.socket
- # self.namespace = namespace
- # self.execute()
- except KeyboardInterrupt:
- pass
+++ /dev/null
-import argparse
-from abc import ABC, abstractmethod
-from pathlib import Path
-from typing import Dict, List, Optional, Set, Tuple, Type, TypeVar
-from urllib.parse import quote
-
-from knot_resolver.constants import API_SOCK_FILE, CONFIG_FILE
-from knot_resolver.datamodel.types import IPAddressPort, WritableFilePath
-from knot_resolver.utils.modeling import parsing
-from knot_resolver.utils.modeling.exceptions import DataValidationError
-from knot_resolver.utils.requests import SocketDesc
-
-T = TypeVar("T", bound=Type["Command"])
-
-CompWords = Dict[str, Optional[str]]
-
-COMP_DIRNAMES = "#dirnames#"
-COMP_FILENAMES = "#filenames#"
-COMP_NOSPACE = "#nospace#"
-
-_registered_commands: List[Type["Command"]] = []
-
-
-def get_mutually_exclusive_args(parser: argparse.ArgumentParser) -> List[Set[str]]:
- groups: List[Set[str]] = []
-
- for group in parser._mutually_exclusive_groups: # noqa: SLF001
- group_args: Set[str] = set()
- for action in group._group_actions: # noqa: SLF001
- if action.option_strings:
- group_args.update(action.option_strings)
- if group_args:
- groups.append(group_args)
- return groups
-
-
-def get_parser_action(name: str, parser_actions: List[argparse.Action]) -> Optional[argparse.Action]:
- for action in parser_actions:
- if (action.choices and name in action.choices) or (action.option_strings and name in action.option_strings):
- return action
- return None
-
-
-def get_subparser_command(subparser: argparse.ArgumentParser) -> Optional["Command"]:
- if "command" in subparser._defaults: # noqa: SLF001
- return subparser._defaults["command"] # noqa: SLF001
- return None
-
-
-def comp_get_actions_words(parser_actions: List[argparse.Action]) -> CompWords:
- words: CompWords = {}
- for action in parser_actions:
- if isinstance(action, argparse._SubParsersAction) and action.choices: # noqa: SLF001
- for choice, parser in action.choices.items():
- words[choice] = parser.description if isinstance(parser, argparse.ArgumentParser) else None
- elif action.option_strings:
- for opt in action.option_strings:
- words[opt] = action.help
- elif not action.option_strings and action.choices:
- for choice in action.choices:
- words[choice] = action.help
- elif not action.option_strings and not action.choices:
- words[COMP_DIRNAMES] = None
- words[COMP_FILENAMES] = None
- return words
-
-
-def comp_get_words(args: List[str], parser: argparse.ArgumentParser) -> CompWords: # noqa: C901, PLR0912
- words: CompWords = comp_get_actions_words(parser._actions) # noqa: SLF001
- nargs = len(args)
-
- skip_arg = False
- for i, arg in enumerate(args):
- action: Optional[argparse.Action] = get_parser_action(arg, parser._actions) # noqa: SLF001
-
- if skip_arg:
- skip_arg = False
- continue
-
- if not action:
- continue
-
- if i + 1 >= nargs:
- continue
-
- # remove exclusive arguments from words
- for exclusive_args in get_mutually_exclusive_args(parser):
- if arg in exclusive_args:
- for earg in exclusive_args:
- if earg in words.keys():
- del words[earg]
- # remove alternative arguments from words
- for opt in action.option_strings:
- if opt in words.keys():
- del words[opt]
-
- # if not action or action is HelpAction or VersionAction
- if isinstance(action, (argparse._HelpAction, argparse._VersionAction)): # noqa: SLF001
- words = {}
- break
-
- # if action is StoreTrueAction or StoreFalseAction
- if isinstance(action, argparse._StoreConstAction): # noqa: SLF001
- continue
-
- # if action is StoreAction
- if isinstance(action, argparse._StoreAction): # noqa: SLF001
- if i + 2 >= nargs:
- choices = {}
- if action.choices:
- for choice in action.choices:
- choices[choice] = action.help
- else:
- choices[COMP_DIRNAMES] = None
- choices[COMP_FILENAMES] = None
- words = choices
- skip_arg = True
- continue
-
- # if action is SubParserAction
- if isinstance(action, argparse._SubParsersAction): # noqa: SLF001
- subparser: Optional[argparse.ArgumentParser] = action.choices.get(arg, None)
-
- command = get_subparser_command(subparser) if subparser else None
- if command and subparser:
- return command.completion(args[i + 1 :], subparser)
- if subparser:
- return comp_get_words(args[i + 1 :], subparser)
- return {}
-
- return words
-
-
-def register_command(cls: T) -> T:
- _registered_commands.append(cls)
- return cls
-
-
-def get_help_command() -> Type["Command"]:
- for command in _registered_commands:
- if command.__name__ == "HelpCommand":
- return command
- raise ValueError("missing HelpCommand")
-
-
-def install_commands_parsers(parser: argparse.ArgumentParser) -> None:
- subparsers = parser.add_subparsers(help="command type")
- for command in _registered_commands:
- subparser, typ = command.register_args_subparser(subparsers)
- subparser.set_defaults(command=typ, subparser=subparser)
-
-
-def get_socket_from_config(config: Path, optional_file: bool) -> Optional[SocketDesc]:
- try:
- with open(config, "r", encoding="utf8") as f:
- data = parsing.try_to_parse(f.read())
-
- mkey = "management"
- if mkey in data:
- management = data[mkey]
-
- skey = "unix-socket"
- if skey in management:
- sock = WritableFilePath(management[skey], object_path=f"/{mkey}/{skey}")
- return SocketDesc(
- f'http+unix://{quote(str(sock), safe="")}/',
- f'Key "/management/unix-socket" in "{config}" file',
- )
- ikey = "interface"
- if ikey in data[mkey]:
- ip = IPAddressPort(management[ikey], object_path=f"/{mkey}/{ikey}")
- return SocketDesc(
- f"http://{ip.addr}:{ip.port}",
- f'Key "/management/interface" in "{config}" file',
- )
- except ValueError as e:
- raise DataValidationError(*e.args) from e # pylint: disable=no-value-for-parameter
- except OSError:
- if not optional_file:
- raise
- return None
- else:
- return None
-
-
-def determine_socket(namespace: argparse.Namespace) -> SocketDesc:
- # 1) socket from '--socket' argument
- if len(namespace.socket) > 0:
- return SocketDesc(namespace.socket[0], "--socket argument")
-
- socket: Optional[SocketDesc] = None
- # 2) socket from config file ('--config' argument)
- if len(namespace.config) > 0:
- socket = get_socket_from_config(namespace.config[0], False)
- # 3) socket from config file (default config file constant)
- else:
- socket = get_socket_from_config(CONFIG_FILE, True)
-
- if socket:
- return socket
- # 4) socket default
- return SocketDesc(str(API_SOCK_FILE), f'Default value "{API_SOCK_FILE}"')
-
-
-class CommandArgs:
- def __init__(self, namespace: argparse.Namespace, parser: argparse.ArgumentParser) -> None:
- self.namespace = namespace
- self.parser = parser
- self.subparser: argparse.ArgumentParser = namespace.subparser
- self.command: Type["Command"] = namespace.command
-
- self.socket: SocketDesc = determine_socket(namespace)
-
-
-class Command(ABC):
- @staticmethod
- @abstractmethod
- def register_args_subparser(
- subparser: "argparse._SubParsersAction[argparse.ArgumentParser]",
- ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]:
- raise NotImplementedError()
-
- @abstractmethod
- def __init__(self, namespace: argparse.Namespace) -> None: # pylint: disable=[unused-argument]
- super().__init__()
-
- @abstractmethod
- def run(self, args: CommandArgs) -> None:
- raise NotImplementedError()
-
- @staticmethod
- @abstractmethod
- def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords:
- raise NotImplementedError()
+++ /dev/null
-# noqa: INP001
-import argparse
-import sys
-from enum import Enum
-from typing import Any, Dict, List, Optional, Tuple, Type
-
-from knot_resolver.client.command import Command, CommandArgs, CompWords, comp_get_words, register_command
-from knot_resolver.datamodel.cache_schema import CacheClearRPCSchema
-from knot_resolver.utils.modeling.exceptions import AggregateDataValidationError, DataValidationError
-from knot_resolver.utils.modeling.parsing import DataFormat, parse_json
-from knot_resolver.utils.requests import request
-
-
-class CacheOperations(Enum):
- CLEAR = 0
-
-
-@register_command
-class CacheCommand(Command):
- def __init__(self, namespace: argparse.Namespace) -> None:
- super().__init__(namespace)
- self.operation: Optional[CacheOperations] = namespace.operation if hasattr(namespace, "operation") else None
- self.output_format: DataFormat = (
- namespace.output_format if hasattr(namespace, "output_format") else DataFormat.YAML
- )
-
- # CLEAR operation
- self.clear_dict: Dict[str, Any] = {}
- if hasattr(namespace, "exact_name"):
- self.clear_dict["exact-name"] = namespace.exact_name
- if hasattr(namespace, "name"):
- self.clear_dict["name"] = namespace.name
- if hasattr(namespace, "rr_type"):
- self.clear_dict["rr-type"] = namespace.rr_type
- if hasattr(namespace, "chunk_size"):
- self.clear_dict["chunk-size"] = namespace.chunk_size
-
- @staticmethod
- def register_args_subparser(
- subparser: "argparse._SubParsersAction[argparse.ArgumentParser]",
- ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]:
- cache_parser = subparser.add_parser("cache", help="Performs operations on the cache of the running resolver.")
-
- config_subparsers = cache_parser.add_subparsers(help="operation type")
-
- # 'clear' operation
- clear_subparser = config_subparsers.add_parser(
- "clear", help="Purge cache records that match specified criteria."
- )
- clear_subparser.set_defaults(operation=CacheOperations.CLEAR, exact_name=False)
- clear_subparser.add_argument(
- "--exact-name",
- help="If set, only records with the same name are purged.",
- action="store_true",
- dest="exact_name",
- )
- clear_subparser.add_argument(
- "--rr-type",
- help="Optional, the resource record type to purge. It is supported only with the '--exact-name' flag set.",
- action="store",
- type=str,
- )
- clear_subparser.add_argument(
- "--chunk-size",
- help="Optional, the number of records to remove in one round; the default is 100."
- " The purpose is not to block the resolver for long."
- " The resolver repeats the cache clearing after one millisecond until all matching data is cleared.",
- action="store",
- type=int,
- default=100,
- )
- clear_subparser.add_argument(
- "name",
- type=str,
- nargs="?",
- help="Optional, subtree name to purge; if omitted,"
- " the entire cache is purged (and all other parameters are ignored).",
- default=None,
- )
-
- output_format = clear_subparser.add_mutually_exclusive_group()
- output_format_default = DataFormat.YAML
- output_format.add_argument(
- "--json",
- help="Set JSON as the output format.",
- const=DataFormat.JSON,
- action="store_const",
- dest="output_format",
- default=output_format_default,
- )
- output_format.add_argument(
- "--yaml",
- help="Set YAML as the output format. YAML is the default.",
- const=DataFormat.YAML,
- action="store_const",
- dest="output_format",
- default=output_format_default,
- )
-
- return cache_parser, CacheCommand
-
- @staticmethod
- def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords:
- return comp_get_words(args, parser)
-
- def run(self, args: CommandArgs) -> None:
- if not self.operation:
- args.subparser.print_help()
- sys.exit()
-
- if self.operation == CacheOperations.CLEAR:
- try:
- validated = CacheClearRPCSchema(self.clear_dict)
- except (AggregateDataValidationError, DataValidationError) as e:
- print(e, file=sys.stderr)
- sys.exit(1)
-
- body: str = DataFormat.JSON.dict_dump(validated.get_unparsed_data())
- response = request(args.socket, "POST", "cache/clear", body)
- body_dict = parse_json(response.body)
-
- if response.status != 200:
- print(response, file=sys.stderr)
- sys.exit(1)
- print(self.output_format.dict_dump(body_dict, indent=4))
+++ /dev/null
-# noqa: INP001
-import argparse
-from enum import Enum
-from typing import List, Tuple, Type
-
-from knot_resolver.client.command import (
- Command,
- CommandArgs,
- CompWords,
- comp_get_words,
- register_command,
-)
-
-
-class Shells(Enum):
- BASH = 0
- FISH = 1
-
-
-@register_command
-class CompletionCommand(Command):
- def __init__(self, namespace: argparse.Namespace) -> None:
- super().__init__(namespace)
- self.shell: Shells = namespace.shell
- self.args: List[str] = namespace.args
- if namespace.extra is not None:
- self.args.append("--")
-
- @staticmethod
- def register_args_subparser(
- subparser: "argparse._SubParsersAction[argparse.ArgumentParser]",
- ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]:
- completion = subparser.add_parser(
- "completion",
- help="commands auto-completion",
- )
-
- shells_dest = "shell"
- shells = completion.add_mutually_exclusive_group()
- shells.add_argument("--bash", action="store_const", dest=shells_dest, const=Shells.BASH, default=Shells.BASH)
- shells.add_argument("--fish", action="store_const", dest=shells_dest, const=Shells.FISH)
-
- completion.add_argument("--args", help="arguments to complete", nargs=argparse.REMAINDER, default=[])
-
- return completion, CompletionCommand
-
- @staticmethod
- def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords:
- return comp_get_words(args, parser)
-
- def run(self, args: CommandArgs) -> None:
- words: CompWords = {}
-
- parser = args.parser
- if parser:
- words = comp_get_words(self.args, args.parser)
-
- # print completion words
- # based on required bash/fish shell format
- if self.shell == Shells.BASH:
- print(" ".join(words))
- elif self.shell == Shells.FISH:
- # TODO: FISH completion implementation
- pass
- else:
- raise ValueError(f"unexpected value of {Shells}: {self.shell}")
+++ /dev/null
-# noqa: INP001
-import argparse
-import sys
-from enum import Enum
-from typing import List, Literal, Optional, Tuple, Type
-
-from knot_resolver.client.command import COMP_NOSPACE, Command, CommandArgs, CompWords, comp_get_words, register_command
-from knot_resolver.datamodel import KresConfig
-from knot_resolver.utils.modeling.parsing import DataFormat, parse_json, try_to_parse
-from knot_resolver.utils.requests import request
-
-
-class Operations(Enum):
- SET = 0
- DELETE = 1
- GET = 2
-
-
-def operation_to_method(operation: Operations) -> Literal["PUT", "GET", "DELETE"]:
- if operation == Operations.SET:
- return "PUT"
- if operation == Operations.DELETE:
- return "DELETE"
- return "GET"
-
-
-@register_command
-class ConfigCommand(Command):
- def __init__(self, namespace: argparse.Namespace) -> None:
- super().__init__(namespace)
- self.path: str = str(namespace.path) if hasattr(namespace, "path") else ""
- self.format: DataFormat = namespace.format if hasattr(namespace, "format") else DataFormat.JSON
- self.operation: Optional[Operations] = namespace.operation if hasattr(namespace, "operation") else None
- self.file: Optional[str] = namespace.file if hasattr(namespace, "file") else None
-
- @staticmethod
- def register_args_subparser(
- subparser: "argparse._SubParsersAction[argparse.ArgumentParser]",
- ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]:
- config = subparser.add_parser("config", help="Performs operations on the running resolver's configuration.")
- path_help = "Optional, path (JSON pointer, RFC6901) to the configuration resources. "
- " By default, the entire configuration is selected."
-
- config_subparsers = config.add_subparsers(help="operation type")
-
- # GET operation
- get_op = config_subparsers.add_parser("get", help="Get current configuration from the resolver.")
- get_op.set_defaults(operation=Operations.GET, format=DataFormat.YAML)
-
- get_op.add_argument(
- "-p",
- "--path",
- help=path_help,
- action="store",
- type=str,
- default="",
- )
- get_op.add_argument(
- "file",
- help="Optional, path to the file where to save exported configuration data."
- " If not specified, data will be printed.",
- type=str,
- nargs="?",
- )
-
- get_formats = get_op.add_mutually_exclusive_group()
- get_formats.add_argument(
- "--json",
- help="Get configuration data in JSON format.",
- const=DataFormat.JSON,
- action="store_const",
- dest="format",
- )
- get_formats.add_argument(
- "--yaml",
- help="Get configuration data in YAML format, default.",
- const=DataFormat.YAML,
- action="store_const",
- dest="format",
- )
-
- # SET operation
- set_op = config_subparsers.add_parser("set", help="Set new configuration for the resolver.")
- set_op.set_defaults(operation=Operations.SET)
-
- set_op.add_argument(
- "-p",
- "--path",
- help=path_help,
- action="store",
- type=str,
- default="",
- )
-
- value_or_file = set_op.add_mutually_exclusive_group()
- value_or_file.add_argument(
- "file",
- help="Optional, path to file with new configuration.",
- type=str,
- nargs="?",
- )
- value_or_file.add_argument(
- "value",
- help="Optional, new configuration value.",
- type=str,
- nargs="?",
- )
-
- # DELETE operation
- delete_op = config_subparsers.add_parser(
- "delete", help="Delete given configuration property or list item at the given index."
- )
- delete_op.set_defaults(operation=Operations.DELETE)
- delete_op.add_argument(
- "-p",
- "--path",
- help=path_help,
- action="store",
- type=str,
- default="",
- )
- return config, ConfigCommand
-
- @staticmethod
- def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords:
- nargs = len(args)
-
- if nargs > 1 and args[-2] in ["-p", "--path"]:
- words: CompWords = {}
- words[COMP_NOSPACE] = None
-
- path = args[-1]
- path_nodes = path.split("/")
-
- prefix = ""
- properties = KresConfig.json_schema()["properties"]
- is_list = False
- for i, node in enumerate(path_nodes):
- # first node is empty string
- if i == 0:
- continue
-
- if node in properties:
- is_list = False
- if "properties" in properties[node]:
- properties = properties[node]["properties"]
- prefix += f"/{node}"
- continue
- if "items" in properties[node]:
- properties = properties[node]["items"]["properties"]
- prefix += f"/{node}"
- is_list = True
- continue
- del words[COMP_NOSPACE]
- break
- if is_list and node.isnumeric():
- prefix += f"/{node}"
- continue
-
- for key in properties.keys():
- words[f"{prefix}/{key}"] = properties[key]["description"]
-
- return words
-
- return comp_get_words(args, parser)
-
- def run(self, args: CommandArgs) -> None:
- if not self.operation:
- args.subparser.print_help()
- sys.exit()
-
- new_config = None
- path = f"v1/config{self.path}"
- method = operation_to_method(self.operation)
-
- if self.operation == Operations.SET:
- if self.file:
- try:
- with open(self.file, "r") as f:
- new_config = f.read()
- except FileNotFoundError:
- new_config = self.file
- else:
- # use STDIN also when file is not specified
- new_config = input("Type new configuration: ")
-
- body = DataFormat.JSON.dict_dump(try_to_parse(new_config)) if new_config else None
- response = request(args.socket, method, path, body)
-
- if response.status != 200:
- print(response, file=sys.stderr)
- sys.exit(1)
-
- if self.operation == Operations.GET and self.file:
- with open(self.file, "w") as f:
- f.write(self.format.dict_dump(parse_json(response.body), indent=4))
- print(f"saved to: {self.file}")
- elif response.body:
- print(self.format.dict_dump(parse_json(response.body), indent=4))
+++ /dev/null
-# noqa: INP001
-import argparse
-import sys
-from pathlib import Path
-from typing import Any, Dict, List, Optional, Tuple, Type
-
-from knot_resolver.client.command import Command, CommandArgs, CompWords, comp_get_words, register_command
-from knot_resolver.constants import CONFIG_FILE
-from knot_resolver.datamodel import KresConfig
-from knot_resolver.datamodel.globals import Context, reset_global_validation_context, set_global_validation_context
-from knot_resolver.utils.modeling import try_to_parse
-from knot_resolver.utils.modeling.exceptions import DataParsingError, DataValidationError
-from knot_resolver.utils.modeling.parsing import data_combine
-
-
-@register_command
-class ConvertCommand(Command):
- def __init__(self, namespace: argparse.Namespace) -> None:
- super().__init__(namespace)
- self.input_file: str = namespace.input_file
- self.output_file: Optional[str] = namespace.output_file
- self.strict: bool = namespace.strict
- self.type: str = namespace.type
-
- @staticmethod
- def register_args_subparser(
- subparser: "argparse._SubParsersAction[argparse.ArgumentParser]",
- ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]:
- convert = subparser.add_parser("convert", help="Converts JSON or YAML configuration to Lua script.")
- convert.set_defaults(strict=False)
- convert.add_argument(
- "--strict",
- help="Enable strict rules during validation, e.g. path/file existence and permissions.",
- action="store_true",
- dest="strict",
- )
- convert.add_argument(
- "--type", help="The type of Lua script to generate", choices=["worker", "policy-loader"], default="worker"
- )
- convert.add_argument(
- "-o",
- "--output",
- type=str,
- nargs="?",
- help="Optional, output file for converted configuration in Lua script."
- " If not specified, converted configuration is printed.",
- dest="output_file",
- default=None,
- )
- convert.add_argument(
- "input_file",
- type=str,
- nargs="*",
- help="File or combination of files with configuration in YAML or JSON format.",
- default=[CONFIG_FILE],
- )
- return convert, ConvertCommand
-
- @staticmethod
- def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords:
- return comp_get_words(args, parser)
-
- def run(self, args: CommandArgs) -> None:
- data: Dict[str, Any] = {}
- try:
- for file in self.input_file:
- with open(file, "r") as f:
- raw = f.read()
- parsed = try_to_parse(raw)
- data = data_combine(data, parsed)
-
- set_global_validation_context(Context(Path(Path(self.input_file[0]).parent), self.strict))
- if self.type == "worker":
- lua = KresConfig(data).render_kresd_lua()
- elif self.type == "policy-loader":
- lua = KresConfig(data).render_policy_loader_lua()
- else:
- raise ValueError(f"Invalid self.type={self.type}")
- reset_global_validation_context()
- except (DataParsingError, DataValidationError) as e:
- print(e, file=sys.stderr)
- sys.exit(1)
-
- if self.output_file:
- with open(self.output_file, "w") as f:
- f.write(lua)
- else:
- print(lua)
+++ /dev/null
-# noqa: INP001
-import argparse
-import json
-import os
-import sys
-from pathlib import Path
-from typing import List, Optional, Tuple, Type
-
-from knot_resolver.client.command import Command, CommandArgs, CompWords, comp_get_words, register_command
-from knot_resolver.utils import which
-from knot_resolver.utils.requests import request
-
-PROCS_TYPE = List
-
-
-@register_command
-class DebugCommand(Command):
- def __init__(self, namespace: argparse.Namespace) -> None:
- self.proc_type: Optional[str] = namespace.proc_type
- self.sudo: bool = namespace.sudo
- self.gdb: str = namespace.gdb
- self.print_only: bool = namespace.print_only
- self.gdb_args: List[str] = namespace.extra if namespace.extra is not None else []
- super().__init__(namespace)
-
- @staticmethod
- def register_args_subparser(
- subparser: "argparse._SubParsersAction[argparse.ArgumentParser]",
- ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]:
- debug = subparser.add_parser(
- "debug",
- help="Run GDB on the manager's subprocesses",
- )
- debug.add_argument(
- "--sudo",
- dest="sudo",
- help="Run GDB with sudo",
- action="store_true",
- default=False,
- )
- debug.add_argument(
- "--gdb",
- help="Custom GDB executable (may be a command on PATH, or an absolute path)",
- type=str,
- default=None,
- )
- debug.add_argument(
- "--print-only",
- help="Prints the GDB command line into stderr as a Python array, does not execute GDB",
- action="store_true",
- default=False,
- )
- debug.add_argument(
- "proc_type",
- help="Optional, the type of process to debug. May be 'kresd', 'gc', or 'all'.",
- choices=["kresd", "gc", "all"],
- type=str,
- nargs="?",
- default="kresd",
- )
- return debug, DebugCommand
-
- @staticmethod
- def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords:
- return comp_get_words(args, parser)
-
- def run(self, args: CommandArgs) -> None: # noqa: C901, PLR0912, PLR0915
- if self.gdb is None:
- try:
- gdb_cmd = str(which.which("gdb"))
- except RuntimeError:
- print("Could not find 'gdb' in $PATH. Is GDB installed?", file=sys.stderr)
- sys.exit(1)
- elif "/" not in self.gdb:
- try:
- gdb_cmd = str(which.which(self.gdb))
- except RuntimeError:
- print(f"Could not find '{self.gdb}' in $PATH.", file=sys.stderr)
- sys.exit(1)
- else:
- gdb_cmd_path = Path(self.gdb).absolute()
- if not gdb_cmd_path.exists():
- print(f"Could not find '{self.gdb}'.", file=sys.stderr)
- sys.exit(1)
- gdb_cmd = str(gdb_cmd_path)
-
- response = request(args.socket, "GET", f"processes/{self.proc_type}")
- if response.status != 200:
- print(response, file=sys.stderr)
- sys.exit(1)
-
- procs = json.loads(response.body)
- if not isinstance(procs, PROCS_TYPE):
- print(
- f"Unexpected response type '{type(procs).__name__}' from manager. Expected '{PROCS_TYPE.__name__}'",
- file=sys.stderr,
- )
- sys.exit(1)
- if len(procs) == 0:
- print(
- f"There are no processes of type '{self.proc_type}' available to debug",
- file=sys.stderr,
- )
-
- exec_args = []
-
- # Put `sudo --` at the beginning of the command.
- if self.sudo:
- try:
- sudo_cmd = str(which.which("sudo"))
- except RuntimeError:
- print("Could not find 'sudo' in $PATH. Is sudo installed?", file=sys.stderr)
- sys.exit(1)
- exec_args.extend([sudo_cmd, "--"])
-
- # Attach GDB to processes - the processes are attached using the `add-inferior` and `attach` GDB
- # commands. This way, we can debug multiple processes.
- exec_args.extend([gdb_cmd, "--"])
- exec_args.extend(["-init-eval-command", "set detach-on-fork off"])
- exec_args.extend(["-init-eval-command", "set schedule-multiple on"])
- exec_args.extend(["-init-eval-command", f'attach {procs[0]["pid"]}'])
- inferior = 2
- for proc in procs[1:]:
- exec_args.extend(["-init-eval-command", "add-inferior"])
- exec_args.extend(["-init-eval-command", f"inferior {inferior}"])
- exec_args.extend(["-init-eval-command", f'attach {proc["pid"]}'])
- inferior += 1
-
- num_inferiors = inferior - 1
- if num_inferiors > 1:
- # Now we switch back to the first process and add additional provided GDB arguments.
- exec_args.extend(["-init-eval-command", "inferior 1"])
- exec_args.extend(
- [
- "-init-eval-command",
- "echo \\n\\nYou are now debugging multiple Knot Resolver processes. To switch between "
- "them, use the 'inferior <n>' command, where <n> is an integer from 1 to "
- f"{num_inferiors}.\\n\\n",
- ]
- )
- exec_args.extend(self.gdb_args)
-
- if self.print_only:
- print(f"{exec_args}")
- else:
- os.execl(*exec_args)
+++ /dev/null
-# noqa: INP001
-import argparse
-from typing import List, Tuple, Type
-
-from knot_resolver.client.command import Command, CommandArgs, CompWords, comp_get_words, register_command
-
-
-@register_command
-class HelpCommand(Command):
- def __init__(self, namespace: argparse.Namespace) -> None:
- super().__init__(namespace)
-
- def run(self, args: CommandArgs) -> None:
- args.parser.print_help()
-
- @staticmethod
- def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords:
- return comp_get_words(args, parser)
-
- @staticmethod
- def register_args_subparser(
- subparser: "argparse._SubParsersAction[argparse.ArgumentParser]",
- ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]:
- stop = subparser.add_parser("help", help="show this help message and exit")
- return stop, HelpCommand
+++ /dev/null
-# noqa: INP001
-import argparse
-import sys
-from typing import List, Optional, Tuple, Type
-
-from knot_resolver.client.command import Command, CommandArgs, CompWords, comp_get_words, register_command
-from knot_resolver.utils.modeling.parsing import DataFormat, parse_json
-from knot_resolver.utils.requests import request
-
-
-@register_command
-class MetricsCommand(Command):
- def __init__(self, namespace: argparse.Namespace) -> None:
- self.file: Optional[str] = namespace.file
- self.prometheus: bool = namespace.prometheus
-
- super().__init__(namespace)
-
- @staticmethod
- def register_args_subparser(
- subparser: "argparse._SubParsersAction[argparse.ArgumentParser]",
- ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]:
- metrics = subparser.add_parser(
- "metrics",
- help="Get aggregated metrics from the running resolver"
- " in JSON format (default) or optionally in Prometheus format."
- "\nThe 'prometheus-client' Python package needs to be installed if you wish to use the Prometheus format."
- "\nRequires a connection to the management HTTP API.",
- )
-
- metrics.add_argument(
- "--prometheus",
- help="Get metrics in Prometheus format if dependencies are met in the resolver.",
- action="store_true",
- default=False,
- )
-
- metrics.add_argument(
- "file",
- help="Optional. The file into which metrics will be exported."
- "\nIf not specified, the metrics are printed into stdout.",
- nargs="?",
- default=None,
- )
- return metrics, MetricsCommand
-
- @staticmethod
- def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords:
- return comp_get_words(args, parser)
-
- def run(self, args: CommandArgs) -> None:
- response = request(args.socket, "GET", "metrics/prometheus" if self.prometheus else "metrics/json")
-
- if response.status == 200:
- if self.prometheus:
- metrics = response.body
- else:
- metrics = DataFormat.JSON.dict_dump(parse_json(response.body), indent=4)
-
- if self.file:
- with open(self.file, "w") as f:
- f.write(metrics)
- else:
- print(metrics)
- else:
- print(response, file=sys.stderr)
- if self.prometheus and response.status == 404:
- print("Prometheus is unavailable due to missing optional dependencies", file=sys.stderr)
- sys.exit(1)
+++ /dev/null
-# noqa: INP001
-import argparse
-import copy
-import sys
-from typing import Any, Dict, List, Optional, Tuple, Type
-
-from knot_resolver.client.command import Command, CommandArgs, CompWords, comp_get_words, register_command
-from knot_resolver.constants import VERSION
-from knot_resolver.utils.modeling.exceptions import DataParsingError
-from knot_resolver.utils.modeling.parsing import DataFormat, try_to_parse
-
-
-def _remove(config: Dict[str, Any], path: str) -> Optional[Any]:
- keys = path.split("/")
- last = keys[-1]
-
- current = config
- for key in keys[1:-1]:
- if key in current:
- current = current[key]
- else:
- return None
- if isinstance(current, dict) and last in current:
- val = copy.copy(current[last])
- del current[last]
- print(f"removed {path}")
- return val
- return None
-
-
-def _add(config: Dict[str, Any], path: str, val: Any, rewrite: bool = False) -> None:
- keys = path.split("/")
- last = keys[-1]
-
- current = config
- for key in keys[1:-1]:
- if key not in current or key in current and not isinstance(current[key], dict):
- current[key] = {}
- current = current[key]
-
- if rewrite or last not in current:
- current[last] = val
- print(f"added {path}")
-
-
-def _rename(config: Dict[str, Any], path: str, new_path: str) -> None:
- val: Optional[Any] = _remove(config, path)
- if val:
- _add(config, new_path, val)
-
-
-@register_command
-class MigrateCommand(Command):
- def __init__(self, namespace: argparse.Namespace) -> None:
- super().__init__(namespace)
- self.input_file: str = namespace.input_file
- self.output_file: Optional[str] = namespace.output_file
- self.output_format: DataFormat = namespace.output_format
-
- @staticmethod
- def register_args_subparser(
- subparser: "argparse._SubParsersAction[argparse.ArgumentParser]",
- ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]:
- migrate = subparser.add_parser("migrate", help="Migrates JSON or YAML configuration to the newer version.")
-
- migrate.set_defaults(output_format=DataFormat.YAML)
- output_formats = migrate.add_mutually_exclusive_group()
- output_formats.add_argument(
- "--json",
- help="Get migrated configuration data in JSON format.",
- const=DataFormat.JSON,
- action="store_const",
- dest="output_format",
- )
- output_formats.add_argument(
- "--yaml",
- help="Get migrated configuration data in YAML format, default.",
- const=DataFormat.YAML,
- action="store_const",
- dest="output_format",
- )
-
- migrate.add_argument(
- "input_file",
- type=str,
- help="File with configuration in YAML or JSON format.",
- )
- migrate.add_argument(
- "output_file",
- type=str,
- nargs="?",
- help="Optional, output file for migrated configuration in desired output format."
- " If not specified, migrated configuration is printed.",
- default=None,
- )
- return migrate, MigrateCommand
-
- @staticmethod
- def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords:
- return comp_get_words(args, parser)
-
- def run(self, args: CommandArgs) -> None: # noqa: C901, PLR0912, PLR0915
- with open(self.input_file, "r") as f:
- data = f.read()
-
- try:
- parsed = try_to_parse(data)
- except DataParsingError as e:
- print(e, file=sys.stderr)
- sys.exit(1)
-
- new = parsed.copy()
-
- # REMOVE
- _remove(new, "/dnssec/refresh-time")
- _remove(new, "/dnssec/hold-down-time")
- _remove(new, "/dnssec/time-skew-detection")
- _remove(new, "/dnssec/keep-removed")
- _remove(new, "/local-data/root-fallback-addresses")
- _remove(new, "/local-data/root-fallback-addresses-files")
- _remove(new, "/logging/debugging")
- _remove(new, "/max-workers")
- _remove(new, "/network/tls/auto-discovery")
- _remove(new, "/webmgmt")
-
- # RENAME/MOVE
- cache_key = "cache"
- if cache_key in new:
- gc_key = "garbage-collector"
- if gc_key in new[cache_key]:
- gc = new[cache_key][gc_key]
- if gc is False:
- _add(new, "/cache/garbage-collector/enable", False)
- else:
- _add(new, "/cache/garbage-collector/enable", True)
- prefetch_key = "prefetch"
- if prefetch_key in new[cache_key]:
- prediction_key = "prediction"
- if prediction_key in new[cache_key][prefetch_key]:
- prediction = new[cache_key][prefetch_key][prediction_key]
- if prediction is None:
- _add(new, "/cache/prefetch/prediction/enable", False)
- else:
- _add(new, "/cache/prefetch/prediction/enable", True)
- _rename(new, "/defer/enabled", "/defer/enable")
- dns64_key = "dns64"
- if dns64_key in new:
- if new[dns64_key] is False:
- _add(new, "/dns64/enable", False, rewrite=True)
- else:
- _add(new, "/dns64/enable", True, rewrite=True)
- _rename(new, "/dns64/rev-ttl", "/dns64/reverse-ttl")
- dnssec_key = "dnssec"
- if dnssec_key in new:
- if new[dnssec_key] is False:
- _add(new, "/dnssec/enable", False, rewrite=True)
- else:
- # by default the DNSSEC is enabled
- pass
- _rename(new, "/dnssec/trust-anchor-sentinel", "/dnssec/sentinel")
- _rename(new, "/dnssec/trust-anchor-signal-query", "/dnssec/signal-query")
- logging_key = "logging"
- if logging_key in new:
- dnstap_key = "dnstap"
- if dnstap_key in new[logging_key]:
- dnstap = new[logging_key][dnstap_key]
- if dnstap is None:
- _add(new, "/logging/dnstap/enable", False)
- else:
- _add(new, "/logging/dnstap/enable", True)
-
- _rename(new, "/logging/dnssec-bogus", "/dnssec/log-bogus")
- _rename(new, "/monitoring/enabled", "/monitoring/metrics")
- monitoring_key = "monitoring"
- if monitoring_key in new:
- graphite_key = "graphite"
- if graphite_key in new[monitoring_key]:
- graphite = new[monitoring_key][graphite_key]
- if graphite is False:
- _add(new, "/monitoring/graphite/enable", False)
- else:
- _add(new, "/monitoring/graphite/enable", True)
- network_key = "network"
- if network_key in new:
- proxy_protocol_key = "proxy-protocol"
- if proxy_protocol_key in new[network_key]:
- proxy_protocol = new[network_key][proxy_protocol_key]
- if proxy_protocol is None:
- _add(new, "/network/proxy-protocol/enable", False)
- else:
- _add(new, "/network/proxy-protocol/enable", True)
- _rename(new, "/network/tls/files-watchdog", "/network/tls/watchdog")
- rate_limiting_key = "rate-limiting"
- if rate_limiting_key in new:
- _add(new, "/rate-limiting/enable", True)
-
- # remove empty dicts
- new = {k: v for k, v in new.items() if v}
-
- dumped = self.output_format.dict_dump(new)
- if self.output_file:
- with open(self.output_file, "w") as f:
- f.write(dumped)
- else:
- print(f"\nNew migrated configuration (v{VERSION}):")
- print("---")
- print(dumped)
+++ /dev/null
-# noqa: INP001
-import argparse
-import json
-import sys
-from typing import Iterable, List, Optional, Tuple, Type
-
-from knot_resolver.client.command import Command, CommandArgs, CompWords, register_command
-from knot_resolver.utils.requests import request
-
-PROCESSES_TYPE = Iterable
-
-
-@register_command
-class PidsCommand(Command):
- def __init__(self, namespace: argparse.Namespace) -> None:
- self.proc_type: Optional[str] = namespace.proc_type
- self.json: int = namespace.json
-
- super().__init__(namespace)
-
- @staticmethod
- def register_args_subparser(
- subparser: "argparse._SubParsersAction[argparse.ArgumentParser]",
- ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]:
- pids = subparser.add_parser("pids", help="List the PIDs of the Manager's subprocesses")
- pids.add_argument(
- "proc_type",
- help="Optional, the type of process to query. May be 'kresd', 'gc', or 'all' (default).",
- nargs="?",
- default="all",
- )
- pids.add_argument(
- "--json",
- help="Optional, makes the output more verbose, in JSON.",
- action="store_true",
- default=False,
- )
- return pids, PidsCommand
-
- @staticmethod
- def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords:
- return {}
-
- def run(self, args: CommandArgs) -> None:
- response = request(args.socket, "GET", f"processes/{self.proc_type}")
-
- if response.status == 200:
- processes = json.loads(response.body)
- if isinstance(processes, PROCESSES_TYPE):
- if self.json:
- print(json.dumps(processes, indent=2))
- else:
- for p in processes:
- print(p["pid"])
-
- else:
- print(
- f"Unexpected response type '{type(processes).__name__}'"
- f" from manager. Expected '{PROCESSES_TYPE.__name__}'",
- file=sys.stderr,
- )
- sys.exit(1)
- else:
- print(response, file=sys.stderr)
- sys.exit(1)
+++ /dev/null
-# noqa: INP001
-import argparse
-import sys
-from typing import List, Tuple, Type
-
-from knot_resolver.client.command import Command, CommandArgs, CompWords, register_command
-from knot_resolver.utils.requests import request
-
-
-@register_command
-class ReloadCommand(Command):
- def __init__(self, namespace: argparse.Namespace) -> None:
- super().__init__(namespace)
- self.force: bool = namespace.force
-
- @staticmethod
- def register_args_subparser(
- subparser: "argparse._SubParsersAction[argparse.ArgumentParser]",
- ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]:
- reload = subparser.add_parser(
- "reload",
- help="Tells the resolver to reload YAML configuration file."
- " Old processes are replaced by new ones (with updated configuration) using rolling restarts."
- " So there will be no DNS service unavailability during reload operation.",
- )
- reload.add_argument(
- "--force",
- help="Force a reload, even if the configuration hasn't changed.",
- action="store_true",
- default=False,
- )
- return reload, ReloadCommand
-
- @staticmethod
- def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords:
- return {}
-
- def run(self, args: CommandArgs) -> None:
- response = request(args.socket, "POST", "reload/force" if self.force else "reload")
-
- if response.status != 200:
- print(response, file=sys.stderr)
- sys.exit(1)
+++ /dev/null
-# noqa: INP001
-import argparse
-import json
-import sys
-from typing import List, Optional, Tuple, Type
-
-from knot_resolver.client.command import Command, CommandArgs, CompWords, comp_get_words, register_command
-from knot_resolver.datamodel import kres_config_json_schema
-from knot_resolver.utils.requests import request
-
-
-@register_command
-class SchemaCommand(Command):
- def __init__(self, namespace: argparse.Namespace) -> None:
- super().__init__(namespace)
- self.live: bool = namespace.live
- self.file: Optional[str] = namespace.file
-
- @staticmethod
- def register_args_subparser(
- subparser: "argparse._SubParsersAction[argparse.ArgumentParser]",
- ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]:
- schema = subparser.add_parser(
- "schema", help="Shows JSON-schema repersentation of the Knot Resolver's configuration."
- )
- schema.add_argument(
- "-l",
- "--live",
- help="Get configuration JSON-schema from the running resolver. Requires connection to the management API.",
- action="store_true",
- default=False,
- )
- schema.add_argument("file", help="Optional, file where to export JSON-schema.", nargs="?", default=None)
-
- return schema, SchemaCommand
-
- @staticmethod
- def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords:
- return comp_get_words(args, parser)
-
- def run(self, args: CommandArgs) -> None:
- if self.live:
- response = request(args.socket, "GET", "schema")
- if response.status != 200:
- print(response, file=sys.stderr)
- sys.exit(1)
- schema = response.body
- else:
- schema = json.dumps(kres_config_json_schema(), indent=4)
-
- if self.file:
- with open(self.file, "w") as f:
- f.write(schema)
- else:
- print(schema)
+++ /dev/null
-# noqa: INP001
-import argparse
-import sys
-from typing import List, Tuple, Type
-
-from knot_resolver.client.command import Command, CommandArgs, CompWords, register_command
-from knot_resolver.utils.requests import request
-
-
-@register_command
-class StopCommand(Command):
- def __init__(self, namespace: argparse.Namespace) -> None:
- super().__init__(namespace)
-
- @staticmethod
- def register_args_subparser(
- subparser: "argparse._SubParsersAction[argparse.ArgumentParser]",
- ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]:
- stop = subparser.add_parser(
- "stop", help="Tells the resolver to shutdown everthing. No process will run after this command."
- )
- return stop, StopCommand
-
- def run(self, args: CommandArgs) -> None:
- response = request(args.socket, "POST", "stop")
-
- if response.status != 200:
- print(response, file=sys.stderr)
- sys.exit(1)
-
- @staticmethod
- def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords:
- return {}
+++ /dev/null
-# noqa: INP001
-import argparse
-import sys
-from pathlib import Path
-from typing import Any, Dict, List, Tuple, Type
-
-from knot_resolver.client.command import Command, CommandArgs, CompWords, comp_get_words, register_command
-from knot_resolver.constants import CONFIG_FILE
-from knot_resolver.datamodel import KresConfig
-from knot_resolver.datamodel.globals import Context, reset_global_validation_context, set_global_validation_context
-from knot_resolver.utils.modeling import try_to_parse
-from knot_resolver.utils.modeling.exceptions import DataParsingError, DataValidationError
-from knot_resolver.utils.modeling.parsing import data_combine
-
-
-@register_command
-class ValidateCommand(Command):
- def __init__(self, namespace: argparse.Namespace) -> None:
- super().__init__(namespace)
- self.input_file: str = namespace.input_file
- self.strict: bool = namespace.strict
-
- @staticmethod
- def register_args_subparser(
- subparser: "argparse._SubParsersAction[argparse.ArgumentParser]",
- ) -> Tuple[argparse.ArgumentParser, "Type[Command]"]:
- validate = subparser.add_parser("validate", help="Validates configuration in JSON or YAML format.")
- validate.set_defaults(strict=False)
- validate.add_argument(
- "--strict",
- help="Enable strict rules during validation, e.g. paths/files existence and permissions.",
- action="store_true",
- dest="strict",
- )
- validate.add_argument(
- "input_file",
- type=str,
- nargs="*",
- help="File or combination of files with the declarative configuration in YAML or JSON format.",
- default=[CONFIG_FILE],
- )
-
- return validate, ValidateCommand
-
- @staticmethod
- def completion(args: List[str], parser: argparse.ArgumentParser) -> CompWords:
- return comp_get_words(args, parser)
-
- def run(self, args: CommandArgs) -> None:
- data: Dict[str, Any] = {}
- try:
- for file in self.input_file:
- with open(file, "r") as f:
- raw = f.read()
- parsed = try_to_parse(raw)
- data = data_combine(data, parsed)
-
- set_global_validation_context(Context(Path(self.input_file[0]).parent, self.strict))
- KresConfig(data)
- reset_global_validation_context()
- except (DataParsingError, DataValidationError) as e:
- print(e, file=sys.stderr)
- sys.exit(1)
- if not self.strict:
- print(
- "Basic validation was successful."
- "\nIf you want more strict validation, you can use the '--strict' switch."
- "\nDuring strict validation, the existence and access rights of paths are also checked."
- "\n\nHowever, if you are using an additional file system permission control mechanism,"
- "\nsuch as access control lists (ACLs), this validation will likely fail."
- "\nThis is because the validation runs under a different user/group than the resolver itself"
- "\nand attempts to access the configured paths directly."
- )
+++ /dev/null
-import argparse
-import importlib
-import os
-import sys
-
-from knot_resolver.constants import VERSION
-
-from .client import KRES_CLIENT_NAME, KresClient
-from .command import install_commands_parsers
-
-
-def auto_import_commands() -> None:
- prefix = f"{'.'.join(__name__.split('.')[:-1])}.commands."
- for module_name in os.listdir(os.path.dirname(__file__) + "/commands"):
- if module_name[-3:] != ".py":
- continue
- importlib.import_module(f"{prefix}{module_name[:-3]}")
-
-
-def create_main_argument_parser() -> argparse.ArgumentParser:
- parser = argparse.ArgumentParser(
- KRES_CLIENT_NAME,
- description="Knot Resolver command-line utility that serves as a client for"
- " communicating with the Knot Resolver management API."
- " The utility also provides tools to work with the resolver's"
- " declarative configuration (validate, convert, ...).",
- )
- parser.add_argument(
- "-V",
- "--version",
- action="version",
- version=VERSION,
- help="Get version",
- )
- # parser.add_argument(
- # "-i",
- # "--interactive",
- # action="store_true",
- # help="Use the utility in interactive mode.",
- # default=False,
- # required=False,
- # )
- config_or_socket = parser.add_mutually_exclusive_group()
- config_or_socket.add_argument(
- "-s",
- "--socket",
- action="store",
- type=str,
- help="Optional, path to the resolver's management API, unix-domain socket, or network interface."
- " Cannot be used together with '--config'.",
- default=[],
- nargs=1,
- required=False,
- )
- config_or_socket.add_argument(
- "-c",
- "--config",
- action="store",
- type=str,
- help="Optional, path to the resolver's declarative configuration to retrieve the management API configuration."
- " Cannot be used together with '--socket'.",
- default=[],
- nargs=1,
- required=False,
- )
- return parser
-
-
-def main() -> None:
- auto_import_commands()
- parser = create_main_argument_parser()
- install_commands_parsers(parser)
-
- # TODO: This is broken with unpatched versions of poethepoet, because they drop the `--` pseudo-argument.
- # Patch submitted at <https://github.com/nat-n/poethepoet/pull/163>.
- try:
- pa_index = sys.argv.index("--", 1)
- argv_to_parse = sys.argv[1:pa_index]
- argv_extra = sys.argv[(pa_index + 1) :]
- except ValueError:
- argv_to_parse = sys.argv[1:]
- argv_extra = None
-
- namespace = parser.parse_args(argv_to_parse)
- if hasattr(namespace, "extra"):
- raise TypeError("'extra' is already an attribute - this is disallowed for commands")
- namespace.extra = argv_extra
-
- client = KresClient(namespace, parser)
- client.execute()
-
- # if namespace.interactive or len(vars(namespace)) == 2:
- # client.interactive()
- # else:
- # client.execute()
+++ /dev/null
-"""
-The module contains autodetection logic for available controllers.
-
-Because we have to catch errors from imports, they are located in functions which are invoked at the end of this file.
-We supported multiple subprocess controllers while developing it. It now all converged onto just supervisord.
-The interface however remains so that different controllers can be added in the future.
-"""
-
-import asyncio
-import logging
-from typing import List, Optional
-
-from knot_resolver.controller.interface import SubprocessController
-from knot_resolver.datamodel.config_schema import KresConfig
-
-logger = logging.getLogger(__name__)
-
-"""
-List of all subprocess controllers that are available in order of priority.
-It is filled dynamically based on available modules that do not fail to import.
-"""
-_registered_controllers: List[SubprocessController] = []
-
-
-def try_supervisord() -> None:
- """Attempt to load supervisord controllers."""
- try:
- from knot_resolver.controller.supervisord import SupervisordSubprocessController
-
- _registered_controllers.append(SupervisordSubprocessController())
- except ImportError:
- logger.error("Failed to import modules related to supervisord service manager", exc_info=True)
-
-
-async def get_best_controller_implementation(config: KresConfig) -> SubprocessController:
- logger.info("Starting service manager auto-selection...")
-
- if len(_registered_controllers) == 0:
- logger.error("No controllers are available! Did you install all dependencies?")
- raise LookupError("No service managers available!")
-
- # check all controllers concurrently
- res = await asyncio.gather(*(cont.is_controller_available(config) for cont in _registered_controllers))
- logger.info(
- "Available subprocess controllers are %s",
- str(tuple((str(c) for r, c in zip(res, _registered_controllers) if r))),
- )
-
- # take the first one on the list which is available
- for avail, controller in zip(res, _registered_controllers):
- if avail:
- logger.info("Selected controller '%s'", str(controller))
- return controller
-
- # or fail
- raise LookupError("Can't find any available service manager!")
-
-
-def list_controller_names() -> List[str]:
- """
- Return a list of names of registered controllers.
-
- The listed controllers are not necessary functional.
- """
- return [str(controller) for controller in sorted(_registered_controllers, key=str)]
-
-
-async def get_controller_by_name(config: KresConfig, name: str) -> SubprocessController:
- logger.debug("Subprocess controller selected manualy by the user, testing feasibility...")
-
- controller: Optional[SubprocessController] = None
- for c in sorted(_registered_controllers, key=str):
- if str(c).startswith(name):
- if str(c) != name:
- logger.debug("Assuming '%s' is a shortcut for '%s'", name, str(c))
- controller = c
- break
-
- if controller is None:
- logger.error("Subprocess controller with name '%s' was not found", name)
- raise LookupError(f"No subprocess controller named '{name}' found")
-
- if await controller.is_controller_available(config):
- logger.info("Selected controller '%s'", str(controller))
- return controller
- raise LookupError("The selected subprocess controller is not available for use on this system.")
-
-
-# run the imports on module load
-try_supervisord()
+++ /dev/null
-from typing import List
-
-from knot_resolver import KresBaseError
-
-
-class KresSubprocessControllerError(KresBaseError):
- """Class for errors that are raised in the controller module."""
-
-
-class KresSubprocessControllerExec(Exception): # noqa: N818
- """
- Custom non-error exception that indicates the need for exec().
-
- Raised by the controller (supervisord) and caught by the controlled process (manager).
- The exception says that the process needs to perform a re-exec during startup.
- This ensures that the process runs under the controller (supervisord) in a process tree hierarchy.
- """
-
- def __init__(self, exec_args: List[str], *args: object) -> None:
- self.exec_args = exec_args
- super().__init__(*args)
+++ /dev/null
-import asyncio
-import itertools
-import json
-import logging
-import struct
-import sys
-from abc import ABC, abstractmethod # pylint: disable=no-name-in-module
-from enum import Enum, auto
-from pathlib import Path
-from typing import Dict, Iterable, Optional, Type, TypeVar
-from weakref import WeakValueDictionary
-
-from knot_resolver.controller.exceptions import KresSubprocessControllerError
-from knot_resolver.controller.registered_workers import register_worker, unregister_worker
-from knot_resolver.datamodel.config_schema import KresConfig
-from knot_resolver.manager.constants import kresd_config_file, policy_loader_config_file
-
-logger = logging.getLogger(__name__)
-
-
-class SubprocessType(Enum):
- KRESD = auto()
- POLICY_LOADER = auto()
- GC = auto()
-
-
-class SubprocessStatus(Enum):
- RUNNING = auto()
- FATAL = auto()
- EXITED = auto()
- UNKNOWN = auto()
-
-
-T = TypeVar("T", bound="KresID")
-
-
-class KresID:
- """ID object used for identifying subprocesses."""
-
- _used: "Dict[SubprocessType, WeakValueDictionary[int, KresID]]" = {k: WeakValueDictionary() for k in SubprocessType}
-
- @classmethod
- def alloc(cls: Type[T], typ: SubprocessType) -> T:
- # find free ID closest to zero
- for i in itertools.count(start=0, step=1):
- if i not in cls._used[typ]:
- return cls.new(typ, i)
-
- raise RuntimeError("Reached an end of an infinite loop. How?")
-
- @classmethod
- def new(cls: "Type[T]", typ: SubprocessType, n: int) -> "T":
- if n in cls._used[typ]:
- # Ignoring typing here, because I can't find a way how to make the _used dict
- # typed based on subclass. I am not even sure that it's different between subclasses,
- # it's probably still the same dict. But we don't really care about it
- return cls._used[typ][n] # type: ignore[return-value]
- val = cls(typ, n, _i_know_what_i_am_doing=True)
- cls._used[typ][n] = val
- return val
-
- def __init__(self, typ: SubprocessType, n: int, _i_know_what_i_am_doing: bool = False) -> None:
- if not _i_know_what_i_am_doing:
- raise RuntimeError("Don't do this. You seem to have no idea what it does")
-
- self._id = n
- self._type = typ
-
- @property
- def subprocess_type(self) -> SubprocessType:
- return self._type
-
- def __repr__(self) -> str:
- return f"KresID({self})"
-
- def __hash__(self) -> int:
- return self._id
-
- def __eq__(self, o: object) -> bool:
- if isinstance(o, KresID):
- return self._type == o._type and self._id == o._id
- return False
-
- def __str__(self) -> str:
- """Return string representation of the ID usable directly in the underlying service supervisor."""
- raise NotImplementedError()
-
- @staticmethod
- def from_string(val: str) -> "KresID":
- """Inverse of __str__."""
- raise NotImplementedError()
-
- def __int__(self) -> int:
- return self._id
-
-
-class Subprocess(ABC):
- """One SubprocessInstance corresponds to one manager's subprocess."""
-
- def __init__(self, config: KresConfig, kresid: KresID) -> None:
- self._id = kresid
- self._config = config
- self._registered_worker: bool = False
- self._pid: Optional[int] = None
-
- self._config_file: Optional[Path] = None
- if self.type is SubprocessType.KRESD:
- self._config_file = kresd_config_file(self._config, self.id)
- elif self.type is SubprocessType.POLICY_LOADER:
- self._config_file = policy_loader_config_file(self._config)
-
- def _render_lua(self) -> Optional[str]:
- if self.type is SubprocessType.KRESD:
- return self._config.render_kresd_lua()
- if self.type is SubprocessType.POLICY_LOADER:
- return self._config.render_policy_loader_lua()
- return None
-
- def _write_config(self) -> None:
- config_lua = self._render_lua()
- if config_lua and self._config_file:
- with open(self._config_file, "w", encoding="utf8") as file:
- file.write(config_lua)
-
- def _unlink_config(self) -> None:
- if self._config_file:
- self._config_file.unlink(missing_ok=True)
-
- async def start(self, new_config: Optional[KresConfig] = None) -> None:
- if new_config:
- self._config = new_config
- self._write_config()
-
- try:
- await self._start()
- if self.type is SubprocessType.KRESD:
- register_worker(self)
- self._registered_worker = True
- except KresSubprocessControllerError:
- self._unlink_config()
- raise
-
- async def apply_new_config(self, new_config: KresConfig) -> None:
- self._config = new_config
-
- # update config file
- logger.debug(f"Writing config file for {self.id}")
- self._write_config()
-
- # update runtime status
- logger.debug(f"Restarting {self.id}")
- await self._restart()
-
- async def stop(self) -> None:
- if self._registered_worker:
- unregister_worker(self)
- await self._stop()
- await self.cleanup()
-
- async def cleanup(self) -> None:
- """
- Remove temporary files and all traces of this instance running.
-
- It is NOT SAFE to call this while the kresd is running,
- because it will break automatic restarts (at the very least).
- """
- self._unlink_config()
-
- def __eq__(self, o: object) -> bool:
- return isinstance(o, type(self)) and o.type == self.type and o.id == self.id
-
- def __hash__(self) -> int:
- return hash(type(self)) ^ hash(self.type) ^ hash(self.id)
-
- @abstractmethod
- async def _start(self) -> None:
- pass
-
- @abstractmethod
- async def _stop(self) -> None:
- pass
-
- @abstractmethod
- async def _restart(self) -> None:
- pass
-
- @abstractmethod
- async def get_pid(self) -> int:
- pass
-
- @abstractmethod
- def status(self) -> SubprocessStatus:
- pass
-
- @property
- def type(self) -> SubprocessType:
- return self.id.subprocess_type
-
- @property
- def id(self) -> KresID:
- return self._id
-
- async def command(self, cmd: str) -> object:
- if not self._registered_worker:
- raise RuntimeError("the command cannot be sent to a process other than the kresd worker")
-
- reader: asyncio.StreamReader
- writer: Optional[asyncio.StreamWriter] = None
-
- try:
- reader, writer = await asyncio.open_unix_connection(f"./control/{int(self.id)}")
-
- # drop prompt
- _ = await reader.read(2)
-
- # switch to JSON mode
- writer.write("__json\n".encode("utf8"))
-
- # write command
- writer.write(cmd.encode("utf8"))
- writer.write(b"\n")
- await writer.drain()
-
- # read result
- (msg_len,) = struct.unpack(">I", await reader.read(4))
- result_bytes = await reader.readexactly(msg_len)
-
- try:
- return json.loads(result_bytes.decode("utf8"))
- except json.JSONDecodeError:
- return result_bytes.decode("utf8")
-
- finally:
- if writer is not None:
- writer.close()
-
- # proper closing of the socket is only implemented in later versions of python
- if sys.version_info >= (3, 7):
- await writer.wait_closed()
-
-
-class SubprocessController(ABC):
- """
- The common Subprocess Controller interface.
-
- This is what KresManager requires and what has to be implemented by all controllers.
- """
-
- @abstractmethod
- async def is_controller_available(self, config: KresConfig) -> bool:
- """Return bool, whether the controller is available with the given config."""
-
- @abstractmethod
- async def initialize_controller(self, config: KresConfig) -> None:
- """
- Initialize the Subprocess Controller.
-
- Should be called when we want to really start using the controller with a specific configuration.
- """
-
- @abstractmethod
- async def get_all_running_instances(self) -> Iterable[Subprocess]:
- """Must NOT be called before initialize_controller()."""
-
- @abstractmethod
- async def shutdown_controller(self) -> None:
- """
- Shutting the Process Cntroller.
-
- Allows us to stop the service manager process or simply cleanup,
- so that we don't reuse the same resources in a new run.
-
- Must NOT be called before initialize_controller()
- """
-
- @abstractmethod
- async def create_subprocess(self, subprocess_config: KresConfig, subprocess_type: SubprocessType) -> Subprocess:
- """
- Return a Subprocess object which can be operated on.
-
- The subprocess is not started or in any way active after this call.
- That has to be performaed manually using the returned object itself.
-
- Must NOT be called before initialize_controller()
- """
-
- @abstractmethod
- async def get_subprocess_status(self) -> Dict[KresID, SubprocessStatus]:
- """
- Get a status of running subprocesses as seen by the controller.
-
- This method actively polls for information.
-
- Must NOT be called before initialize_controller()
- """
+++ /dev/null
-import asyncio
-import logging
-from typing import TYPE_CHECKING, Dict, List, Tuple
-
-from .exceptions import KresSubprocessControllerError
-
-if TYPE_CHECKING:
- from knot_resolver.controller.interface import KresID, Subprocess
-
-
-logger = logging.getLogger(__name__)
-
-
-_REGISTERED_WORKERS: "Dict[KresID, Subprocess]" = {}
-
-
-def get_registered_workers_kresids() -> "List[KresID]":
- return list(_REGISTERED_WORKERS.keys())
-
-
-async def command_single_registered_worker(cmd: str) -> "Tuple[KresID, object]":
- for sub in _REGISTERED_WORKERS.values():
- return sub.id, await sub.command(cmd)
- raise KresSubprocessControllerError(
- "Unable to execute the command. There is no kresd worker running to execute the command."
- "Try start/restart the resolver.",
- )
-
-
-async def command_registered_workers(cmd: str) -> "Dict[KresID, object]":
- async def single_pair(sub: "Subprocess") -> "Tuple[KresID, object]":
- return sub.id, await sub.command(cmd)
-
- pairs = await asyncio.gather(*(single_pair(inst) for inst in _REGISTERED_WORKERS.values()))
- return dict(pairs)
-
-
-def unregister_worker(subprocess: "Subprocess") -> None:
- """Unregister kresd worker "Subprocess" from the list."""
- del _REGISTERED_WORKERS[subprocess.id]
-
-
-def register_worker(subprocess: "Subprocess") -> None:
- """Register kresd worker "Subprocess" on the list."""
- _REGISTERED_WORKERS[subprocess.id] = subprocess
+++ /dev/null
-import logging
-from os import getppid, kill # pylint: disable=[no-name-in-module]
-from pathlib import Path
-from typing import Any, Dict, Iterable, NoReturn, Optional, Union, cast
-from xmlrpc.client import Fault, ServerProxy
-
-import supervisor.xmlrpc # type: ignore[import]
-
-from knot_resolver.controller.exceptions import KresSubprocessControllerError, KresSubprocessControllerExec
-from knot_resolver.controller.interface import (
- KresID,
- Subprocess,
- SubprocessController,
- SubprocessStatus,
- SubprocessType,
-)
-from knot_resolver.controller.supervisord.config_file import SupervisordKresID, write_config_file
-from knot_resolver.datamodel.config_schema import KresConfig, workers_max_count
-from knot_resolver.manager.constants import supervisord_config_file, supervisord_pid_file, supervisord_sock_file
-from knot_resolver.utils import which
-from knot_resolver.utils.async_utils import call, readfile
-from knot_resolver.utils.compat.asyncio import async_in_a_thread
-
-logger = logging.getLogger(__name__)
-
-
-async def _start_supervisord(config: KresConfig) -> None:
- logger.debug("Writing supervisord config")
- await write_config_file(config)
- logger.debug("Starting supervisord")
- res = await call(["supervisord", "--configuration", str(supervisord_config_file(config).absolute())])
- if res != 0:
- raise KresSubprocessControllerError(f"Supervisord exited with exit code {res}")
-
-
-async def _exec_supervisord(config: KresConfig) -> NoReturn:
- logger.debug("Writing supervisord config")
- await write_config_file(config)
- logger.debug("Execing supervisord")
- raise KresSubprocessControllerExec(
- [
- str(which.which("supervisord")),
- "supervisord",
- "--configuration",
- str(supervisord_config_file(config).absolute()),
- ]
- )
-
-
-async def _reload_supervisord(config: KresConfig) -> None:
- await write_config_file(config)
- try:
- supervisord = _create_supervisord_proxy(config)
- supervisord.reloadConfig()
- except Fault as e:
- raise KresSubprocessControllerError(f"supervisord reload failed: {e}") from e
-
-
-@async_in_a_thread
-def _stop_supervisord(config: KresConfig) -> None:
- supervisord = _create_supervisord_proxy(config)
- # pid = supervisord.getPID()
- try:
- # we might be trying to shut down supervisord at a moment, when it's waiting
- # for us to stop. Therefore, this shutdown request for supervisord might
- # die and it's not a problem.
- supervisord.shutdown()
- except Fault as e:
- if e.faultCode == 6 and e.faultString == "SHUTDOWN_STATE":
- # supervisord is already stopping, so it's fine
- pass
- else:
- # something wrong happened, let's be loud about it
- raise
-
- # It is always better to clean up.
- # This way, we can be sure that we are starting with a newly generated configuration.
- supervisord_config_file(config).unlink()
-
-
-async def _is_supervisord_available() -> bool:
- # yes, it is! The code in this file wouldn't be running without it due to imports :)
-
- # so let's just check that we can find supervisord and supervisorctl binaries
- try:
- which.which("supervisord")
- which.which("supervisorctl")
- except RuntimeError:
- logger.error("Failed to find supervisord or supervisorctl executables in $PATH")
- return False
-
- return True
-
-
-async def _get_supervisord_pid(config: KresConfig) -> Optional[int]:
- if not Path(supervisord_pid_file(config)).exists():
- return None
-
- return int(await readfile(supervisord_pid_file(config)))
-
-
-def _is_process_runinng(pid: int) -> bool:
- try:
- # kill with signal 0 is a safe way to test that a process exists
- kill(pid, 0)
- except ProcessLookupError:
- return False
- else:
- return True
-
-
-async def _is_supervisord_running(config: KresConfig) -> bool:
- pid = await _get_supervisord_pid(config)
- if pid is None:
- return False
- if not _is_process_runinng(pid) or getppid() != pid:
- supervisord_pid_file(config).unlink()
- return False
- return True
-
-
-def _create_proxy(config: KresConfig) -> ServerProxy:
- return ServerProxy(
- "http://127.0.0.1",
- transport=supervisor.xmlrpc.SupervisorTransport(
- None, None, serverurl="unix://" + str(supervisord_sock_file(config))
- ),
- )
-
-
-def _create_supervisord_proxy(config: KresConfig) -> Any:
- proxy = _create_proxy(config)
- return getattr(proxy, "supervisor")
-
-
-def _create_fast_proxy(config: KresConfig) -> Any:
- proxy = _create_proxy(config)
- return getattr(proxy, "fast")
-
-
-def _convert_subprocess_status(proc: Any) -> SubprocessStatus:
- conversion_tbl = {
- # "STOPPED": None, # filtered out elsewhere
- "STARTING": SubprocessStatus.RUNNING,
- "RUNNING": SubprocessStatus.RUNNING,
- "BACKOFF": SubprocessStatus.RUNNING,
- "STOPPING": SubprocessStatus.RUNNING,
- "EXITED": SubprocessStatus.EXITED,
- "FATAL": SubprocessStatus.FATAL,
- "UNKNOWN": SubprocessStatus.UNKNOWN,
- }
-
- if proc["statename"] in conversion_tbl:
- status = conversion_tbl[proc["statename"]]
- else:
- logger.warning(f"Unknown supervisord process state {proc['statename']}")
- status = SubprocessStatus.UNKNOWN
- return status
-
-
-def _list_running_subprocesses(config: KresConfig) -> Dict[SupervisordKresID, SubprocessStatus]:
- try:
- supervisord = _create_supervisord_proxy(config)
- processes: Any = supervisord.getAllProcessInfo()
- except Fault as e:
- raise KresSubprocessControllerError(f"failed to get info from all running processes: {e}") from e
-
- # there will be a manager process as well, but we don't want to report anything on ourselves
- processes = [pr for pr in processes if pr["name"] != "manager"]
-
- # convert all the names
- return {
- SupervisordKresID.from_string(f"{pr['group']}:{pr['name']}"): _convert_subprocess_status(pr)
- for pr in processes
- if pr["statename"] != "STOPPED"
- }
-
-
-class SupervisordSubprocess(Subprocess):
- def __init__(
- self,
- config: KresConfig,
- controller: "SupervisordSubprocessController",
- base_id: Union[SubprocessType, SupervisordKresID],
- ) -> None:
- if isinstance(base_id, SubprocessType):
- super().__init__(config, SupervisordKresID.alloc(base_id))
- else:
- super().__init__(config, base_id)
- self._controller: "SupervisordSubprocessController" = controller
-
- @property
- def name(self) -> str:
- return str(self.id)
-
- def status(self) -> SubprocessStatus:
- try:
- supervisord = _create_supervisord_proxy(self._config)
- status = supervisord.getProcessInfo(self.name)
- except Fault as e:
- raise KresSubprocessControllerError(f"failed to get status from '{self.id}' process: {e}") from e
- return _convert_subprocess_status(status)
-
- @async_in_a_thread
- def _start(self) -> None:
- # +1 for canary process (same as in config_file.py)
- assert int(self.id) <= int(workers_max_count()) + 1, "trying to spawn more than allowed limit of workers"
- try:
- supervisord = _create_fast_proxy(self._config)
- supervisord.startProcess(self.name)
- except Fault as e:
- raise KresSubprocessControllerError(f"failed to start '{self.id}'") from e
-
- @async_in_a_thread
- def _stop(self) -> None:
- supervisord = _create_supervisord_proxy(self._config)
- supervisord.stopProcess(self.name)
-
- @async_in_a_thread
- def _restart(self) -> None:
- supervisord = _create_supervisord_proxy(self._config)
- supervisord.stopProcess(self.name)
- fast = _create_fast_proxy(self._config)
- fast.startProcess(self.name)
-
- @async_in_a_thread
- def get_pid(self) -> int:
- if self._pid is None:
- supervisord = _create_supervisord_proxy(self._config)
- info = supervisord.getProcessInfo(self.name)
- self._pid = info["pid"]
- return self._pid
-
- def get_used_config(self) -> KresConfig:
- return self._config
-
-
-class SupervisordSubprocessController(SubprocessController):
- def __init__(self) -> None: # pylint: disable=super-init-not-called
- self._controller_config: Optional[KresConfig] = None
-
- def __str__(self) -> str:
- return "supervisord"
-
- async def is_controller_available(self, config: KresConfig) -> bool:
- res = await _is_supervisord_available()
- if not res:
- logger.info("Failed to find usable supervisord.")
-
- logger.debug("Detection - supervisord controller is available for use")
- return res
-
- async def get_all_running_instances(self) -> Iterable[Subprocess]:
- assert self._controller_config is not None
-
- if await _is_supervisord_running(self._controller_config):
- states = _list_running_subprocesses(self._controller_config)
- return [
- SupervisordSubprocess(self._controller_config, self, id_)
- for id_ in states
- if states[id_] == SubprocessStatus.RUNNING
- ]
- return []
-
- async def initialize_controller(self, config: KresConfig) -> None:
- self._controller_config = config
-
- if not await _is_supervisord_running(config):
- logger.info(
- "We want supervisord to restart us when needed, we will therefore exec() it and let it start us again."
- )
- await _exec_supervisord(config)
- else:
- logger.info("Supervisord is already running, we will just update its config...")
- await _reload_supervisord(config)
-
- async def shutdown_controller(self) -> None:
- assert self._controller_config is not None
- await _stop_supervisord(self._controller_config)
-
- async def create_subprocess(self, subprocess_config: KresConfig, subprocess_type: SubprocessType) -> Subprocess:
- return SupervisordSubprocess(subprocess_config, self, subprocess_type)
-
- @async_in_a_thread
- def get_subprocess_status(self) -> Dict[KresID, SubprocessStatus]:
- assert self._controller_config is not None
- return cast(Dict[KresID, SubprocessStatus], _list_running_subprocesses(self._controller_config))
+++ /dev/null
-import logging
-import os
-import sys
-from dataclasses import dataclass
-from pathlib import Path
-from typing import Literal
-
-from jinja2 import Template
-
-from knot_resolver.constants import KRES_CACHE_GC_EXECUTABLE, KRESD_EXECUTABLE, LINUX_SYS, NOTIFY_SUPPORT
-from knot_resolver.controller.interface import KresID, SubprocessType
-from knot_resolver.datamodel.config_schema import KresConfig, workers_max_count
-from knot_resolver.datamodel.logging_schema import LogTargetEnum
-from knot_resolver.manager.constants import (
- kres_cache_dir,
- kresd_config_file_supervisord_pattern,
- policy_loader_config_file,
- supervisord_config_file,
- supervisord_config_file_tmp,
- supervisord_pid_file,
- supervisord_sock_file,
- supervisord_subprocess_log_dir,
- user_constants,
-)
-from knot_resolver.utils.async_utils import read_resource, writefile
-
-logger = logging.getLogger(__name__)
-
-
-class SupervisordKresID(KresID):
- # WARNING: be really careful with renaming. If the naming schema is changing,
- # we should be able to parse the old one as well, otherwise updating manager will
- # cause weird behavior
-
- @staticmethod
- def from_string(val: str) -> "SupervisordKresID":
- # the double name is checked because thats how we read it from supervisord
- if val in ("cache-gc", "cache-gc:cache-gc"):
- return SupervisordKresID.new(SubprocessType.GC, 0)
- if val in ("policy-loader", "policy-loader:policy-loader"):
- return SupervisordKresID.new(SubprocessType.POLICY_LOADER, 0)
- val = val.replace("kresd:kresd", "")
- return SupervisordKresID.new(SubprocessType.KRESD, int(val))
-
- def __str__(self) -> str:
- if self.subprocess_type is SubprocessType.GC:
- return "cache-gc"
- if self.subprocess_type is SubprocessType.POLICY_LOADER:
- return "policy-loader"
- if self.subprocess_type is SubprocessType.KRESD:
- return f"kresd:kresd{self._id}"
- raise RuntimeError(f"Unexpected subprocess type {self.subprocess_type}")
-
-
-def kres_cache_gc_args(config: KresConfig) -> str:
- args = ""
-
- if config.logging.level == "debug" or (config.logging.groups and "cache-gc" in config.logging.groups):
- args += " -v"
-
- gc_config = config.cache.garbage_collector
- args += (
- f" -d {gc_config.interval.millis()}"
- f" -u {gc_config.threshold}"
- f" -f {gc_config.release}"
- f" -l {gc_config.rw_deletes}"
- f" -L {gc_config.rw_reads}"
- f" -t {gc_config.temp_keys_space.mbytes()}"
- f" -m {gc_config.rw_duration.micros()}"
- f" -w {gc_config.rw_delay.micros()}"
- )
- if gc_config.dry_run:
- args += " -n"
- return args
-
-
-@dataclass
-class ProcessTypeConfig:
- """Data structure holding data for supervisord config template."""
-
- logfile: Path
- workdir: str
- command: str
- startsecs: int
- environment: str
- max_procs: int = 1
-
- @staticmethod
- def create_gc_config(config: KresConfig) -> "ProcessTypeConfig":
- cwd = str(os.getcwd())
- return ProcessTypeConfig( # type: ignore[call-arg]
- logfile=supervisord_subprocess_log_dir(config) / "gc.log",
- workdir=cwd,
- command=f"{KRES_CACHE_GC_EXECUTABLE} -c {kres_cache_dir(config)}{kres_cache_gc_args(config)}",
- startsecs=0,
- environment="",
- )
-
- @staticmethod
- def create_policy_loader_config(config: KresConfig) -> "ProcessTypeConfig":
- cwd = str(os.getcwd())
- return ProcessTypeConfig( # type: ignore[call-arg]
- logfile=supervisord_subprocess_log_dir(config) / "policy-loader.log",
- workdir=cwd,
- command=f"{KRESD_EXECUTABLE} -c {(policy_loader_config_file(config))} -c - -n",
- startsecs=0,
- environment="",
- )
-
- @staticmethod
- def create_kresd_config(config: KresConfig) -> "ProcessTypeConfig":
- cwd = str(os.getcwd())
- environment = 'SYSTEMD_INSTANCE="%(process_num)d"'
-
- # Default for non-Linux systems without support for systemd NOTIFY message.
- # Therefore, we need to give the kresd workers a few seconds to start properly.
- startsecs = 3
-
- if NOTIFY_SUPPORT:
- # There is support for systemd NOTIFY message.
- # Here, 'startsecs' serves as a timeout for waiting for NOTIFY message.
- startsecs = 60
- environment += ",X-SUPERVISORD-TYPE=notify"
-
- return ProcessTypeConfig( # type: ignore[call-arg]
- logfile=supervisord_subprocess_log_dir(config) / "kresd%(process_num)d.log",
- workdir=cwd,
- command=f"{KRESD_EXECUTABLE} -c {kresd_config_file_supervisord_pattern(config)} -n",
- startsecs=startsecs,
- environment=environment,
- max_procs=int(workers_max_count()) + 1, # +1 for the canary process
- )
-
- @staticmethod
- def create_manager_config(_config: KresConfig) -> "ProcessTypeConfig":
- if LINUX_SYS:
- # read original command from /proc
- with open("/proc/self/cmdline", "rb") as f:
- args = [s.decode("utf-8") for s in f.read()[:-1].split(b"\0")]
- else:
- # other systems
- args = [sys.executable] + sys.argv
-
- # insert debugger when asked
- if os.environ.get("KRES_DEBUG_MANAGER"):
- logger.warning("Injecting debugger into the supervisord config")
- # the args array looks like this:
- # [PYTHON_PATH, "-m", "knot_resolver", ...]
- args = args[:1] + ["-m", "debugpy", "--listen", "0.0.0.0:5678", "--wait-for-client"] + args[2:]
-
- cmd = '"' + '" "'.join(args) + '"'
- environment = "KRES_SUPRESS_LOG_PREFIX=true"
- if NOTIFY_SUPPORT:
- environment += ",X-SUPERVISORD-TYPE=notify"
-
- return ProcessTypeConfig( # type: ignore[call-arg]
- workdir=user_constants().working_directory_on_startup,
- command=cmd,
- startsecs=600 if NOTIFY_SUPPORT else 0,
- environment=environment,
- logfile=Path(""), # this will be ignored
- )
-
-
-@dataclass
-class SupervisordConfig:
- unix_http_server: Path
- pid_file: Path
- workdir: str
- logfile: Path
- loglevel: Literal["critical", "error", "warn", "info", "debug", "trace", "blather"]
- target: LogTargetEnum
- notify_support: bool
-
- @staticmethod
- def create(config: KresConfig) -> "SupervisordConfig":
- # determine the correct logging level
- if config.logging.groups and "supervisord" in config.logging.groups:
- loglevel = "info"
- else:
- loglevel = {
- "crit": "critical",
- "err": "error",
- "warning": "warn",
- "notice": "warn",
- "info": "info",
- "debug": "debug",
- }[config.logging.level]
- cwd = str(os.getcwd())
- return SupervisordConfig( # type: ignore[call-arg]
- unix_http_server=supervisord_sock_file(config),
- pid_file=supervisord_pid_file(config),
- workdir=cwd,
- logfile=Path("syslog" if config.logging.target == "syslog" else "/dev/null"),
- loglevel=loglevel, # type: ignore[arg-type]
- target=config.logging.target,
- notify_support=NOTIFY_SUPPORT,
- )
-
-
-async def write_config_file(config: KresConfig) -> None:
- if not supervisord_subprocess_log_dir(config).exists():
- supervisord_subprocess_log_dir(config).mkdir(exist_ok=True)
-
- template = await read_resource(__package__, "supervisord.conf.j2")
- assert template is not None
- template = template.decode("utf8")
- config_string = Template(template).render(
- gc=ProcessTypeConfig.create_gc_config(config),
- loader=ProcessTypeConfig.create_policy_loader_config(config),
- kresd=ProcessTypeConfig.create_kresd_config(config),
- manager=ProcessTypeConfig.create_manager_config(config),
- config=SupervisordConfig.create(config),
- )
- await writefile(supervisord_config_file_tmp(config), config_string)
- # atomically replace (we don't technically need this right now, but better safe then sorry)
- os.rename(supervisord_config_file_tmp(config), supervisord_config_file(config))
+++ /dev/null
-# type: ignore
-# pylint: skip-file
-
-"""
-This file is modified version of supervisord's source code:
-https://github.com/Supervisor/supervisor/blob/5d9c39619e2e7e7fca33c890cb2a9f2d3d0ab762/supervisor/rpcinterface.py
-
-The changes made are:
-
- - removed everything that we do not need, reformatted to fit our code stylepo (2022-06-24)
- - made startProcess faster by setting delay to 0 (2022-06-24)
-
-
-The original supervisord licence follows:
---------------------------------------------------------------------
-
-Supervisor is licensed under the following license:
-
- A copyright notice accompanies this license document that identifies
- the copyright holders.
-
- Redistribution and use in source and binary forms, with or without
- modification, are permitted provided that the following conditions are
- met:
-
- 1. Redistributions in source code must retain the accompanying
- copyright notice, this list of conditions, and the following
- disclaimer.
-
- 2. Redistributions in binary form must reproduce the accompanying
- copyright notice, this list of conditions, and the following
- disclaimer in the documentation and/or other materials provided
- with the distribution.
-
- 3. Names of the copyright holders must not be used to endorse or
- promote products derived from this software without prior
- written permission from the copyright holders.
-
- 4. If any files are modified, you must cause the modified files to
- carry prominent notices stating that you changed the files and
- the date of any change.
-
- Disclaimer
-
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND
- ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
- TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
- PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
- HOLDERS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
- EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
- TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
- DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
- ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
- TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
- THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
- SUCH DAMAGE.
-"""
-
-from supervisor.http import NOT_DONE_YET
-from supervisor.options import BadCommand, NoPermission, NotExecutable, NotFound, split_namespec
-from supervisor.states import RUNNING_STATES, ProcessStates, SupervisorStates
-from supervisor.xmlrpc import Faults, RPCError
-
-
-class SupervisorNamespaceRPCInterface:
- def __init__(self, supervisord):
- self.supervisord = supervisord
-
- def _update(self, text):
- self.update_text = text # for unit tests, mainly
- if isinstance(self.supervisord.options.mood, int) and self.supervisord.options.mood < SupervisorStates.RUNNING:
- raise RPCError(Faults.SHUTDOWN_STATE)
-
- # RPC API methods
-
- def _getGroupAndProcess(self, name): # noqa: N802
- # get process to start from name
- group_name, process_name = split_namespec(name)
-
- group = self.supervisord.process_groups.get(group_name)
- if group is None:
- raise RPCError(Faults.BAD_NAME, name)
-
- if process_name is None:
- return group, None
-
- process = group.processes.get(process_name)
- if process is None:
- raise RPCError(Faults.BAD_NAME, name)
-
- return group, process
-
- def startProcess(self, name, wait=True): # noqa: N802
- """Start a process
-
- @param string name Process name (or ``group:name``, or ``group:*``)
- @param boolean wait Wait for process to be fully started
- @return boolean result Always true unless error
-
- """
- self._update("startProcess")
- group, process = self._getGroupAndProcess(name)
- if process is None:
- group_name, process_name = split_namespec(name)
- return self.startProcessGroup(group_name, wait)
-
- # test filespec, don't bother trying to spawn if we know it will
- # eventually fail
- try:
- filename, argv = process.get_execv_args()
- except NotFound as e:
- raise RPCError(Faults.NO_FILE, e.args[0]) from e
- except (BadCommand, NotExecutable, NoPermission) as why:
- raise RPCError(Faults.NOT_EXECUTABLE, why.args[0]) from why
-
- if process.get_state() in RUNNING_STATES:
- raise RPCError(Faults.ALREADY_STARTED, name)
-
- if process.get_state() == ProcessStates.UNKNOWN:
- raise RPCError(Faults.FAILED, "%s is in an unknown process state" % name)
-
- process.spawn()
-
- # We call reap() in order to more quickly obtain the side effects of
- # process.finish(), which reap() eventually ends up calling. This
- # might be the case if the spawn() was successful but then the process
- # died before its startsecs elapsed or it exited with an unexpected
- # exit code. In particular, finish() may set spawnerr, which we can
- # check and immediately raise an RPCError, avoiding the need to
- # defer by returning a callback.
-
- self.supervisord.reap()
-
- if process.spawnerr:
- raise RPCError(Faults.SPAWN_ERROR, name)
-
- # We call process.transition() in order to more quickly obtain its
- # side effects. In particular, it might set the process' state from
- # STARTING->RUNNING if the process has a startsecs==0.
- process.transition()
-
- if wait and process.get_state() != ProcessStates.RUNNING:
- # by default, this branch will almost always be hit for processes
- # with default startsecs configurations, because the default number
- # of startsecs for a process is "1", and the process will not have
- # entered the RUNNING state yet even though we've called
- # transition() on it. This is because a process is not considered
- # RUNNING until it has stayed up > startsecs.
-
- def onwait():
- if process.spawnerr:
- raise RPCError(Faults.SPAWN_ERROR, name)
-
- state = process.get_state()
-
- if state not in (ProcessStates.STARTING, ProcessStates.RUNNING):
- raise RPCError(Faults.ABNORMAL_TERMINATION, name)
-
- if state == ProcessStates.RUNNING:
- return True
-
- return NOT_DONE_YET
-
- onwait.delay = 0
- onwait.rpcinterface = self
- return onwait # deferred
-
- return True
-
-
-# this is not used in code but referenced via an entry point in the conf file
-def make_main_rpcinterface(supervisord):
- return SupervisorNamespaceRPCInterface(supervisord)
+++ /dev/null
-# type: ignore
-# pylint: disable=protected-access
-import atexit
-import os
-import signal
-from typing import Any, Optional
-
-from supervisor.compat import as_string
-from supervisor.events import ProcessStateFatalEvent, ProcessStateRunningEvent, ProcessStateStartingEvent, subscribe
-from supervisor.options import ServerOptions
-from supervisor.process import Subprocess
-from supervisor.states import SupervisorStates
-from supervisor.supervisord import Supervisor
-
-from knot_resolver.utils.systemd_notify import systemd_notify
-
-superd: Optional[Supervisor] = None
-
-
-def check_for_fatal_manager(event: ProcessStateFatalEvent) -> None:
- assert superd is not None
-
- proc: Subprocess = event.process
- processname = as_string(proc.config.name)
- if processname == "manager":
- # stop the whole supervisord gracefully
- superd.options.logger.critical("manager process entered FATAL state! Shutting down")
- superd.options.mood = SupervisorStates.SHUTDOWN
-
- # force the interpreter to exit with exit code 1
- atexit.register(lambda: os._exit(1))
-
-
-def check_for_starting_manager(event: ProcessStateStartingEvent) -> None:
- assert superd is not None
-
- proc: Subprocess = event.process
- processname = as_string(proc.config.name)
- if processname == "manager":
- # manager has sucessfully started, report it upstream
- systemd_notify(STATUS="Starting services...")
-
-
-def check_for_runnning_manager(event: ProcessStateRunningEvent) -> None:
- assert superd is not None
-
- proc: Subprocess = event.process
- processname = as_string(proc.config.name)
- if processname == "manager":
- # manager has sucessfully started, report it upstream
- systemd_notify(READY="1", STATUS="Ready")
-
-
-def get_server_options_signal(self):
- sig = self.signal_receiver.get_signal()
- if sig == signal.SIGHUP and superd is not None:
- superd.options.logger.info("received SIGHUP, forwarding to the process 'manager'")
- manager_pid = superd.process_groups["manager"].processes["manager"].pid
- os.kill(manager_pid, signal.SIGHUP)
- return None
-
- return sig
-
-
-def inject(supervisord: Supervisor, **_config: Any) -> Any: # pylint: disable=useless-return
- global superd
- superd = supervisord
-
- # This status notification here unsets the env variable $NOTIFY_SOCKET provided by systemd
- # and stores it locally. Therefore, it shouldn't clash with $NOTIFY_SOCKET we are providing
- # downstream
- systemd_notify(STATUS="Initializing supervisord...")
-
- # register events
- subscribe(ProcessStateFatalEvent, check_for_fatal_manager)
- subscribe(ProcessStateStartingEvent, check_for_starting_manager)
- subscribe(ProcessStateRunningEvent, check_for_runnning_manager)
-
- # forward SIGHUP to manager
- ServerOptions.get_signal = get_server_options_signal
-
- # this method is called by supervisord when loading the plugin,
- # it should return XML-RPC object, which we don't care about
- # That's why why are returning just None
- return None
+++ /dev/null
-#define PY_SSIZE_T_CLEAN
-#include <Python.h>
-
-#include <stdbool.h>
-#include <stdint.h>
-#include <stdio.h>
-#include <unistd.h>
-#include <sys/types.h>
-#include <stdlib.h>
-#include <string.h>
-#include <errno.h>
-#include <sys/socket.h>
-#include <fcntl.h>
-#include <stddef.h>
-#include <sys/socket.h>
-#include <sys/un.h>
-
-#define CONTROL_SOCKET_NAME "supervisor-notify-socket"
-#define NOTIFY_SOCKET_NAME "NOTIFY_SOCKET"
-#define MODULE_NAME "notify"
-#define RECEIVE_BUFFER_SIZE 2048
-
-#if __linux__
-
-static PyObject *NotifySocketError;
-
-static PyObject *init_control_socket(PyObject *self, PyObject *args)
-{
- /* create socket */
- int controlfd = socket(AF_UNIX, SOCK_DGRAM | SOCK_NONBLOCK, 0);
- if (controlfd == -1) goto fail_errno;
-
- /* construct the address; sd_notify() requires that the path is absolute */
- struct sockaddr_un server_addr = {0};
- server_addr.sun_family = AF_UNIX;
- const size_t cwd_max = sizeof(server_addr) - offsetof(struct sockaddr_un, sun_path)
- /* but we also need space for making the path longer: */
- - 1/*slash*/ - strlen(CONTROL_SOCKET_NAME);
- if (!getcwd(server_addr.sun_path, cwd_max))
- goto fail_errno;
- char *p = server_addr.sun_path + strlen(server_addr.sun_path);
- *p = '/';
- strcpy(p + 1, CONTROL_SOCKET_NAME);
-
- /* overwrite the (pseudo-)file if it exists */
- (void)unlink(CONTROL_SOCKET_NAME);
- int res = bind(controlfd, (struct sockaddr *)&server_addr, sizeof(server_addr));
- if (res < 0) goto fail_errno;
-
- /* make sure that we get credentials with messages */
- int data = (int)true;
- res = setsockopt(controlfd, SOL_SOCKET, SO_PASSCRED, &data, sizeof(data));
- if (res < 0) goto fail_errno;
- /* store the name of the socket in env to fake systemd */
- char *old_value = getenv(NOTIFY_SOCKET_NAME);
- if (old_value != NULL) {
- printf("[notify_socket] warning, running under systemd and overwriting $%s\n",
- NOTIFY_SOCKET_NAME);
- // fixme
- }
-
- res = setenv(NOTIFY_SOCKET_NAME, server_addr.sun_path, 1);
- if (res < 0) goto fail_errno;
-
- return PyLong_FromLong((long)controlfd);
-fail_errno:
- PyErr_SetFromErrno(NotifySocketError);
- return NULL;
-}
-
-static PyObject *handle_control_socket_connection_event(PyObject *self,
- PyObject *args)
-{
- long controlfd;
- if (!PyArg_ParseTuple(args, "i", &controlfd))
- return NULL;
-
- /* read command assuming it fits and it was sent all at once */
- // prepare space to read filedescriptors
- struct msghdr msg;
- msg.msg_name = NULL;
- msg.msg_namelen = 0;
-
- // prepare a place to read the actual message
- char place_for_data[RECEIVE_BUFFER_SIZE];
- bzero(&place_for_data, sizeof(place_for_data));
- struct iovec iov = { .iov_base = &place_for_data,
- .iov_len = sizeof(place_for_data) };
- msg.msg_iov = &iov;
- msg.msg_iovlen = 1;
-
- char cmsg[CMSG_SPACE(sizeof(struct ucred))];
- msg.msg_control = cmsg;
- msg.msg_controllen = sizeof(cmsg);
-
- /* Receive real plus ancillary data */
- int len = recvmsg(controlfd, &msg, 0);
- if (len == -1) {
- if (errno == EWOULDBLOCK || errno == EAGAIN) {
- Py_RETURN_NONE;
- } else {
- PyErr_SetFromErrno(NotifySocketError);
- return NULL;
- }
- }
-
- /* read the sender pid */
- struct cmsghdr *cmsgp = CMSG_FIRSTHDR(&msg);
- pid_t pid = -1;
- while (cmsgp != NULL) {
- if (cmsgp->cmsg_type == SCM_CREDENTIALS) {
- if (
- cmsgp->cmsg_len != CMSG_LEN(sizeof(struct ucred)) ||
- cmsgp->cmsg_level != SOL_SOCKET
- ) {
- printf("[notify_socket] invalid cmsg data, ignoring\n");
- Py_RETURN_NONE;
- }
-
- struct ucred cred;
- memcpy(&cred, CMSG_DATA(cmsgp), sizeof(cred));
- pid = cred.pid;
- }
- cmsgp = CMSG_NXTHDR(&msg, cmsgp);
- }
- if (pid == -1) {
- printf("[notify_socket] ignoring received data without credentials: %s\n",
- place_for_data);
- Py_RETURN_NONE;
- }
-
- /* return received data as a tuple (pid, data bytes) */
- return Py_BuildValue("iy", pid, place_for_data);
-}
-
-static PyMethodDef NotifyMethods[] = {
- { "init_socket", init_control_socket, METH_VARARGS,
- "Init notify socket. Returns it's file descriptor." },
- { "read_message", handle_control_socket_connection_event, METH_VARARGS,
- "Reads datagram from notify socket. Returns tuple of PID and received bytes." },
- { NULL, NULL, 0, NULL } /* Sentinel */
-};
-
-static struct PyModuleDef notifymodule = {
- PyModuleDef_HEAD_INIT, MODULE_NAME, /* name of module */
- NULL, /* module documentation, may be NULL */
- -1, /* size of per-interpreter state of the module,
- or -1 if the module keeps state in global variables. */
- NotifyMethods
-};
-
-PyMODINIT_FUNC PyInit_notify(void)
-{
- PyObject *m;
-
- m = PyModule_Create(¬ifymodule);
- if (m == NULL)
- return NULL;
-
- NotifySocketError =
- PyErr_NewException(MODULE_NAME ".error", NULL, NULL);
- Py_XINCREF(NotifySocketError);
- if (PyModule_AddObject(m, "error", NotifySocketError) < 0) {
- Py_XDECREF(NotifySocketError);
- Py_CLEAR(NotifySocketError);
- Py_DECREF(m);
- return NULL;
- }
-
- return m;
-}
-
-#endif
+++ /dev/null
-# type: ignore
-# pylint: disable=protected-access
-
-import os
-import sys
-import traceback
-from typing import Any, Literal
-
-from supervisor.dispatchers import POutputDispatcher
-from supervisor.loggers import LevelsByName, StreamHandler, SyslogHandler
-from supervisor.supervisord import Supervisor
-
-FORWARD_LOG_LEVEL = LevelsByName.CRIT # to make sure it's always printed
-
-
-def empty_function(*args, **kwargs):
- pass
-
-
-FORWARD_MSG_FORMAT: str = "%(name)s[%(pid)d]%(stream)s: %(data)s"
-
-
-def p_output_dispatcher_log(self: POutputDispatcher, data: bytearray):
- if data:
- # parse the input
- if not isinstance(data, bytes):
- text = data
- else:
- try:
- text = data.decode("utf-8")
- except UnicodeDecodeError:
- text = "Undecodable: %r" % data
-
- # print line by line prepending correct prefix to match the style
- config = self.process.config
- config.options.logger.handlers = forward_handlers
- for line in text.splitlines():
- stream = ""
- if self.channel == "stderr":
- stream = " (stderr)"
- config.options.logger.log(
- FORWARD_LOG_LEVEL, FORWARD_MSG_FORMAT, name=config.name, stream=stream, data=line, pid=self.process.pid
- )
- config.options.logger.handlers = supervisord_handlers
-
-
-def _create_handler(fmt, level, target: Literal["stdout", "stderr", "syslog"]) -> StreamHandler:
- if target == "syslog":
- handler = SyslogHandler()
- else:
- handler = StreamHandler(sys.stdout if target == "stdout" else sys.stderr)
- handler.setFormat(fmt)
- handler.setLevel(level)
- return handler
-
-
-supervisord_handlers = []
-forward_handlers = []
-
-
-def inject(supervisord: Supervisor, **config: Any) -> Any: # pylint: disable=useless-return
- try:
- # reconfigure log handlers
- supervisord.options.logger.info("reconfiguring log handlers")
- supervisord_handlers.append(
- _create_handler(
- f"%(asctime)s supervisor[{os.getpid()}]: [%(levelname)s] %(message)s\n",
- supervisord.options.loglevel,
- config["target"],
- )
- )
- forward_handlers.append(
- _create_handler("%(asctime)s %(message)s\n", supervisord.options.loglevel, config["target"])
- )
- supervisord.options.logger.handlers = supervisord_handlers
-
- # replace output handler for subprocesses
- POutputDispatcher._log = p_output_dispatcher_log # noqa: SLF001
-
- # we forward stdio in all cases, even when logging to syslog. This should prevent the unforturtunate
- # case of swallowing an error message leaving the users confused. To make the forwarded lines obvious
- # we just prepend a explanatory string at the beginning of all messages
- if config["target"] == "syslog":
- global FORWARD_MSG_FORMAT
- FORWARD_MSG_FORMAT = "captured stdio output from " + FORWARD_MSG_FORMAT
-
- # this method is called by supervisord when loading the plugin,
- # it should return XML-RPC object, which we don't care about
- # That's why why are returning just None
- return None
-
- # if we fail to load the module, print some explanation
- # should not happen when run by endusers
- except BaseException:
- traceback.print_exc()
- raise
+++ /dev/null
-# type: ignore
-# ruff: noqa: SLF001
-# pylint: disable=c-extension-no-member
-
-from knot_resolver.constants import NOTIFY_SUPPORT
-
-if NOTIFY_SUPPORT:
- import os
- import signal
- import time
- from functools import partial
- from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
-
- from supervisor.events import ProcessStateEvent, ProcessStateStartingEvent, subscribe
- from supervisor.medusa.asyncore_25 import compact_traceback
- from supervisor.process import Subprocess
- from supervisor.states import ProcessStates
- from supervisor.supervisord import Supervisor
-
- from knot_resolver.controller.supervisord.plugin import notify
-
- starting_processes: List[Subprocess] = []
-
- def is_type_notify(proc: Subprocess) -> bool:
- return (
- proc.config.environment is not None and proc.config.environment.get("X-SUPERVISORD-TYPE", None) == "notify"
- )
-
- class NotifySocketDispatcher:
- """
- See supervisor.dispatcher
- """
-
- def __init__(self, supervisor: Supervisor, fd: int):
- self._supervisor = supervisor
- self.fd = fd
- self.closed = False # True if close() has been called
-
- def __repr__(self):
- return f"<{self.__class__.__name__} with fd={self.fd}>"
-
- def readable(self):
- return True
-
- def writable(self):
- return False
-
- def handle_read_event(self):
- logger: Any = self._supervisor.options.logger
-
- res: Optional[Tuple[int, bytes]] = notify.read_message(self.fd)
- if res is None:
- return # there was some junk
- pid, data = res
-
- # pylint: disable=undefined-loop-variable
- for proc in starting_processes:
- if proc.pid == pid:
- break
- else:
- logger.warn(f"ignoring ready notification from unregistered PID={pid}")
- return
-
- if data.startswith(b"READY=1"):
- # handle case, when some process is really ready
-
- if is_type_notify(proc):
- proc._assertInState(ProcessStates.STARTING)
- proc.change_state(ProcessStates.RUNNING)
- logger.info(
- f"success: {proc.config.name} entered RUNNING state, process sent notification via $NOTIFY_SOCKET"
- )
- else:
- logger.warn(
- f"ignoring READY notification from {proc.config.name}, which is not configured to send it"
- )
-
- elif data.startswith(b"STOPPING=1"):
- # just accept the message, filter unwanted notifications and do nothing else
-
- if is_type_notify(proc):
- logger.info(
- f"success: {proc.config.name} entered STOPPING state, process sent notification via $NOTIFY_SOCKET"
- )
- else:
- logger.warn(
- f"ignoring STOPPING notification from {proc.config.name}, which is not configured to send it"
- )
-
- else:
- # handle case, when we got something unexpected
- logger.warn(f"ignoring unrecognized data on $NOTIFY_SOCKET sent from PID={pid}, data='{data!r}'")
- return
-
- def handle_write_event(self):
- raise ValueError("this dispatcher is not writable")
-
- def handle_error(self):
- _nil, t, v, tbinfo = compact_traceback()
-
- self._supervisor.options.logger.error(
- f"uncaptured python exception, closing notify socket {repr(self)} ({t}:{v} {tbinfo})"
- )
- self.close()
-
- def close(self):
- if not self.closed:
- os.close(self.fd)
- self.closed = True
-
- def flush(self):
- pass
-
- def keep_track_of_starting_processes(event: ProcessStateEvent) -> None:
- global starting_processes
-
- proc: Subprocess = event.process
-
- if isinstance(event, ProcessStateStartingEvent):
- # process is starting
- # if proc not in starting_processes:
- starting_processes.append(proc)
-
- else:
- # not starting
- starting_processes = [p for p in starting_processes if p.pid is not proc.pid]
-
- notify_dispatcher: Optional[NotifySocketDispatcher] = None
-
- def process_transition(slf: Subprocess) -> None:
- if not is_type_notify(slf):
- return slf
-
- # modified version of upstream process transition code
- if slf.state == ProcessStates.STARTING:
- if time.time() - slf.laststart > slf.config.startsecs:
- # STARTING -> STOPPING if the process has not sent ready notification
- # within proc.config.startsecs
- slf.config.options.logger.warn(
- f"process '{slf.config.name}' did not send ready notification within {slf.config.startsecs} secs, killing"
- )
- slf.kill(signal.SIGKILL)
- slf.x_notifykilled = True # used in finish() function to set to FATAL state
- slf.laststart = time.time() + 1 # prevent immediate state transition to RUNNING from happening
-
- # return self for chaining
- return slf
-
- def subprocess_finish_tail(slf, pid, sts) -> Tuple[Any, Any, Any]:
- if getattr(slf, "x_notifykilled", False):
- # we want FATAL, not STOPPED state after timeout waiting for startup notification
- # why? because it's likely not gonna help to try starting the process up again if
- # it failed so early
- slf.change_state(ProcessStates.FATAL)
-
- # clear the marker value
- del slf.x_notifykilled
-
- # return for chaining
- return slf, pid, sts
-
- def supervisord_get_process_map(supervisord: Any, mp: Dict[Any, Any]) -> Dict[Any, Any]:
- global notify_dispatcher
- if notify_dispatcher is None:
- notify_dispatcher = NotifySocketDispatcher(supervisord, notify.init_socket())
- supervisord.options.logger.info("notify: injected $NOTIFY_SOCKET into event loop")
-
- # add our dispatcher to the result
- assert notify_dispatcher.fd not in mp
- mp[notify_dispatcher.fd] = notify_dispatcher
-
- return mp
-
- def process_spawn_as_child_add_env(slf: Subprocess, *args: Any) -> Tuple[Any, ...]:
- if is_type_notify(slf):
- slf.config.environment["NOTIFY_SOCKET"] = os.getcwd() + "/supervisor-notify-socket"
- return (slf, *args)
-
- T = TypeVar("T")
- U = TypeVar("U")
-
- def chain(first: Callable[..., U], second: Callable[[U], T]) -> Callable[..., T]:
- def wrapper(*args: Any, **kwargs: Any) -> T:
- res = first(*args, **kwargs)
- if isinstance(res, tuple):
- return second(*res)
- return second(res)
-
- return wrapper
-
- def append(first: Callable[..., T], second: Callable[..., None]) -> Callable[..., T]:
- def wrapper(*args: Any, **kwargs: Any) -> T:
- res = first(*args, **kwargs)
- second(*args, **kwargs)
- return res
-
- return wrapper
-
- def monkeypatch(supervisord: Supervisor) -> None:
- """Inject ourselves into supervisord code"""
-
- # append notify socket handler to event loop
- supervisord.get_process_map = chain(
- supervisord.get_process_map, partial(supervisord_get_process_map, supervisord)
- )
-
- # prepend timeout handler to transition method
- Subprocess.transition = chain(process_transition, Subprocess.transition)
- Subprocess.finish = append(Subprocess.finish, subprocess_finish_tail)
-
- # add environment variable $NOTIFY_SOCKET to starting processes
- Subprocess._spawn_as_child = chain(process_spawn_as_child_add_env, Subprocess._spawn_as_child)
-
- # keep references to starting subprocesses
- subscribe(ProcessStateEvent, keep_track_of_starting_processes)
-
- def inject(supervisord: Supervisor, **_config: Any) -> Any: # pylint: disable=useless-return
- monkeypatch(supervisord)
-
- # this method is called by supervisord when loading the plugin,
- # it should return XML-RPC object, which we don't care about
- # That's why why are returning just None
- return None
+++ /dev/null
-[supervisord]
-pidfile = {{ config.pid_file }}
-directory = {{ config.workdir }}
-nodaemon = true
-
-{# disable initial logging until patch_logger.py takes over #}
-logfile = /dev/null
-logfile_maxbytes = 0
-silent = true
-
-{# config for patch_logger.py #}
-loglevel = {{ config.loglevel }}
-{# there are more options in the plugin section #}
-
-[unix_http_server]
-file = {{ config.unix_http_server }}
-
-[supervisorctl]
-serverurl = unix://{{ config.unix_http_server }}
-
-{# Extensions to changing the supervisord behavior #}
-[rpcinterface:patch_logger]
-supervisor.rpcinterface_factory = knot_resolver.controller.supervisord.plugin.patch_logger:inject
-target = {{ config.target }}
-
-[rpcinterface:manager_integration]
-supervisor.rpcinterface_factory = knot_resolver.controller.supervisord.plugin.manager_integration:inject
-
-{# sd_notify is supported only on Linux based systems #}
-{% if config.notify_support -%}
-[rpcinterface:sd_notify]
-supervisor.rpcinterface_factory = knot_resolver.controller.supervisord.plugin.sd_notify:inject
-{%- endif %}
-
-{# Extensions for actual API control #}
-[rpcinterface:supervisor]
-supervisor.rpcinterface_factory = supervisor.rpcinterface:make_main_rpcinterface
-
-[rpcinterface:fast]
-supervisor.rpcinterface_factory = knot_resolver.controller.supervisord.plugin.fast_rpcinterface:make_main_rpcinterface
-
-[program:manager]
-redirect_stderr=false
-directory={{ manager.workdir }}
-command={{ manager.command }}
-stopsignal=SIGINT
-killasgroup=true
-autorestart=true
-autostart=true
-{# Note that during startup,
- manager will signal being ready only after sequential startup of all kresd workers,
- i.e. it might take lots of time currently, if the user configured very large rulesets (e.g. huge RPZ).
- Let's permit it lots of time, assuming that useful work is being done.
-#}
-startsecs={{ manager.startsecs }}
-environment={{ manager.environment }}
-stdout_logfile=NONE
-stderr_logfile=NONE
-
-[program:kresd]
-process_name=%(program_name)s%(process_num)d
-numprocs={{ kresd.max_procs }}
-directory={{ kresd.workdir }}
-command={{ kresd.command }}
-autostart=false
-autorestart=true
-stopsignal=TERM
-killasgroup=true
-startsecs={{ kresd.startsecs }}
-environment={{ kresd.environment }}
-stdout_logfile=NONE
-stderr_logfile=NONE
-
-[program:policy-loader]
-directory={{ loader.workdir }}
-command={{ loader.command }}
-autostart=false
-stopsignal=TERM
-killasgroup=true
-exitcodes=0
-startsecs={{ loader.startsecs }}
-environment={{ loader.environment }}
-stdout_logfile=NONE
-stderr_logfile=NONE
-
-[program:cache-gc]
-redirect_stderr=false
-directory={{ gc.workdir }}
-command={{ gc.command }}
-autostart=false
-autorestart=true
-stopsignal=TERM
-killasgroup=true
-startsecs={{ gc.startsecs }}
-environment={{ gc.environment }}
-stdout_logfile=NONE
-stderr_logfile=NONE
+++ /dev/null
-from .config_schema import KresConfig, kres_config_json_schema
-
-__all__ = ["KresConfig", "kres_config_json_schema"]
+++ /dev/null
-# ruff: noqa: E501
-from typing import List, Optional
-
-from knot_resolver.constants import CACHE_DIR
-from knot_resolver.datamodel.templates import template_from_str
-from knot_resolver.datamodel.types import (
- DNSRecordTypeEnum,
- DomainName,
- EscapedStr,
- IntNonNegative,
- IntPositive,
- Percent,
- ReadableFile,
- SizeUnit,
- TimeUnit,
- WritableDir,
-)
-from knot_resolver.utils.modeling import ConfigSchema
-from knot_resolver.utils.modeling.base_schema import lazy_default
-
-_CACHE_CLEAR_TEMPLATE = template_from_str(
- "{% from 'macros/cache_macros.lua.j2' import cache_clear %} {{ cache_clear(params) }}"
-)
-
-
-class CacheClearRPCSchema(ConfigSchema):
- name: Optional[DomainName] = None
- exact_name: bool = False
- rr_type: Optional[DNSRecordTypeEnum] = None
- chunk_size: IntPositive = IntPositive(100)
-
- def _validate(self) -> None:
- if self.rr_type and not self.exact_name:
- raise ValueError("'rr-type' is only supported with 'exact-name: true'")
-
- def render_lua(self) -> str:
- return _CACHE_CLEAR_TEMPLATE.render(params=self) # pyright: reportUnknownMemberType=false
-
-
-class PrefillSchema(ConfigSchema):
- """
- Prefill the cache periodically by importing zone data obtained over HTTP.
-
- ---
- origin: Origin for the imported data. Cache prefilling is only supported for the root zone ('.').
- url: URL of the zone data to be imported.
- refresh_interval: Time interval between consecutive refreshes of the imported zone data.
- ca_file: Path to the file containing a CA certificate bundle that is used to authenticate the HTTPS connection.
- """
-
- origin: DomainName
- url: EscapedStr
- refresh_interval: TimeUnit = TimeUnit("1d")
- ca_file: Optional[ReadableFile] = None
-
- def _validate(self) -> None:
- if str(self.origin) != ".":
- raise ValueError("cache prefilling is not yet supported for non-root zones")
-
-
-class GarbageCollectorSchema(ConfigSchema):
- """
- Configuration options of the cache garbage collector (kres-cache-gc).
-
- ---
- enable: Enable/disable cache garbage collector.
- interval: Time interval how often the garbage collector will be run.
- threshold: Cache usage in percent that triggers the garbage collector.
- release: Percent of used cache to be freed by the garbage collector.
- temp_keys_space: Maximum amount of temporary memory for copied keys (0 = unlimited).
- rw_deletes: Maximum number of deleted records per read-write transaction (0 = unlimited).
- rw_reads: Maximum number of read records per read-write transaction (0 = unlimited).
- rw_duration: Maximum duration of read-write transaction (0 = unlimited).
- rw_delay: Wait time between two read-write transactions.
- dry_run: Run the garbage collector in dry-run mode.
- """
-
- enable: bool = True
- interval: TimeUnit = TimeUnit("1s")
- threshold: Percent = Percent(80)
- release: Percent = Percent(10)
- temp_keys_space: SizeUnit = SizeUnit("0M")
- rw_deletes: IntNonNegative = IntNonNegative(100)
- rw_reads: IntNonNegative = IntNonNegative(200)
- rw_duration: TimeUnit = TimeUnit("0us")
- rw_delay: TimeUnit = TimeUnit("0us")
- dry_run: bool = False
-
-
-class PredictionSchema(ConfigSchema):
- """
- Helps keep the cache hot by prefetching expiring records and learning usage patterns and repetitive queries.
-
- ---
- enable: Enable/disable prediction.
- window: Sampling window length.
- period: Number of windows that can be kept in memory.
- """
-
- enable: bool = False
- window: TimeUnit = TimeUnit("15m")
- period: IntPositive = IntPositive(24)
-
-
-class PrefetchSchema(ConfigSchema):
- """
- These options help keep the cache hot by prefetching expiring records or learning usage patterns and repetitive queries.
-
- ---
- expiring: Prefetch expiring records.
- prediction: Prefetch record by predicting based on usage patterns and repetitive queries.
- """
-
- expiring: bool = False
- prediction: PredictionSchema = PredictionSchema()
-
-
-class CacheSchema(ConfigSchema):
- """
- DNS resolver cache configuration.
-
- ---
- storage: Cache storage of the DNS resolver.
- size_max: Maximum size of the cache.
- garbage_collector: Use the garbage collector (kres-cache-gc) to periodically clear cache.
- ttl_min: Minimum time-to-live for the cache entries.
- ttl_max: Maximum time-to-live for the cache entries.
- ns_timeout: Time interval for which a nameserver address will be ignored after determining that it does not return (useful) answers.
- prefill: Prefill the cache periodically by importing zone data obtained over HTTP.
- prefetch: These options help keep the cache hot by prefetching expiring records or learning usage patterns and repetitive queries.
- """
-
- storage: WritableDir = lazy_default(WritableDir, str(CACHE_DIR))
- size_max: SizeUnit = SizeUnit("100M")
- garbage_collector: GarbageCollectorSchema = GarbageCollectorSchema()
- ttl_min: TimeUnit = TimeUnit("5s")
- ttl_max: TimeUnit = TimeUnit("1d")
- ns_timeout: TimeUnit = TimeUnit("1000ms")
- prefill: Optional[List[PrefillSchema]] = None
- prefetch: PrefetchSchema = PrefetchSchema()
-
- def _validate(self) -> None:
- if self.ttl_min.seconds() > self.ttl_max.seconds():
- raise ValueError("'ttl-max' can't be smaller than 'ttl-min'")
+++ /dev/null
-# ruff: noqa: E501
-import logging
-import os
-import socket
-from typing import Any, Dict, List, Literal, Optional, Tuple, Union
-
-from knot_resolver.constants import API_SOCK_FILE, RUN_DIR, VERSION, WORKERS_SUPPORT
-from knot_resolver.datamodel.cache_schema import CacheSchema
-from knot_resolver.datamodel.defer_schema import DeferSchema
-from knot_resolver.datamodel.dns64_schema import Dns64Schema
-from knot_resolver.datamodel.dnssec_schema import DnssecSchema
-from knot_resolver.datamodel.forward_schema import FallbackSchema, ForwardSchema
-from knot_resolver.datamodel.globals import Context, get_global_validation_context, set_global_validation_context
-from knot_resolver.datamodel.local_data_schema import LocalDataSchema, RPZSchema, RuleSchema
-from knot_resolver.datamodel.logging_schema import LoggingSchema
-from knot_resolver.datamodel.lua_schema import LuaSchema
-from knot_resolver.datamodel.management_schema import ManagementSchema
-from knot_resolver.datamodel.monitoring_schema import MonitoringSchema
-from knot_resolver.datamodel.network_schema import NetworkSchema
-from knot_resolver.datamodel.options_schema import OptionsSchema
-from knot_resolver.datamodel.rate_limiting_schema import RateLimitingSchema
-from knot_resolver.datamodel.templates import KRESD_CONFIG_TEMPLATE, POLICY_LOADER_CONFIG_TEMPLATE
-from knot_resolver.datamodel.types import EscapedStr, IntPositive, WritableDir
-from knot_resolver.datamodel.view_schema import ViewSchema
-from knot_resolver.utils.modeling import ConfigSchema
-from knot_resolver.utils.modeling.base_schema import lazy_default
-from knot_resolver.utils.modeling.exceptions import AggregateDataValidationError, DataValidationError
-
-WORKERS_MAX = 256
-
-logger = logging.getLogger(__name__)
-
-
-def _cpu_count() -> Optional[int]:
- try:
- return len(os.sched_getaffinity(0)) # type: ignore[attr-defined]
- except (NotImplementedError, AttributeError) as e:
- cpus = os.cpu_count()
- if cpus is None:
- logger.warning(
- "The number of usable CPUs could not be determined using"
- f" 'os.sched_getaffinity()' or 'os.cpu_count()':\n{e}"
- )
- return cpus
-
-
-def workers_max_count() -> int:
- c = _cpu_count()
- if c:
- return c * 10
- return WORKERS_MAX
-
-
-def _get_views_tags(views: List[ViewSchema]) -> List[str]:
- tags = []
- for view in views:
- if view.tags:
- tags += [str(tag) for tag in view.tags if tag not in tags]
- return tags
-
-
-def _check_local_data_tags(
- views_tags: List[str], rules_or_rpz: Union[List[RuleSchema], List[RPZSchema]]
-) -> Tuple[List[str], List[DataValidationError]]:
- tags = []
- errs = []
-
- i = 0
- for rule in rules_or_rpz:
- tags_not_in = []
- if rule.tags:
- for tag in rule.tags:
- tag_str = str(tag)
- if tag_str not in tags:
- tags.append(tag_str)
- if tag_str not in views_tags:
- tags_not_in.append(tag_str)
- if len(tags_not_in) > 0:
- errs.append(
- DataValidationError(
- f"some tags {tags_not_in} not found in '/views' tags", f"/local-data/rules[{i}]/tags"
- )
- )
- i += 1
- return tags, errs
-
-
-class KresConfig(ConfigSchema):
- class Raw(ConfigSchema):
- """
- Knot Resolver declarative configuration.
-
- ---
- version: Version of the configuration schema. By default it is the latest supported by the resolver, but couple of versions back are be supported as well.
- nsid: Name Server Identifier (RFC 5001) which allows DNS clients to request resolver to send back its NSID along with the reply to a DNS request.
- hostname: Internal DNS resolver hostname. Default is machine hostname.
- rundir: Directory where the resolver can create files and which will be it's cwd.
- workers: The number of running kresd (Knot Resolver daemon) workers. If set to 'auto', it is equal to number of CPUs available.
- management: Configuration of management HTTP API.
- options: Fine-tuning global parameters of DNS resolver operation.
- network: Network connections and protocols configuration.
- views: List of views and its configuration.
- local_data: Local data for forward records (A/AAAA) and reverse records (PTR).
- forward: List of Forward Zones and its configuration.
- fallback: Config for fallback on resolution failure.
- cache: DNS resolver cache configuration.
- dnssec: DNSSEC configuration.
- dns64: DNS64 (RFC 6147) configuration.
- logging: Logging and debugging configuration.
- monitoring: Metrics exposition configuration (Prometheus, Graphite)
- lua: Custom Lua configuration.
- rate_limiting: Configuration of rate limiting.
- defer: Configuration of request prioritization (defer).
- """
-
- version: int = 1
- nsid: Optional[EscapedStr] = None
- hostname: Optional[EscapedStr] = None
- rundir: WritableDir = lazy_default(WritableDir, str(RUN_DIR))
- workers: Union[Literal["auto"], IntPositive] = IntPositive(1)
- management: ManagementSchema = lazy_default(ManagementSchema, {"unix-socket": str(API_SOCK_FILE)})
- options: OptionsSchema = OptionsSchema()
- network: NetworkSchema = NetworkSchema()
- views: Optional[List[ViewSchema]] = None
- local_data: LocalDataSchema = LocalDataSchema()
- forward: Optional[List[ForwardSchema]] = None
- fallback: FallbackSchema = FallbackSchema()
- cache: CacheSchema = lazy_default(CacheSchema, {})
- dnssec: DnssecSchema = DnssecSchema()
- dns64: Dns64Schema = Dns64Schema()
- logging: LoggingSchema = LoggingSchema()
- monitoring: MonitoringSchema = MonitoringSchema()
- rate_limiting: RateLimitingSchema = RateLimitingSchema()
- defer: DeferSchema = DeferSchema()
- lua: LuaSchema = LuaSchema()
-
- _LAYER = Raw
-
- #### When ADDING options, please also update config_nodes() in ../manager/manager.py
- nsid: Optional[EscapedStr]
- hostname: EscapedStr
- rundir: WritableDir
- workers: IntPositive
- management: ManagementSchema
- options: OptionsSchema
- network: NetworkSchema
- views: Optional[List[ViewSchema]]
- local_data: LocalDataSchema
- forward: Optional[List[ForwardSchema]]
- fallback: FallbackSchema
- cache: CacheSchema
- dnssec: DnssecSchema
- dns64: Dns64Schema
- logging: LoggingSchema
- monitoring: MonitoringSchema
- rate_limiting: RateLimitingSchema
- defer: DeferSchema
- lua: LuaSchema
-
- def _hostname(self, obj: Raw) -> Any:
- if obj.hostname is None:
- return socket.gethostname()
- return obj.hostname
-
- def _workers(self, obj: Raw) -> Any:
- no_workers_support_msg = (
- "On this system, you cannot run more than one worker because "
- "SO_REUSEPORT (Linux) or SO_REUSEPORT_LB (FreeBSD) socket option is not supported."
- )
- if not WORKERS_SUPPORT and (int(obj.workers) > 1):
- raise ValueError(no_workers_support_msg)
-
- if obj.workers == "auto":
- if not WORKERS_SUPPORT:
- logger.info(
- "Running on system without support for multiple workers,"
- f"' workers' configuration automatically set to 1. {no_workers_support_msg}"
- )
- return IntPositive(1)
-
- count = _cpu_count()
- if count:
- return IntPositive(count)
- raise ValueError(
- "The number of available CPUs to automatically set the number of running 'kresd' workers could not be determined."
- "The number of workers can be configured manually in 'workers' option."
- )
-
- return obj.workers
-
- def _validate(self) -> None: # noqa: C901
- # warn about '/management/unix-socket' not located in '/rundir'
- if self.management.unix_socket and self.management.unix_socket.to_path().parent != self.rundir.to_path():
- logger.warning(
- f"The management API unix-socket '{self.management.unix_socket}'"
- f" is not located in the resolver's rundir '{self.rundir}'."
- " This can lead to permissions issues."
- )
-
- # enforce max-workers config
- workers_max = workers_max_count()
- if int(self.workers) > workers_max:
- raise ValueError(
- f"can't run with more workers than the recommended maximum {workers_max} or hardcoded {WORKERS_MAX}"
- )
-
- # sanity check
- cpu_count = _cpu_count()
- if cpu_count and int(self.workers) > 10 * cpu_count:
- raise ValueError(
- "refusing to run with more then 10 workers per cpu core, the system wouldn't behave nicely"
- )
-
- # get all tags from views
- views_tags = []
- if self.views:
- views_tags = _get_views_tags(self.views)
-
- # get local-data tags and check its existence in views
- errs = []
- local_data_tags = []
- if self.local_data.rules:
- rules_tags, rules_errs = _check_local_data_tags(views_tags, self.local_data.rules)
- errs += rules_errs
- local_data_tags += rules_tags
- if self.local_data.rpz:
- rpz_tags, rpz_errs = _check_local_data_tags(views_tags, self.local_data.rpz)
- errs += rpz_errs
- local_data_tags += rpz_tags
-
- # look for unused tags in /views
- unused_tags = views_tags.copy()
- for tag in local_data_tags:
- if tag in unused_tags:
- unused_tags.remove(tag)
- if len(unused_tags) > 1:
- errs.append(DataValidationError(f"unused tags {unused_tags} found", "/views"))
-
- # raise all validation errors
- if len(errs) == 1:
- raise errs[0]
- if len(errs) > 1:
- raise AggregateDataValidationError("/", errs)
-
- def render_kresd_lua(self) -> str:
- # FIXME the `cwd` argument is used only for configuring control socket path
- # it should be removed and relative path used instead as soon as issue
- # https://gitlab.nic.cz/knot/knot-resolver/-/issues/720 is fixed
- return KRESD_CONFIG_TEMPLATE.render(cfg=self, cwd=os.getcwd())
-
- def render_policy_loader_lua(self) -> str:
- return POLICY_LOADER_CONFIG_TEMPLATE.render(cfg=self, cwd=os.getcwd())
-
-
-def get_rundir_without_validation(data: Dict[str, Any]) -> WritableDir:
- """
- Without fully parsing, try to get a rundir from a raw config data, otherwise use default.
-
- Attempts a dir validation to produce a good error message.
- Used for initial manager startup.
- """
- return WritableDir(data["rundir"] if "rundir" in data else str(RUN_DIR), object_path="/rundir")
-
-
-def kres_config_json_schema() -> Dict[str, Any]:
- """
- At this moment, to create any instance of 'ConfigSchema' even with default values, it is necessary to set the global context.
-
- In the case of generating a JSON schema, strict validation must be turned off, otherwise it may happen that the creation of the JSON schema fails,
- It may fail due to non-existence of the directory/file or their rights.
- This should be fixed in the future. For more info, see 'datamodel.globals.py' module.
- """
- context = get_global_validation_context()
- set_global_validation_context(Context(None, False))
-
- schema = KresConfig.json_schema(
- schema_id=f"https://www.knot-resolver.cz/documentation/v{VERSION}/_static/config.schema.json",
- title="Knot Resolver configuration JSON schema",
- description=f"Version Knot Resolver {VERSION}",
- )
- # setting back to previous values
- set_global_validation_context(context)
-
- return schema
+++ /dev/null
-from knot_resolver.datamodel.types import TimeUnit
-from knot_resolver.utils.modeling import ConfigSchema
-
-
-class DeferSchema(ConfigSchema):
- """
- Configuration of request prioritization (defer).
-
- ---
- enable: Use request prioritization.
- log_period: Minimal time between two log messages, or '0s' to disable.
- """
-
- enable: bool = True
- log_period: TimeUnit = TimeUnit("0s")
+++ /dev/null
-###### Working notes about configuration schema
-
-
-## TODO nit: nest one level deeper inside `dnssec`, probably
-dnssec:
- keep-removed: 0
- refresh-time: 10s
- hold-down-time: 30d
-
-## TODO nit: I don't like this name, at least not for the experimental thing we have there
-network:
- tls:
- auto_discovery: boolean
-
-#### General questions
-Plurals: do we name attributes in plural if they're a list;
- some of them even allow a non-list if using a single element.
-
-
-#### New-policy brainstorming
-
-dnssec:
- # Convert to key: style instead of list?
- # - easier to handle in API/CLI (which might be a common action on names with broken DNSSEC)
- # - allows to supply a value - stamp for expiration of that NTA
- # (absolute time, but I can imagine API/CLI converting from duration when executed)
- # - syntax isn't really more difficult, mainly it forces one entry per line (seems OK)
- negative-trust-anchors:
- example.org:
- my.example.net:
-
-
-view:
- # When a client request arrives, based on the `view` class of rules we may either
- # decide for a direct answer or for marking the request with a set of tags.
- # The concepts of matching and actions are a very good fit for this,
- # and that matches our old policy approach. Matching here should avoid QNAME+QTYPE;
- # instead it's e.g. suitable for access control.
- # RPZ files also support rules that fall into this `view` class.
- #
- # Selecting a single rule: the most specific client-IP prefix
- # that also matches additional conditions.
- - subnet: [ 0.0.0.0/0, ::/0 ]
- answer: refused
- # some might prefer `allow: refused` ?
- # Also, RCODEs are customary in CAPITALS though maybe not in configs.
-
- - subnet: [ 10.0.0.0/8, 192.168.0.0/16 ]
- # Adding `tags` implies allowing the query.
- tags: [ t1, t2, t3 ] # theoretically we could use space-separated string
- options: # only some of the global options can be overridden in view
- minimize: true
- dns64: true
- rate-limit: # future option, probably (optionally?) structured
- # LATER: rulesets are a relatively unclear feature for now.
- # Their main point is to allow prioritization and avoid
- # intermixing rules that come from different sources.
- # Also some properties might be specifyable per ruleset.
- ruleset: tt
-
- - subnet: [ 10.0.10.0/24 ] # maybe allow a single value instead of a list?
- # LATER: special addresses?
- # - for kresd-internal requests
- # - shorthands for all private IPv4 and/or IPv6;
- # though yaml's repeated nodes could mostly cover that
- # or just copy&paste from docs
- answer: allow
-
-# Or perhaps a more complex approach? Probably not.
-# We might have multiple conditions at once and multiple actions at once,
-# but I don't expect these to be common, so the complication is probably not worth it.
-# An advantage would be that the separation of the two parts would be more visible.
-view:
- - match:
- subnet: [ 10.0.0.0/8, 192.168.0.0/16 ]
- do:
- tags: [ t1, t2, t3 ]
- options: # ...
-
-
-local-data: # TODO: name
- #FIXME: tags - allow assigning them to (groups of) addresses/records.
-
- addresses: # automatically adds PTR records and NODATA (LATER: overridable NODATA?)
- foo.bar: [ 127.0.0.1, ::1 ]
- my.pc.corp: 192.168.12.95
- addresses-files: # files in /etc/hosts format (and semantics like `addresses`)
- - /etc/hosts
-
- # Zonefile format seems quite handy here. Details:
- # - probably use `local-data.ttl` from model as the default
- # - and . root to avoid confusion if someone misses a final dot.
- records: |
- example.net. TXT "foo bar"
- A 192.168.2.3
- A 192.168.2.4
- local.example.org AAAA ::1
-
- subtrees:
- nodata: true # impl ATM: defaults to false, set (only) for each rule/name separately
- # impl: options like `ttl` and `nodata` might make sense to be settable (only?) per ruleset
-
- subtrees: # TODO: perhaps just allow in the -tagged style, if we can't avoid lists anyway?
- - type: empty
- roots: [ sub2.example.org ] # TODO: name it the same as for forwarding
- tags: [ t2 ]
- - type: nxdomain
- # Will we need to support multiple file formats in future and choose here?
- roots-file: /path/to/file.txt
- - type: empty
- roots-url: https://example.org/blocklist.txt
- refresh: 1d
- # Is it a separate ruleset? Optionally? Persistence?
- # (probably the same questions for local files as well)
-
- - type: redirect
- roots: [ sub4.example.org ]
- addresses: [ 127.0.0.1, ::1 ]
-
-local-data-tagged: # TODO: name (view?); and even structure seems unclear.
- # TODO: allow only one "type" per list entry? (addresses / addresses-files / subtrees / ...)
- - tags: [ t1, t2 ]
- addresses: #... otherwise the same as local-data
- - tags: [ t2 ]
- records: # ...
- - tags: [ t3 ]
- subtrees: empty
- roots: [ sub2.example.org ]
-
-local-data-tagged: # this avoids lists, so it's relatively easy to amend through API
- "t1 t2": # perhaps it's not nice that tags don't form a proper list?
- addresses:
- foo.bar: [ 127.0.0.1, ::1 ]
- t4:
- addresses:
- foo.bar: [ 127.0.0.1, ::1 ]
-local-data: # avoids lists and merges into the untagged `local-data` config subtree
- tagged: # (getting quite deep, though)
- t1 t2:
- addresses:
- foo.bar: [ 127.0.0.1, ::1 ]
-# or even this ugly thing:
-local-data-tagged t1 t2:
- addresses:
- foo.bar: [ 127.0.0.1, ::1 ]
-
-forward: # TODO: "name" is from Unbound, but @vcunat would prefer "subtree" or something.
- - name: '.' # Root is the default so could be omitted?
- servers: [2001:148f:fffe::1, 2001:148f:ffff::1, 185.43.135.1, 193.14.47.1]
- # TLS forward, server authenticated using hostname and system-wide CA certificates
- # https://www.knot-resolver.cz/documentation/latest/modules-policy.html?highlight=forward#tls-examples
- - name: '.'
- servers:
- - address: [ 192.0.2.1, 192.0.2.2@5353 ]
- transport: tls
- pin-sha256: Wg==
- - address: 2001:DB8::d0c
- transport: tls
- hostname: res.example.com
- ca-file: /etc/knot-resolver/tlsca.crt
- options:
- # LATER: allow a subset of options here, per sub-tree?
- # Though that's not necessarily related to forwarding (e.g. TTL limits),
- # especially implementation-wise it probably won't matter.
-
-
-# Too confusing approach, I suppose? Different from usual way of thinking but closer to internal model.
-# Down-sides:
-# - multiple rules for the same name won't be possible (future, with different tags)
-# - loading names from a file won't be possible (or URL, etc.)
-rules:
- example.org: &fwd_odvr
- type: forward
- servers: [2001:148f:fffe::1, 2001:148f:ffff::1, 185.43.135.1, 193.14.47.1]
- sub2.example.org:
- type: empty
- tags: [ t3, t5 ]
- sub3.example.org:
- type: forward-auth
- dnssec: no
-
-
-# @amrazek: current valid config
-
-views:
- - subnets: [ 0.0.0.0/0, "::/0" ]
- answer: refused
- - subnets: [ 0.0.0.0/0, "::/0" ]
- tags: [t01, t02, t03]
- options:
- minimize: true # default
- dns64: true # default
- - subnets: 10.0.10.0/24 # can be single value
- answer: allow
-
-local-data:
- ttl: 1d
- nodata: true
- addresses:
- foo.bar: [ 127.0.0.1, "::1" ]
- my.pc.corp: 192.168.12.95
- addresses-files:
- - /etc/hosts
- records: |
- example.net. TXT "foo bar"
- A 192.168.2.3
- A 192.168.2.4
- local.example.org AAAA ::1
- subtrees:
- - type: empty
- roots: [ sub2.example.org ]
- tags: [ t2 ]
- - type: nxdomain
- roots-file: /path/to/file.txt
- - type: empty
- roots-url: https://example.org/blocklist.txt
- refresh: 1d
- - type: redirect
- roots: [ sub4.example.org ]
- addresses: [ 127.0.0.1, "::1" ]
-
-forward:
- - subtree: '.'
- servers:
- - address: [ 192.0.2.1, 192.0.2.2@5353 ]
- transport: tls
- pin-sha256: Wg==
- - address: 2001:DB8::d0c
- transport: tls
- hostname: res.example.com
- ca-file: /etc/knot-resolver/tlsca.crt
- options:
- dnssec: true # default
- - subtree: 1.168.192.in-addr.arpa
- servers: [ 192.0.2.1@5353 ]
- options:
- dnssec: false # policy.STUB?
+++ /dev/null
-from typing import List, Optional
-
-from knot_resolver.datamodel.types import IPv6Network, IPv6Network96, TimeUnit
-from knot_resolver.utils.modeling import ConfigSchema
-
-
-class Dns64Schema(ConfigSchema):
- """
- DNS64 (RFC 6147) configuration.
-
- ---
- enable: Enable/disable DNS64.
- prefix: IPv6 prefix to be used for synthesizing AAAA records.
- reverse_ttl: TTL in CNAME generated in the reverse 'ip6.arpa.' subtree.
- exclude_subnets: IPv6 subnets that are disallowed in answer.
- """
-
- enable: bool = False
- prefix: IPv6Network96 = IPv6Network96("64:ff9b::/96")
- reverse_ttl: Optional[TimeUnit] = None
- exclude_subnets: Optional[List[IPv6Network]] = None
+++ /dev/null
-# ruff: noqa: E501
-from typing import List, Optional
-
-from knot_resolver.datamodel.types import DomainName, EscapedStr, ReadableFile
-from knot_resolver.utils.modeling import ConfigSchema
-
-
-class TrustAnchorFileSchema(ConfigSchema):
- """
- Trust-anchor zonefile configuration.
-
- ---
- file: Path to the zonefile that stores trust-anchors.
- read_only: Blocks zonefile updates according to RFC 5011.
-
- """
-
- file: ReadableFile
- read_only: bool = False
-
-
-class DnssecSchema(ConfigSchema):
- """
- DNSSEC configuration.
-
- ---
- enable: Enable/disable DNSSEC.
- log_bogus: Enable logging for each DNSSEC validation failure if '/logging/level' is set to at least 'notice'.
- sentinel: Allows users of DNSSEC validating resolver to detect which root keys are configured in resolver's chain of trust. (RFC 8509)
- signal_query: Signaling Trust Anchor Knowledge in DNSSEC Using Key Tag Query, according to (RFC 8145#section-5).
- trust_anchors: List of trust-anchors in DS/DNSKEY records format.
- trust_anchors_files: List of zone-files where trust-anchors are stored.
- trust_anchors: Trust-anchors configuration.
- negative_trust_anchors: List of domain names representing negative trust-anchors. (RFC 7646)
- """
-
- enable: bool = True
- log_bogus: bool = False
- sentinel: bool = True
- signal_query: bool = True
- trust_anchors: Optional[List[EscapedStr]] = None
- trust_anchors_files: Optional[List[TrustAnchorFileSchema]] = None
- negative_trust_anchors: Optional[List[DomainName]] = None
+++ /dev/null
-from typing import Any, List, Literal, Optional, Union
-
-from knot_resolver.datamodel.types import DomainName, IPAddressOptionalPort, ListOrItem, PinSha256, ReadableFile
-from knot_resolver.utils.modeling import ConfigSchema
-
-
-class ForwardServerSchema(ConfigSchema):
- """
- Forward server configuration.
-
- ---
- address: IP address(es) of a forward server.
- transport: Transport protocol for a forward server.
- pin_sha256: Hash of accepted CA certificate.
- hostname: Hostname of the Forward server.
- ca_file: Path to CA certificate file.
- """
-
- address: ListOrItem[IPAddressOptionalPort]
- transport: Optional[Literal["tls"]] = None
- pin_sha256: Optional[ListOrItem[PinSha256]] = None
- hostname: Optional[DomainName] = None
- ca_file: Optional[ReadableFile] = None
-
- def _validate(self) -> None:
- if self.pin_sha256 and (self.hostname or self.ca_file):
- raise ValueError("'pin-sha256' cannot be configured together with 'hostname' or 'ca-file'")
-
-
-class ForwardOptionsSchema(ConfigSchema):
- """
- Subtree(s) forward options.
-
- ---
- authoritative: The forwarding target is an authoritative server.
- dnssec: Enable/disable DNSSEC.
- """
-
- authoritative: bool = False
- dnssec: bool = True
-
-
-class ForwardSchema(ConfigSchema):
- """
- Configuration of forward subtree.
-
- ---
- subtree: Subtree(s) to forward.
- servers: Forward servers configuration.
- options: Subtree(s) forward options.
- """
-
- subtree: ListOrItem[DomainName]
- servers: List[Union[IPAddressOptionalPort, ForwardServerSchema]]
- options: ForwardOptionsSchema = ForwardOptionsSchema()
-
- def _validate(self) -> None:
- def is_port_custom(servers: List[Any]) -> bool:
- for server in servers:
- if isinstance(server, IPAddressOptionalPort) and server.port:
- return int(server.port) != 53
- if isinstance(server, ForwardServerSchema):
- return is_port_custom(server.address.to_std())
- return False
-
- def is_transport_tls(servers: List[Any]) -> bool:
- for server in servers:
- if isinstance(server, ForwardServerSchema):
- return server.transport == "tls"
- return False
-
- if self.options.authoritative and is_port_custom(self.servers):
- raise ValueError("Forwarding to authoritative servers on a custom port is currently not supported.")
-
- if self.options.authoritative and is_transport_tls(self.servers):
- raise ValueError("Forwarding to authoritative servers using TLS protocol is not supported.")
-
-
-class FallbackSchema(ConfigSchema):
- """
- Configuration for fallback after resolution failure.
-
- ---
- enable: Enable/disable the fallback.
- servers: Forward servers configuration for fallback.
- """
-
- enable: bool = False
- servers: Optional[List[Union[IPAddressOptionalPort, ForwardServerSchema]]] = None
-
- def _validate(self) -> None:
- if self.enable and self.servers is None:
- raise ValueError("Fallback enabled without configuring servers.")
+++ /dev/null
-"""
-The parsing and validation of the datamodel is dependent on a global state:
-- a file system path used for resolving relative paths.
-
-Commentary from @vsraier:
-=========================
-
-While this is not ideal, it is the best we can do at the moment. When I created this module,
-the datamodel was dependent on the global state implicitly. The validation procedures just read
-the current working directory. This module is the first step in removing the global dependency.
-
-At some point in the future, it might be interesting to add something like a "validation context"
-to the modelling tools. It is not technically complicated, but it requires
-massive model changes I am not willing to make at the moment. Ideally, when implementing this,
-the BaseSchema would turn into an empty class without any logic. Not even a constructor. All logic
-would be in the ObjectMapper class. Similar to how Gson works in Java or AutoMapper in C#.
-""" # noqa: D205
-
-from pathlib import Path
-from typing import Optional
-
-
-class Context:
- resolve_root: Optional[Path]
- strict_validation: bool
- permissions_default: bool
-
- def __init__(
- self, resolve_root: Optional[Path], strict_validation: bool = True, permissions_default: bool = True
- ) -> None:
- self.resolve_root = resolve_root
- self.strict_validation = strict_validation
- self.permissions_default = permissions_default
-
-
-_global_context: Context = Context(None)
-
-
-def set_global_validation_context(context: Context) -> None:
- global _global_context
- _global_context = context
-
-
-def get_global_validation_context() -> Context:
- return _global_context
-
-
-def reset_global_validation_context() -> None:
- global _global_context
- _global_context = Context(None)
-
-
-def get_resolve_root() -> Path:
- if _global_context.resolve_root is None:
- raise RuntimeError(
- "Global validation context 'resolve_root' is not set!"
- " Before validation, you have to set it using `set_global_validation_context()` function!"
- )
-
- return _global_context.resolve_root
-
-
-def get_strict_validation() -> bool:
- return _global_context.strict_validation
-
-
-def get_permissions_default() -> bool:
- return _global_context.permissions_default
+++ /dev/null
-from typing import Any, Dict, List, Literal, Optional, Union
-
-from knot_resolver.constants import WATCHDOG_LIB
-from knot_resolver.datamodel.types import (
- DomainName,
- EscapedStr,
- IDPattern,
- IPAddress,
- ListOrItem,
- ReadableFile,
- TimeUnit,
-)
-from knot_resolver.utils.modeling import ConfigSchema
-
-
-class RuleSchema(ConfigSchema):
- """
- Local data advanced rule configuration.
-
- ---
- name: Hostname(s).
- subtree: Type of subtree.
- address: Address(es) to pair with hostname(s).
- file: Path to file(s) with hostname and IP address(es) pairs in '/etc/hosts' like format.
- records: Direct addition of records in DNS zone file format.
- tags: Tags to link with other policy rules.
- ttl: Optional, TTL value used for these answers.
- nodata: Optional, use NODATA synthesis. NODATA will be synthesized for matching name, but mismatching type(e.g. AAAA query when only A exists).
- """ # noqa: E501
-
- name: Optional[ListOrItem[DomainName]] = None
- subtree: Optional[Literal["empty", "nxdomain", "redirect"]] = None
- address: Optional[ListOrItem[IPAddress]] = None
- file: Optional[ListOrItem[ReadableFile]] = None
- records: Optional[EscapedStr] = None
- tags: Optional[List[IDPattern]] = None
- ttl: Optional[TimeUnit] = None
- nodata: Optional[bool] = None
- # TODO: probably also implement the rule options from RPZSchema (.log + .dry_run)
-
- def _validate(self) -> None:
- options_sum = sum([bool(self.address), bool(self.subtree), bool(self.file), bool(self.records)])
- if options_sum == 2 and bool(self.address) and self.subtree in {"empty", "redirect"}:
- pass # these combinations still make sense
- elif options_sum > 1:
- raise ValueError("only one of 'address', 'subtree' or 'file' can be configured")
- elif options_sum < 1:
- raise ValueError("one of 'address', 'subtree', 'file' or 'records' must be configured")
-
- options_sum2 = sum([bool(self.name), bool(self.file), bool(self.records)])
- if options_sum2 != 1:
- raise ValueError("one of 'name', 'file or 'records' must be configured")
-
- if bool(self.nodata) and bool(self.subtree) and not bool(self.address):
- raise ValueError("'nodata' defined but unused with 'subtree'")
-
-
-class RPZSchema(ConfigSchema):
- class Raw(ConfigSchema):
- """
- Configuration or Response Policy Zone (RPZ).
-
- ---
- file: Path to the RPZ zone file.
- watchdog: Enables files watchdog for configured RPZ file. Requires the optional 'watchdog' dependency.
- tags: Tags to link with other policy rules.
- log: Enables logging information whenever this RPZ matches.
- """
-
- file: ReadableFile
- watchdog: Union[Literal["auto"], bool] = "auto"
- tags: Optional[List[IDPattern]] = None
- log: Optional[List[Literal["ip", "name"]]] = None
- # dry_run: bool = False
-
- _LAYER = Raw
-
- file: ReadableFile
- watchdog: bool
- tags: Optional[List[IDPattern]]
- log: Optional[List[Literal["ip", "name"]]]
- # dry_run: bool
-
- def _watchdog(self, obj: Raw) -> Any:
- if obj.watchdog == "auto":
- return WATCHDOG_LIB
- return obj.watchdog
-
- def _validate(self) -> None:
- if self.watchdog and not WATCHDOG_LIB:
- raise ValueError(
- "'watchdog' is enabled, but the required 'watchdog' dependency (optional) is not installed"
- )
-
-
-class LocalDataSchema(ConfigSchema):
- """
- Local data for forward records (A/AAAA) and reverse records (PTR).
-
- ---
- ttl: Default TTL value used for added local data/records.
- nodata: Use NODATA synthesis. NODATA will be synthesized for matching name, but mismatching type(e.g. AAAA query when only A exists).
- addresses: Direct addition of hostname and IP addresses pairs.
- addresses_files: Direct addition of hostname and IP addresses pairs from files in '/etc/hosts' like format.
- records: Direct addition of records in DNS zone file format.
- rules: Local data rules.
- rpz: List of Response Policy Zones and its configuration.
- """ # noqa: E501
-
- ttl: Optional[TimeUnit] = None
- nodata: bool = True
- addresses: Optional[Dict[DomainName, ListOrItem[IPAddress]]] = None
- addresses_files: Optional[List[ReadableFile]] = None
- records: Optional[EscapedStr] = None
- rules: Optional[List[RuleSchema]] = None
- rpz: Optional[List[RPZSchema]] = None
- # root_fallback_addresses*: removed, rarely useful
+++ /dev/null
-import os
-from typing import Any, List, Literal, Optional, Set, Type, Union, cast
-
-from knot_resolver.datamodel.types import WritableFilePath
-from knot_resolver.utils.modeling import ConfigSchema
-from knot_resolver.utils.modeling.base_schema import is_obj_type_valid
-
-LogLevelEnum = Literal["crit", "err", "warning", "notice", "info", "debug"]
-LogTargetEnum = Literal["syslog", "stderr", "stdout"]
-
-LogGroupsProcessesEnum = Literal[
- "manager",
- "supervisord",
- "policy-loader",
- "kresd",
- "cache-gc",
-]
-
-LogGroupsManagerEnum = Literal[
- "files",
- "metrics",
- "server",
-]
-
-LogGroupsKresdEnum = Literal[
- ## Now the LOG_GRP_*_TAG defines, exactly from ../../../lib/log.h
- "system",
- "cache",
- "io",
- "net",
- "ta",
- "tasent",
- "tasign",
- "taupd",
- "tls",
- "gnutls",
- "tls_cl",
- "xdp",
- "doh",
- "dnssec",
- "hint",
- "plan",
- "iterat",
- "valdtr",
- "resolv",
- "select",
- "zoncut",
- "cookie",
- "statis",
- "rebind",
- "worker",
- "policy",
- "daf",
- "timejm",
- "timesk",
- "graphi",
- "prefil",
- "primin",
- "srvstl",
- "wtchdg",
- "nsid",
- "dnstap",
- "tests",
- "dotaut",
- "http",
- "contrl",
- "module",
- "devel",
- "renum",
- "exterr",
- "rules",
- "prlayr",
- "defer",
- "doq",
- "ngtcp2",
- # "reqdbg",... (non-displayed section of the enum)
-]
-
-LogGroupsEnum = Literal[LogGroupsProcessesEnum, LogGroupsManagerEnum, LogGroupsKresdEnum]
-
-
-class DnstapSchema(ConfigSchema):
- """
- Logging DNS queries and responses to a unix socket.
-
- ---
- enable: Enable/disable DNS queries logging.
- unix_socket: Path to unix domain socket where dnstap messages will be sent.
- log_queries: Log queries from downstream in wire format.
- log_responses: Log responses to downstream in wire format.
- log_tcp_rtt: Log TCP RTT (Round-trip time).
- """
-
- enable: bool = False
- unix_socket: Optional[WritableFilePath] = None
- log_queries: bool = False
- log_responses: bool = False
- log_tcp_rtt: bool = False
-
- def _validate(self) -> None:
- if self.enable and self.unix_socket is None:
- raise ValueError("DNS queries logging enabled, but 'unix-socket' not specified")
-
-
-class LoggingSchema(ConfigSchema):
- class Raw(ConfigSchema):
- """
- Logging and debugging configuration.
-
- ---
- level: Global logging level.
- target: Global logging stream target. "from-env" uses $KRES_LOGGING_TARGET and defaults to "stdout".
- groups: List of groups for which 'debug' logging level is set.
- dnstap: Logging DNS requests and responses to a unix socket.
- """
-
- level: LogLevelEnum = "notice"
- target: Union[LogTargetEnum, Literal["from-env"]] = "from-env"
- groups: Optional[List[LogGroupsEnum]] = None
- dnstap: DnstapSchema = DnstapSchema()
-
- _LAYER = Raw
-
- level: LogLevelEnum
- target: LogTargetEnum
- groups: Optional[List[LogGroupsEnum]]
- dnstap: DnstapSchema
-
- def _target(self, raw: Raw) -> LogTargetEnum:
- if raw.target == "from-env":
- target = os.environ.get("KRES_LOGGING_TARGET") or "stdout"
- if not is_obj_type_valid(target, cast(Type[Any], LogTargetEnum)):
- raise ValueError(f"logging target '{target}' read from $KRES_LOGGING_TARGET is invalid")
- return cast(LogTargetEnum, target)
- return raw.target
-
- def _validate(self) -> None:
- if self.groups is None:
- return
-
- checked: Set[str] = set()
- for i, g in enumerate(self.groups):
- if g in checked:
- raise ValueError(f"duplicate logging group '{g}' on index {i}")
- checked.add(g)
+++ /dev/null
-from typing import Optional
-
-from knot_resolver.datamodel.types import ReadableFile
-from knot_resolver.utils.modeling import ConfigSchema
-
-
-class LuaSchema(ConfigSchema):
- """
- Custom Lua configuration.
-
- ---
- script_only: Ignore declarative configuration intended for workers and use only Lua script or script file configured in this section.
- script: Custom Lua configuration script intended for workers.
- script_file: Path to file that contains Lua configuration script for workers.
- policy_script_only: Ignore declarative configuration intended for policy-loader and use only Lua script or script file configured in this section.
- policy_script: Custom Lua configuration script intended for policy-loader.
- policy_script_file: Path to file that contains Lua configuration script for policy-loader.
- """ # noqa: E501
-
- script_only: bool = False
- script: Optional[str] = None
- script_file: Optional[ReadableFile] = None
- policy_script_only: bool = False
- policy_script: Optional[str] = None
- policy_script_file: Optional[ReadableFile] = None
-
- def _validate(self) -> None:
- if self.script and self.script_file:
- raise ValueError("'lua.script' and 'lua.script-file' are both defined, only one can be used")
- if self.policy_script and self.policy_script_file:
- raise ValueError("'lua.policy-script' and 'lua.policy-script-file' are both defined, only one can be used")
+++ /dev/null
-from typing import Optional
-
-from knot_resolver.datamodel.types import IPAddressPort, WritableFilePath
-from knot_resolver.utils.modeling import ConfigSchema
-
-
-class ManagementSchema(ConfigSchema):
- """
- Configuration of management HTTP API.
-
- ---
- unix_socket: Path to unix domain socket to listen to.
- interface: IP address and port number to listen to.
- """
-
- unix_socket: Optional[WritableFilePath] = None
- interface: Optional[IPAddressPort] = None
-
- def _validate(self) -> None:
- if bool(self.unix_socket) == bool(self.interface):
- raise ValueError("One of 'interface' or 'unix-socket' must be configured.")
+++ /dev/null
-from typing import Literal, Union
-
-from knot_resolver.datamodel.types import DomainName, EscapedStr, IPAddress, PortNumber, TimeUnit
-from knot_resolver.utils.modeling import ConfigSchema
-
-
-class GraphiteSchema(ConfigSchema):
- enable: bool = False
- host: Union[None, IPAddress, DomainName] = None
- port: PortNumber = PortNumber(2003)
- prefix: EscapedStr = EscapedStr("")
- interval: TimeUnit = TimeUnit("5s")
- tcp: bool = False
-
- def _validate(self) -> None:
- if self.enable and not self.host:
- raise ValueError("'host' option must be configured to enable graphite bridge")
-
-
-class MonitoringSchema(ConfigSchema):
- """
- ---
- metrics: configures, whether metrics/statistics will be collected by the resolver
- graphite: optionally configures where should graphite metrics be sent to
- """ # noqa: D205, D400, D415
-
- metrics: Literal["manager-only", "lazy", "always"] = "lazy"
- graphite: GraphiteSchema = GraphiteSchema()
+++ /dev/null
-from typing import Any, List, Literal, Optional, Union
-
-from knot_resolver.constants import WATCHDOG_LIB
-from knot_resolver.datamodel.types import (
- EscapedStr32B,
- Int0_512,
- Int0_65535,
- Int1_4096,
- InterfaceOptionalPort,
- IPAddress,
- IPAddressEM,
- IPNetwork,
- IPv4Address,
- IPv6Address,
- ListOrItem,
- PortNumber,
- ReadableFile,
- SizeUnit,
- WritableFilePath,
-)
-from knot_resolver.utils.modeling import ConfigSchema
-
-KindEnum = Literal["dns", "xdp", "dot", "doh-legacy", "doh2", "doq"]
-
-
-class EdnsBufferSizeSchema(ConfigSchema):
- """
- EDNS payload size advertised in DNS packets.
-
- ---
- upstream: Maximum EDNS upstream (towards other DNS servers) payload size.
- downstream: Maximum EDNS downstream (towards clients) payload size for communication.
- """
-
- upstream: SizeUnit = SizeUnit("1232B")
- downstream: SizeUnit = SizeUnit("1232B")
-
-
-class AddressRenumberingSchema(ConfigSchema):
- """
- Renumbers addresses in answers to different address space.
-
- ---
- source: Source subnet.
- destination: Destination address prefix.
- """
-
- source: IPNetwork
- destination: Union[IPAddressEM, IPAddress]
-
-
-class QUICSchema(ConfigSchema):
- """
- Optional DoQ configuration.
-
- ---
- max_conns: Maximum number of active connections a single worker is allowed to accept.
- max_streams: Maximum number of concurrent streams a connection is allowed to open.
- require_retry: Require address validation for unknown source addresses.
- This adds a 1-RTT delay to connection establishment.
- """
-
- max_conns: Int1_4096 = Int1_4096(1024)
- max_streams: Int1_4096 = Int1_4096(1024)
- require_retry: bool = False
-
-
-class TLSSchema(ConfigSchema):
- class Raw(ConfigSchema):
- """
- TLS configuration, also affects DNS over TLS and DNS over HTTPS.
-
- ---
- watchdog: Enables watchdog of changes in TLS certificate files. Requires the optional 'watchdog' dependency.
- cert_file: Path to certificate file.
- key_file: Path to certificate key file.
- sticket_secret: Secret for TLS session resumption via tickets. (RFC 5077).
- sticket_secret_file: Path to file with secret for TLS session resumption via tickets. (RFC 5077).
- padding: EDNS(0) padding of queries and answers sent over an encrypted channel.
- """
-
- watchdog: Union[Literal["auto"], bool] = "auto"
- cert_file: Optional[ReadableFile] = None
- key_file: Optional[ReadableFile] = None
- sticket_secret: Optional[EscapedStr32B] = None
- sticket_secret_file: Optional[ReadableFile] = None
- padding: Union[bool, Int0_512] = True
-
- _LAYER = Raw
-
- watchdog: bool
- cert_file: Optional[ReadableFile] = None
- key_file: Optional[ReadableFile] = None
- sticket_secret: Optional[EscapedStr32B] = None
- sticket_secret_file: Optional[ReadableFile] = None
- padding: Union[bool, Int0_512] = True
-
- def _watchdog(self, obj: Raw) -> Any:
- if obj.watchdog == "auto":
- return WATCHDOG_LIB
- return obj.watchdog
-
- def _validate(self) -> None:
- if self.sticket_secret and self.sticket_secret_file:
- raise ValueError("'sticket_secret' and 'sticket_secret_file' are both defined, only one can be used")
- if bool(self.cert_file) != bool(self.key_file):
- raise ValueError("'cert-file' and 'key-file' must be configured together")
- if self.cert_file and self.key_file and self.watchdog and not WATCHDOG_LIB:
- raise ValueError(
- "'files-watchdog' is enabled, but the required 'watchdog' dependency (optional) is not installed"
- )
-
-
-class ListenSchema(ConfigSchema):
- class Raw(ConfigSchema):
- """
- Configuration of listening interface.
-
- ---
- unix_socket: Path to unix domain socket to listen to.
- interface: IP address or interface name with optional port number to listen to.
- port: Port number to listen to.
- kind: Specifies DNS query transport protocol.
- freebind: Used for binding to non-local address.
- """
-
- interface: Optional[ListOrItem[InterfaceOptionalPort]] = None
- unix_socket: Optional[ListOrItem[WritableFilePath]] = None
- port: Optional[PortNumber] = None
- kind: KindEnum = "dns"
- freebind: bool = False
-
- _LAYER = Raw
-
- interface: Optional[ListOrItem[InterfaceOptionalPort]]
- unix_socket: Optional[ListOrItem[WritableFilePath]]
- port: Optional[PortNumber]
- kind: KindEnum
- freebind: bool
-
- def _interface(self, origin: Raw) -> Optional[ListOrItem[InterfaceOptionalPort]]:
- if origin.interface:
- port_set: Optional[bool] = None
- for intrfc in origin.interface: # type: ignore[attr-defined]
- if origin.port and intrfc.port:
- raise ValueError("The port number is defined in two places ('port' option and '@<port>' syntax).")
- if port_set is not None and (bool(intrfc.port) != port_set):
- raise ValueError(
- "The '@<port>' syntax must be used either for all or none of the interface in the list."
- )
- port_set = bool(intrfc.port)
- return origin.interface
-
- def _port(self, origin: Raw) -> Optional[PortNumber]:
- if origin.port:
- return origin.port
- # default port number based on kind
- if origin.interface:
- if origin.kind in ["dot", "doq"]:
- return PortNumber(853)
- if origin.kind in ["doh-legacy", "doh2"]:
- return PortNumber(443)
- return PortNumber(53)
- return None
-
- def _validate(self) -> None:
- if bool(self.unix_socket) == bool(self.interface):
- raise ValueError("One of 'interface' or 'unix-socket' must be configured.")
- if self.port and self.unix_socket:
- raise ValueError(
- "'unix-socket' and 'port' are not compatible options."
- " Port configuration can only be used with 'interface' option."
- )
-
-
-class ProxyProtocolSchema(ConfigSchema):
- """
- PROXYv2 protocol configuration.
-
- ---
- enable: Enable/disable PROXYv2 protocol.
- allow: Allow usage of the PROXYv2 protocol headers by clients on the specified addresses.
- """
-
- enable: bool = False
- allow: Optional[List[Union[IPAddress, IPNetwork]]] = None
-
-
-class NetworkSchema(ConfigSchema):
- """
- Network connections and protocols configuration.
-
- ---
- do_ipv4: Enable/disable using IPv4 for contacting upstream nameservers.
- do_ipv6: Enable/disable using IPv6 for contacting upstream nameservers.
- out_interface_v4: IPv4 address used to perform queries. Not set by default, which lets the OS choose any address.
- out_interface_v6: IPv6 address used to perform queries. Not set by default, which lets the OS choose any address.
- tcp_pipeline: TCP pipeline limit. The number of outstanding queries that a single client connection can make in parallel.
- edns_tcp_keepalive: Allows clients to discover the connection timeout. (RFC 7828)
- edns_buffer_size: Maximum EDNS payload size advertised in DNS packets. Different values can be configured for communication downstream (towards clients) and upstream (towards other DNS servers).
- address_renumbering: Renumbers addresses in answers to different address space.
- tls: TLS configuration, also affects DNS over TLS, DNS over HTTPS and DNS over QUIC.
- quic: DNS over QUIC configuration.
- proxy_protocol: PROXYv2 protocol configuration.
- listen: List of interfaces to listen to and its configuration.
- """ # noqa: E501
-
- do_ipv4: bool = True
- do_ipv6: bool = True
- out_interface_v4: Optional[IPv4Address] = None
- out_interface_v6: Optional[IPv6Address] = None
- tcp_pipeline: Int0_65535 = Int0_65535(100)
- edns_tcp_keepalive: bool = True
- edns_buffer_size: EdnsBufferSizeSchema = EdnsBufferSizeSchema()
- address_renumbering: Optional[List[AddressRenumberingSchema]] = None
- tls: TLSSchema = TLSSchema()
- proxy_protocol: ProxyProtocolSchema = ProxyProtocolSchema()
- quic: QUICSchema = QUICSchema()
- listen: List[ListenSchema] = [
- ListenSchema({"interface": "127.0.0.1"}),
- ListenSchema({"interface": "::1", "freebind": True}),
- ]
+++ /dev/null
-from typing import Literal
-
-from knot_resolver.utils.modeling import ConfigSchema
-
-GlueCheckingEnum = Literal["normal", "strict", "permissive"]
-
-
-class OptionsSchema(ConfigSchema):
- """
- Fine-tuning global parameters of DNS resolver operation.
-
- ---
- glue_checking: Glue records strictness checking level.
- minimize: Send minimum amount of information in recursive queries to enhance privacy.
- query_loopback: Permits queries to loopback addresses.
- reorder_rrset: Controls whether resource records within a RRSet are reordered each time it is served from the cache.
- query_case_randomization: Randomize Query Character Case.
- priming: Initializing DNS resolver cache with Priming Queries (RFC 8109)
- rebinding_protection: Protection against DNS Rebinding attack.
- refuse_no_rd: Queries without RD (recursion desired) bit set in query are answered with REFUSED.
- time_jump_detection: Detection of difference between local system time and expiration time bounds in DNSSEC signatures for '. NS' records.
- violators_workarounds: Workarounds for known DNS protocol violators.
- serve_stale: Allows using timed-out records in case DNS resolver is unable to contact upstream servers.
- """ # noqa: E501
-
- glue_checking: GlueCheckingEnum = "normal"
- minimize: bool = True
- query_loopback: bool = False
- reorder_rrset: bool = True
- query_case_randomization: bool = True
- priming: bool = True
- rebinding_protection: bool = False
- refuse_no_rd: bool = True
- time_jump_detection: bool = True
- violators_workarounds: bool = False
- serve_stale: bool = False
+++ /dev/null
-from typing import Optional
-
-from knot_resolver.datamodel.types import (
- Int0_32,
- IntPositive,
- TimeUnit,
-)
-from knot_resolver.utils.modeling import ConfigSchema
-
-
-class RateLimitingSchema(ConfigSchema):
- """
- Configuration of rate limiting.
-
- ---
- enable: Enable/disable rate limiting
- rate_limit: Maximal number of allowed queries per second from a single host.
- instant_limit: Maximal number of allowed queries at a single point in time from a single host.
- capacity: Expected maximal number of blocked networks/hosts at the same time.
- slip: Number of restricted responses out of which one is sent as truncated, the others are dropped.
- log_period: Minimal time between two log messages, or '0s' to disable.
- dry_run: Perform only classification and logging but no restrictions.
- """
-
- enable: bool = False
- rate_limit: Optional[IntPositive] = None
- instant_limit: IntPositive = IntPositive(50)
- capacity: IntPositive = IntPositive(524288)
- slip: Int0_32 = Int0_32(2)
- log_period: TimeUnit = TimeUnit("0s")
- dry_run: bool = False
-
- def _validate(self) -> None:
- if self.enable and not self.rate_limit:
- raise ValueError("'rate-limit' has to be configured to enable rate limiting")
-
- max_instant_limit = int(2**32 // 768 - 1)
- if not int(self.instant_limit) <= max_instant_limit:
- raise ValueError(f"'instant-limit' has to be in range 1..{max_instant_limit}")
- if self.rate_limit and not int(self.rate_limit) <= 1000 * int(self.instant_limit):
- raise ValueError("'rate-limit' has to be in range 1..(1000 * instant-limit)")
+++ /dev/null
-import os
-import sys
-
-from jinja2 import Environment, FileSystemLoader, StrictUndefined, Template
-
-
-def _get_templates_dir() -> str:
- module = sys.modules["knot_resolver.datamodel"].__file__
- if module:
- templates_dir = os.path.join(os.path.dirname(module), "templates")
- if os.path.isdir(templates_dir):
- return templates_dir
- raise NotADirectoryError(f"the templates dir '{templates_dir}' is not a directory or does not exist")
- raise OSError("package 'knot_resolver.datamodel' cannot be located or loaded")
-
-
-_TEMPLATES_DIR = _get_templates_dir()
-
-
-def _import_kresd_config_template() -> Template:
- path = os.path.join(_TEMPLATES_DIR, "kresd.lua.j2")
- with open(path, "r", encoding="UTF-8") as file:
- template = file.read()
- return template_from_str(template)
-
-
-def _import_policy_loader_config_template() -> Template:
- path = os.path.join(_TEMPLATES_DIR, "policy-loader.lua.j2")
- with open(path, "r", encoding="UTF-8") as file:
- template = file.read()
- return template_from_str(template)
-
-
-def template_from_str(template: str) -> Template:
- ldr = FileSystemLoader(_TEMPLATES_DIR)
- env = Environment(trim_blocks=True, lstrip_blocks=True, loader=ldr, undefined=StrictUndefined)
- return env.from_string(template)
-
-
-KRESD_CONFIG_TEMPLATE = _import_kresd_config_template()
-
-
-POLICY_LOADER_CONFIG_TEMPLATE = _import_policy_loader_config_template()
+++ /dev/null
-cache.open({{ cfg.cache.size_max.bytes() }}, 'lmdb://{{ cfg.cache.storage }}')
-cache.min_ttl({{ cfg.cache.ttl_min.seconds() }})
-cache.max_ttl({{ cfg.cache.ttl_max.seconds() }})
-cache.ns_tout({{ cfg.cache.ns_timeout.millis() }})
-
-{% if cfg.cache.prefill %}
--- cache.prefill
-modules.load('prefill')
-prefill.config({
-{% for item in cfg.cache.prefill %}
- ['{{ item.origin.punycode() }}'] = {
- url = '{{ item.url }}',
- interval = {{ item.refresh_interval.seconds() }},
- {{ "ca_file = '" + item.ca_file|string + "'," if item.ca_file }}
- }
-{% endfor %}
-})
-{% endif %}
-
-{% if cfg.cache.prefetch.expiring %}
--- cache.prefetch.expiring
-modules.load('prefetch')
-{% endif %}
-
-{% if cfg.cache.prefetch.prediction.enable %}
--- cache.prefetch.prediction
-modules.load('predict')
-predict.config({
- window = {{ cfg.cache.prefetch.prediction.window.minutes() }},
- period = {{ cfg.cache.prefetch.prediction.period }},
-})
-{% endif %}
+++ /dev/null
-{% from 'macros/common_macros.lua.j2' import boolean %}
-
-{% if cfg.defer.enable and disable_defer is not defined -%}
-assert(C.defer_init(
- '{{ cfg.rundir }}/defer',
- {{ cfg.defer.log_period.millis() }},
- {{ cfg.workers }}) == 0)
-{% else %}
-assert(C.defer_init(nil, 0, 0) == 0)
-{%- endif %}
+++ /dev/null
-{% from 'macros/common_macros.lua.j2' import string_table %}
-
-{% if cfg.dns64.enable %}
-
--- Enable DNS64 by loading module
-modules.load('dns64')
-
--- Configure DNS64 module
-dns64.config({
- prefix = '{{ cfg.dns64.prefix.to_std().network_address|string }}',
-{% if cfg.dns64.reverse_ttl %}
- rev_ttl = {{ cfg.dns64.reverse_ttl.seconds() }},
-{% endif %}
-{% if cfg.dns64.exclude_subnets %}
- exclude_subnets = {{ string_table(cfg.dns64.exclude_subnets) }},
-{% endif %}
-})
-
-{% else %}
-
--- Disable DNS64 by unloading module
--- modules.unload('dns64')
-
-{% endif %}
\ No newline at end of file
+++ /dev/null
-{% from 'macros/common_macros.lua.j2' import boolean %}
-
-{% if cfg.dnssec.enable %}
-
--- dnssec.logging-bogus
-{% if cfg.dnssec.log_bogus %}
-modules.load('bogus_log')
-{% else %}
--- modules.unload('bogus_log')
-{% endif %}
-
--- dnssec.sentinel
-{% if cfg.dnssec.sentinel %}
-modules.load('ta_sentinel')
-{% else %}
-modules.unload('ta_sentinel')
-{% endif %}
-
--- dnssec.signal-query
-{% if cfg.dnssec.signal_query %}
-modules.load('ta_signal_query')
-{% else %}
-modules.unload('ta_signal_query')
-{% endif %}
-
-{% if cfg.dnssec.trust_anchors %}
--- dnssec.trust-anchors
-{% for ta in cfg.dnssec.trust_anchors %}
-trust_anchors.add('{{ ta }}')
-{% endfor %}
-{% endif %}
-
-{% if cfg.dnssec.negative_trust_anchors %}
--- dnssec.negative-trust-anchors
-trust_anchors.set_insecure({
-{% for nta in cfg.dnssec.negative_trust_anchors %}
- '{{ nta }}',
-{% endfor %}
-})
-{% endif %}
-
-{% if cfg.dnssec.trust_anchors_files %}
--- dnssec.trust-anchors-files
-{% for taf in cfg.dnssec.trust_anchors_files %}
-trust_anchors.add_file('{{ taf.file }}', {{ boolean(taf.read_only) }})
-{% endfor %}
-{% endif %}
-
-{% else %}
-
--- Disable DNSSEC
-trust_anchors.remove('.')
-
-{% endif %}
+++ /dev/null
-{% from 'macros/forward_macros.lua.j2' import policy_rule_forward_add, forward_servers %}
-
-{% if cfg.forward %}
-{% for fwd in cfg.forward %}
-{% for subtree in fwd.subtree %}
-{{ policy_rule_forward_add(subtree,fwd.options,fwd.servers) }}
-{% endfor %}
-{% endfor %}
-{% endif %}
-
-
-{% if cfg.fallback and cfg.fallback.enable %}
-modules.load('fallback')
-fallback.config({
- targets = {{ forward_servers(cfg.fallback.servers) }},
- options = {},
-})
-{% endif %}
+++ /dev/null
-{% if not cfg.lua.script_only %}
-
--- FFI library
-ffi = require('ffi')
-local C = ffi.C
-
--- Do not clear the DB with rules; we had it prepared by a different process.
-assert(C.kr_rules_init(nil, 0, false) == 0)
-
--- hostname
-hostname('{{ cfg.hostname }}')
-
-{% if cfg.nsid %}
--- nsid
-modules.load('nsid')
-nsid.name('{{ cfg.nsid }}' .. worker.id)
-{% endif %}
-
--- LOGGING section ----------------------------------
-{% include "logging.lua.j2" %}
-
--- MONITORING section -------------------------------
-{% include "monitoring.lua.j2" %}
-
--- OPTIONS section ----------------------------------
-{% include "options.lua.j2" %}
-
--- NETWORK section ----------------------------------
-{% include "network.lua.j2" %}
-
--- DNSSEC section -----------------------------------
-{% include "dnssec.lua.j2" %}
-
--- FORWARD and FALLBACK section ----------------------------------
-{% include "forward.lua.j2" %}
-
--- CACHE section ------------------------------------
-{% include "cache.lua.j2" %}
-
--- DNS64 section ------------------------------------
-{% include "dns64.lua.j2" %}
-
--- RATE-LIMITING section ------------------------------------
-{% include "rate_limiting.lua.j2" %}
-
--- DEFER section ------------------------------------
-{% include "defer.lua.j2" %}
-
-{% endif %}
-
--- LUA section --------------------------------------
--- Custom Lua code cannot be validated
-
-{% if cfg.lua.script_file %}
-{% import cfg.lua.script_file as script_file %}
-{{ script_file }}
-{% endif %}
-
-{% if cfg.lua.script %}
-{{ cfg.lua.script }}
-{% endif %}
+++ /dev/null
-{% from 'macros/local_data_macros.lua.j2' import local_data_rules, local_data_records, local_data_addresses, local_data_addresses_files %}
-{% from 'macros/common_macros.lua.j2' import boolean %}
-
-modules = { 'hints > iterate' }
-
-{# addresses #}
-{% if cfg.local_data.addresses -%}
-{{ local_data_addresses(cfg.local_data.addresses, cfg.local_data.nodata, cfg.local_data.ttl) }}
-{%- endif %}
-
-{# addresses-files #}
-{% if cfg.local_data.addresses_files -%}
-{{ local_data_addresses_files(cfg.local_data.addresses_files, cfg.local_data.nodata, cfg.local_data.ttl) }}
-{%- endif %}
-
-{# records #}
-{% if cfg.local_data.records -%}
-{{ local_data_records(cfg.local_data.records, false, cfg.local_data.nodata, cfg.local_data.ttl, none) }}
-{%- endif %}
-
-{# rules #}
-{% if cfg.local_data.rules -%}
-{{ local_data_rules(cfg.local_data.rules, cfg.local_data.nodata, cfg.local_data.ttl) }}
-{%- endif %}
-
-{# rpz #}
-{% if cfg.local_data.rpz -%}
-{% for rpz in cfg.local_data.rpz %}
-{{ local_data_records(rpz.file, true, cfg.local_data.nodata, cfg.local_data.ttl, rpz) }}
-{% endfor %}
-{%- endif %}
+++ /dev/null
-{% from 'macros/common_macros.lua.j2' import boolean %}
-
--- logging.level
-{% if cfg.logging.groups and "kresd" in cfg.logging.groups %}
-log_level('debug')
-{% else %}
-log_level('{{ cfg.logging.level }}')
-{% endif %}
-
-{% if cfg.logging.target -%}
--- logging.target
-log_target('{{ cfg.logging.target }}')
-{%- endif %}
-
-{% if cfg.logging.groups %}
--- logging.groups
-log_groups({
-{% for g in cfg.logging.groups %}
-{% if g not in [
- "manager", "supervisord", "policy-loader", "kresd", "cache-gc",
- "files", "metrics", "server",
-] %}
- '{{ g }}',
-{% endif %}
-{% endfor %}
-})
-{% endif %}
-
-{% if cfg.logging.dnstap.enable -%}
--- logging.dnstap
-modules.load('dnstap')
-dnstap.config({
- socket_path = '{{ cfg.logging.dnstap.unix_socket }}',
- client = {
- log_queries = {{ boolean(cfg.logging.dnstap.log_queries) }},
- log_responses = {{ boolean(cfg.logging.dnstap.log_responses) }},
- log_tcp_rtt = {{ boolean(cfg.logging.dnstap.log_tcp_rtt) }}
- }
-})
-{%- endif %}
+++ /dev/null
-{% from 'macros/common_macros.lua.j2' import boolean, quotes, qtype_table %}
-
-
-{% macro cache_clear(params) -%}
-cache.clear(
-{{- quotes(params.name) if params.name else 'nil' -}},
-{{- boolean(params.exact_name) -}},
-{{- qtype_table(params.rr_type) if params.rr_type else 'nil' -}},
-{{- params.chunk_size if not params.exact_name else 'nil' -}}
-)
-{%- endmacro %}
+++ /dev/null
-{% macro quotes(string) -%}
-'{{ string }}'
-{%- endmacro %}
-
-{% macro boolean(val, negation=false) -%}
-{%- if negation -%}
-{{ 'false' if val else 'true' }}
-{%- else-%}
-{{ 'true' if val else 'false' }}
-{%- endif -%}
-{%- endmacro %}
-
-{# Return string or table of strings #}
-{% macro string_table(table) -%}
-{%- if table is string -%}
-'{{ table|string }}'
-{%- else-%}
-{
-{%- for item in table -%}
-'{{ item|string }}',
-{%- endfor -%}
-}
-{%- endif -%}
-{%- endmacro %}
-
-{# Return str2ip or table of str2ip #}
-{% macro str2ip_table(table) -%}
-{%- if table is string -%}
-kres.str2ip('{{ table|string }}')
-{%- else-%}
-{
-{%- for item in table -%}
-kres.str2ip('{{ item|string }}'),
-{%- endfor -%}
-}
-{%- endif -%}
-{%- endmacro %}
-
-{# Return qtype or table of qtype #}
-{% macro qtype_table(table) -%}
-{%- if table is string -%}
-kres.type.{{ table|string }}
-{%- else-%}
-{
-{%- for item in table -%}
-kres.type.{{ item|string }},
-{%- endfor -%}
-}
-{%- endif -%}
-{%- endmacro %}
-
-{# Return server address or table of server addresses #}
-{% macro servers_table(servers) -%}
-{%- if servers is string -%}
-'{{ servers|string }}'
-{%- else-%}
-{
-{%- for item in servers -%}
-{%- if item.address is defined and item.address -%}
-'{{ item.address|string }}',
-{%- else -%}
-'{{ item|string }}',
-{%- endif -%}
-{%- endfor -%}
-}
-{%- endif -%}
-{%- endmacro %}
-
-{# Return server address or table of server addresses #}
-{% macro tls_servers_table(servers) -%}
-{
-{%- for item in servers -%}
-{%- if item.address is defined and item.address -%}
-{'{{ item.address|string }}',{{ tls_server_auth(item) }}},
-{%- else -%}
-'{{ item|string }}',
-{%- endif -%}
-{%- endfor -%}
-}
-{%- endmacro %}
-
-{% macro tls_server_auth(server) -%}
-{%- if server.hostname -%}
-hostname='{{ server.hostname|string }}',
-{%- endif -%}
-{%- if server.ca_file -%}
-ca_file='{{ server.ca_file|string }}',
-{%- endif -%}
-{%- if server.pin_sha256 -%}
-pin_sha256=
-{%- if server.pin_sha256 is string -%}
-'{{ server.pin_sha256|string }}',
-{%- else -%}
-{
-{%- for pin in server.pin_sha256 -%}
-'{{ pin|string }}',
-{%- endfor -%}
-}
-{%- endif -%}
-{%- endif -%}
-{%- endmacro %}
+++ /dev/null
-{% from 'macros/common_macros.lua.j2' import boolean, string_table %}
-
-{% macro forward_options(options) -%}
-{dnssec={{ boolean(options.dnssec) }},auth={{ boolean(options.authoritative) }}}
-{%- endmacro %}
-
-{% macro forward_server(server) -%}
-{%- if server.address is defined and server.address-%}
-{%- for addr in server.address -%}
-{'{{ addr }}',
-{%- if server.transport == 'tls' -%}
-tls=true,
-{%- else -%}
-tls=false,
-{%- endif -%}
-{%- if server.hostname -%}
-hostname='{{ server.hostname }}',
-{%- endif -%}
-{%- if server.pin_sha256 -%}
-pin_sha256={{ string_table(server.pin_sha256) }},
-{%- endif -%}
-{%- if server.ca_file -%}
-ca_file='{{ server.ca_file }}',
-{%- endif -%}
-},
-{%- endfor -%}
-{% else %}
-{'{{ server }}'},
-{%- endif -%}
-{%- endmacro %}
-
-{% macro forward_servers(servers) -%}
-{
-{%- for server in servers -%}
-{{ forward_server(server) }}
-{%- endfor -%}
-}
-{%- endmacro %}
-
-{% macro policy_rule_forward_add(subtree,options,servers) -%}
-policy.rule_forward_add('{{ subtree }}',{{ forward_options(options) }},{{ forward_servers(servers) }})
-{%- endmacro %}
+++ /dev/null
-{% from 'macros/common_macros.lua.j2' import string_table, boolean %}
-{% from 'macros/policy_macros.lua.j2' import policy_get_tagset, policy_todname %}
-
-
-{%- macro local_data_ttl(ttl) -%}
-{%- if ttl -%}
-{{ ttl.seconds() }}
-{%- else -%}
-{{ 'C.KR_RULE_TTL_DEFAULT' }}
-{%- endif -%}
-{%- endmacro -%}
-
-
-{% macro kr_rule_local_address(name, address, nodata, ttl, tags=none) -%}
-assert(C.kr_rule_local_address('{{ name }}', '{{ address }}',
- {{ boolean(nodata) }}, {{ local_data_ttl(ttl)}}, {{ policy_get_tagset(tags) }},
- C.KR_RULE_OPTS_DEFAULT) == 0)
-{%- endmacro -%}
-
-
-{% macro local_data_addresses(pairs, nodata, ttl) -%}
-{% for name, addresses in pairs.items() %}
-{% for address in addresses %}
-{{ kr_rule_local_address(name, address, nodata, ttl) }}
-{% endfor %}
-{% endfor%}
-{%- endmacro %}
-
-
-{% macro kr_rule_local_hosts(file, nodata, ttl, tags=none) -%}
-assert(C.kr_rule_local_hosts('{{ file }}', {{ boolean(nodata) }},
- {{ local_data_ttl(ttl)}}, {{ policy_get_tagset(tags) }}, C.KR_RULE_OPTS_DEFAULT) == 0)
-{%- endmacro %}
-
-
-{% macro local_data_addresses_files(files, nodata, ttl, tags) -%}
-{% for file in files %}
-{{ kr_rule_local_hosts(file, nodata, ttl, tags) }}
-{% endfor %}
-{%- endmacro %}
-
-
-{% macro local_data_records(input_str, is_rpz, nodata, ttl, extra, id='rrs') -%}
-{{ id }} = ffi.new('struct kr_rule_zonefile_config')
-{{ id }}.ttl = {{ local_data_ttl(ttl) }}
-{{ id }}.tags = {{ policy_get_tagset(extra.tags) }}
-{{ id }}.nodata = {{ boolean(nodata) }}
-{{ id }}.is_rpz = {{ boolean(is_rpz) }}
-{% if is_rpz -%}
-{{ id }}.filename = '{{ input_str }}'
-{% else %}
-{{ id }}.input_str = [[
-{{ input_str.multiline() }}
-]]
-{% endif %}
-{# .opts are complicated #}
-{{ id }}.opts = C.KR_RULE_OPTS_DEFAULT
-{% if extra is not none -%}
-{% if false and extra.dry_run is not none and extra.dry_run -%}
-{{ id }}.opts.score = 4
-{% else %}
-{{ id }}.opts.score = 9
-{% endif %}
-{% if 'log' in extra and extra.log is not none -%}
-{{ id }}.opts.log_level = 3 -- notice
-{% if 'ip' in extra.log -%}
-{{ id }}.opts.log_ip = true
-{% endif %}
-{% if 'name' in extra.log -%}
-{{ id }}.opts.log_name = true
-{% endif %}
-{% endif %}
-{% endif %}
-assert(C.kr_rule_zonefile({{ id }})==0)
-{%- endmacro %}
-
-
-{% macro kr_rule_local_subtree(name, type, ttl, tags=none) -%}
-assert(C.kr_rule_local_subtree(todname('{{ name }}'),
- C.KR_RULE_SUB_{{ type.upper() }}, {{ local_data_ttl(ttl) }}, {{ policy_get_tagset(tags) }},
- C.KR_RULE_OPTS_DEFAULT) == 0)
-{%- endmacro %}
-
-
-{% macro local_data_rules(items, nodata, ttl) -%}
-{% for item in items %}
-{% if item.name %}
-{% for name in item.name %}
-{% if item.address %}
-{% for address in item.address %}
-{{ kr_rule_local_address(name, address, nodata if item.nodata is none else item.nodata, item.ttl or ttl, item.tags) }}
-{% endfor %}
-{% endif %}
-{% if item.subtree %}
-{{ kr_rule_local_subtree(name, item.subtree, item.ttl or ttl, item.tags) }}
-{% endif %}
-{% endfor %}
-{% elif item.file %}
-{% for file in item.file %}
-{{ kr_rule_local_hosts(file, nodata if item.nodata is none else item.nodata, item.ttl or ttl, item.tags) }}
-{% endfor %}
-{% elif item.records %}
-{{ local_data_records(item.records, false, nodata if item.nodata is none else item.nodata, item.ttl or ttl, item) }}
-{% endif %}
-{% endfor %}
-{%- endmacro %}
+++ /dev/null
-{% macro http_config(http_cfg, kind, tls=true) -%}
-http.config({tls={{ 'true' if tls else 'false'}},
-{%- if http_cfg.cert_file -%}
- cert='{{ http_cfg.cert_file }}',
-{%- endif -%}
-{%- if http_cfg.key_file -%}
- key='{{ http_cfg.key_file }}',
-{%- endif -%}
-},'{{ kind }}')
-{%- endmacro %}
-
-
-{% macro listen_kind(kind) -%}
-{%- if kind == "dot" -%}
-'tls'
-{%- elif kind == "doh-legacy" -%}
-'doh_legacy'
-{%- else -%}
-'{{ kind }}'
-{%- endif -%}
-{%- endmacro %}
-
-
-{% macro net_listen_unix_socket(path, kind, freebind) -%}
-net.listen('{{ path }}',nil,{kind={{ listen_kind(kind) }},freebind={{ 'true' if freebind else 'false'}}})
-{%- endmacro %}
-
-
-{% macro net_listen_interface(interface, kind, freebind, port) -%}
-net.listen(
-{%- if interface.addr -%}
-'{{ interface.addr }}',
-{%- elif interface.if_name -%}
-net['{{ interface.if_name }}'],
-{%- endif -%}
-{%- if interface.port -%}
-{{ interface.port }},
-{%- else -%}
-{{ port }},
-{%- endif -%}
-{kind={{ listen_kind(kind) }},freebind={{ 'true' if freebind else 'false'}}})
-{%- endmacro %}
-
-
-{% macro network_listen(listen) -%}
-{%- if listen.unix_socket -%}
-{% for path in listen.unix_socket %}
-{{ net_listen_unix_socket(path, listen.kind, listen.freebind) }}
-{% endfor %}
-{%- elif listen.interface -%}
-{% for interface in listen.interface %}
-{{ net_listen_interface(interface, listen.kind, listen.freebind, listen.port) }}
-{% endfor %}
-{%- endif -%}
-{%- endmacro %}
\ No newline at end of file
+++ /dev/null
-{% from 'macros/common_macros.lua.j2' import string_table, str2ip_table, qtype_table, servers_table, tls_servers_table %}
-
-
-{# Add policy #}
-
-{% macro policy_add(rule, postrule=false) -%}
-{%- if postrule -%}
-policy.add({{ rule }},true)
-{%- else -%}
-policy.add({{ rule }})
-{%- endif -%}
-{%- endmacro %}
-
-
-{# Slice #}
-
-{% macro policy_slice_randomize_psl(seed='') -%}
-{%- if seed == '' -%}
-policy.slice_randomize_psl()
-{%- else -%}
-policy.slice_randomize_psl(seed={{ seed }})
-{%- endif -%}
-{%- endmacro %}
-
-{% macro policy_slice(func, actions) -%}
-policy.slice(
-{%- if func == 'randomize-psl' -%}
-policy.slice_randomize_psl()
-{%- else -%}
-policy.slice_randomize_psl()
-{%- endif -%}
-,{{ actions }})
-{%- endmacro %}
-
-
-{# Flags #}
-
-{% macro policy_flags(flags) -%}
-policy.FLAGS({
-{{- flags -}}
-})
-{%- endmacro %}
-
-
-{# Price factor #}
-
-{% macro policy_price_factor(factor) -%}
-policy.PRICE_FACTOR16({{ (factor|float * 2**16)|round|int }})
-{%- endmacro %}
-
-
-{# Tags assign #}
-
-{% macro policy_tags_assign(tags) -%}
-policy.TAGS_ASSIGN({{ string_table(tags) }})
-{%- endmacro %}
-
-{% macro policy_get_tagset(tags) -%}
-{%- if tags is defined and tags-%}
-policy.get_tagset({{ string_table(tags) }})
-{%- else -%}
-0
-{%- endif -%}
-{%- endmacro %}
-
-
-{# Filters #}
-
-{% macro policy_all(action) -%}
-policy.all({{ action }})
-{%- endmacro %}
-
-{% macro policy_suffix(action, suffix_table) -%}
-policy.suffix({{ action }},{{ suffix_table }})
-{%- endmacro %}
-
-{% macro policy_suffix_common(action, suffix_table, common_suffix=none) -%}
-policy.suffix_common({{ action }},{{ suffix_table }}
-{%- if common_suffix -%}
-,{{ common_suffix }}
-{%- endif -%}
-)
-{%- endmacro %}
-
-{% macro policy_pattern(action, pattern) -%}
-policy.pattern({{ action }},'{{ pattern }}')
-{%- endmacro %}
-
-{% macro policy_rpz(action, path, watch=true) -%}
-policy.rpz({{ action|string }},'{{ path|string }}',{{ 'true' if watch else 'false' }})
-{%- endmacro %}
-
-
-{# Custom filters #}
-
-{% macro declare_policy_qtype_custom_filter() -%}
-function policy_qtype(action, qtype)
-
- local function has_value (tab, val)
- for index, value in ipairs(tab) do
- if value == val then
- return true
- end
- end
-
- return false
- end
-
- return function (state, query)
- if query.stype == qtype then
- return action
- elseif has_value(qtype, query.stype) then
- return action
- else
- return nil
- end
- end
-end
-{%- endmacro %}
-
-{% macro policy_qtype_custom_filter(action, qtype) -%}
-policy_qtype({{ action }}, {{ qtype }})
-{%- endmacro %}
-
-
-{# Auto Filter #}
-
-{% macro policy_auto_filter(action, filter=none) -%}
-{%- if filter.suffix -%}
-{{ policy_suffix(action, policy_todname(filter.suffix)) }}
-{%- elif filter.pattern -%}
-{{ policy_pattern(action, filter.pattern) }}
-{%- elif filter.qtype -%}
-{{ policy_qtype_custom_filter(action, qtype_table(filter.qtype)) }}
-{%- else -%}
-{{ policy_all(action) }}
-{%- endif %}
-{%- endmacro %}
-
-
-{# Non-chain actions #}
-
-{% macro policy_pass() -%}
-policy.PASS
-{%- endmacro %}
-
-{% macro policy_deny() -%}
-policy.DENY
-{%- endmacro %}
-
-{% macro policy_deny_msg(message) -%}
-policy.DENY_MSG('{{ message|string }}')
-{%- endmacro %}
-
-{% macro policy_drop() -%}
-policy.DROP
-{%- endmacro %}
-
-{% macro policy_refuse() -%}
-policy.REFUSE
-{%- endmacro %}
-
-{% macro policy_tc() -%}
-policy.TC
-{%- endmacro %}
-
-{% macro policy_reroute(reroute) -%}
-policy.REROUTE(
-{%- for item in reroute -%}
-{['{{ item.source }}']='{{ item.destination }}'},
-{%- endfor -%}
-)
-{%- endmacro %}
-
-{% macro policy_answer(answer) -%}
-policy.ANSWER({[kres.type.{{ answer.rtype }}]={rdata=
-{%- if answer.rtype in ['A','AAAA'] -%}
-{{ str2ip_table(answer.rdata) }},
-{%- elif answer.rtype == '' -%}
-{# TODO: Do the same for other record types that require a special rdata type in Lua.
-By default, the raw string from config is used. #}
-{%- else -%}
-{{ string_table(answer.rdata) }},
-{%- endif -%}
-ttl={{ answer.ttl.seconds()|int }}}},{{ 'true' if answer.nodata else 'false' }})
-{%- endmacro %}
-
-{# policy.ANSWER( { [kres.type.A] = { rdata=kres.str2ip('192.0.2.7'), ttl=300 }}) #}
-
-{# Chain actions #}
-
-{% macro policy_mirror(mirror) -%}
-policy.MIRROR(
-{% if mirror is string %}
-'{{ mirror }}'
-{% else %}
-{
-{%- for addr in mirror -%}
-'{{ addr }}',
-{%- endfor -%}
-}
-{%- endif -%}
-)
-{%- endmacro %}
-
-{% macro policy_debug_always() -%}
-policy.DEBUG_ALWAYS
-{%- endmacro %}
-
-{% macro policy_debug_cache_miss() -%}
-policy.DEBUG_CACHE_MISS
-{%- endmacro %}
-
-{% macro policy_qtrace() -%}
-policy.QTRACE
-{%- endmacro %}
-
-{% macro policy_reqtrace() -%}
-policy.REQTRACE
-{%- endmacro %}
-
-{% macro policy_stub(servers) -%}
-policy.STUB({{ servers_table(servers) }})
-{%- endmacro %}
-
-{% macro policy_forward(servers) -%}
-policy.FORWARD({{ servers_table(servers) }})
-{%- endmacro %}
-
-{% macro policy_tls_forward(servers) -%}
-policy.TLS_FORWARD({{ tls_servers_table(servers) }})
-{%- endmacro %}
-
-
-{# Auto action #}
-
-{% macro policy_auto_action(rule) -%}
-{%- if rule.action == 'pass' -%}
-{{ policy_pass() }}
-{%- elif rule.action == 'deny' -%}
-{%- if rule.message -%}
-{{ policy_deny_msg(rule.message) }}
-{%- else -%}
-{{ policy_deny() }}
-{%- endif -%}
-{%- elif rule.action == 'drop' -%}
-{{ policy_drop() }}
-{%- elif rule.action == 'refuse' -%}
-{{ policy_refuse() }}
-{%- elif rule.action == 'tc' -%}
-{{ policy_tc() }}
-{%- elif rule.action == 'reroute' -%}
-{{ policy_reroute(rule.reroute) }}
-{%- elif rule.action == 'answer' -%}
-{{ policy_answer(rule.answer) }}
-{%- elif rule.action == 'mirror' -%}
-{{ policy_mirror(rule.mirror) }}
-{%- elif rule.action == 'debug-always' -%}
-{{ policy_debug_always() }}
-{%- elif rule.action == 'debug-cache-miss' -%}
-{{ policy_sebug_cache_miss() }}
-{%- elif rule.action == 'qtrace' -%}
-{{ policy_qtrace() }}
-{%- elif rule.action == 'reqtrace' -%}
-{{ policy_reqtrace() }}
-{%- endif -%}
-{%- endmacro %}
-
-
-{# Other #}
-
-{% macro policy_todname(name) -%}
-todname('{{ name.punycode()|string }}')
-{%- endmacro %}
-
-{% macro policy_todnames(names) -%}
-policy.todnames({
-{%- if names is string -%}
-'{{ names.punycode()|string }}'
-{%- else -%}
-{%- for name in names -%}
-'{{ name.punycode()|string }}',
-{%- endfor -%}
-{%- endif -%}
-})
-{%- endmacro %}
+++ /dev/null
-{%- macro get_proto_set(protocols) -%}
-0
-{%- for p in protocols or [] -%}
- + 2^C.KR_PROTO_{{ p.upper() }}
-{%- endfor -%}
-{%- endmacro -%}
-
-{% macro view_flags(options) -%}
-{% if not options.minimize -%}
-"NO_MINIMIZE",
-{%- endif %}
-{% if not options.dns64 -%}
-"DNS64_DISABLE",
-{%- endif %}
-{% if not options.fallback -%}
-"FALLBACK_DISABLE",
-{%- endif %}
-{%- endmacro %}
-
-{% macro view_answer(answer) -%}
-{%- if answer == 'allow' -%}
-policy.TAGS_ASSIGN({})
-{%- elif answer == 'refused' -%}
-'policy.REFUSE'
-{%- elif answer == 'noanswer' -%}
-'policy.NO_ANSWER'
-{%- endif -%}
-{%- endmacro %}
+++ /dev/null
---- control socket location
-local ffi = require('ffi')
-local id = os.getenv('SYSTEMD_INSTANCE')
-if not id then
- log_error(ffi.C.LOG_GRP_SYSTEM, 'environment variable $SYSTEMD_INSTANCE not set, which should not have been possible due to running under manager')
-else
- -- Bind to control socket in CWD (= rundir in config)
- -- FIXME replace with relative path after fixing https://gitlab.nic.cz/knot/knot-resolver/-/issues/720
- local path = '{{ cwd }}/control/'..id
- log_warn(ffi.C.LOG_GRP_SYSTEM, 'path = ' .. path)
- local ok, err = pcall(net.listen, path, nil, { kind = 'control' })
- if not ok then
- log_warn(ffi.C.LOG_GRP_NETWORK, 'bind to '..path..' failed '..err)
- end
-end
-
-{% if cfg.monitoring.metrics == "always" %}
-modules.load('stats')
-{% endif %}
-
---- function used for statistics collection
-function collect_lazy_statistics()
- if stats == nil then
- modules.load('stats')
- end
-
- return stats.list()
-end
-
---- function used for statistics collection
-function collect_statistics()
- return stats.list()
-end
+++ /dev/null
-{% from 'macros/common_macros.lua.j2' import boolean %}
-{% from 'macros/network_macros.lua.j2' import network_listen, http_config %}
-
--- network.do-ipv4/6
-net.ipv4 = {{ boolean(cfg.network.do_ipv4) }}
-net.ipv6 = {{ boolean(cfg.network.do_ipv6) }}
-
-{% if cfg.network.out_interface_v4 %}
--- network.out-interface-v4
-net.outgoing_v4('{{ cfg.network.out_interface_v4 }}')
-{% endif %}
-
-{% if cfg.network.out_interface_v6 %}
--- network.out-interface-v6
-net.outgoing_v6('{{ cfg.network.out_interface_v6 }}')
-{% endif %}
-
--- network.tcp-pipeline
-net.tcp_pipeline({{ cfg.network.tcp_pipeline }})
-
--- network.edns-keep-alive
-{% if cfg.network.edns_tcp_keepalive %}
-modules.load('edns_keepalive')
-{% else %}
-modules.unload('edns_keepalive')
-{% endif %}
-
--- network.edns-buffer-size
-net.bufsize(
- {{ cfg.network.edns_buffer_size.downstream.bytes() }},
- {{ cfg.network.edns_buffer_size.upstream.bytes() }}
-)
-
-{% if cfg.network.tls.cert_file and cfg.network.tls.key_file %}
--- network.tls
-net.tls('{{ cfg.network.tls.cert_file }}', '{{ cfg.network.tls.key_file }}')
-{% endif %}
-
-{% if cfg.network.tls.sticket_secret %}
--- network.tls.sticket-secret
-net.tls_sticket_secret('{{ cfg.network.tls.sticket_secret }}')
-{% endif %}
-
-{% if cfg.network.tls.sticket_secret_file %}
--- network.tls.sticket-secret-file
-net.tls_sticket_secret_file('{{ cfg.network.tls.sticket_secret_file }}')
-{% endif %}
-
--- network.tls.padding
-net.tls_padding(
-{%- if cfg.network.tls.padding == true -%}
-true
-{%- elif cfg.network.tls.padding == false -%}
-false
-{%- else -%}
-{{ cfg.network.tls.padding }}
-{%- endif -%}
-)
-
--- network.quic.max_conns
-net.quic_max_conns({{ cfg.network.quic.max_conns }})
-
--- network.quic.max_streams
-net.quic_max_streams({{ cfg.network.quic.max_streams }})
-
--- network.quic.require_retry
-net.quic_require_retry({{ boolean(cfg.network.quic.require_retry) }})
-
-{% if cfg.network.address_renumbering %}
--- network.address-renumbering
-modules.load('renumber')
-renumber.config({
-{% for item in cfg.network.address_renumbering %}
- {'{{ item.source }}', '{{ item.destination }}'},
-{% endfor %}
-})
-{% endif %}
-
-{%- set vars = {'doh_legacy': False} -%}
-{% for listen in cfg.network.listen if listen.kind == "doh-legacy" -%}
-{%- if vars.update({'doh_legacy': True}) -%}{%- endif -%}
-{%- endfor %}
-
-{% if vars.doh_legacy %}
--- doh_legacy http config
-modules.load('http')
-{{ http_config(cfg.network.tls,"doh_legacy") }}
-{% endif %}
-
-{% if cfg.network.proxy_protocol.enable %}
--- network.proxy-protocol
-net.proxy_allowed({
-{% for item in cfg.network.proxy_protocol.allow %}
-'{{ item }}',
-{% endfor %}
-})
-{% else %}
-net.proxy_allowed({})
-{% endif %}
-
--- network.listen
-{% for listen in cfg.network.listen %}
-{{ network_listen(listen) }}
-{% endfor %}
+++ /dev/null
-{% from 'macros/common_macros.lua.j2' import boolean %}
-
--- options.glue-checking
-mode('{{ cfg.options.glue_checking }}')
-
-{% if cfg.options.rebinding_protection %}
--- options.rebinding-protection
-modules.load('rebinding < iterate')
-{% endif %}
-
-{% if cfg.options.violators_workarounds %}
--- options.violators-workarounds
-modules.load('workarounds < iterate')
-{% endif %}
-
-{% if cfg.options.serve_stale %}
--- options.serve-stale
-modules.load('serve_stale < cache')
-{% endif %}
-
--- options.query-priming
-{% if cfg.options.priming %}
-modules.load('priming')
-{% else %}
-modules.unload('priming')
-{% endif %}
-
--- options.time-jump-detection
-{% if cfg.options.time_jump_detection %}
-modules.load('detect_time_jump')
-{% else %}
-modules.unload('detect_time_jump')
-{% endif %}
-
--- options.refuse-no-rd
-{% if cfg.options.refuse_no_rd %}
-modules.load('refuse_nord')
-{% else %}
-modules.unload('refuse_nord')
-{% endif %}
-
--- options.qname-minimisation
-option('NO_MINIMIZE', {{ boolean(cfg.options.minimize,true) }})
-
--- options.query-loopback
-option('ALLOW_LOCAL', {{ boolean(cfg.options.query_loopback) }})
-
--- options.reorder-rrset
-option('REORDER_RR', {{ boolean(cfg.options.reorder_rrset) }})
-
--- options.query-case-randomization
-option('NO_0X20', {{ boolean(cfg.options.query_case_randomization,true) }})
\ No newline at end of file
+++ /dev/null
-{% if not cfg.lua.policy_script_only %}
-
--- FFI library
-ffi = require('ffi')
-local C = ffi.C
-
--- logging.level
-{% if cfg.logging.groups and "policy-loader" in cfg.logging.groups %}
-log_level('debug')
-{% else %}
-log_level('{{ cfg.logging.level }}')
-{% endif %}
-
-{% if cfg.logging.target -%}
--- logging.target
-log_target('{{ cfg.logging.target }}')
-{%- endif %}
-
-{% if cfg.logging.groups %}
--- logging.groups
-log_groups({
-{% for g in cfg.logging.groups %}
-{% if g not in [
- "manager", "supervisord", "policy-loader", "kresd", "cache-gc",
- "files", "metrics", "server",
-] %}
- '{{ g }}',
-{% endif %}
-{% endfor %}
-})
-{% endif %}
-
--- Config required for the cache opening
-cache.open({{ cfg.cache.size_max.bytes() }}, 'lmdb://{{ cfg.cache.storage }}')
-
--- VIEWS section ------------------------------------
-{% include "views.lua.j2" %}
-
--- LOCAL-DATA section -------------------------------
-{% include "local_data.lua.j2" %}
-
--- FORWARD section ----------------------------------
-{% include "forward.lua.j2" %}
-
--- DEFER section ------------------------------------
--- Force-disable defer to avoid the default defer config.
-{% set disable_defer = true %}
-{% include "defer.lua.j2" %}
-
-{% endif %}
-
--- LUA section --------------------------------------
--- Custom Lua code cannot be validated
-
-{% if cfg.lua.policy_script_file %}
-{% import cfg.lua.policy_script_file as policy_script_file %}
-{{ policy_script_file }}
-{% endif %}
-
-{% if cfg.lua.policy_script %}
-{{ cfg.lua.policy_script }}
-{% endif %}
-
-
--- exit policy-loader properly
-quit()
+++ /dev/null
-{% from 'macros/common_macros.lua.j2' import boolean %}
-
-{% if cfg.rate_limiting.enable %}
-assert(
- C.ratelimiting_init(
- '{{ cfg.rundir }}/ratelimiting',
- {{ cfg.rate_limiting.capacity }},
- {{ cfg.rate_limiting.instant_limit }},
- {{ cfg.rate_limiting.rate_limit }},
- {{ cfg.rate_limiting.slip }},
- {{ cfg.rate_limiting.log_period.millis() }},
- {{ boolean(cfg.rate_limiting.dry_run) }}
- ) == 0
-)
-{% endif %}
+++ /dev/null
-{% from 'macros/common_macros.lua.j2' import quotes %}
-{% from 'macros/view_macros.lua.j2' import get_proto_set, view_flags, view_answer %}
-{% from 'macros/policy_macros.lua.j2' import policy_flags, policy_tags_assign, policy_price_factor %}
-
-{% if cfg.views %}
-{% for view in cfg.views %}
-{% for subnet in view.subnets %}
-
-assert(C.kr_view_insert_action('{{ subnet }}', '{{ view.dst_subnet or '' }}',
- {{ get_proto_set(view.protocols) }}, policy.COMBINE({
-{%- set flags = view_flags(view.options) -%}
-{% if flags %}
- {{ quotes(policy_flags(flags)) }},
-{%- endif %}
-{% if view.options.price_factor|float != 1.0 %}
- {{ quotes(policy_price_factor(view.options.price_factor)) }},
-{%- endif %}
-{% if view.tags %}
- {{ policy_tags_assign(view.tags) }},
-{% elif view.answer %}
- {{ view_answer(view.answer) }},
-{%- endif %}
- })) == 0)
-
-{% endfor %}
-{% endfor %}
-{% endif %}
+++ /dev/null
-from .enums import DNSRecordTypeEnum, PolicyActionEnum, PolicyFlagEnum
-from .files import AbsoluteDir, Dir, File, FilePath, ReadableFile, WritableDir, WritableFilePath
-from .generic_types import ListOrItem
-from .types import (
- DomainName,
- EscapedStr,
- EscapedStr32B,
- FloatNonNegative,
- IDPattern,
- Int0_32,
- Int0_512,
- Int0_65535,
- Int1_4096,
- InterfaceName,
- InterfaceOptionalPort,
- InterfacePort,
- IntNonNegative,
- IntPositive,
- IPAddress,
- IPAddressEM,
- IPAddressOptionalPort,
- IPAddressPort,
- IPNetwork,
- IPv4Address,
- IPv6Address,
- IPv6Network,
- IPv6Network96,
- Percent,
- PinSha256,
- PortNumber,
- SizeUnit,
- TimeUnit,
-)
-
-__all__ = [
- "PolicyActionEnum",
- "PolicyFlagEnum",
- "DNSRecordTypeEnum",
- "DomainName",
- "EscapedStr",
- "EscapedStr32B",
- "FloatNonNegative",
- "IDPattern",
- "Int0_32",
- "Int0_512",
- "Int1_4096",
- "Int0_65535",
- "InterfaceName",
- "InterfaceOptionalPort",
- "InterfacePort",
- "IntNonNegative",
- "IntPositive",
- "IPAddress",
- "IPAddressEM",
- "IPAddressOptionalPort",
- "IPAddressPort",
- "IPNetwork",
- "IPv4Address",
- "IPv6Address",
- "IPv6Network",
- "IPv6Network96",
- "ListOrItem",
- "Percent",
- "PinSha256",
- "PortNumber",
- "SizeUnit",
- "TimeUnit",
- "AbsoluteDir",
- "ReadableFile",
- "WritableDir",
- "WritableFilePath",
- "File",
- "FilePath",
- "Dir",
-]
+++ /dev/null
-# ruff: noqa: SLF001
-
-import re
-from typing import Any, Dict, Type, Union
-
-from knot_resolver.utils.compat.typing import Pattern
-from knot_resolver.utils.modeling import BaseValueType
-
-
-class IntBase(BaseValueType):
- """Base class to work with integer value."""
-
- _orig_value: int
- _value: int
-
- def __init__(self, source_value: Any, object_path: str = "/") -> None:
- if isinstance(source_value, int) and not isinstance(source_value, bool):
- self._orig_value = source_value
- self._value = source_value
- else:
- raise ValueError(
- f"Unexpected value for '{type(self)}'."
- f" Expected integer, got '{source_value}' with type '{type(source_value)}'",
- object_path,
- )
-
- def __int__(self) -> int:
- return self._value
-
- def __str__(self) -> str:
- return str(self._value)
-
- def __repr__(self) -> str:
- return f'{type(self).__name__}("{self._value}")'
-
- def __eq__(self, o: object) -> bool:
- return isinstance(o, IntBase) and o._value == self._value
-
- def serialize(self) -> Any:
- return self._orig_value
-
- @classmethod
- def json_schema(cls: Type["IntBase"]) -> Dict[Any, Any]:
- return {"type": "integer"}
-
-
-class FloatBase(BaseValueType):
- """Base class to work with float value."""
-
- _orig_value: Union[float, int]
- _value: float
-
- def __init__(self, source_value: Any, object_path: str = "/") -> None:
- if isinstance(source_value, (float, int)) and not isinstance(source_value, bool):
- self._orig_value = source_value
- self._value = float(source_value)
- else:
- raise ValueError(
- f"Unexpected value for '{type(self)}'."
- f" Expected float, got '{source_value}' with type '{type(source_value)}'",
- object_path,
- )
-
- def __int__(self) -> int:
- return int(self._value)
-
- def __float__(self) -> float:
- return self._value
-
- def __str__(self) -> str:
- return str(self._value)
-
- def __repr__(self) -> str:
- return f'{type(self).__name__}("{self._value}")'
-
- def __eq__(self, o: object) -> bool:
- return isinstance(o, FloatBase) and o._value == self._value
-
- def serialize(self) -> Any:
- return self._orig_value
-
- @classmethod
- def json_schema(cls: Type["FloatBase"]) -> Dict[Any, Any]:
- return {"type": "number"}
-
-
-class StrBase(BaseValueType):
- """Base class to work with string value."""
-
- _orig_value: str
- _value: str
-
- def __init__(self, source_value: Any, object_path: str = "/") -> None:
- if isinstance(source_value, (str, int)) and not isinstance(source_value, bool):
- self._orig_value = str(source_value)
- self._value = str(source_value)
- else:
- raise ValueError(
- f"Unexpected value for '{type(self)}'."
- f" Expected string, got '{source_value}' with type '{type(source_value)}'",
- object_path,
- )
-
- def __int__(self) -> int:
- raise ValueError("Can't convert string to an integer.")
-
- def __str__(self) -> str:
- return self._value
-
- def __repr__(self) -> str:
- return f'{type(self).__name__}("{self._value}")'
-
- def __hash__(self) -> int:
- return hash(self._value)
-
- def __eq__(self, o: object) -> bool:
- return isinstance(o, StrBase) and o._value == self._value
-
- def serialize(self) -> Any:
- return self._orig_value
-
- @classmethod
- def json_schema(cls: Type["StrBase"]) -> Dict[Any, Any]:
- return {"type": "string"}
-
-
-class StringLengthBase(StrBase):
- """
- Base class to work with string value length.
-
- Just inherit the class and set the values for '_min_bytes' and '_max_bytes'.
-
- class String32B(StringLengthBase):
- _min_bytes: int = 32
- """
-
- _min_bytes: int = 1
- _max_bytes: int
-
- def __init__(self, source_value: Any, object_path: str = "/") -> None:
- super().__init__(source_value, object_path)
- value_bytes = len(self._value.encode("utf-8"))
- if hasattr(self, "_min_bytes") and (value_bytes < self._min_bytes):
- raise ValueError(
- f"the string value {source_value} is shorter than the minimum {self._min_bytes} bytes.", object_path
- )
- if hasattr(self, "_max_bytes") and (value_bytes > self._max_bytes):
- raise ValueError(
- f"the string value {source_value} is longer than the maximum {self._max_bytes} bytes.", object_path
- )
-
- @classmethod
- def json_schema(cls: Type["StringLengthBase"]) -> Dict[Any, Any]:
- typ: Dict[str, Any] = {"type": "string"}
- if hasattr(cls, "_min_bytes"):
- typ["minLength"] = cls._min_bytes
- if hasattr(cls, "_max_bytes"):
- typ["maxLength"] = cls._max_bytes
- return typ
-
-
-class IntRangeBase(IntBase):
- """
- Base class to work with integer value in range.
-
- Just inherit the class and set the values for '_min' and '_max'.
-
- class IntNonNegative(IntRangeBase):
- _min: int = 0
- """
-
- _min: int
- _max: int
-
- def __init__(self, source_value: Any, object_path: str = "/") -> None:
- super().__init__(source_value, object_path)
- if hasattr(self, "_min") and (self._value < self._min):
- raise ValueError(f"value {self._value} is lower than the minimum {self._min}.", object_path)
- if hasattr(self, "_max") and (self._value > self._max):
- raise ValueError(f"value {self._value} is higher than the maximum {self._max}", object_path)
-
- @classmethod
- def json_schema(cls: Type["IntRangeBase"]) -> Dict[Any, Any]:
- typ: Dict[str, Any] = {"type": "integer"}
- if hasattr(cls, "_min"):
- typ["minimum"] = cls._min
- if hasattr(cls, "_max"):
- typ["maximum"] = cls._max
- return typ
-
-
-class FloatRangeBase(FloatBase):
- """
- Base class to work with float value in range.
-
- Just inherit the class and set the values for '_min' and '_max'.
-
- class FloatNonNegative(IntRangeBase):
- _min: float = 0.0
- """
-
- _min: float
- _max: float
-
- def __init__(self, source_value: Any, object_path: str = "/") -> None:
- super().__init__(source_value, object_path)
- if hasattr(self, "_min") and (self._value < self._min):
- raise ValueError(f"value {self._value} is lower than the minimum {self._min}.", object_path)
- if hasattr(self, "_max") and (self._value > self._max):
- raise ValueError(f"value {self._value} is higher than the maximum {self._max}", object_path)
-
- @classmethod
- def json_schema(cls: Type["FloatRangeBase"]) -> Dict[Any, Any]:
- typ: Dict[str, Any] = {"type": "number"}
- if hasattr(cls, "_min"):
- typ["minimum"] = cls._min
- if hasattr(cls, "_max"):
- typ["maximum"] = cls._max
- return typ
-
-
-class PatternBase(StrBase):
- """
- Base class to work with string value that match regex pattern.
-
- Just inherit the class and set regex pattern for '_re'.
-
- class ABPattern(PatternBase):
- _re: Pattern[str] = re.compile(r"ab*")
- """
-
- _re: Pattern[str]
-
- def __init__(self, source_value: Any, object_path: str = "/") -> None:
- super().__init__(source_value, object_path)
- if not type(self)._re.match(self._value):
- raise ValueError(f"'{self._value}' does not match '{self._re.pattern}' pattern", object_path)
-
- @classmethod
- def json_schema(cls: Type["PatternBase"]) -> Dict[Any, Any]:
- return {"type": "string", "pattern": rf"{cls._re.pattern}"}
-
-
-class UnitBase(StrBase):
- """
- Base class to work with string value that match regex pattern.
-
- Just inherit the class and set '_units'.
-
- class CustomUnit(PatternBase):
- _units = {"b": 1, "kb": 1000}
- """
-
- _re: Pattern[str]
- _units: Dict[str, int]
- _base_value: int
-
- def __init__(self, source_value: Any, object_path: str = "/") -> None:
- super().__init__(source_value, object_path)
-
- type(self)._re = re.compile(rf"^(\d+)({r'|'.join(type(self)._units.keys())})$")
- grouped = self._re.search(self._value)
- if grouped:
- val, unit = grouped.groups()
- if unit is None:
- raise ValueError(f"Missing units. Accepted units are {list(type(self)._units.keys())}", object_path)
- if unit not in type(self)._units:
- raise ValueError(
- f"Used unexpected unit '{unit}' for {type(self).__name__}."
- f" Accepted units are {list(type(self)._units.keys())}",
- object_path,
- )
- self._base_value = int(val) * type(self)._units[unit]
- else:
- raise ValueError(
- f"Unexpected value for '{type(self)}'."
- " Expected string that matches pattern "
- rf"'{type(self)._re.pattern}'."
- f" Positive integer and one of the units {list(type(self)._units.keys())}, got '{source_value}'.",
- object_path,
- )
-
- def __int__(self) -> int:
- return self._base_value
-
- def __repr__(self) -> str:
- return f"Unit[{type(self).__name__},{self._value}]"
-
- def __eq__(self, o: object) -> bool:
- """Two instances are equal when they represent the same size regardless of their string representation."""
- return isinstance(o, UnitBase) and o._value == self._value
-
- def serialize(self) -> Any:
- return self._orig_value
-
- @classmethod
- def json_schema(cls: Type["UnitBase"]) -> Dict[Any, Any]:
- return {"type": "string", "pattern": rf"{cls._re.pattern}"}
+++ /dev/null
-from typing import Literal
-
-# Policy actions
-PolicyActionEnum = Literal[
- # Nonchain actions
- "pass",
- "deny",
- "drop",
- "refuse",
- "tc",
- "reroute",
- "answer",
- # Chain actions
- "mirror",
- "forward",
- "stub",
- "debug-always",
- "debug-cache-miss",
- "qtrace",
- "reqtrace",
-]
-
-# FLAGS from https://www.knot-resolver.cz/documentation/latest/lib.html?highlight=options#c.kr_qflags
-PolicyFlagEnum = Literal[
- "no-minimize",
- "no-ipv4",
- "no-ipv6",
- "tcp",
- "resolved",
- "await-ipv4",
- "await-ipv6",
- "await-cut",
- "no-edns",
- "cached",
- "no-cache",
- "expiring",
- "allow_local",
- "dnssec-want",
- "dnssec-bogus",
- "dnssec-insecure",
- "dnssec-cd",
- "stub",
- "always-cut",
- "dnssec-wexpand",
- "permissive",
- "strict",
- "badcookie-again",
- "cname",
- "reorder-rr",
- "trace",
- "no-0x20",
- "dnssec-nods",
- "dnssec-optout",
- "nonauth",
- "forward",
- "dns64-mark",
- "cache-tried",
- "no-ns-found",
- "pkt-is-sane",
- "dns64-disable",
-]
-
-# DNS records from 'kres.type' table
-DNSRecordTypeEnum = Literal[
- "A",
- "A6",
- "AAAA",
- "AFSDB",
- "ANY",
- "APL",
- "ATMA",
- "AVC",
- "AXFR",
- "CAA",
- "CDNSKEY",
- "CDS",
- "CERT",
- "CNAME",
- "CSYNC",
- "DHCID",
- "DLV",
- "DNAME",
- "DNSKEY",
- "DOA",
- "DS",
- "EID",
- "EUI48",
- "EUI64",
- "GID",
- "GPOS",
- "HINFO",
- "HIP",
- "HTTPS",
- "IPSECKEY",
- "ISDN",
- "IXFR",
- "KEY",
- "KX",
- "L32",
- "L64",
- "LOC",
- "LP",
- "MAILA",
- "MAILB",
- "MB",
- "MD",
- "MF",
- "MG",
- "MINFO",
- "MR",
- "MX",
- "NAPTR",
- "NID",
- "NIMLOC",
- "NINFO",
- "NS",
- "NSAP",
- "NSAP-PTR",
- "NSEC",
- "NSEC3",
- "NSEC3PARAM",
- "NULL",
- "NXT",
- "OPENPGPKEY",
- "OPT",
- "PTR",
- "PX",
- "RKEY",
- "RP",
- "RRSIG",
- "RT",
- "SIG",
- "SINK",
- "SMIMEA",
- "SOA",
- "SPF",
- "SRV",
- "SSHFP",
- "SVCB",
- "TA",
- "TALINK",
- "TKEY",
- "TLSA",
- "TSIG",
- "TXT",
- "UID",
- "UINFO",
- "UNSPEC",
- "URI",
- "WKS",
- "X25",
- "ZONEMD",
-]
+++ /dev/null
-# ruff: noqa: D205, D400, D415
-import logging
-import os
-import stat
-from enum import Flag, auto
-from grp import getgrnam
-from pathlib import Path
-from pwd import getpwnam, getpwuid
-from typing import Any, Dict, Tuple, Type, TypeVar
-
-from knot_resolver.constants import GROUP, USER
-from knot_resolver.datamodel.globals import get_permissions_default, get_resolve_root, get_strict_validation
-from knot_resolver.utils.modeling.base_value_type import BaseValueType
-
-logger = logging.getLogger(__name__)
-
-
-class UncheckedPath(BaseValueType):
- """
- Wrapper around pathlib.Path object.
-
- Can represent pretty much any Path. No checks are performed on the value. The value is taken as is.
- """
-
- _value: Path
-
- def __init__(
- self, source_value: Any, parents: Tuple["UncheckedPath", ...] = tuple(), object_path: str = "/"
- ) -> None:
- self._object_path: str = object_path
- self._parents: Tuple[UncheckedPath, ...] = parents
- self.strict_validation: bool = get_strict_validation()
-
- if isinstance(source_value, str):
- # we do not load global validation context if the path is absolute
- # this prevents errors when constructing defaults in the schema
- resolve_root = Path("/") if source_value.startswith("/") else get_resolve_root()
-
- self._raw_value: str = source_value
- if self._parents:
- pp = map(lambda p: p.to_path(), self._parents)
- self._value: Path = Path(resolve_root, *pp, source_value)
- else:
- self._value: Path = Path(resolve_root, source_value)
- else:
- raise ValueError(f"expected file path in a string, got '{source_value}' with type '{type(source_value)}'.")
-
- def __str__(self) -> str:
- return str(self._value)
-
- def __eq__(self, o: object) -> bool:
- if not isinstance(o, UncheckedPath):
- return False
-
- return o._value == self._value
-
- def __int__(self) -> int:
- raise RuntimeError("Path cannot be converted to type <int>")
-
- def to_path(self) -> Path:
- return self._value
-
- def serialize(self) -> Any:
- return self._raw_value
-
- def relative_to(self, parent: "UncheckedPath") -> "UncheckedPath":
- """Return a path with an added parent part."""
- return UncheckedPath(self._raw_value, parents=(parent, *self._parents), object_path=self._object_path)
-
- UPT = TypeVar("UPT", bound="UncheckedPath")
-
- def reconstruct(self, cls: Type[UPT]) -> UPT:
- """Rebuild this object as an instance of its subclass. Practically, allows for conversions from."""
- return cls(self._raw_value, parents=self._parents, object_path=self._object_path)
-
- @classmethod
- def json_schema(cls: Type["UncheckedPath"]) -> Dict[Any, Any]:
- return {
- "type": "string",
- }
-
-
-class Dir(UncheckedPath):
- """
- Path, that is enforced to be:
- - an existing directory
- """
-
- def __init__(
- self, source_value: Any, parents: Tuple["UncheckedPath", ...] = tuple(), object_path: str = "/"
- ) -> None:
- try:
- super().__init__(source_value, parents=parents, object_path=object_path)
- if self.strict_validation and not self._value.is_dir():
- raise ValueError(f"path '{self._value}' does not point to an existing directory")
- except PermissionError as e:
- raise ValueError(str(e)) from e
-
-
-class AbsoluteDir(Dir):
- """
- Path, that is enforced to be:
- - absolute
- - an existing directory
- """
-
- def __init__(
- self, source_value: Any, parents: Tuple["UncheckedPath", ...] = tuple(), object_path: str = "/"
- ) -> None:
- super().__init__(source_value, parents=parents, object_path=object_path)
- if self.strict_validation and not self._value.is_absolute():
- raise ValueError(f"path '{self._value}' is not absolute")
-
-
-class File(UncheckedPath):
- """
- Path, that is enforced to be:
- - an existing file
- """
-
- def __init__(
- self, source_value: Any, parents: Tuple["UncheckedPath", ...] = tuple(), object_path: str = "/"
- ) -> None:
- try:
- super().__init__(source_value, parents=parents, object_path=object_path)
- if self.strict_validation and not self._value.exists():
- raise ValueError(f"file '{self._value}' does not exist")
- if self.strict_validation and not self._value.is_file():
- raise ValueError(f"path '{self._value}' is not a file")
- except PermissionError as e:
- raise ValueError(str(e)) from e
-
-
-class FilePath(UncheckedPath):
- """
- Path, that is enforced to be:
- - parent of the last path segment is an existing directory
- - it does not point to a dir
- """
-
- def __init__(
- self, source_value: Any, parents: Tuple["UncheckedPath", ...] = tuple(), object_path: str = "/"
- ) -> None:
- try:
- super().__init__(source_value, parents=parents, object_path=object_path)
- p = self._value.parent
- if self.strict_validation and (not p.exists() or not p.is_dir()):
- raise ValueError(f"path '{self._value}' does not point inside an existing directory")
- if self.strict_validation and self._value.is_dir():
- raise ValueError(f"path '{self._value}' points to a directory when we expected a file")
- except PermissionError as e:
- raise ValueError(str(e)) from e
-
-
-class _PermissionMode(Flag):
- READ = auto()
- WRITE = auto()
- EXECUTE = auto()
-
-
-def _check_permission(dest_path: Path, perm_mode: _PermissionMode) -> bool:
- chflags = {
- _PermissionMode.READ: [stat.S_IRUSR, stat.S_IRGRP, stat.S_IROTH],
- _PermissionMode.WRITE: [stat.S_IWUSR, stat.S_IWGRP, stat.S_IWOTH],
- _PermissionMode.EXECUTE: [stat.S_IXUSR, stat.S_IXGRP, stat.S_IXOTH],
- }
-
- # running outside the manager (client, ...)
- if get_permissions_default():
- user_uid = getpwnam(USER).pw_uid
- user_gid = getgrnam(GROUP).gr_gid
- username = USER
- # running under root privileges
- elif os.geteuid() == 0:
- return True
- # running normally under an unprivileged user
- else:
- user_uid = os.getuid()
- user_gid = os.getgid()
- username = getpwuid(user_uid).pw_name
-
- dest_stat = os.stat(dest_path)
- dest_uid = dest_stat.st_uid
- dest_gid = dest_stat.st_gid
- dest_mode = dest_stat.st_mode
-
- def accessible(perm: _PermissionMode) -> bool:
- if user_uid == dest_uid:
- return bool(dest_mode & chflags[perm][0])
- b_groups = os.getgrouplist(username, user_gid)
- if user_gid == dest_gid or dest_gid in b_groups:
- return bool(dest_mode & chflags[perm][1])
- return bool(dest_mode & chflags[perm][2])
-
- # __iter__ for class enum.Flag added in python3.11
- # 'for perm in perm_mode:' fails for <=python3.11
- return all(not (perm in perm_mode and not accessible(perm)) for perm in _PermissionMode)
-
-
-class ReadableFile(File):
- """
- Path, that is enforced to be:
-
- - an existing file
- - readable by knot-resolver processes
- """
-
- def __init__(
- self, source_value: Any, parents: Tuple["UncheckedPath", ...] = tuple(), object_path: str = "/"
- ) -> None:
- super().__init__(source_value, parents=parents, object_path=object_path)
-
- if self.strict_validation and not _check_permission(self._value, _PermissionMode.READ):
- msg = f"{USER}:{GROUP} has insufficient permissions to read '{self._value}'"
- if not os.access(self._value, os.R_OK):
- raise ValueError(msg)
- logger.info(f"{msg}, but the resolver can somehow (ACLs, ...) read the file")
-
-
-class WritableDir(Dir):
- """
- Path, that is enforced to be:
- - an existing directory
- - writable/executable by knot-resolver processes
- """
-
- def __init__(
- self, source_value: Any, parents: Tuple["UncheckedPath", ...] = tuple(), object_path: str = "/"
- ) -> None:
- super().__init__(source_value, parents=parents, object_path=object_path)
-
- if self.strict_validation and not _check_permission(
- self._value, _PermissionMode.WRITE | _PermissionMode.EXECUTE
- ):
- msg = f"{USER}:{GROUP} has insufficient permissions to write/execute '{self._value}'"
- if not os.access(self._value, os.W_OK | os.X_OK):
- raise ValueError(msg)
- logger.info(f"{msg}, but the resolver can somehow (ACLs, ...) write to the directory")
-
-
-class WritableFilePath(FilePath):
- """
- Path, that is enforced to be:
- - parent of the last path segment is an existing directory
- - it does not point to a dir
- - writable/executable parent directory by knot-resolver processes
- """
-
- def __init__(
- self, source_value: Any, parents: Tuple["UncheckedPath", ...] = tuple(), object_path: str = "/"
- ) -> None:
- super().__init__(source_value, parents=parents, object_path=object_path)
-
- if self.strict_validation:
- # check that parent dir is writable
- if not _check_permission(self._value.parent, _PermissionMode.WRITE | _PermissionMode.EXECUTE):
- msg = f"{USER}:{GROUP} has insufficient permissions to write/execute '{self._value.parent}'"
- # os.access() on the dir just provides a more precise message,
- # as the os.access() on the file below check everything in one go
- if not os.access(self._value.parent, os.W_OK | os.X_OK):
- raise ValueError(msg)
- logger.info(f"{msg}, but the resolver can somehow (ACLs, ...) write to the directory")
-
- # check that existing file is writable
- if self._value.exists() and not _check_permission(self._value, _PermissionMode.WRITE):
- msg = f"{USER}:{GROUP} has insufficient permissions to write/execute '{self._value}'"
- if not os.access(self._value, os.W_OK):
- raise ValueError(msg)
- logger.info(f"{msg}, but the resolver can somehow (ACLs, ...) write to the file")
+++ /dev/null
-from typing import Any, List, TypeVar, Union
-
-from knot_resolver.utils.modeling import BaseGenericTypeWrapper
-
-T = TypeVar("T")
-
-
-class ListOrItem(BaseGenericTypeWrapper[Union[List[T], T]]):
- _value_orig: Union[List[T], T]
- _list: List[T]
-
- def __init__(self, source_value: Any, object_path: str = "/") -> None: # pylint: disable=unused-argument
- self._value_orig: Union[List[T], T] = source_value
-
- self._list: List[T] = source_value if isinstance(source_value, list) else [source_value]
- if len(self) == 0:
- raise ValueError("empty list is not allowed")
-
- def __getitem__(self, index: Any) -> T:
- return self._list[index]
-
- def __int__(self) -> int:
- raise ValueError(f"Can't convert '{type(self).__name__}' to an integer.")
-
- def __str__(self) -> str:
- return str(self._value_orig)
-
- def to_std(self) -> List[T]:
- return self._list
-
- def __eq__(self, o: object) -> bool:
- return isinstance(o, ListOrItem) and o._value_orig == self._value_orig
-
- def __len__(self) -> int:
- return len(self._list)
-
- def serialize(self) -> Union[List[T], T]:
- return self._value_orig
+++ /dev/null
-import ipaddress
-import re
-from typing import Any, Dict, Optional, Type, Union
-
-from knot_resolver.datamodel.types.base_types import (
- FloatRangeBase,
- IntRangeBase,
- PatternBase,
- StrBase,
- StringLengthBase,
- UnitBase,
-)
-from knot_resolver.utils.modeling import BaseValueType
-
-
-class IntNonNegative(IntRangeBase):
- _min: int = 0
-
-
-class IntPositive(IntRangeBase):
- _min: int = 1
-
-
-class Int0_32(IntRangeBase): # noqa: N801
- _min: int = 0
- _max: int = 32
-
-
-class Int0_512(IntRangeBase): # noqa: N801
- _min: int = 0
- _max: int = 512
-
-
-class Int1_4096(IntRangeBase): # noqa: N801
- _min: int = 1
- _max: int = 4096
-
-
-class Int0_65535(IntRangeBase): # noqa: N801
- _min: int = 0
- _max: int = 65_535
-
-
-class Percent(IntRangeBase):
- _min: int = 0
- _max: int = 100
-
-
-class PortNumber(IntRangeBase):
- _min: int = 1
- _max: int = 65_535
-
- @classmethod
- def from_str(cls: Type["PortNumber"], port: str, object_path: str = "/") -> "PortNumber":
- try:
- return cls(int(port), object_path)
- except ValueError as e:
- raise ValueError(f"invalid port number {port}") from e
-
-
-class FloatNonNegative(FloatRangeBase):
- _min: float = 0.0
-
-
-class SizeUnit(UnitBase):
- _units = {"B": 1, "K": 1024, "M": 1024**2, "G": 1024**3}
-
- def bytes(self) -> int:
- return self._base_value
-
- def mbytes(self) -> int:
- return self._base_value // 1024**2
-
-
-class TimeUnit(UnitBase):
- _units = {"us": 1, "ms": 10**3, "s": 10**6, "m": 60 * 10**6, "h": 3600 * 10**6, "d": 24 * 3600 * 10**6}
-
- def minutes(self) -> int:
- return self._base_value // 1000**2 // 60
-
- def seconds(self) -> int:
- return self._base_value // 1000**2
-
- def millis(self) -> int:
- return self._base_value // 1000
-
- def micros(self) -> int:
- return self._base_value
-
-
-class EscapedStr(StrBase):
- """A string where escape sequences are ignored and quotes are escaped."""
-
- def __init__(self, source_value: Any, object_path: str = "/") -> None:
- super().__init__(source_value, object_path)
-
- escape = {
- "'": r"\'",
- '"': r"\"",
- "\a": r"\a",
- "\n": r"\n",
- "\r": r"\r",
- "\t": r"\t",
- "\b": r"\b",
- "\f": r"\f",
- "\v": r"\v",
- "\0": r"\0",
- }
-
- s = list(self._value)
- for i, c in enumerate(self._value):
- if c in escape:
- s[i] = escape[c]
- elif not c.isalnum():
- s[i] = repr(c)[1:-1]
- self._value = "".join(s)
-
- def multiline(self) -> str:
- """
- Lua multiline string is enclosed in double square brackets '[[ ]]'.
-
- This method makes sure that double square brackets are escaped.
- """
- replace = {
- "[[": r"\[\[",
- "]]": r"\]\]",
- }
-
- ml = self._orig_value
- for s, r in replace.items():
- ml = ml.replace(s, r)
- return ml
-
-
-class EscapedStr32B(EscapedStr, StringLengthBase):
- """Same as 'EscapedStr', but minimal length is 32 bytes."""
-
- _min_bytes: int = 32
-
-
-class DomainName(StrBase):
- """Fully or partially qualified domain name."""
-
- _punycode: str
- # fmt: off
- _re = re.compile(
- r"(?=^.{,253}\.?$)" # max 253 chars
- r"(^"
- # do not allow hyphen at the start and at the end of label
- r"(?!-)[^.]{,62}[^.-]" # max 63 chars in label; except dot
- r"(\.(?!-)[^.]{,62}[^.-])*" # start with dot; max 63 chars in label except dot
- r"\.?" # end with or without dot
- r"$)"
- r"|^\.$" # allow root-zone
- )
- # fmt: on
-
- def __init__(self, source_value: Any, object_path: str = "/") -> None:
- super().__init__(source_value, object_path)
- try:
- punycode = self._value.encode("idna").decode("utf-8") if self._value != "." else "."
- except ValueError as e:
- raise ValueError(
- f"conversion of '{self._value}' to IDN punycode representation failed",
- object_path,
- ) from e
-
- if type(self)._re.match(punycode): # noqa: SLF001
- self._punycode = punycode
- else:
- raise ValueError(
- f"'{source_value}' represented in punycode '{punycode}' does not match '{self._re.pattern}' pattern",
- object_path,
- )
-
- def __hash__(self) -> int:
- if self._value.endswith("."):
- return hash(self._value)
- return hash(f"{self._value}.")
-
- def punycode(self) -> str:
- return self._punycode
-
- @classmethod
- def json_schema(cls: Type["DomainName"]) -> Dict[Any, Any]:
- return {"type": "string", "pattern": rf"{cls._re.pattern}"}
-
-
-class InterfaceName(PatternBase):
- """Network interface name."""
-
- _re = re.compile(r"^[a-zA-Z0-9]+(?:[-_][a-zA-Z0-9]+)*$")
-
-
-class IDPattern(PatternBase):
- """Alphanumerical ID for identifying systemd slice."""
-
- _re = re.compile(r"^(?!-)[a-z0-9-]*[a-z0-9]+$")
-
-
-class PinSha256(PatternBase):
- """A string that stores base64 encoded sha256."""
-
- _re = re.compile(r"^[A-Za-z\d+/]{43}=$")
-
-
-class InterfacePort(StrBase):
- addr: Union[None, ipaddress.IPv4Address, ipaddress.IPv6Address] = None
- if_name: Optional[InterfaceName] = None
- port: PortNumber
-
- def __init__(self, source_value: Any, object_path: str = "/") -> None:
- super().__init__(source_value, object_path)
-
- parts = self._value.split("@")
- if len(parts) == 2:
- try:
- self.addr = ipaddress.ip_address(parts[0])
- except ValueError as e1:
- try:
- self.if_name = InterfaceName(parts[0])
- except ValueError as e2:
- raise ValueError(f"expected IP address or interface name, got '{parts[0]}'.", object_path) from (
- e1 and e2
- )
- self.port = PortNumber.from_str(parts[1], object_path)
- else:
- raise ValueError(f"expected '<ip-address|interface-name>@<port>', got '{source_value}'.", object_path)
-
-
-class InterfaceOptionalPort(StrBase):
- addr: Union[None, ipaddress.IPv4Address, ipaddress.IPv6Address] = None
- if_name: Optional[InterfaceName] = None
- port: Optional[PortNumber] = None
-
- def __init__(self, source_value: Any, object_path: str = "/") -> None:
- super().__init__(source_value, object_path)
-
- parts = self._value.split("@")
- if 0 < len(parts) < 3:
- try:
- self.addr = ipaddress.ip_address(parts[0])
- except ValueError as e1:
- try:
- self.if_name = InterfaceName(parts[0])
- except ValueError as e2:
- raise ValueError(f"expected IP address or interface name, got '{parts[0]}'.", object_path) from (
- e1 and e2
- )
- if len(parts) == 2:
- self.port = PortNumber.from_str(parts[1], object_path)
- else:
- raise ValueError(f"expected '<ip-address|interface-name>[@<port>]', got '{parts}'.", object_path)
-
-
-class IPAddressPort(StrBase):
- addr: Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
- port: PortNumber
-
- def __init__(self, source_value: Any, object_path: str = "/") -> None:
- super().__init__(source_value, object_path)
-
- parts = self._value.split("@")
- if len(parts) == 2:
- self.port = PortNumber.from_str(parts[1], object_path)
- try:
- self.addr = ipaddress.ip_address(parts[0])
- except ValueError as e:
- raise ValueError(f"failed to parse IP address '{parts[0]}'.", object_path) from e
- else:
- raise ValueError(f"expected '<ip-address>@<port>', got '{source_value}'.", object_path)
-
-
-class IPAddressOptionalPort(StrBase):
- addr: Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
- port: Optional[PortNumber] = None
-
- def __init__(self, source_value: Any, object_path: str = "/") -> None:
- super().__init__(source_value)
- parts = source_value.split("@")
- if 0 < len(parts) < 3:
- try:
- self.addr = ipaddress.ip_address(parts[0])
- except ValueError as e:
- raise ValueError(f"failed to parse IP address '{parts[0]}'.", object_path) from e
- if len(parts) == 2:
- self.port = PortNumber.from_str(parts[1], object_path)
- else:
- raise ValueError(f"expected '<ip-address>[@<port>]', got '{parts}'.", object_path)
-
-
-class IPv4Address(BaseValueType):
- _value: ipaddress.IPv4Address
-
- def __init__(self, source_value: Any, object_path: str = "/") -> None:
- if isinstance(source_value, str):
- try:
- self._value: ipaddress.IPv4Address = ipaddress.IPv4Address(source_value)
- except ValueError as e:
- raise ValueError("failed to parse IPv4 address.") from e
- else:
- raise ValueError(
- "Unexpected value for a IPv4 address."
- f" Expected string, got '{source_value}' with type '{type(source_value)}'",
- object_path,
- )
-
- def to_std(self) -> ipaddress.IPv4Address:
- return self._value
-
- def __str__(self) -> str:
- return str(self._value)
-
- def __int__(self) -> int:
- raise ValueError("Can't convert IPv4 address to an integer")
-
- def __repr__(self) -> str:
- return f'{type(self).__name__}("{self._value}")'
-
- def __eq__(self, o: object) -> bool:
- """Two instances of IPv4Address are equal when they represent same IPv4 address as string."""
- return isinstance(o, IPv4Address) and str(o._value) == str(self._value)
-
- def serialize(self) -> Any:
- return str(self._value)
-
- @classmethod
- def json_schema(cls: Type["IPv4Address"]) -> Dict[Any, Any]:
- return {
- "type": "string",
- }
-
-
-class IPv6Address(BaseValueType):
- _value: ipaddress.IPv6Address
-
- def __init__(self, source_value: Any, object_path: str = "/") -> None:
- if isinstance(source_value, str):
- try:
- self._value: ipaddress.IPv6Address = ipaddress.IPv6Address(source_value)
- except ValueError as e:
- raise ValueError("failed to parse IPv6 address.") from e
- else:
- raise ValueError(
- "Unexpected value for a IPv6 address."
- f" Expected string, got '{source_value}' with type '{type(source_value)}'",
- object_path,
- )
-
- def to_std(self) -> ipaddress.IPv6Address:
- return self._value
-
- def __str__(self) -> str:
- return str(self._value)
-
- def __int__(self) -> int:
- raise ValueError("Can't convert IPv6 address to an integer")
-
- def __repr__(self) -> str:
- return f'{type(self).__name__}("{self._value}")'
-
- def __eq__(self, o: object) -> bool:
- """Two instances of IPv6Address are equal when they represent same IPv6 address as string."""
- return isinstance(o, IPv6Address) and str(o._value) == str(self._value)
-
- def serialize(self) -> Any:
- return str(self._value)
-
- @classmethod
- def json_schema(cls: Type["IPv6Address"]) -> Dict[Any, Any]:
- return {
- "type": "string",
- }
-
-
-IPAddress = Union[IPv4Address, IPv6Address]
-
-
-class IPAddressEM(BaseValueType):
- """IP address with exclamation mark suffix, e.g. '127.0.0.1!'."""
-
- _value: str
- _addr: Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
-
- def __init__(self, source_value: Any, object_path: str = "/") -> None:
- if isinstance(source_value, str):
- if source_value.endswith("!"):
- addr, suff = source_value.split("!", 1)
- if suff != "":
- raise ValueError(f"suffix '{suff}' found after '!'.")
- else:
- raise ValueError("string does not end with '!'.")
- try:
- self._addr: Union[ipaddress.IPv4Address, ipaddress.IPv6Address] = ipaddress.ip_address(addr)
- self._value = source_value
- except ValueError as e:
- raise ValueError("failed to parse IP address.") from e
- else:
- raise ValueError(
- "Unexpected value for a IPv6 address."
- f" Expected string, got '{source_value}' with type '{type(source_value)}'",
- object_path,
- )
-
- def to_std(self) -> str:
- return self._value
-
- def __str__(self) -> str:
- return self._value
-
- def __int__(self) -> int:
- raise ValueError("Can't convert to an integer")
-
- def __repr__(self) -> str:
- return f'{type(self).__name__}("{self._value}")'
-
- def __eq__(self, o: object) -> bool:
- """Two instances of IPAddressEM are equal when they represent same string."""
- return isinstance(o, IPAddressEM) and o._value == self._value
-
- def serialize(self) -> Any:
- return self._value
-
- @classmethod
- def json_schema(cls: Type["IPAddressEM"]) -> Dict[Any, Any]:
- return {
- "type": "string",
- }
-
-
-class IPNetwork(BaseValueType):
- _value: Union[ipaddress.IPv4Network, ipaddress.IPv6Network]
-
- def __init__(self, source_value: Any, object_path: str = "/") -> None:
- if isinstance(source_value, str):
- try:
- self._value: Union[ipaddress.IPv4Network, ipaddress.IPv6Network] = ipaddress.ip_network(source_value)
- except ValueError as e:
- raise ValueError("failed to parse IP network.") from e
- else:
- raise ValueError(
- "Unexpected value for a network subnet."
- f" Expected string, got '{source_value}' with type '{type(source_value)}'"
- )
-
- def __int__(self) -> int:
- raise ValueError("Can't convert network prefix to an integer")
-
- def __str__(self) -> str:
- return self._value.with_prefixlen
-
- def __repr__(self) -> str:
- return f'{type(self).__name__}("{self._value}")'
-
- def __eq__(self, o: object) -> bool:
- return isinstance(o, IPNetwork) and o._value == self._value
-
- def to_std(self) -> Union[ipaddress.IPv4Network, ipaddress.IPv6Network]:
- return self._value
-
- def serialize(self) -> Any:
- return self._value.with_prefixlen
-
- @classmethod
- def json_schema(cls: Type["IPNetwork"]) -> Dict[Any, Any]:
- return {
- "type": "string",
- }
-
-
-class IPv6Network(BaseValueType):
- _value: ipaddress.IPv6Network
-
- def __init__(self, source_value: Any, object_path: str = "/") -> None:
- if isinstance(source_value, str):
- try:
- self._value: ipaddress.IPv6Network = ipaddress.IPv6Network(source_value)
- except ValueError as e:
- raise ValueError("failed to parse IPv6 network.") from e
- else:
- raise ValueError(
- "Unexpected value for a IPv6 network subnet."
- f" Expected string, got '{source_value}' with type '{type(source_value)}'"
- )
-
- def to_std(self) -> ipaddress.IPv6Network:
- return self._value
-
- def __str__(self) -> str:
- return self._value.with_prefixlen
-
- def __int__(self) -> int:
- raise ValueError("Can't convert network prefix to an integer")
-
- def __repr__(self) -> str:
- return f'{type(self).__name__}("{self._value}")'
-
- def __eq__(self, o: object) -> bool:
- return isinstance(o, IPv6Network) and o._value == self._value
-
- def serialize(self) -> Any:
- return self._value.with_prefixlen
-
- @classmethod
- def json_schema(cls: Type["IPv6Network"]) -> Dict[Any, Any]:
- return {
- "type": "string",
- }
-
-
-class IPv6Network96(IPv6Network):
- def __init__(self, source_value: Any, object_path: str = "/") -> None:
- super().__init__(source_value, object_path=object_path)
- if self._value.prefixlen == 128:
- raise ValueError(
- "Expected IPv6 network address with /96 prefix length."
- " Submitted address has been interpreted as /128."
- " Maybe, you forgot to add /96 after the base address?"
- )
-
- if self._value.prefixlen != 96:
- raise ValueError(
- "expected IPv6 network address with /96 prefix length." f" Got prefix lenght of {self._value.prefixlen}"
- )
+++ /dev/null
-from typing import List, Literal, Optional
-
-from knot_resolver.datamodel.types import FloatNonNegative, IDPattern, IPNetwork
-from knot_resolver.utils.modeling import ConfigSchema
-
-
-class ViewOptionsSchema(ConfigSchema):
- """
- Configuration options for clients identified by the view.
-
- ---
- minimize: Send minimum amount of information in recursive queries to enhance privacy.
- dns64: Enable/disable DNS64.
- price_factor: Multiplies rate-limiting and defer prices of operations, use 0 to whitelist.
- fallback: Enable/disable fallback on resolution failure.
- """
-
- minimize: bool = True
- dns64: bool = True
- price_factor: FloatNonNegative = FloatNonNegative(1.0)
- fallback: bool = True
-
-
-class ViewSchema(ConfigSchema):
- """
- Configuration parameters that allow you to create personalized policy rules and other.
-
- ---
- subnets: Identifies the client based on his subnet. Rule with more precise subnet takes priority.
- dst_subnet: Destination subnet, as an additional condition.
- protocols: Transport protocol, as an additional condition.
- tags: Tags to link with other policy rules.
- answer: Direct approach how to handle request from clients identified by the view.
- options: Configuration options for clients identified by the view.
- """
-
- subnets: List[IPNetwork]
- dst_subnet: Optional[IPNetwork] = None # could be a list as well, iterated in template
- protocols: Optional[List[Literal["udp53", "tcp53", "dot", "doh", "doq"]]] = None
- tags: Optional[List[IDPattern]] = None
- answer: Optional[Literal["allow", "refused", "noanswer"]] = None
- options: ViewOptionsSchema = ViewOptionsSchema()
-
- def _validate(self) -> None:
- if bool(self.tags) == bool(self.answer):
- raise ValueError("exactly one of 'tags' and 'answer' must be configured")
+++ /dev/null
-class KresBaseError(Exception):
- """Base class for all custom errors used in the knot_resolver module."""
+++ /dev/null
-from knot_resolver.manager.main import main
-
-if __name__ == "__main__":
- main()
+++ /dev/null
-import asyncio
-from asyncio import Lock
-from typing import Any, Awaitable, Callable, List, Tuple
-
-from knot_resolver.datamodel import KresConfig
-from knot_resolver.utils.functional import Result
-from knot_resolver.utils.modeling.exceptions import DataParsingError
-from knot_resolver.utils.modeling.types import NoneType
-
-from .exceptions import KresManagerBaseError
-
-VerifyCallback = Callable[[KresConfig, KresConfig, bool], Awaitable[Result[None, str]]]
-UpdateCallback = Callable[[KresConfig, bool], Awaitable[None]]
-
-
-class ConfigStore:
- def __init__(self, initial_config: KresConfig) -> None:
- self._config = initial_config
- self._verifiers: List[VerifyCallback] = []
- self._callbacks: List[UpdateCallback] = []
- self._update_lock: Lock = Lock()
-
- async def update(self, config: KresConfig, force: bool = False) -> None:
- # invoke pre-change verifiers
- results: Tuple[Result[None, str], ...] = tuple(
- await asyncio.gather(*[ver(self._config, config, force) for ver in self._verifiers])
- )
- err_res = filter(lambda r: r.is_err(), results)
- errs = list(map(lambda r: r.unwrap_err(), err_res))
- if len(errs) > 0:
- raise KresManagerBaseError("Configuration validation failed. The reasons are:\n - " + "\n - ".join(errs))
-
- async with self._update_lock:
- # update the stored config with the new version
- self._config = config
-
- # invoke change callbacks
- for call in self._callbacks:
- await call(config, force)
-
- async def renew(self, force: bool = False) -> None:
- await self.update(self._config, force)
-
- async def register_verifier(self, verifier: VerifyCallback) -> None:
- self._verifiers.append(verifier)
- res = await verifier(self.get(), self.get(), False)
- if res.is_err():
- raise DataParsingError(f"Initial config verification failed with error: {res.unwrap_err()}")
-
- async def register_on_change_callback(self, callback: UpdateCallback) -> None:
- """Register new callback and immediately call it with current config."""
- self._callbacks.append(callback)
- await callback(self.get(), False)
-
- def get(self) -> KresConfig:
- return self._config
-
-
-def only_on_no_changes_update(selector: Callable[[KresConfig], Any]) -> Callable[[UpdateCallback], UpdateCallback]:
- def decorator(orig_func: UpdateCallback) -> UpdateCallback:
- original_value_set: Any = False
- original_value: Any = None
-
- async def new_func_update(config: KresConfig, force: bool = False) -> None:
- nonlocal original_value_set
- nonlocal original_value
- if not original_value_set:
- original_value_set = True
- elif original_value == selector(config) or force:
- await orig_func(config, force)
- original_value = selector(config)
-
- return new_func_update
-
- return decorator
-
-
-def only_on_real_changes_update(selector: Callable[[KresConfig], Any]) -> Callable[[UpdateCallback], UpdateCallback]:
- def decorator(orig_func: UpdateCallback) -> UpdateCallback:
- original_value_set: Any = False
- original_value: Any = None
-
- async def new_func_update(config: KresConfig, force: bool) -> None:
- nonlocal original_value_set
- nonlocal original_value
- if not original_value_set:
- original_value_set = True
- await orig_func(config, force)
- elif original_value != selector(config) or force:
- await orig_func(config, force)
- original_value = selector(config)
-
- return new_func_update
-
- return decorator
-
-
-def only_on_real_changes_verifier(selector: Callable[[KresConfig], Any]) -> Callable[[VerifyCallback], VerifyCallback]:
- def decorator(orig_func: VerifyCallback) -> VerifyCallback:
- original_value_set: Any = False
- original_value: Any = None
-
- async def new_func_verifier(old: KresConfig, new: KresConfig, force: bool) -> Result[NoneType, str]:
- nonlocal original_value_set
- nonlocal original_value
- if not original_value_set:
- original_value_set = True
- original_value = selector(new)
- await orig_func(old, new, force)
- elif original_value != selector(new):
- original_value = selector(new)
- await orig_func(old, new, force)
- return Result.ok(None)
-
- return new_func_verifier
-
- return decorator
+++ /dev/null
-import logging
-from pathlib import Path
-from typing import TYPE_CHECKING, Optional
-
-if TYPE_CHECKING:
- from knot_resolver.controller.interface import KresID
- from knot_resolver.datamodel.config_schema import KresConfig
- from knot_resolver.manager.config_store import ConfigStore
-
-
-LOGGING_LEVEL_STARTUP = logging.DEBUG
-PID_FILE_NAME = "knot-resolver.pid"
-
-FIX_COUNTER_ATTEMPTS_MAX = 2
-FIX_COUNTER_DECREASE_INTERVAL_SEC = 30 * 60
-PROCESSES_WATCHDOG_INTERVAL_SEC: float = 5
-
-
-def kres_cache_dir(config: "KresConfig") -> Path:
- return config.cache.storage.to_path()
-
-
-def policy_loader_config_file(_config: "KresConfig") -> Path:
- return Path("policy-loader.conf")
-
-
-def kresd_config_file(_config: "KresConfig", kres_id: "KresID") -> Path:
- return Path(f"kresd{int(kres_id)}.conf")
-
-
-def kresd_config_file_supervisord_pattern(_config: "KresConfig") -> Path:
- return Path("kresd%(process_num)d.conf")
-
-
-def supervisord_config_file(_config: "KresConfig") -> Path:
- return Path("supervisord.conf")
-
-
-def supervisord_config_file_tmp(_config: "KresConfig") -> Path:
- return Path("supervisord.conf.tmp")
-
-
-def supervisord_pid_file(_config: "KresConfig") -> Path:
- return Path("supervisord.pid")
-
-
-def supervisord_sock_file(_config: "KresConfig") -> Path:
- return Path("supervisord.sock")
-
-
-def supervisord_subprocess_log_dir(_config: "KresConfig") -> Path:
- return Path("logs")
-
-
-class _UserConstants:
- """Class for accessing constants, which are technically not constants as they are user configurable."""
-
- def __init__(self, config_store: "ConfigStore", working_directory_on_startup: str) -> None:
- self._config_store = config_store
- self.working_directory_on_startup = working_directory_on_startup
-
-
-_user_constants: Optional[_UserConstants] = None
-
-
-async def init_user_constants(config_store: "ConfigStore", working_directory_on_startup: str) -> None:
- global _user_constants
- _user_constants = _UserConstants(config_store, working_directory_on_startup)
-
-
-def user_constants() -> _UserConstants:
- assert _user_constants is not None
- return _user_constants
+++ /dev/null
-from knot_resolver import KresBaseError
-
-
-class KresManagerBaseError(KresBaseError):
- """Base class for all errors used in the manager module."""
+++ /dev/null
-from .reload import files_reload
-from .watchdog import init_files_watchdog
-
-__all__ = ["files_reload", "init_files_watchdog"]
+++ /dev/null
-import logging
-
-from knot_resolver.controller.registered_workers import command_registered_workers
-from knot_resolver.datamodel import KresConfig
-
-logger = logging.getLogger(__name__)
-
-
-async def files_reload(config: KresConfig, force: bool = False) -> None:
- cert_file = config.network.tls.cert_file
- key_file = config.network.tls.key_file
-
- if cert_file and key_file:
- if not cert_file.to_path().exists():
- logger.error(f"TLS cert files reload failed: cert-file {cert_file} file don't exist")
- elif not key_file.to_path().exists():
- logger.error(f"TLS cert files failed: cert-file {key_file} file don't exist")
- else:
- logger.info("TLS cert files reload triggered")
- cmd = f"net.tls('{cert_file}', '{key_file}')"
- await command_registered_workers(cmd)
+++ /dev/null
-import logging
-from pathlib import Path
-from typing import Any, Dict, List, Optional
-
-from knot_resolver.constants import WATCHDOG_LIB
-from knot_resolver.datamodel import KresConfig
-from knot_resolver.manager.config_store import ConfigStore, only_on_real_changes_update
-from knot_resolver.manager.triggers import cancel_cmd, trigger_cmd, trigger_renew
-
-logger = logging.getLogger(__name__)
-
-FilesToWatch = Dict[Path, Optional[str]]
-
-
-def watched_files_config(config: KresConfig) -> List[Any]:
- return [
- config.network.tls.watchdog,
- config.network.tls.cert_file,
- config.network.tls.key_file,
- config.local_data.rpz,
- ]
-
-
-if WATCHDOG_LIB:
- from watchdog.events import (
- FileSystemEvent,
- FileSystemEventHandler,
- )
- from watchdog.observers import Observer
-
- class FilesWatchdogEventHandler(FileSystemEventHandler):
- def __init__(self, files: FilesToWatch, config: KresConfig) -> None:
- self._files = files
- self._config = config
-
- def _trigger(self, cmd: Optional[str]) -> None:
- if cmd:
- trigger_cmd(self._config, cmd)
- trigger_renew(self._config)
-
- def on_created(self, event: FileSystemEvent) -> None:
- src_path = Path(str(event.src_path))
- if src_path in self._files.keys():
- logger.info(f"Watched file '{src_path}' has been created")
- self._trigger(self._files[src_path])
-
- def on_deleted(self, event: FileSystemEvent) -> None:
- src_path = Path(str(event.src_path))
- if src_path in self._files.keys():
- logger.warning(f"Watched file '{src_path}' has been deleted")
- cmd = self._files[src_path]
- if cmd:
- cancel_cmd(cmd)
- for file in self._files.keys():
- if file.parent == src_path:
- logger.warning(f"Watched directory '{src_path}' has been deleted")
- cmd = self._files[file]
- if cmd:
- cancel_cmd(cmd)
-
- def on_moved(self, event: FileSystemEvent) -> None:
- src_path = Path(str(event.src_path))
- if src_path in self._files.keys():
- logger.info(f"Watched file '{src_path}' has been moved")
- self._trigger(self._files[src_path])
-
- def on_modified(self, event: FileSystemEvent) -> None:
- src_path = Path(str(event.src_path))
- if src_path in self._files.keys():
- logger.info(f"Watched file '{src_path}' has been modified")
- self._trigger(self._files[src_path])
-
- _files_watchdog: Optional["FilesWatchdog"] = None
-
- class FilesWatchdog:
- def __init__(self, files_to_watch: FilesToWatch, config: KresConfig) -> None:
- self._observer = Observer()
-
- event_handler = FilesWatchdogEventHandler(files_to_watch, config)
- dirs_to_watch: List[Path] = []
- for file in files_to_watch.keys():
- if file.parent not in dirs_to_watch:
- dirs_to_watch.append(file.parent)
-
- for d in dirs_to_watch:
- self._observer.schedule(
- event_handler,
- str(d),
- recursive=False,
- )
- logger.info(f"Directory '{d}' scheduled for watching")
-
- def start(self) -> None:
- self._observer.start()
-
- def stop(self) -> None:
- self._observer.stop()
- self._observer.join()
-
-
-@only_on_real_changes_update(watched_files_config)
-async def _init_files_watchdog(config: KresConfig, force: bool = False) -> None:
- if WATCHDOG_LIB:
- global _files_watchdog
-
- if _files_watchdog:
- _files_watchdog.stop()
- files_to_watch: FilesToWatch = {}
-
- # network.tls
- if config.network.tls.watchdog and config.network.tls.cert_file and config.network.tls.key_file:
- net_tls = f"net.tls('{config.network.tls.cert_file}', '{config.network.tls.key_file}')"
- files_to_watch[config.network.tls.cert_file.to_path()] = net_tls
- files_to_watch[config.network.tls.key_file.to_path()] = net_tls
-
- # local-data.rpz
- if config.local_data.rpz:
- for rpz in config.local_data.rpz:
- if rpz.watchdog:
- files_to_watch[rpz.file.to_path()] = None
-
- if files_to_watch:
- logger.info("Initializing files watchdog")
- _files_watchdog = FilesWatchdog(files_to_watch, config)
- _files_watchdog.start()
-
-
-async def init_files_watchdog(config_store: ConfigStore) -> None:
- # register files watchdog callback
- await config_store.register_on_change_callback(_init_files_watchdog)
+++ /dev/null
-import logging
-import logging.handlers
-import os
-import sys
-
-from knot_resolver.datamodel.config_schema import KresConfig
-from knot_resolver.datamodel.logging_schema import LogGroupsManagerEnum
-from knot_resolver.manager.config_store import ConfigStore, only_on_real_changes_update
-from knot_resolver.utils.modeling.types import get_generic_type_arguments
-
-from .constants import LOGGING_LEVEL_STARTUP
-
-STDOUT = "stdout"
-SYSLOG = "syslog"
-STDERR = "stderr"
-
-NOTICE_LEVEL = (logging.WARNING + logging.INFO) // 2
-NOTICE_NAME = "NOTICE"
-
-_config_to_level = {
- "crit": logging.CRITICAL,
- "err": logging.ERROR,
- "warning": logging.WARNING,
- "notice": NOTICE_LEVEL,
- "info": logging.INFO,
- "debug": logging.DEBUG,
-}
-
-_level_to_name = {
- logging.CRITICAL: "CRITICAL",
- logging.ERROR: "ERROR",
- logging.WARNING: "WARNING",
- NOTICE_LEVEL: NOTICE_NAME,
- logging.INFO: "INFO",
- logging.DEBUG: "DEBUG",
-}
-
-logger = logging.getLogger(__name__)
-
-
-def get_log_format(config: KresConfig) -> str:
- """Based on an environment variable $KRES_SUPRESS_LOG_PREFIX, returns the appropriate format string for logger."""
- if os.environ.get("KRES_SUPRESS_LOG_PREFIX") == "true":
- # In this case, we are running under supervisord and it's adding prefixes to our output
- return "[%(levelname)s] %(name)s: %(message)s"
- # In this case, we are running standalone during inicialization and we need to add a prefix to each line
- # by ourselves to make it consistent
- assert config.logging.target != SYSLOG
- stream = ""
- if config.logging.target == STDERR:
- stream = f" ({STDERR})"
-
- pid = os.getpid()
- return f"%(asctime)s manager[{pid}]{stream}: [%(levelname)s] %(name)s: %(message)s"
-
-
-async def _set_log_level(config: KresConfig) -> None:
- groups = config.logging.groups
- target = _config_to_level[config.logging.level]
-
- # when logging group is set to make us log with DEBUG
- if groups and "manager" in groups:
- target = logging.DEBUG
-
- # expect exactly one existing log handler on the root
- logger.warning(f"Changing logging level to '{_level_to_name[target]}'")
- logging.getLogger().setLevel(target)
-
- # set debug groups
- if groups:
- package_name = __name__.rsplit(".", 1)[0]
- for group in groups:
- if group in get_generic_type_arguments(LogGroupsManagerEnum):
- logger_name = f"{package_name}.{group}"
- logger.warning(f"Changing logging level of '{logger_name}' group to '{_level_to_name[logging.DEBUG]}'")
- logging.getLogger(logger_name).setLevel(logging.DEBUG)
-
-
-async def _set_logging_handler(config: KresConfig) -> None:
- target = config.logging.target
-
- handler: logging.Handler
- if target == SYSLOG:
- handler = logging.handlers.SysLogHandler(address="/dev/log")
- handler.setFormatter(logging.Formatter("%(name)s: %(message)s"))
- elif target == STDOUT:
- handler = logging.StreamHandler(sys.stdout)
- handler.setFormatter(logging.Formatter(get_log_format(config)))
- elif target == STDERR:
- handler = logging.StreamHandler(sys.stderr)
- handler.setFormatter(logging.Formatter(get_log_format(config)))
- else:
- raise RuntimeError(f"Unexpected value '{target}' for log target in the config")
-
- root = logging.getLogger()
-
- # if we had a MemoryHandler before, we should give it the new handler where we can flush it
- if isinstance(root.handlers[0], logging.handlers.MemoryHandler):
- root.handlers[0].setTarget(handler)
-
- # stop the old handler
- root.handlers[0].flush()
- root.handlers[0].close()
- root.removeHandler(root.handlers[0])
-
- # configure the new handler
- root.addHandler(handler)
-
-
-@only_on_real_changes_update(lambda config: config.logging)
-async def _configure_logger(config: KresConfig, force: bool = False) -> None:
- await _set_logging_handler(config)
- await _set_log_level(config)
-
-
-async def logger_init(config_store: ConfigStore) -> None:
- await config_store.register_on_change_callback(_configure_logger)
-
-
-def logger_startup() -> None:
- logging.getLogger().setLevel(LOGGING_LEVEL_STARTUP)
- err_handler = logging.StreamHandler(sys.stderr)
- err_handler.setFormatter(logging.Formatter(logging.BASIC_FORMAT))
- logging.getLogger().addHandler(logging.handlers.MemoryHandler(10_000, logging.ERROR, err_handler))
+++ /dev/null
-"""
-Effectively the same as normal __main__.py.
-
-However, we moved it's content over to this
-file to allow us to exclude the __main__.py file from black's autoformatting
-"""
-
-import argparse
-import sys
-from typing import NoReturn
-
-from knot_resolver.constants import CONFIG_FILE, VERSION
-from knot_resolver.manager.logger import logger_startup
-from knot_resolver.manager.server import start_server
-from knot_resolver.utils import compat
-
-
-def parse_args() -> argparse.Namespace:
- parser = argparse.ArgumentParser(description="Knot Resolver - caching DNS resolver")
- parser.add_argument(
- "-V",
- "--version",
- help="Get version",
- action="version",
- version=VERSION,
- )
- parser.add_argument(
- "-c",
- "--config",
- help="One or more configuration files to load."
- f" Overrides default configuration file location at '{str(CONFIG_FILE)}'"
- " Files must not contain the same options."
- " However, they may extend individual subsections."
- " The location of the first configuration file determines"
- "the prefix for every relative path in the configuration.",
- type=str,
- nargs="+",
- required=False,
- default=[str(CONFIG_FILE)],
- )
- return parser.parse_args()
-
-
-def main() -> NoReturn:
- # initial logging is to memory until we read the config
- logger_startup()
-
- # parse arguments
- args = parse_args()
-
- exit_code = compat.asyncio.run(start_server(config=args.config))
- sys.exit(exit_code)
+++ /dev/null
-import asyncio
-import logging
-import os
-import sys
-import time
-from secrets import token_hex
-from subprocess import SubprocessError
-from typing import Any, Callable, List, Optional
-
-from knot_resolver.controller.exceptions import KresSubprocessControllerError
-from knot_resolver.controller.interface import Subprocess, SubprocessController, SubprocessStatus, SubprocessType
-from knot_resolver.controller.registered_workers import command_registered_workers, get_registered_workers_kresids
-from knot_resolver.datamodel import KresConfig
-from knot_resolver.manager.config_store import (
- ConfigStore,
- only_on_no_changes_update,
- only_on_real_changes_update,
- only_on_real_changes_verifier,
-)
-from knot_resolver.manager.files import files_reload
-from knot_resolver.utils.compat.asyncio import create_task
-from knot_resolver.utils.functional import Result
-from knot_resolver.utils.modeling.types import NoneType
-
-from .constants import FIX_COUNTER_ATTEMPTS_MAX, FIX_COUNTER_DECREASE_INTERVAL_SEC, PROCESSES_WATCHDOG_INTERVAL_SEC
-
-logger = logging.getLogger(__name__)
-
-
-class _FixCounter:
- def __init__(self) -> None:
- self._counter = 0
- self._timestamp = time.time()
-
- def increase(self) -> None:
- self._counter += 1
- self._timestamp = time.time()
-
- def try_decrease(self) -> None:
- if time.time() - self._timestamp > FIX_COUNTER_DECREASE_INTERVAL_SEC and self._counter > 0:
- logger.info(
- "Enough time has passed since last detected instability,"
- f" decreasing fix attempt counter to {self._counter}"
- )
- self._counter -= 1
- self._timestamp = time.time()
-
- def __str__(self) -> str:
- return str(self._counter)
-
- def is_too_high(self) -> bool:
- return self._counter >= FIX_COUNTER_ATTEMPTS_MAX
-
-
-async def _subprocess_desc(subprocess: Subprocess) -> object:
- return {
- "type": subprocess.type.name,
- "pid": await subprocess.get_pid(),
- "status": subprocess.status().name,
- }
-
-
-class KresManager: # pylint: disable=too-many-instance-attributes
- """
- Core of the whole operation. Orchestrates individual instances under some service manager like systemd.
-
- Instantiate with `KresManager.create()`, not with the usual constructor!
- """
-
- def __init__(self, _i_know_what_i_am_doing: bool = False) -> None:
- if not _i_know_what_i_am_doing:
- logger.error(
- "Trying to create an instance of KresManager using normal constructor. Please use "
- "`KresManager.get_instance()` instead"
- )
- raise AssertionError
-
- self._workers: List[Subprocess] = []
- self._gc: Optional[Subprocess] = None
- self._policy_loader: Optional[Subprocess] = None
- self._manager_lock = asyncio.Lock()
- self._workers_reset_needed: bool = False
- self._controller: SubprocessController
- self._processes_watchdog_task: Optional["asyncio.Task[None]"] = None
- self._fix_counter: _FixCounter = _FixCounter()
- self._config_store: ConfigStore
- self._shutdown_triggers: List[Callable[[int], None]] = []
-
- @staticmethod
- async def create(
- subprocess_controller: SubprocessController,
- config_store: ConfigStore,
- ) -> "KresManager":
- """Create new instance of KresManager."""
- inst = KresManager(_i_know_what_i_am_doing=True)
- await inst._async_init(subprocess_controller, config_store) # noqa: SLF001
- return inst
-
- async def _async_init(self, subprocess_controller: SubprocessController, config_store: ConfigStore) -> None:
- self._controller = subprocess_controller
- self._config_store = config_store
-
- # initialize subprocess controller
- logger.debug("Starting controller")
- await self._controller.initialize_controller(config_store.get())
- self._processes_watchdog_task = create_task(self._processes_watchdog())
- logger.debug("Looking for already running workers")
- await self._collect_already_running_workers()
-
- # register and immediately call a verifier that loads policy rules into the rules database
- await config_store.register_verifier(self.load_policy_rules)
-
- # configuration nodes that are relevant to kresd workers and the cache garbage collector
- def config_nodes(config: KresConfig) -> List[Any]:
- return [
- config.nsid,
- config.hostname,
- # config.rundir not allowed to change
- config.workers,
- # config.management not allowed to change and not affecting workers anyway
- config.options,
- config.network,
- # config.views fully handled by policy-loader
- # config.local_data fully handled by policy-loader
- config.forward,
- config.fallback,
- config.cache,
- config.dnssec,
- config.dns64,
- config.logging,
- config.monitoring,
- config.rate_limiting,
- config.defer,
- config.lua,
- ]
-
- # register and immediately call a verifier that validates config with 'canary' kresd process
- await config_store.register_verifier(only_on_real_changes_verifier(config_nodes)(self.validate_config))
-
- # register and immediately call a callback to apply config to all 'kresd' workers and 'cache-gc'
- await config_store.register_on_change_callback(only_on_real_changes_update(config_nodes)(self.apply_config))
-
- # register callback to reset policy rules for each 'kresd' worker
- await config_store.register_on_change_callback(self.reset_workers_policy_rules)
-
- # register and immediately call a callback to set new TLS session ticket secret for 'kresd' workers
- await config_store.register_on_change_callback(
- only_on_real_changes_update(config_nodes)(self.set_new_tls_sticket_secret)
- )
-
- # register callback that reloads files (TLS cert files) if selected configuration has not been changed
- await config_store.register_on_change_callback(only_on_no_changes_update(config_nodes)(files_reload))
-
- async def _spawn_new_worker(self, config: KresConfig) -> None:
- subprocess = await self._controller.create_subprocess(config, SubprocessType.KRESD)
- await subprocess.start()
- self._workers.append(subprocess)
-
- async def _stop_a_worker(self) -> None:
- if len(self._workers) == 0:
- raise IndexError("Can't stop a kresd when there are no running")
-
- subprocess = self._workers.pop()
- await subprocess.stop()
-
- async def _collect_already_running_workers(self) -> None:
- for subp in await self._controller.get_all_running_instances():
- if subp.type == SubprocessType.KRESD:
- self._workers.append(subp)
- elif subp.type == SubprocessType.GC:
- assert self._gc is None
- self._gc = subp
- elif subp.type == SubprocessType.POLICY_LOADER:
- assert self._policy_loader is None
- self._policy_loader = subp
- else:
- raise RuntimeError("unexpected subprocess type")
-
- async def _rolling_restart(self, new_config: KresConfig) -> None:
- for kresd in self._workers:
- await kresd.apply_new_config(new_config)
-
- async def _ensure_number_of_children(self, config: KresConfig, n: int) -> None:
- # kill children that are not needed
- while len(self._workers) > n:
- await self._stop_a_worker()
-
- # spawn new children if needed
- while len(self._workers) < n:
- await self._spawn_new_worker(config)
-
- async def _run_policy_loader(self, config: KresConfig) -> None:
- if self._policy_loader:
- await self._policy_loader.start(config)
- else:
- subprocess = await self._controller.create_subprocess(config, SubprocessType.POLICY_LOADER)
- await subprocess.start()
- self._policy_loader = subprocess
-
- def _is_policy_loader_exited(self) -> bool:
- if self._policy_loader:
- return self._policy_loader.status() is SubprocessStatus.EXITED
- return False
-
- def _is_gc_running(self) -> bool:
- return self._gc is not None
-
- async def _start_gc(self, config: KresConfig) -> None:
- subprocess = await self._controller.create_subprocess(config, SubprocessType.GC)
- await subprocess.start()
- self._gc = subprocess
-
- async def _stop_gc(self) -> None:
- assert self._gc is not None
- await self._gc.stop()
- self._gc = None
-
- def add_shutdown_trigger(self, trigger: Callable[[int], None]) -> None:
- self._shutdown_triggers.append(trigger)
-
- async def validate_config(self, _old: KresConfig, new: KresConfig, force: bool = False) -> Result[NoneType, str]:
- async with self._manager_lock:
- if _old.rate_limiting != new.rate_limiting:
- logger.debug("Unlinking shared ratelimiting memory")
- try:
- os.unlink(str(_old.rundir) + "/ratelimiting")
- except FileNotFoundError:
- pass
- if _old.workers != new.workers or _old.defer != new.defer:
- logger.debug("Unlinking shared defer memory")
- try:
- os.unlink(str(_old.rundir) + "/defer")
- except FileNotFoundError:
- pass
- logger.debug("Testing the new config with a canary process")
- try:
- # technically, this has side effects of leaving a new process runnning
- # but it's practically not a problem, because
- # if it keeps running, the config is valid and others will soon join as well
- # if it crashes and the startup fails, then well, it's not running anymore... :)
- await self._spawn_new_worker(new)
- except (SubprocessError, KresSubprocessControllerError):
- logger.error("Kresd with the new config failed to start, rejecting config")
- return Result.err("canary kresd process failed to start. Config might be invalid.")
-
- logger.debug("Canary process test passed.")
- return Result.ok(None)
-
- async def get_processes(self, proc_type: Optional[SubprocessType]) -> List[object]:
- processes = await self._controller.get_all_running_instances()
- return [await _subprocess_desc(pr) for pr in processes if proc_type is None or pr.type == proc_type]
-
- async def _reload_system_state(self) -> None:
- async with self._manager_lock:
- self._workers = []
- self._policy_loader = None
- self._gc = None
- await self._collect_already_running_workers()
-
- async def reset_workers_policy_rules(self, _config: KresConfig, force: bool = False) -> None:
- # command all running 'kresd' workers to reset their old policy rules,
- # unless the workers have already been started with a new config so reset is not needed
- if self._workers_reset_needed and get_registered_workers_kresids():
- logger.debug("Resetting policy rules for all running 'kresd' workers")
- cmd_results = await command_registered_workers("require('ffi').C.kr_rules_reset()")
- for worker, res in cmd_results.items():
- if res != 0:
- logger.error("Failed to reset policy rules in %s: %s", worker, res)
- else:
- logger.debug(
- "Skipped resetting policy rules for all running 'kresd' workers:"
- " the workers are already running with new configuration"
- )
-
- async def set_new_tls_sticket_secret(self, config: KresConfig, force: bool = False) -> None:
- if int(config.workers) == 1:
- logger.info(
- "There is no need to synchronize the TLS session secret across all workers"
- " because only one kresd worker is configured - skipping auto-generation"
- )
- return
-
- if config.network.tls.sticket_secret or config.network.tls.sticket_secret_file:
- logger.debug("User-configured TLS resumption secret found - skipping auto-generation")
- return
-
- logger.debug("Creating TLS session ticket secret")
- secret = token_hex(32)
- logger.debug("Setting TLS session ticket secret for all running 'kresd' workers")
- cmd_results = await command_registered_workers(f"net.tls_sticket_secret('{secret}')")
- for worker, res in cmd_results.items():
- if res not in (0, True):
- logger.error("Failed to set TLS session ticket secret in %s: %s", worker, res)
-
- async def apply_config(self, config: KresConfig, force: bool = False, _noretry: bool = False) -> None:
- try:
- async with self._manager_lock:
- logger.debug("Applying config to all workers")
- await self._rolling_restart(config)
- await self._ensure_number_of_children(config, int(config.workers))
-
- if self._is_gc_running() != config.cache.garbage_collector.enable:
- if config.cache.garbage_collector.enable:
- logger.debug("Starting cache GC")
- await self._start_gc(config)
- else:
- logger.debug("Stopping cache GC")
- await self._stop_gc()
- except KresSubprocessControllerError as e:
- if _noretry:
- raise
- if self._fix_counter.is_too_high():
- logger.error(f"Failed to apply config: {e}")
- logger.error("There have already been problems recently, refusing to try to fix it.")
- await (
- self.forced_shutdown()
- ) # possible improvement - the person who requested this change won't get a response this way
- else:
- logger.error(f"Failed to apply config: {e}")
- logger.warning("Reloading system state and trying again.")
- self._fix_counter.increase()
- await self._reload_system_state()
- await self.apply_config(config, _noretry=True)
-
- logger.info("Config applied successfully to all workers")
- self._workers_reset_needed = False
-
- async def load_policy_rules(self, _old: KresConfig, new: KresConfig, force: bool = False) -> Result[NoneType, str]:
- try:
- async with self._manager_lock:
- if _old.cache.size_max != new.cache.size_max:
- logger.debug("Unlinking shared cache top memory")
- try:
- os.unlink(str(_old.cache.storage) + "/top")
- except FileNotFoundError:
- pass
-
- logger.debug("Running kresd 'policy-loader'")
- await self._run_policy_loader(new)
-
- # wait for 'policy-loader' to finish
- logger.debug("Waiting for 'policy-loader' to finish loading policy rules")
- while not self._is_policy_loader_exited(): # noqa: ASYNC110
- await asyncio.sleep(1)
-
- # Clean up policy-loader configuration.
- # If we don't do this, we may start with
- # an old configuration and fail to detect a bug.
- if self._policy_loader:
- await self._policy_loader.cleanup()
-
- except (SubprocessError, KresSubprocessControllerError) as e:
- logger.error(f"Failed to load policy rules: {e}")
- return Result.err("kresd 'policy-loader' process failed to start. Config might be invalid.")
-
- self._workers_reset_needed = True
- logger.debug("Loading policy rules has been successfully completed")
- return Result.ok(None)
-
- async def stop(self) -> None:
- if self._processes_watchdog_task is not None:
- try:
- self._processes_watchdog_task.cancel() # cancel it
- await self._processes_watchdog_task # and let it really finish
- except asyncio.CancelledError:
- pass
-
- async with self._manager_lock:
- # we could stop all the children one by one right now
- # we won't do that and we leave that up to the subprocess controller to do that while it is shutting down
- await self._controller.shutdown_controller()
- # now, when everything is stopped, let's clean up all the remains
- await asyncio.gather(*[w.cleanup() for w in self._workers])
-
- async def forced_shutdown(self) -> None:
- logger.warning("Collecting all remaining workers...")
- await self._reload_system_state()
- logger.warning("Terminating...")
- for trigger in self._shutdown_triggers:
- trigger(1)
-
- async def _instability_handler(self) -> None:
- if self._fix_counter.is_too_high():
- logger.error(
- "Already attempted too many times to fix system state. Refusing to try again and shutting down."
- )
- await self.forced_shutdown()
- return
-
- try:
- logger.warning("Instability detected. Dropping known list of workers and reloading it from the system.")
- self._fix_counter.increase()
- await self._reload_system_state()
- logger.warning("Workers reloaded. Applying old config....")
- await self._config_store.renew()
- logger.warning(f"System stability hopefully renewed. Fix attempt counter is currently {self._fix_counter}")
- except BaseException:
- logger.error("Failed attempting to fix an error. Forcefully shutting down.", exc_info=True)
- await self.forced_shutdown()
-
- async def _processes_watchdog(self) -> None: # noqa: C901, PLR0912
- while True:
- await asyncio.sleep(PROCESSES_WATCHDOG_INTERVAL_SEC)
-
- self._fix_counter.try_decrease()
-
- try:
- # gather current state
- async with self._manager_lock:
- detected_subprocesses = await self._controller.get_subprocess_status()
- expected_ids = [x.id for x in self._workers]
- if self._gc:
- expected_ids.append(self._gc.id)
-
- invoke_callback = False
-
- if self._policy_loader:
- expected_ids.append(self._policy_loader.id)
-
- for eid in expected_ids:
- if eid not in detected_subprocesses:
- logger.error("Subprocess with id '%s' was not found in the system!", eid)
- invoke_callback = True
- continue
-
- if detected_subprocesses[eid] is SubprocessStatus.FATAL:
- if self._policy_loader and self._policy_loader.id == eid:
- logger.info(
- "Subprocess '%s' is skipped by WatchDog"
- " because its status is monitored in a different way.",
- eid,
- )
- continue
- logger.error("Subprocess '%s' is in FATAL state!", eid)
- invoke_callback = True
- continue
-
- if detected_subprocesses[eid] is SubprocessStatus.UNKNOWN:
- logger.warning("Subprocess '%s' is in UNKNOWN state!", eid)
-
- non_registered_ids = detected_subprocesses.keys() - set(expected_ids)
- if len(non_registered_ids) != 0:
- logger.error(
- "Found additional process in the system, which shouldn't be there - %s",
- non_registered_ids,
- )
- invoke_callback = True
-
- except KresSubprocessControllerError as e:
- # wait few seconds and see if 'processes_watchdog' task is cancelled (during shutdown)
- # otherwise it is an error
- await asyncio.sleep(3)
- invoke_callback = True
- logger.error(f"Processes watchdog failed with SubprocessControllerError: {e}")
- except asyncio.CancelledError:
- raise
- except BaseException:
- invoke_callback = True
- logger.error("Processes watchdog failed with an unexpected exception.", exc_info=True)
-
- if invoke_callback:
- try:
- await self._instability_handler()
- except Exception:
- logger.error("Processes watchdog failed while invoking instability callback", exc_info=True)
- logger.error("Violently terminating!")
- sys.exit(1)
+++ /dev/null
-from .collect import report_json
-from .prometheus import init_prometheus, report_prometheus
-
-__all__ = ["init_prometheus", "report_json", "report_prometheus"]
+++ /dev/null
-import logging
-from typing import Dict, Optional
-
-from knot_resolver.controller.interface import KresID
-from knot_resolver.controller.registered_workers import command_registered_workers, get_registered_workers_kresids
-from knot_resolver.datamodel import KresConfig
-from knot_resolver.utils.modeling.parsing import DataFormat
-
-logger = logging.getLogger(__name__)
-
-
-async def collect_kresd_workers_metrics(config: KresConfig) -> Optional[Dict[KresID, object]]:
- if config.monitoring.metrics == "manager-only":
- logger.debug("Skipping kresd stat collection due to configuration")
- return None
-
- cmd = "collect_statistics()"
- if config.monitoring.metrics == "lazy":
- cmd = "collect_lazy_statistics()"
- logger.debug(f"Collecting stats from all kresd workers using method '{cmd}'")
-
- return await command_registered_workers(cmd)
-
-
-async def report_json(config: KresConfig) -> bytes:
- metrics_raw = await collect_kresd_workers_metrics(config)
- metrics_dict: Dict[str, Optional[object]] = {}
-
- if metrics_raw:
- for kresd_id, kresd_metrics in metrics_raw.items():
- metrics_dict[str(kresd_id)] = kresd_metrics
- else:
- # if we have no metrics, return None for every kresd worker
- for kresd_id in get_registered_workers_kresids():
- metrics_dict[str(kresd_id)] = None
-
- return DataFormat.JSON.dict_dump(metrics_dict).encode()
+++ /dev/null
-import asyncio
-import logging
-from typing import Any, Dict, Generator, List, Optional, Tuple
-
-from knot_resolver.constants import PROMETHEUS_LIB
-from knot_resolver.controller.interface import KresID
-from knot_resolver.controller.registered_workers import get_registered_workers_kresids
-from knot_resolver.datamodel.config_schema import KresConfig
-from knot_resolver.manager.config_store import ConfigStore, only_on_real_changes_update
-from knot_resolver.utils import compat
-from knot_resolver.utils.functional import Result
-
-from .collect import collect_kresd_workers_metrics
-
-logger = logging.getLogger(__name__)
-
-if PROMETHEUS_LIB:
- from prometheus_client import exposition
- from prometheus_client.bridge.graphite import GraphiteBridge
- from prometheus_client.core import (
- REGISTRY,
- CounterMetricFamily,
- GaugeMetricFamily,
- HistogramMetricFamily,
- Metric,
- )
-
- _graphite_bridge: Optional[GraphiteBridge] = None
-
- _metrics_collector: Optional["KresPrometheusMetricsCollector"] = None
-
- def _counter(name: str, description: str, label: Tuple[str, str], value: float) -> CounterMetricFamily:
- c = CounterMetricFamily(name, description, labels=(label[0],))
- c.add_metric((label[1],), value)
- return c
-
- def _gauge(name: str, description: str, label: Tuple[str, str], value: float) -> GaugeMetricFamily:
- c = GaugeMetricFamily(name, description, labels=(label[0],))
- c.add_metric((label[1],), value)
- return c
-
- def _histogram(
- name: str, description: str, label: Tuple[str, str], buckets: List[Tuple[str, int]], sum_value: float
- ) -> HistogramMetricFamily:
- c = HistogramMetricFamily(name, description, labels=(label[0],))
- c.add_metric((label[1],), buckets, sum_value=sum_value)
- return c
-
- def _parse_resolver_metrics(instance_id: "KresID", metrics: Any) -> Generator[Metric, None, None]:
- sid = str(instance_id)
-
- # response latency histogram
- bucket_names_in_resolver = ("1ms", "10ms", "50ms", "100ms", "250ms", "500ms", "1000ms", "1500ms", "slow")
- bucket_names_in_prometheus = ("0.001", "0.01", "0.05", "0.1", "0.25", "0.5", "1.0", "1.5", "+Inf")
-
- # add smaller bucket counts
- def _bucket_count(answer: Dict[str, int], duration: str) -> int:
- index = bucket_names_in_resolver.index(duration)
- return sum([int(answer[bucket_names_in_resolver[i]]) for i in range(index + 1)])
-
- yield _histogram(
- "resolver_response_latency",
- "Time it takes to respond to queries in seconds",
- label=("instance_id", sid),
- buckets=[
- (bnp, _bucket_count(metrics["answer"], duration))
- for bnp, duration in zip(bucket_names_in_prometheus, bucket_names_in_resolver)
- ],
- sum_value=metrics["answer"]["sum_ms"] / 1_000,
- )
-
- # "request" metrics
- yield _counter(
- "resolver_request_total",
- "total number of DNS requests (including internal client requests)",
- label=("instance_id", sid),
- value=metrics["request"]["total"],
- )
- yield _counter(
- "resolver_request_total4",
- "total number of IPv4 DNS requests",
- label=("instance_id", sid),
- value=metrics["request"]["total4"],
- )
- yield _counter(
- "resolver_request_total6",
- "total number of IPv6 DNS requests",
- label=("instance_id", sid),
- value=metrics["request"]["total6"],
- )
- yield _counter(
- "resolver_request_internal",
- "number of internal requests generated by Knot Resolver (e.g. DNSSEC trust anchor updates)",
- label=("instance_id", sid),
- value=metrics["request"]["internal"],
- )
- yield _counter(
- "resolver_request_udp",
- "number of external requests received over plain UDP (RFC 1035)",
- label=("instance_id", sid),
- value=metrics["request"]["udp"],
- )
- yield _counter(
- "resolver_request_udp4",
- "number of external requests received over IPv4 plain UDP (RFC 1035)",
- label=("instance_id", sid),
- value=metrics["request"]["udp4"],
- )
- yield _counter(
- "resolver_request_udp6",
- "number of external requests received over IPv6 plain UDP (RFC 1035)",
- label=("instance_id", sid),
- value=metrics["request"]["udp6"],
- )
- yield _counter(
- "resolver_request_tcp",
- "number of external requests received over plain TCP (RFC 1035)",
- label=("instance_id", sid),
- value=metrics["request"]["tcp"],
- )
- yield _counter(
- "resolver_request_tcp4",
- "number of external requests received over IPv4 plain TCP (RFC 1035)",
- label=("instance_id", sid),
- value=metrics["request"]["tcp4"],
- )
- yield _counter(
- "resolver_request_tcp6",
- "number of external requests received over IPv6 plain TCP (RFC 1035)",
- label=("instance_id", sid),
- value=metrics["request"]["tcp6"],
- )
- yield _counter(
- "resolver_request_dot",
- "number of external requests received over DNS-over-TLS (RFC 7858)",
- label=("instance_id", sid),
- value=metrics["request"]["dot"],
- )
- yield _counter(
- "resolver_request_dot4",
- "number of external requests received over IPv4 DNS-over-TLS (RFC 7858)",
- label=("instance_id", sid),
- value=metrics["request"]["dot4"],
- )
- yield _counter(
- "resolver_request_dot6",
- "number of external requests received over IPv6 DNS-over-TLS (RFC 7858)",
- label=("instance_id", sid),
- value=metrics["request"]["dot6"],
- )
- yield _counter(
- "resolver_request_doh",
- "number of external requests received over DNS-over-HTTP (RFC 8484)",
- label=("instance_id", sid),
- value=metrics["request"]["doh"],
- )
- yield _counter(
- "resolver_request_doh4",
- "number of external requests received over IPv4 DNS-over-HTTP (RFC 8484)",
- label=("instance_id", sid),
- value=metrics["request"]["doh4"],
- )
- yield _counter(
- "resolver_request_doh6",
- "number of external requests received over IPv6 DNS-over-HTTP (RFC 8484)",
- label=("instance_id", sid),
- value=metrics["request"]["doh6"],
- )
- yield _counter(
- "resolver_request_xdp",
- "number of external requests received over plain UDP via an AF_XDP socket",
- label=("instance_id", sid),
- value=metrics["request"]["xdp"],
- )
- yield _counter(
- "resolver_request_xdp4",
- "number of external requests received over IPv4 plain UDP via an AF_XDP socket",
- label=("instance_id", sid),
- value=metrics["request"]["xdp4"],
- )
- yield _counter(
- "resolver_request_xdp6",
- "number of external requests received over IPv6 plain UDP via an AF_XDP socket",
- label=("instance_id", sid),
- value=metrics["request"]["xdp6"],
- )
- yield _counter(
- "resolver_request_doq",
- "number of external requests received over DNS-over-QUIC (RFC 9250)",
- label=("instance_id", sid),
- value=metrics["request"]["doq"],
- )
- yield _counter(
- "resolver_request_doq4",
- "number of external requests received over IPv4 DNS-over-QUIC (RFC 9250)",
- label=("instance_id", sid),
- value=metrics["request"]["doq4"],
- )
- yield _counter(
- "resolver_request_doq6",
- "number of external requests received over IPv6 DNS-over-QUIC (RF 9250)",
- label=("instance_id", sid),
- value=metrics["request"]["doq6"],
- )
-
- # "answer" metrics
- yield _counter(
- "resolver_answer_total",
- "total number of answered queries",
- label=("instance_id", sid),
- value=metrics["answer"]["total"],
- )
- yield _counter(
- "resolver_answer_cached",
- "number of queries answered from cache",
- label=("instance_id", sid),
- value=metrics["answer"]["cached"],
- )
- yield _counter(
- "resolver_answer_stale",
- "number of queries that utilized stale data",
- label=("instance_id", sid),
- value=metrics["answer"]["stale"],
- )
- yield _counter(
- "resolver_answer_rcode_noerror",
- "number of NOERROR answers",
- label=("instance_id", sid),
- value=metrics["answer"]["noerror"],
- )
- yield _counter(
- "resolver_answer_rcode_nodata",
- "number of NOERROR answers without any data",
- label=("instance_id", sid),
- value=metrics["answer"]["nodata"],
- )
- yield _counter(
- "resolver_answer_rcode_nxdomain",
- "number of NXDOMAIN answers",
- label=("instance_id", sid),
- value=metrics["answer"]["nxdomain"],
- )
- yield _counter(
- "resolver_answer_rcode_servfail",
- "number of SERVFAIL answers",
- label=("instance_id", sid),
- value=metrics["answer"]["servfail"],
- )
- yield _counter(
- "resolver_answer_flag_aa",
- "number of authoritative answers",
- label=("instance_id", sid),
- value=metrics["answer"]["aa"],
- )
- yield _counter(
- "resolver_answer_flag_tc",
- "number of truncated answers",
- label=("instance_id", sid),
- value=metrics["answer"]["tc"],
- )
- yield _counter(
- "resolver_answer_flag_ra",
- "number of answers with recursion available flag",
- label=("instance_id", sid),
- value=metrics["answer"]["ra"],
- )
- yield _counter(
- "resolver_answer_flag_rd",
- "number of recursion desired (in answer!)",
- label=("instance_id", sid),
- value=metrics["answer"]["rd"],
- )
- yield _counter(
- "resolver_answer_flag_ad",
- "number of authentic data (DNSSEC) answers",
- label=("instance_id", sid),
- value=metrics["answer"]["ad"],
- )
- yield _counter(
- "resolver_answer_flag_cd",
- "number of checking disabled (DNSSEC) answers",
- label=("instance_id", sid),
- value=metrics["answer"]["cd"],
- )
- yield _counter(
- "resolver_answer_flag_do",
- "number of DNSSEC answer OK",
- label=("instance_id", sid),
- value=metrics["answer"]["do"],
- )
- yield _counter(
- "resolver_answer_flag_edns0",
- "number of answers with EDNS0 present",
- label=("instance_id", sid),
- value=metrics["answer"]["edns0"],
- )
-
- # "query" metrics
- yield _counter(
- "resolver_query_edns",
- "number of queries with EDNS present",
- label=("instance_id", sid),
- value=metrics["query"]["edns"],
- )
- yield _counter(
- "resolver_query_dnssec",
- "number of queries with DNSSEC DO=1",
- label=("instance_id", sid),
- value=metrics["query"]["dnssec"],
- )
-
- # "predict" metrics (optional)
- if "predict" in metrics:
- if "epoch" in metrics["predict"]:
- yield _counter(
- "resolver_predict_epoch",
- "current prediction epoch (based on time of day and sampling window)",
- label=("instance_id", sid),
- value=metrics["predict"]["epoch"],
- )
- yield _counter(
- "resolver_predict_queue",
- "number of queued queries in current window",
- label=("instance_id", sid),
- value=metrics["predict"]["queue"],
- )
- yield _counter(
- "resolver_predict_learned",
- "number of learned queries in current window",
- label=("instance_id", sid),
- value=metrics["predict"]["learned"],
- )
-
- def _create_resolver_metrics_loaded_gauge(kresid: "KresID", loaded: bool) -> GaugeMetricFamily:
- return _gauge(
- "resolver_metrics_loaded",
- "0 if metrics from resolver instance were not loaded, otherwise 1",
- label=("instance_id", str(kresid)),
- value=int(loaded),
- )
-
- class KresPrometheusMetricsCollector:
- def __init__(self, config_store: ConfigStore) -> None:
- self._stats_raw: "Optional[Dict[KresID, object]]" = None
- self._config_store: ConfigStore = config_store
- self._collection_task: "Optional[asyncio.Task[None]]" = None
- self._skip_immediate_collection: bool = False
-
- def collect(self) -> Generator[Metric, None, None]:
- # schedule new stats collection
- self._trigger_stats_collection()
-
- # if we have no data, return metrics with information about it and exit
- if self._stats_raw is None:
- for kresid in get_registered_workers_kresids():
- yield _create_resolver_metrics_loaded_gauge(kresid, False)
- return
-
- # if we have data, parse them
- for kresid in get_registered_workers_kresids():
- success = False
- try:
- if kresid in self._stats_raw:
- metrics = self._stats_raw[kresid]
- yield from _parse_resolver_metrics(kresid, metrics)
- success = True
- except KeyError as e:
- logger.warning(
- "Failed to load metrics from resolver instance %s: attempted to read missing statistic %s",
- str(kresid),
- str(e),
- )
-
- yield _create_resolver_metrics_loaded_gauge(kresid, success)
-
- def describe(self) -> List[Metric]:
- # this function prevents the collector registry from invoking the collect function on startup
- return []
-
- async def collect_kresd_stats(self, _triggered_from_prometheus_library: bool = False) -> None:
- if self._skip_immediate_collection:
- # this would happen because we are calling this function first manually before stat generation,
- # and once again immediately afterwards caused by the prometheus library's stat collection
- #
- # this is a code made to solve problem with calling async functions from sync methods
- self._skip_immediate_collection = False
- return
-
- config = self._config_store.get()
- self._stats_raw = await collect_kresd_workers_metrics(config)
-
- # if this function was not called by the prometheus library and calling collect() is imminent,
- # we should block the next collection cycle as it would be useless
- if not _triggered_from_prometheus_library:
- self._skip_immediate_collection = True
-
- def _trigger_stats_collection(self) -> None:
- # we are running inside an event loop, but in a synchronous function and that sucks a lot
- # it means that we shouldn't block the event loop by performing a blocking stats collection
- # but it also means that we can't yield to the event loop as this function is synchronous
- # therefore we can only start a new task, but we can't wait for it
- # which causes the metrics to be delayed by one collection pass (not the best, but probably good enough)
- #
- # this issue can be prevented by calling the `collect_kresd_stats()` function manually before entering
- # the Prometheus library. We just have to prevent the library from invoking it again. See the mentioned
- # function for details
-
- if compat.asyncio.is_event_loop_running():
- # when running, we can schedule the new data collection
- if self._collection_task is not None and not self._collection_task.done():
- logger.warning("Statistics collection task is still running. Skipping scheduling of a new one!")
- else:
- self._collection_task = compat.asyncio.create_task(
- self.collect_kresd_stats(_triggered_from_prometheus_library=True)
- )
-
- else:
- # when not running, we can start a new loop (we are not in the manager's main thread)
- compat.asyncio.run(self.collect_kresd_stats(_triggered_from_prometheus_library=True))
-
- @only_on_real_changes_update(lambda c: c.monitoring.graphite)
- async def _init_graphite_bridge(config: KresConfig, force: bool = False) -> None:
- """Start graphite bridge if required."""
- global _graphite_bridge
- if config.monitoring.graphite.enable and _graphite_bridge is None:
- logger.info(
- "Starting Graphite metrics exporter for [%s]:%d",
- str(config.monitoring.graphite.host),
- int(config.monitoring.graphite.port),
- )
- _graphite_bridge = GraphiteBridge(
- (str(config.monitoring.graphite.host), int(config.monitoring.graphite.port))
- )
- _graphite_bridge.start(
- interval=config.monitoring.graphite.interval.seconds(), prefix=str(config.monitoring.graphite.prefix)
- )
-
- async def _deny_turning_off_graphite_bridge(
- old_config: KresConfig, new_config: KresConfig, force: bool = False
- ) -> Result[None, str]:
- if old_config.monitoring.graphite.enable and not new_config.monitoring.graphite.enable:
- return Result.err(
- "You can't turn off graphite monitoring dynamically."
- " If you really want this feature, please let the developers know."
- )
-
- if (
- old_config.monitoring.graphite.enable
- and new_config.monitoring.graphite.enable
- and old_config.monitoring.graphite != new_config.monitoring.graphite
- ):
- return Result.err("Changing graphite exporter configuration in runtime is not allowed.")
-
- return Result.ok(None)
-
-
-async def init_prometheus(config_store: ConfigStore) -> None:
- """Initialize metrics collection. Must be called before any other function from this module."""
- if PROMETHEUS_LIB:
- # init and register metrics collector
- global _metrics_collector
- _metrics_collector = KresPrometheusMetricsCollector(config_store)
- REGISTRY.register(_metrics_collector) # type: ignore[arg-type]
-
- # register graphite bridge
- await config_store.register_verifier(_deny_turning_off_graphite_bridge)
- await config_store.register_on_change_callback(_init_graphite_bridge)
-
-
-async def report_prometheus() -> Optional[bytes]:
- if PROMETHEUS_LIB:
- # manually trigger stat collection so that we do not have to wait for it
- if _metrics_collector is not None:
- await _metrics_collector.collect_kresd_stats()
- else:
- raise RuntimeError("Function invoked before initializing the module!")
- return exposition.generate_latest()
- return None
+++ /dev/null
-import asyncio
-import errno
-import json
-import logging
-import os
-import signal
-import sys
-from functools import partial
-from http import HTTPStatus
-from pathlib import Path
-from pwd import getpwuid
-from time import time
-from typing import Any, Dict, List, Literal, Optional, Set, Union, cast
-
-from aiohttp import web
-from aiohttp.web import middleware
-from aiohttp.web_app import Application
-from aiohttp.web_response import json_response
-from aiohttp.web_runner import AppRunner, TCPSite, UnixSite
-
-from knot_resolver.constants import USER
-from knot_resolver.controller import get_best_controller_implementation
-from knot_resolver.controller.exceptions import KresSubprocessControllerError, KresSubprocessControllerExec
-from knot_resolver.controller.interface import SubprocessType
-from knot_resolver.controller.registered_workers import command_single_registered_worker
-from knot_resolver.datamodel import kres_config_json_schema
-from knot_resolver.datamodel.cache_schema import CacheClearRPCSchema
-from knot_resolver.datamodel.config_schema import KresConfig, get_rundir_without_validation
-from knot_resolver.datamodel.globals import Context, set_global_validation_context
-from knot_resolver.datamodel.management_schema import ManagementSchema
-from knot_resolver.manager import files, metrics
-from knot_resolver.utils import custom_atexit as atexit
-from knot_resolver.utils import ignore_exceptions_optional
-from knot_resolver.utils.async_utils import readfile
-from knot_resolver.utils.compat import asyncio as asyncio_compat
-from knot_resolver.utils.etag import structural_etag
-from knot_resolver.utils.functional import Result
-from knot_resolver.utils.modeling.exceptions import AggregateDataValidationError, DataParsingError, DataValidationError
-from knot_resolver.utils.modeling.parsing import DataFormat, data_combine, try_to_parse
-from knot_resolver.utils.modeling.query import query
-from knot_resolver.utils.modeling.types import NoneType
-from knot_resolver.utils.systemd_notify import systemd_notify
-
-from .config_store import ConfigStore
-from .constants import PID_FILE_NAME, init_user_constants
-from .exceptions import KresManagerBaseError
-from .logger import logger_init
-from .manager import KresManager
-
-logger = logging.getLogger(__name__)
-
-
-@middleware
-async def error_handler(request: web.Request, handler: Any) -> web.Response:
- """
- Handle errors in route handlers.
-
- If an exception is thrown during request processing, this middleware catches it
- and responds accordingly.
- """
- try:
- return await handler(request)
- except (AggregateDataValidationError, DataValidationError) as e:
- return web.Response(text=str(e), status=HTTPStatus.BAD_REQUEST)
- except DataParsingError as e:
- return web.Response(text=f"request processing error:\n{e}", status=HTTPStatus.BAD_REQUEST)
- except KresManagerBaseError as e:
- return web.Response(text=f"request processing failed:\n{e}", status=HTTPStatus.INTERNAL_SERVER_ERROR)
-
-
-def from_mime_type(mime_type: str) -> DataFormat:
- formats = {
- "application/json": DataFormat.JSON,
- "application/octet-stream": DataFormat.JSON, # default in aiohttp
- }
- if mime_type not in formats:
- raise DataParsingError(f"unsupported MIME type '{mime_type}', expected: {str(formats)[1:-1]}")
- return formats[mime_type]
-
-
-def parse_from_mime_type(data: str, mime_type: str) -> Any:
- return from_mime_type(mime_type).parse_to_dict(data)
-
-
-class Server:
- # pylint: disable=too-many-instance-attributes
- # This is top-level class containing pretty much everything. Instead of global
- # variables, we use instance attributes. That's why there are so many and it's
- # ok.
- def __init__(self, store: ConfigStore, config_path: Optional[List[Path]], manager: KresManager) -> None:
- # config store & server dynamic reconfiguration
- self.config_store = store
-
- # HTTP server
- self.app = Application(middlewares=[error_handler])
- self.runner = AppRunner(self.app)
- self.listen: Optional[ManagementSchema] = None
- self.site: Union[NoneType, TCPSite, UnixSite] = None
- self.listen_lock = asyncio.Lock()
- self._config_path: Optional[List[Path]] = config_path
- self._exit_code: int = 0
- self._shutdown_event = asyncio.Event()
- self._manager = manager
-
- async def _reconfigure(self, config: KresConfig, force: bool = False) -> None:
- await self._reconfigure_listen_address(config)
-
- async def _deny_management_changes(
- self, config_old: KresConfig, config_new: KresConfig, force: bool = False
- ) -> Result[None, str]:
- if config_old.management != config_new.management:
- return Result.err(
- "/management: Changing the management API configuration dynamically is not allowed."
- " If you really need this feature, please contact the developers and explain why. Technically,"
- " there are no problems in supporting it. We are only blocking the dynamic changes because"
- " we think the consequences of leaving this footgun unprotected are worse than its usefulness."
- )
- return Result.ok(None)
-
- async def _deny_cache_garbage_collector_changes(
- self, config_old: KresConfig, config_new: KresConfig, _force: bool = False
- ) -> Result[None, str]:
- if config_old.cache.garbage_collector != config_new.cache.garbage_collector:
- return Result.err(
- "/cache/garbage-collector/*: Changing configuration dynamically is not allowed."
- " To change this configuration, you must edit the configuration file and restart the entire resolver."
- )
- return Result.ok(None)
-
- async def _reload_config(self, force: bool = False) -> None:
- if self._config_path is None:
- logger.warning("The manager was started with inlined configuration - can't reload")
- else:
- try:
- data: Dict[str, Any] = {}
- for file in self._config_path:
- file_data = try_to_parse(await readfile(file))
- data = data_combine(data, file_data)
-
- config = KresConfig(data)
- await self.config_store.update(config, force)
- logger.info("Configuration file successfully reloaded")
- except FileNotFoundError:
- logger.error(
- f"Configuration file was not found at '{file}'."
- " Something must have happened to it while we were running."
- )
- logger.error("Configuration has NOT been changed.")
- except (DataParsingError, DataValidationError) as e:
- logger.error(f"Failed to parse the updated configuration file: {e}")
- logger.error("Configuration has NOT been changed.")
- except KresManagerBaseError as e:
- logger.error(f"Reloading of the configuration file failed: {e}")
- logger.error("Configuration has NOT been changed.")
-
- async def _renew_config(self, force: bool = False) -> None:
- try:
- await self.config_store.renew(force)
- logger.info("Configuration successfully renewed")
- except KresManagerBaseError as e:
- logger.error(f"Renewing the configuration failed: {e}")
- logger.error("Configuration has NOT been renewed.")
-
- async def sigint_handler(self) -> None:
- logger.info("Received SIGINT, triggering graceful shutdown")
- self.trigger_shutdown(0)
-
- async def sigterm_handler(self) -> None:
- logger.info("Received SIGTERM, triggering graceful shutdown")
- self.trigger_shutdown(0)
-
- async def sighup_handler(self) -> None:
- logger.info("Received SIGHUP, reloading configuration file")
- systemd_notify(RELOADING="1")
- await self._reload_config()
- systemd_notify(READY="1")
-
- @staticmethod
- def all_handled_signals() -> Set[signal.Signals]:
- return {signal.SIGHUP, signal.SIGINT, signal.SIGTERM}
-
- def bind_signal_handlers(self) -> None:
- asyncio_compat.add_async_signal_handler(signal.SIGTERM, self.sigterm_handler)
- asyncio_compat.add_async_signal_handler(signal.SIGINT, self.sigint_handler)
- asyncio_compat.add_async_signal_handler(signal.SIGHUP, self.sighup_handler)
-
- def unbind_signal_handlers(self) -> None:
- asyncio_compat.remove_signal_handler(signal.SIGTERM)
- asyncio_compat.remove_signal_handler(signal.SIGINT)
- asyncio_compat.remove_signal_handler(signal.SIGHUP)
-
- async def start(self) -> None:
- self._setup_routes()
- await self.runner.setup()
- await self.config_store.register_verifier(self._deny_management_changes)
- await self.config_store.register_verifier(self._deny_cache_garbage_collector_changes)
- await self.config_store.register_on_change_callback(self._reconfigure)
-
- async def wait_for_shutdown(self) -> None:
- await self._shutdown_event.wait()
-
- def trigger_shutdown(self, exit_code: int) -> None:
- self._shutdown_event.set()
- self._exit_code = exit_code
-
- async def _handler_index(self, _request: web.Request) -> web.Response:
- """Indicate that the server is indeed running (dummy index handler)."""
- return json_response(
- {
- "msg": "Knot Resolver Manager is running! The configuration endpoint is at /config",
- "status": "RUNNING",
- }
- )
-
- async def _handler_config_query(self, request: web.Request) -> web.Response:
- """Route handler for changing resolver configuration."""
- # There are a lot of local variables in here, but they are usually immutable (almost SSA form :) )
- # pylint: disable=too-many-locals
-
- # parse the incoming data
- if request.method == "GET":
- update_with: Optional[Dict[str, Any]] = None
- else:
- update_with = parse_from_mime_type(await request.text(), request.content_type)
- document_path = request.match_info["path"]
- getheaders = ignore_exceptions_optional(List[str], None, KeyError)(request.headers.getall)
- etags = getheaders("if-match")
- not_etags = getheaders("if-none-match")
- current_config: Dict[str, Any] = self.config_store.get().get_unparsed_data()
-
- # stop processing if etags
- def strip_quotes(s: str) -> str:
- return s.strip('"')
-
- # WARNING: this check is prone to race conditions. When changing, make sure that the current config
- # is really the latest current config (i.e. no await in between obtaining the config and the checks)
- status = HTTPStatus.NOT_MODIFIED if request.method in ("GET", "HEAD") else HTTPStatus.PRECONDITION_FAILED
- if etags is not None and structural_etag(current_config) not in map(strip_quotes, etags):
- return web.Response(status=status)
- if not_etags is not None and structural_etag(current_config) in map(strip_quotes, not_etags):
- return web.Response(status=status)
-
- # run query
- op = cast(Literal["get", "delete", "patch", "put"], request.method.lower())
- new_config, to_return = query(current_config, op, document_path, update_with)
-
- # update the config
- if request.method != "GET":
- # validate
- config_validated = KresConfig(new_config)
- # apply
- await self.config_store.update(config_validated)
-
- # serialize the response (the `to_return` object is a Dict/list/scalar, we want to return json)
- resp_text: Optional[str] = json.dumps(to_return) if to_return is not None else None
-
- # create the response and return it
- res = web.Response(status=HTTPStatus.OK, text=resp_text, content_type="application/json")
- res.headers.add("ETag", f'"{structural_etag(new_config)}"')
- return res
-
- async def _handler_metrics(self, request: web.Request) -> web.Response:
- raise web.HTTPMovedPermanently("/metrics/json")
-
- async def _handler_metrics_json(self, _request: web.Request) -> web.Response:
- config = self.config_store.get()
-
- return web.Response(
- body=await metrics.report_json(config),
- content_type="application/json",
- charset="utf8",
- )
-
- async def _handler_metrics_prometheus(self, _request: web.Request) -> web.Response:
- metrics_report = await metrics.report_prometheus()
- if not metrics_report:
- raise web.HTTPNotFound()
-
- return web.Response(
- body=metrics_report,
- content_type="text/plain",
- charset="utf8",
- )
-
- async def _handler_cache_clear(self, request: web.Request) -> web.Response:
- data = parse_from_mime_type(await request.text(), request.content_type)
- config = CacheClearRPCSchema(data)
-
- _, result = await command_single_registered_worker(config.render_lua())
- return web.Response(
- body=json.dumps(result),
- content_type="application/json",
- charset="utf8",
- )
-
- async def _handler_schema(self, _request: web.Request) -> web.Response:
- return web.json_response(
- kres_config_json_schema(), headers={"Access-Control-Allow-Origin": "*"}, dumps=partial(json.dumps, indent=4)
- )
-
- async def _handle_view_schema(self, _request: web.Request) -> web.Response:
- """
- Provide a UI for visuallising and understanding JSON schema.
-
- The feature in the Knot Resolver Manager to render schemas is unwanted, as it's completely
- out of scope. However, it can be convinient. We therefore rely on a public web-based viewers
- and provide just a redirect. If this feature ever breaks due to disapearance of the public
- service, we can fix it. But we are not guaranteeing, that this will always work.
- """
-
- return web.Response(
- text="""
- <html>
- <head><title>Redirect to schema viewer</title></head>
- <body>
- <script>
- // we are using JS in order to use proper host
- let protocol = window.location.protocol;
- let host = window.location.host;
- let url = encodeURIComponent(`${protocol}//${host}/schema`);
- window.location.replace(`https://json-schema.app/view/%23?url=${url}`);
- </script>
- <h1>JavaScript required for a dynamic redirect...</h1>
- </body>
- </html>
- """,
- content_type="text/html",
- )
-
- async def _handler_stop(self, _request: web.Request) -> web.Response:
- """Route handler for shutting down the server (and whole manager)."""
- self._shutdown_event.set()
- logger.info("Shutdown event triggered...")
- return web.Response(text="Shutting down...")
-
- async def _handler_reload(self, request: web.Request) -> web.Response:
- """Route handler for reloading the configuration."""
- logger.info("Reloading event triggered...")
- await self._reload_config(force=bool(request.path.endswith("/force")))
- return web.Response(text="Reloading...")
-
- async def _handler_renew(self, request: web.Request) -> web.Response:
- """Route handler for renewing the configuration."""
- logger.info("Renewing configuration event triggered...")
- await self._renew_config(force=bool(request.path.endswith("/force")))
- return web.Response(text="Renewing configuration...")
-
- async def _handler_processes(self, request: web.Request) -> web.Response:
- """Route handler for listing PIDs of subprocesses."""
- proc_type: Optional[SubprocessType] = None
-
- if "path" in request.match_info and len(request.match_info["path"]) > 0:
- ptstr = request.match_info["path"]
- if ptstr == "/kresd":
- proc_type = SubprocessType.KRESD
- elif ptstr == "/gc":
- proc_type = SubprocessType.GC
- elif ptstr == "/all":
- proc_type = None
- else:
- return web.Response(text=f"Invalid process type '{ptstr}'", status=400)
-
- return web.json_response(
- await self._manager.get_processes(proc_type),
- headers={"Access-Control-Allow-Origin": "*"},
- dumps=partial(json.dumps, indent=4),
- )
-
- def _setup_routes(self) -> None:
- self.app.add_routes(
- [
- web.get("/", self._handler_index),
- web.get(r"/v1/config{path:.*}", self._handler_config_query),
- web.put(r"/v1/config{path:.*}", self._handler_config_query),
- web.delete(r"/v1/config{path:.*}", self._handler_config_query),
- web.patch(r"/v1/config{path:.*}", self._handler_config_query),
- web.post("/stop", self._handler_stop),
- web.post("/reload", self._handler_reload),
- web.post("/reload/force", self._handler_reload),
- web.post("/renew", self._handler_renew),
- web.post("/renew/force", self._handler_renew),
- web.get("/schema", self._handler_schema),
- web.get("/schema/ui", self._handle_view_schema),
- web.get("/metrics", self._handler_metrics),
- web.get("/metrics/json", self._handler_metrics_json),
- web.get("/metrics/prometheus", self._handler_metrics_prometheus),
- web.post("/cache/clear", self._handler_cache_clear),
- web.get("/processes{path:.*}", self._handler_processes),
- ]
- )
-
- async def _reconfigure_listen_address(self, config: KresConfig) -> None:
- async with self.listen_lock:
- mgn = config.management
-
- # if the listen address did not change, do nothing
- if self.listen == mgn:
- return
-
- # start the new listen address
- nsite: Union[web.TCPSite, web.UnixSite]
- if mgn.unix_socket:
- nsite = web.UnixSite(self.runner, str(mgn.unix_socket))
- logger.info(f"Starting API HTTP server on http+unix://{mgn.unix_socket}")
- elif mgn.interface:
- nsite = web.TCPSite(self.runner, str(mgn.interface.addr), int(mgn.interface.port))
- logger.info(f"Starting API HTTP server on http://{mgn.interface.addr}:{mgn.interface.port}")
- else:
- raise KresManagerBaseError("Requested API on unsupported configuration format.")
- await nsite.start()
-
- # stop the old listen
- assert (self.listen is None) == (self.site is None)
- if self.listen is not None and self.site is not None:
- if self.listen.unix_socket:
- logger.info(f"Stopping API HTTP server on http+unix://{mgn.unix_socket}")
- elif self.listen.interface:
- logger.info(
- f"Stopping API HTTP server on http://{self.listen.interface.addr}:{self.listen.interface.port}"
- )
- await self.site.stop()
-
- # save new state
- self.listen = mgn
- self.site = nsite
-
- async def shutdown(self) -> None:
- if self.site is not None:
- await self.site.stop()
- await self.runner.cleanup()
-
- def get_exit_code(self) -> int:
- return self._exit_code
-
-
-async def _load_raw_config(config: Union[Path, Dict[str, Any]]) -> Dict[str, Any]:
- # Initial configuration of the manager
- if isinstance(config, Path):
- if not config.exists():
- raise KresManagerBaseError(
- f"Manager is configured to load config file at {config} on startup, but the file does not exist."
- )
- logger.info(f"Loading configuration from '{config}' file.")
- config = try_to_parse(await readfile(config))
-
- # validate the initial configuration
- assert isinstance(config, dict)
- return config
-
-
-async def _load_config(config: Dict[str, Any]) -> KresConfig:
- return KresConfig(config)
-
-
-async def _init_config_store(config: Dict[str, Any]) -> ConfigStore:
- config_validated = await _load_config(config)
- return ConfigStore(config_validated)
-
-
-async def _init_manager(config_store: ConfigStore) -> KresManager:
- """Call asynchronously when the application initializes."""
- # Instantiate subprocess controller (if we wanted to, we could switch it at this point)
- controller = await get_best_controller_implementation(config_store.get())
-
- # Create KresManager. This will perform autodetection of available service managers and
- # select the most appropriate to use (or use the one configured directly)
- manager = await KresManager.create(controller, config_store)
-
- logger.info("Initial configuration applied. Process manager initialized...")
- return manager
-
-
-async def _deny_working_directory_changes(
- config_old: KresConfig, config_new: KresConfig, force: bool = False
-) -> Result[None, str]:
- if config_old.rundir != config_new.rundir:
- return Result.err("Changing manager's `rundir` during runtime is not allowed.")
-
- return Result.ok(None)
-
-
-def _set_working_directory(config_raw: Dict[str, Any]) -> None:
- try:
- rundir = get_rundir_without_validation(config_raw)
- except ValueError as e:
- raise DataValidationError(str(e), "/rundir") from e
-
- logger.debug(f"Changing working directory to '{rundir.to_path().absolute()}'.")
- os.chdir(rundir.to_path())
-
-
-def _lock_working_directory(attempt: int = 0) -> None:
- # the following syscall is atomic, it's essentially the same as acquiring a lock
- try:
- pidfile_fd = os.open(PID_FILE_NAME, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o644)
- except OSError as e:
- if e.errno == errno.EEXIST and attempt == 0:
- # the pid file exists, let's check PID
- with open(PID_FILE_NAME, "r", encoding="utf-8") as f:
- pid = int(f.read().strip())
- try:
- os.kill(pid, 0)
- except OSError as e2:
- if e2.errno == errno.ESRCH:
- os.unlink(PID_FILE_NAME)
- _lock_working_directory(attempt=attempt + 1)
- return
- raise KresManagerBaseError(
- "Another manager is running in the same working directory."
- f" PID file is located at {os.getcwd()}/{PID_FILE_NAME}"
- ) from e
- raise KresManagerBaseError(
- "Another manager is running in the same working directory."
- f" PID file is located at {os.getcwd()}/{PID_FILE_NAME}"
- ) from e
-
- # now we know that we are the only manager running in this directory
-
- # write PID to the pidfile and close it afterwards
- pidfile = os.fdopen(pidfile_fd, "w")
- pid = os.getpid()
- pidfile.write(f"{pid}\n")
- pidfile.close()
-
- # make sure that the file is deleted on shutdown
- atexit.register(lambda: os.unlink(PID_FILE_NAME))
-
-
-async def _sigint_while_shutting_down() -> None:
- logger.warning(
- "Received SIGINT while already shutting down. Ignoring."
- " If you want to forcefully stop the manager right now, use SIGTERM."
- )
-
-
-async def _sigterm_while_shutting_down() -> None:
- logger.warning("Received SIGTERM. Invoking dirty shutdown!")
- sys.exit(128 + signal.SIGTERM)
-
-
-async def start_server(config: List[str]) -> int: # noqa: C901, PLR0915
- # This function is quite long, but it describes how manager runs. So let's silence pylint
- # pylint: disable=too-many-statements
-
- start_time = time()
- working_directory_on_startup = os.getcwd()
- manager: Optional[KresManager] = None
-
- # Block signals during initialization to force their processing once everything is ready
- signal.pthread_sigmask(signal.SIG_BLOCK, Server.all_handled_signals())
-
- # Check if we are running under the intended user, if not, log a warning message
- pw_username = getpwuid(os.getuid()).pw_name
- if pw_username != USER:
- logger.warning(
- f"Knot Resolver does not run as the default '{USER}' user, but as '{pw_username}' instead."
- " This may or may not affect the configuration validation and the proper functioning of the resolver."
- )
- if os.geteuid() == 0:
- logger.warning(" It is not recommended to run under root privileges unless there is no other option.")
-
- # before starting server, initialize the subprocess controller, config store, etc. Any errors during inicialization
- # are fatal
- try:
- # Make sure that the config paths does not change meaning when we change working directory
- config_absolute = [Path(path).absolute() for path in config]
-
- config_data: Dict[str, Any] = {}
- for file in config_absolute:
- # warning about the different parent directories of each config file
- # compared to the first one which is used as the prefix path
- if config_absolute[0].parent != file.parent:
- logger.warning(
- f"The configuration file '{file}' has a parent directory that is different"
- f" from '{config_absolute[0]}', which is used as the prefix for relative paths."
- "This can cause issues with files that are configured with relative paths."
- )
-
- # Preprocess config - load from file or in general take it to the last step before validation.
- config_raw = await _load_raw_config(file)
-
- # combine data from all config files
- config_data = data_combine(config_data, config_raw)
-
- # before processing any configuration, set validation context
- # - resolve_root: root against which all relative paths will be resolved
- # - strict_validation: check for path existence during configuration validation
- # - permissions_default: validate dirs/files rwx permissions against default user:group in constants
- set_global_validation_context(Context(config_absolute[0].parent, True, False))
-
- # We want to change cwd as soon as possible. Some parts of the codebase are using os.getcwd() to get the
- # working directory.
- #
- # If we fail to read rundir from unparsed config, the first config validation error comes from here
- _set_working_directory(config_data)
-
- # We don't want more than one manager in a single working directory. So we lock it with a PID file.
- # Warning - this does not prevent multiple managers with the same naming of kresd service.
- _lock_working_directory()
-
- # set_global_validation_context(Context(config.parent))
-
- # After the working directory is set, we can initialize proper config store with a newly parsed configuration.
- config_store = await _init_config_store(config_data)
-
- # Some "constants" need to be loaded from the initial config,
- # some need to be stored from the initial run conditions
- await init_user_constants(config_store, working_directory_on_startup)
-
- # This behaviour described above with paths means, that we MUST NOT allow `rundir` change after initialization.
- # It would cause strange problems because every other path configuration depends on it. Therefore, we have to
- # add a check to the config store, which disallows changes.
- await config_store.register_verifier(_deny_working_directory_changes)
-
- # Up to this point, we have been logging to memory buffer. But now, when we have the configuration loaded, we
- # can flush the buffer into the proper place
- await logger_init(config_store)
-
- # With configuration on hand, we can initialize monitoring. We want to do this before any subprocesses are
- # started, therefore before initializing manager
- await metrics.init_prometheus(config_store)
-
- await files.init_files_watchdog(config_store)
-
- # After we have loaded the configuration, we can start worrying about subprocess management.
- manager = await _init_manager(config_store)
-
- # prepare instance of the server (no side effects)
- server = Server(config_store, config_absolute, manager)
-
- # add Server's shutdown trigger to the manager
- manager.add_shutdown_trigger(server.trigger_shutdown)
-
- except KresSubprocessControllerExec as e:
- # if we caught this exception, some component wants to perform a reexec during startup. Most likely, it would
- # be a subprocess manager like supervisord, which wants to make sure the manager runs under supervisord in
- # the process tree. So now we stop everything, and exec what we are told to. We are assuming, that the thing
- # we'll exec will invoke us again.
- logger.info("Exec requested with arguments: %s", str(e.exec_args))
-
- # unblock signals, this could actually terminate us straight away
- signal.pthread_sigmask(signal.SIG_UNBLOCK, Server.all_handled_signals())
-
- # run exit functions
- atexit.run_callbacks()
-
- # and finally exec what we were told to exec
- os.execl(*e.exec_args)
-
- except KresSubprocessControllerError as e:
- logger.error(f"Server initialization failed: {e}")
- return 1
-
- except KresManagerBaseError as e:
- # We caught an error with a pretty error message. Just print it and exit.
- logger.error(e)
- return 1
-
- except BaseException:
- logger.error("Uncaught generic exception during manager inicialization...", exc_info=True)
- return 1
-
- # At this point, all backend functionality-providing components are initialized. It's therefore save to start
- # the API server.
- try:
- await server.start()
- except OSError as e:
- if e.errno in (errno.EADDRINUSE, errno.EADDRNOTAVAIL):
- # fancy error reporting of network binding errors
- logger.error(str(e))
- await manager.stop()
- return 1
- raise
-
- # At this point, pretty much everything is ready to go. We should just make sure the user can shut
- # the manager down with signals.
- server.bind_signal_handlers()
- signal.pthread_sigmask(signal.SIG_UNBLOCK, Server.all_handled_signals())
-
- logger.info(f"Manager fully initialized and running in {round(time() - start_time, 3)} seconds")
-
- # notify systemd/anything compatible that we are ready
- systemd_notify(READY="1")
-
- await server.wait_for_shutdown()
-
- # notify systemd that we are shutting down
- systemd_notify(STOPPING="1")
-
- # Ok, now we are tearing everything down.
-
- # First of all, let's block all unwanted interruptions. We don't want to be reconfiguring kresd's while
- # shutting down.
- signal.pthread_sigmask(signal.SIG_BLOCK, Server.all_handled_signals())
- server.unbind_signal_handlers()
- # on the other hand, we want to immediatelly stop when the user really wants us to stop
- asyncio_compat.add_async_signal_handler(signal.SIGTERM, _sigterm_while_shutting_down)
- asyncio_compat.add_async_signal_handler(signal.SIGINT, _sigint_while_shutting_down)
- signal.pthread_sigmask(signal.SIG_UNBLOCK, {signal.SIGTERM, signal.SIGINT})
-
- # After triggering shutdown, we neet to clean everything up
- logger.info("Stopping API service...")
- await server.shutdown()
- logger.info("Stopping kresd manager...")
- await manager.stop()
- logger.info(f"The manager run for {round(time() - start_time)} seconds...")
- return server.get_exit_code()
+++ /dev/null
-import logging
-from threading import Timer
-from typing import Dict, Optional
-from urllib.parse import quote
-
-from knot_resolver.controller.registered_workers import command_registered_workers
-from knot_resolver.datamodel import KresConfig
-from knot_resolver.utils import compat
-from knot_resolver.utils.requests import SocketDesc, request
-
-logger = logging.getLogger(__name__)
-
-_triggers: Optional["Triggers"] = None
-
-
-class Triggers:
- def __init__(self, config: KresConfig) -> None:
- self._config = config
-
- self._reload_force = False
- self._renew_force = False
- self._renew_timer: Optional[Timer] = None
- self._reload_timer: Optional[Timer] = None
- self._cmd_timers: Dict[str, Timer] = {}
-
- management = config.management
- socket = SocketDesc(
- f'http+unix://{quote(str(management.unix_socket), safe="")}/',
- 'Key "/management/unix-socket" in validated configuration',
- )
- if management.interface:
- socket = SocketDesc(
- f"http://{management.interface.addr}:{management.interface.port}",
- 'Key "/management/interface" in validated configuration',
- )
- self._socket = socket
-
- def trigger_cmd(self, cmd: str) -> None:
- def _cmd() -> None:
- if compat.asyncio.is_event_loop_running():
- compat.asyncio.create_task(command_registered_workers(cmd))
- else:
- compat.asyncio.run(command_registered_workers(cmd))
- logger.info(f"Sending '{cmd}' command to reload watched files has finished")
-
- # skipping if command was already triggered
- if cmd in self._cmd_timers and self._cmd_timers[cmd].is_alive():
- logger.info(f"Skipping sending '{cmd}' command, it was already triggered")
- return
- # start a 5sec timer
- logger.info(f"Delayed send of '{cmd}' command has started")
- self._cmd_timers[cmd] = Timer(5, _cmd)
- self._cmd_timers[cmd].start()
-
- def cancel_cmd(self, cmd: str) -> None:
- if cmd in self._cmd_timers:
- self._cmd_timers[cmd].cancel()
-
- def trigger_renew(self, force: bool = False) -> None:
- def _renew() -> None:
- response = request(self._socket, "POST", "renew/force" if force else "renew")
- if response.status != 200:
- logger.error(f"Failed to renew configuration: {response.body}")
- logger.info("Renewing configuration has finished")
- self._renew_force = False
-
- # do not trigger renew if reload is scheduled
- if self._reload_timer and self._reload_timer.is_alive() and self._reload_force >= force:
- logger.info("Skipping renewing configuration, reload was already triggered")
- return
-
- # skipping if reload was already triggered
- if self._renew_timer and self._renew_timer.is_alive():
- if self._renew_force >= force:
- logger.info("Skipping renewing configuration, it was already triggered")
- return
- self._renew_timer.cancel()
- self._renew_force = False
-
- logger.info("Delayed configuration renew has started")
- # start a 5sec timer
- self._renew_timer = Timer(5, _renew)
- self._renew_timer.start()
- self._renew_force = force
-
- def trigger_reload(self, force: bool = False) -> None:
- def _reload() -> None:
- response = request(self._socket, "POST", "reload/force" if force else "reload")
- if response.status != 200:
- logger.error(f"Failed to reload configuration: {response.body}")
- logger.info("Reloading configuration has finished")
- self._reload_force = False
-
- # cancel renew
- if self._renew_timer and self._renew_timer.is_alive() and force >= self._renew_force:
- self._renew_timer.cancel()
- self._renew_force = False
-
- # skipping if reload was already triggered
- if self._reload_timer and self._reload_timer.is_alive():
- if self._reload_force >= force:
- logger.info("Skipping reloading configuration, it was already triggered")
- return
- logger.info("Cancelling already scheduled configuration reload, force reload triggered")
- self._reload_timer.cancel()
- self._reload_force = False
-
- logger.info("Delayed configuration reload has started")
- # start a 5sec timer
- self._reload_timer = Timer(5, _reload)
- self._reload_timer.start()
- self._reload_force = force
-
-
-def trigger_cmd(config: KresConfig, cmd: str) -> None:
- global _triggers
- if not _triggers:
- _triggers = Triggers(config)
- _triggers.trigger_cmd(cmd)
-
-
-def cancel_cmd(cmd: str) -> None:
- global _triggers # noqa: PLW0602
- if _triggers:
- _triggers.cancel_cmd(cmd)
-
-
-def trigger_renew(config: KresConfig, force: bool = False) -> None:
- global _triggers
- if not _triggers:
- _triggers = Triggers(config)
- _triggers.trigger_renew(force)
-
-
-def trigger_reload(config: KresConfig, force: bool = False) -> None:
- global _triggers
- if not _triggers:
- _triggers = Triggers(config)
- _triggers.trigger_reload(force)
+++ /dev/null
-from typing import Any, Callable, Optional, Type, TypeVar
-
-T = TypeVar("T")
-
-
-def ignore_exceptions_optional(
- _tp: Type[T], default: Optional[T], *exceptions: Type[BaseException]
-) -> Callable[[Callable[..., Optional[T]]], Callable[..., Optional[T]]]:
- """
- Wrap function preventing it from raising exceptions and instead returning the configured default value.
-
- :param type[T] _tp: Return type of the function. Essentialy only a template argument for type-checking
- :param T default: The value to return as a default
- :param list[Type[BaseException]] exceptions: The list of exceptions to catch
- :return: value of the decorated function, or default if exception raised
- :rtype: T
- """
-
- def decorator(func: Callable[..., Optional[T]]) -> Callable[..., Optional[T]]:
- def f(*nargs: Any, **nkwargs: Any) -> Optional[T]:
- try:
- return func(*nargs, **nkwargs)
- except BaseException as e:
- if isinstance(e, exceptions):
- return default
- raise
-
- return f
-
- return decorator
-
-
-def ignore_exceptions(
- default: T, *exceptions: Type[BaseException]
-) -> Callable[[Callable[..., Optional[T]]], Callable[..., Optional[T]]]:
- return ignore_exceptions_optional(type(default), default, *exceptions)
-
-
-def phantom_use(var: Any) -> None: # pylint: disable=unused-argument
- """
- Consumes argument doing absolutely nothing with it.
-
- Useful for convincing pylint, that we need the variable even when its unused.
- """
+++ /dev/null
-import asyncio
-import os
-import pkgutil
-import signal
-import sys
-import time
-from asyncio import create_subprocess_exec, create_subprocess_shell
-from pathlib import PurePath
-from threading import Thread
-from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
-
-from knot_resolver.utils.compat.asyncio import to_thread
-
-
-def unblock_signals() -> None:
- if sys.version_info >= (3, 8):
- signal.pthread_sigmask(signal.SIG_UNBLOCK, signal.valid_signals())
- else:
- # the list of signals is not exhaustive, but it should cover all signals we might ever want to block
- signal.pthread_sigmask(
- signal.SIG_UNBLOCK,
- {
- signal.SIGHUP,
- signal.SIGINT,
- signal.SIGTERM,
- signal.SIGUSR1,
- signal.SIGUSR2,
- },
- )
-
-
-async def call(
- cmd: Union[str, bytes, List[str], List[bytes]], shell: bool = False, discard_output: bool = False
-) -> int:
- """Async alternative to subprocess.call()."""
- kwargs: Dict[str, Any] = {
- "preexec_fn": unblock_signals,
- }
- if discard_output:
- kwargs["stdout"] = asyncio.subprocess.DEVNULL
- kwargs["stderr"] = asyncio.subprocess.DEVNULL
-
- if shell:
- if isinstance(cmd, list):
- raise RuntimeError("can't use list of arguments with shell=True")
- proc = await create_subprocess_shell(cmd, **kwargs)
- else:
- if not isinstance(cmd, list):
- raise RuntimeError(
- "Please use list of arguments, not a single string. It will prevent ambiguity when parsing"
- )
- proc = await create_subprocess_exec(*cmd, **kwargs)
-
- return await proc.wait()
-
-
-async def readfile(path: Union[str, PurePath]) -> str:
- """Asynchronously read whole file and return its content."""
-
- def readfile_sync(path: Union[str, PurePath]) -> str:
- with open(path, "r", encoding="utf8") as f:
- return f.read()
-
- return await to_thread(readfile_sync, path)
-
-
-async def writefile(path: Union[str, PurePath], content: str) -> None:
- """Asynchronously set content of a file to a given string `content`."""
-
- def writefile_sync(path: Union[str, PurePath], content: str) -> int:
- with open(path, "w", encoding="utf8") as f:
- return f.write(content)
-
- await to_thread(writefile_sync, path, content)
-
-
-async def wait_for_process_termination(pid: int, sleep_sec: float = 0) -> None:
- """
- Wait for the process termination.
-
- Will wait for any process (does not have to be a child process)
- given by its PID to terminate sleep_sec configures the granularity,
- with which we should return.
- """
-
- def wait_sync(pid: int, sleep_sec: float) -> None:
- while True:
- try:
- os.kill(pid, 0)
- if sleep_sec == 0:
- os.sched_yield()
- else:
- time.sleep(sleep_sec)
- except ProcessLookupError:
- break
-
- await to_thread(wait_sync, pid, sleep_sec)
-
-
-async def read_resource(package: str, filename: str) -> Optional[bytes]:
- return await to_thread(pkgutil.get_data, package, filename)
-
-
-T = TypeVar("T")
-
-
-class BlockingEventDispatcher(Thread, Generic[T]):
- def __init__(self, name: str = "blocking_event_dispatcher") -> None:
- super().__init__(name=name, daemon=True)
- # warning: the asyncio queue is not thread safe
- self._removed_unit_names: "asyncio.Queue[T]" = asyncio.Queue()
- self._main_event_loop = asyncio.get_event_loop()
-
- def dispatch_event(self, event: T) -> None:
- """Dispatch events from the blocking thread."""
-
- async def add_to_queue() -> None:
- await self._removed_unit_names.put(event)
-
- self._main_event_loop.call_soon_threadsafe(add_to_queue)
-
- async def next_event(self) -> T:
- return await self._removed_unit_names.get()
+++ /dev/null
-from . import asyncio, typing
-
-__all__ = ["asyncio", "typing"]
+++ /dev/null
-# We disable pylint checks, because it can't find methods in newer Python versions.
-#
-# pylint: disable=no-member
-
-import asyncio
-import functools
-import logging
-import sys
-from asyncio import AbstractEventLoop, coroutines, events, tasks
-from typing import Any, Awaitable, Callable, Coroutine, Optional, TypeVar
-
-logger = logging.getLogger(__name__)
-
-T = TypeVar("T")
-
-
-async def to_thread(func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
- # version 3.9 and higher, call directly
- if sys.version_info >= (3, 9):
- return await asyncio.to_thread(func, *args, **kwargs) # type: ignore[attr-defined]
-
- # earlier versions, run with default executor
- loop = asyncio.get_event_loop()
- pfunc = functools.partial(func, *args, **kwargs)
- return await loop.run_in_executor(None, pfunc)
-
-
-def async_in_a_thread(func: Callable[..., T]) -> Callable[..., Coroutine[None, None, T]]:
- async def wrapper(*args: Any, **kwargs: Any) -> T:
- return await to_thread(func, *args, **kwargs)
-
- return wrapper
-
-
-def create_task(coro: Awaitable[T], name: Optional[str] = None) -> "asyncio.Task[T]":
- # version 3.8 and higher, call directly
- if sys.version_info >= (3, 8):
- # pylint: disable=unexpected-keyword-arg
- return asyncio.create_task(coro, name=name) # type: ignore[attr-defined,arg-type,call-arg]
-
- # version 3.7 and higher, call directly without the name argument
- if sys.version_info >= (3, 8):
- return asyncio.create_task(coro) # type: ignore[attr-defined,arg-type]
-
- # earlier versions, use older function
- return asyncio.ensure_future(coro)
-
-
-def is_event_loop_running() -> bool:
- loop = events._get_running_loop() # noqa: SLF001
- return loop is not None and loop.is_running()
-
-
-def run(coro: Awaitable[T], debug: Optional[bool] = None) -> T:
- # Adapted version of this:
- # https://github.com/python/cpython/blob/3.9/Lib/asyncio/runners.py#L8
-
- # version 3.7 and higher, call directly
- # disabled due to incompatibilities
- if sys.version_info >= (3, 7):
- return asyncio.run(coro, debug=debug) # type: ignore[attr-defined,arg-type]
-
- # earlier versions, use backported version of the function
- if events._get_running_loop() is not None: # noqa: SLF001
- raise RuntimeError("asyncio.run() cannot be called from a running event loop")
-
- if not coroutines.iscoroutine(coro):
- raise ValueError(f"a coroutine was expected, got {repr(coro)}")
-
- loop = events.new_event_loop()
- try:
- events.set_event_loop(loop)
- if debug is not None:
- loop.set_debug(debug)
- return loop.run_until_complete(coro)
- finally:
- try:
- _cancel_all_tasks(loop)
- loop.run_until_complete(loop.shutdown_asyncgens())
- if hasattr(loop, "shutdown_default_executor"):
- loop.run_until_complete(loop.shutdown_default_executor()) # type: ignore[attr-defined]
- finally:
- events.set_event_loop(None)
- loop.close()
-
-
-def _cancel_all_tasks(loop: AbstractEventLoop) -> None:
- # Backported from:
- # https://github.com/python/cpython/blob/3.9/Lib/asyncio/runners.py#L55-L74
- #
- to_cancel = tasks.all_tasks(loop)
- if not to_cancel:
- return
-
- for task in to_cancel:
- task.cancel()
-
- if sys.version_info >= (3, 7):
- # since 3.7, the loop argument is implicitely the running loop
- # since 3.10, the loop argument is removed
- loop.run_until_complete(tasks.gather(*to_cancel, return_exceptions=True))
- else:
- loop.run_until_complete(tasks.gather(*to_cancel, loop=loop, return_exceptions=True)) # type: ignore[call-overload]
-
- for task in to_cancel:
- if task.cancelled():
- continue
- if task.exception() is not None:
- loop.call_exception_handler(
- {
- "message": "unhandled exception during asyncio.run() shutdown",
- "exception": task.exception(),
- "task": task,
- }
- )
-
-
-def add_async_signal_handler(signal: int, callback: Callable[[], Coroutine[Any, Any, None]]) -> None:
- loop = asyncio.get_event_loop()
- loop.add_signal_handler(signal, lambda: create_task(callback()))
-
-
-def remove_signal_handler(signal: int) -> bool:
- loop = asyncio.get_event_loop()
- return loop.remove_signal_handler(signal)
+++ /dev/null
-# The 'typing.Pattern' is deprecated since python 3.8 and is removed in version 3.12.
-# https://docs.python.org/3.9/library/typing.html#typing.Pattern
-try:
- from typing import Pattern
-except ImportError:
- from re import Pattern
-
-__all__ = ["Pattern"]
+++ /dev/null
-"""
-Custom replacement for standard module `atexit`.
-
-We use `atexit` behind the scenes, we just add the option
-to invoke the exit functions manually.
-"""
-
-import atexit
-from typing import Callable, List
-
-_at_exit_functions: List[Callable[[], None]] = []
-
-
-def register(func: Callable[[], None]) -> None:
- _at_exit_functions.append(func)
- atexit.register(func)
-
-
-def run_callbacks() -> None:
- for func in _at_exit_functions:
- func()
- atexit.unregister(func)
+++ /dev/null
-import base64
-import json
-from hashlib import blake2b
-from typing import Any
-
-
-def structural_etag(obj: Any) -> str:
- m = blake2b(digest_size=15)
- m.update(json.dumps(obj, sort_keys=True).encode("utf8"))
- return base64.urlsafe_b64encode(m.digest()).decode("utf8")
+++ /dev/null
-from enum import Enum, auto
-from typing import Any, Callable, Generic, Iterable, TypeVar, Union
-
-T = TypeVar("T")
-
-
-def foldl(oper: Callable[[T, T], T], default: T, arr: Iterable[T]) -> T:
- val = default
- for x in arr:
- val = oper(val, x)
- return val
-
-
-def contains_element_matching(cond: Callable[[T], bool], arr: Iterable[T]) -> bool:
- return foldl(lambda x, y: x or y, False, map(cond, arr))
-
-
-def all_matches(cond: Callable[[T], bool], arr: Iterable[T]) -> bool:
- return foldl(lambda x, y: x and y, True, map(cond, arr))
-
-
-Succ = TypeVar("Succ")
-Err = TypeVar("Err")
-
-
-class _Status(Enum):
- OK = auto()
- ERROR = auto()
-
-
-class _ResultSentinel:
- pass
-
-
-_RESULT_SENTINEL = _ResultSentinel()
-
-
-class Result(Generic[Succ, Err]):
- @staticmethod
- def ok(succ: T) -> "Result[T, Any]":
- return Result(_Status.OK, succ=succ)
-
- @staticmethod
- def err(err: T) -> "Result[Any, T]":
- return Result(_Status.ERROR, err=err)
-
- def __init__(
- self,
- status: _Status,
- succ: Union[Succ, _ResultSentinel] = _RESULT_SENTINEL,
- err: Union[Err, _ResultSentinel] = _RESULT_SENTINEL,
- ) -> None:
- super().__init__()
- self._status: _Status = status
- self._succ: Union[_ResultSentinel, Succ] = succ
- self._err: Union[_ResultSentinel, Err] = err
-
- def unwrap(self) -> Succ:
- assert self._status is _Status.OK
- assert not isinstance(self._succ, _ResultSentinel)
- return self._succ
-
- def unwrap_err(self) -> Err:
- assert self._status is _Status.ERROR
- assert not isinstance(self._err, _ResultSentinel)
- return self._err
-
- def is_ok(self) -> bool:
- return self._status is _Status.OK
-
- def is_err(self) -> bool:
- return self._status is _Status.ERROR
+++ /dev/null
-# Modeling utils
-
-These utilities are used to model schemas for data stored in a python dictionary or YAML and JSON format.
-The utilities also take care of parsing, validating and creating JSON schemas and basic documentation.
-
-## Creating schema
-
-Schema is created using `ConfigSchema` class. Schema structure is specified using annotations.
-
-```python
-from .modeling import ConfigSchema
-
-class SimpleSchema(ConfigSchema):
- integer: int = 5 # a default value can be specified
- string: str
- boolean: bool
-```
-Even more complex types can be used in a schema. Schemas can be also nested.
-Words in multi-word names are separated by underscore `_` (e.g. `simple_schema`).
-
-```python
-from typing import Dict, List, Optional, Union
-
-class ComplexSchema(ConfigSchema):
- optional: Optional[str] # this field is optional
- union: Union[int, str] # integer and string are both valid
- list: List[int] # list of integers
- dictionary: Dict[str, bool] = {"key": False}
- simple_schema: SimpleSchema # nested schema
-```
-
-
-### Additianal validation
-
-If a some additional validation needs to be done, there is `_validate()` method for that.
-`ValueError` exception should be raised in case of validation error.
-
-```python
-class FieldsSchema(ConfigSchema):
- field1: int
- field2: int
-
- def _validate(self) -> None:
- if self.field1 > self.field2:
- raise ValueError("field1 is bigger than field2")
-```
-
-
-### Additional layer, transformation methods
-
-It is possible to add layers to schema and use a transformation method between layers to process the value.
-Transformation method must be named based on field (`value` in this example) with `_` underscore prefix.
-In this example, the `Layer2Schema` is structure for input data and `Layer1Schema` is for result data.
-
-```python
-class Layer1Schema(ConfigSchema):
- class Layer2Schema(ConfigSchema):
- value: Union[str, int]
-
- _LAYER = Layer2Schema
-
- value: int
-
- def _value(self, obj: Layer2Schema) -> Any:
- if isinstance(str, obj.value):
- return len(obj.value) # transform str values to int; this is just example
- return obj.value
-```
-
-### Documentation and JSON schema
-
-Created schema can be documented using simple docstring. Json schema is created by calling `json_schema()` method on schema class. JSON schema includes description from docstring, defaults, etc.
-
-```python
-SimpleSchema(ConfigSchema):
- """
- This is description for SimpleSchema itself.
-
- ---
- integer: description for integer field
- string: description for string field
- boolean: description for boolean field
- """
-
- integer: int = 5
- string: str
- boolean: bool
-
-json_schema = SimpleSchema.json_schema()
-```
-
-
-## Creating custom type
-
-Custom types can be made by extending `BaseValueType` class which is integrated to parsing and validating process.
-Use `DataValidationError` to rase exception during validation. `object_path` is used to track node in more complex/nested schemas and create useful logging message.
-
-```python
-from .modeling import BaseValueType
-from .modeling.exceptions import DataValidationError
-
-class IntNonNegative(BaseValueType):
- def __init__(self, source_value: Any, object_path: str = "/") -> None:
- super().__init__(source_value)
- if isinstance(source_value, int) and not isinstance(source_value, bool):
- if source_value < 0:
- raise DataValidationError(f"value {source_value} is negative number.", object_path)
- self._value = source_value
- else:
- raise DataValidationError(
- f"expected integer, got '{type(source_value)}'",
- object_path,
- )
-```
-
-For JSON schema you should implement `json_schema` method.
-It should return [JSON schema representation](https://json-schema.org/understanding-json-schema/index.html) of the custom type.
-
-```python
- @classmethod
- def json_schema(cls: Type["IntNonNegative"]) -> Dict[Any, Any]:
- return {"type": "integer", "minimum": 0}
-```
-
-
-## Parsing JSON/YAML
-
-For example, YAML data for `ComplexSchema` can look like this.
-Words in multi-word names are separated by hyphen `-` (e.g. `simple-schema`).
-
-```yaml
-# data.yaml
-union: here could also be a number
-list: [1,2,3,]
-dictionary:
- key": false
-simple-schema:
- integer: 55
- string: this is string
- boolean: false
-```
-
-To parse data from YAML format just use `parse_yaml` function or `parse_json` for JSON format.
-Parsed data are stored in a dict-like object that takes care of `-`/`_` conversion.
-
-```python
-from .modeling import parse_yaml
-
-# read data from file
-with open("data.yaml") as f:
- str_data = f.read()
-
-dict_data = parse_yaml(str_data)
-validated_data = ComplexSchema(dict_data)
-```
\ No newline at end of file
+++ /dev/null
-from .base_generic_type_wrapper import BaseGenericTypeWrapper
-from .base_schema import BaseSchema, ConfigSchema
-from .base_value_type import BaseValueType
-from .parsing import parse_json, parse_yaml, try_to_parse
-
-__all__ = [
- "BaseGenericTypeWrapper",
- "BaseValueType",
- "BaseSchema",
- "ConfigSchema",
- "parse_yaml",
- "parse_json",
- "try_to_parse",
-]
+++ /dev/null
-from typing import Generic, TypeVar
-
-from .base_value_type import BaseTypeABC
-
-T = TypeVar("T")
-
-
-class BaseGenericTypeWrapper(Generic[T], BaseTypeABC): # pylint: disable=abstract-method
- pass
+++ /dev/null
-import enum
-import inspect
-import sys
-from abc import ABC, abstractmethod # pylint: disable=[no-name-in-module]
-from typing import Any, Callable, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar, Union, cast
-
-import yaml
-
-from knot_resolver.utils.functional import all_matches
-
-from .base_generic_type_wrapper import BaseGenericTypeWrapper
-from .base_value_type import BaseValueType
-from .exceptions import AggregateDataValidationError, DataDescriptionError, DataValidationError
-from .renaming import Renamed, renamed
-from .types import (
- get_annotations,
- get_generic_type_argument,
- get_generic_type_arguments,
- get_generic_type_wrapper_argument,
- get_optional_inner_type,
- is_dict,
- is_enum,
- is_generic_type_wrapper,
- is_internal_field_name,
- is_list,
- is_literal,
- is_none_type,
- is_optional,
- is_tuple,
- is_union,
-)
-
-T = TypeVar("T")
-
-
-def is_obj_type(obj: Any, types: Union[type, Tuple[Any, ...], Tuple[type, ...]]) -> bool:
- # To check specific type we are using 'type()' instead of 'isinstance()'
- # because for example 'bool' is instance of 'int', 'isinstance(False, int)' returns True.
- # pylint: disable=unidiomatic-typecheck
- if isinstance(types, tuple):
- return type(obj) in types
- return type(obj) is types
-
-
-class Serializable(ABC):
- """An interface for making classes serializable to a dictionary (and in turn into a JSON)."""
-
- @abstractmethod
- def to_dict(self) -> Dict[Any, Any]:
- raise NotImplementedError(f"...for class {self.__class__.__name__}")
-
- @staticmethod
- def is_serializable(typ: Type[Any]) -> bool:
- return (
- typ in {str, bool, int, float}
- or is_none_type(typ)
- or is_literal(typ)
- or is_dict(typ)
- or is_list(typ)
- or is_generic_type_wrapper(typ)
- or (inspect.isclass(typ) and issubclass(typ, Serializable))
- or (inspect.isclass(typ) and issubclass(typ, BaseValueType))
- or (inspect.isclass(typ) and issubclass(typ, BaseSchema))
- or (is_optional(typ) and Serializable.is_serializable(get_optional_inner_type(typ)))
- or (is_union(typ) and all_matches(Serializable.is_serializable, get_generic_type_arguments(typ)))
- )
-
- @staticmethod
- def serialize(obj: Any) -> Any:
- if isinstance(obj, Serializable):
- return obj.to_dict()
-
- if isinstance(obj, (BaseValueType, BaseGenericTypeWrapper)):
- o = obj.serialize()
- # if Serializable.is_serializable(o):
- return Serializable.serialize(o)
- # return o
-
- if isinstance(obj, list):
- res: List[Any] = [Serializable.serialize(i) for i in cast(List[Any], obj)]
- return res
-
- return obj
-
-
-class _LazyDefault(Generic[T], Serializable):
- """
- Wrapper for default values BaseSchema classes.
-
- Defers the default instantiation until the schema
- itself is being instantiated
- """
-
- def __init__(self, constructor: Callable[..., T], *args: Any, **kwargs: Any) -> None:
- # pylint: disable=[super-init-not-called]
- self._func = constructor
- self._args = args
- self._kwargs = kwargs
-
- def instantiate(self) -> T:
- return self._func(*self._args, **self._kwargs)
-
- def to_dict(self) -> Dict[Any, Any]:
- return Serializable.serialize(self.instantiate())
-
-
-def lazy_default(constructor: Callable[..., T], *args: Any, **kwargs: Any) -> T:
- """We use a factory function because you can't lie about the return type in `__new__`."""
- return _LazyDefault(constructor, *args, **kwargs) # type: ignore[return-value]
-
-
-def _split_docstring(docstring: str) -> Tuple[str, Optional[str]]:
- """Split docstring into description of the class and description of attributes."""
- if "---" not in docstring:
- return ("\n".join([s.strip() for s in docstring.splitlines()]).strip(), None)
-
- doc, attrs_doc = docstring.split("---", maxsplit=1)
- return (
- "\n".join([s.strip() for s in doc.splitlines()]).strip(),
- attrs_doc,
- )
-
-
-def _parse_attrs_docstrings(docstring: str) -> Optional[Dict[str, str]]:
- """Given a docstring of a BaseSchema, return a dict with descriptions of individual attributes."""
- _, attrs_doc = _split_docstring(docstring)
- if attrs_doc is None:
- return None
-
- # try to parse it as yaml:
- data = yaml.safe_load(attrs_doc)
- assert isinstance(data, dict), "Invalid format of attribute description"
- return cast(Dict[str, str], data)
-
-
-def _get_properties_schema(typ: Type[Any]) -> Dict[Any, Any]:
- schema: Dict[Any, Any] = {}
- annot = get_annotations(typ)
-
- docstring: str = typ.__dict__.get("__doc__", "") or ""
- attribute_documentation = _parse_attrs_docstrings(docstring)
- for field_name, python_type in annot.items():
- name = field_name.replace("_", "-")
- schema[name] = _describe_type(python_type)
-
- # description
- if attribute_documentation is not None:
- if field_name not in attribute_documentation:
- raise DataDescriptionError(f"The docstring does not describe field '{field_name}'", str(typ))
- schema[name]["description"] = attribute_documentation[field_name]
- del attribute_documentation[field_name]
-
- # default value
- if hasattr(typ, field_name):
- assert Serializable.is_serializable(
- python_type
- ), f"Type '{python_type}' does not appear to be JSON serializable"
- schema[name]["default"] = Serializable.serialize(getattr(typ, field_name))
-
- if attribute_documentation is not None and len(attribute_documentation) > 0:
- raise DataDescriptionError(
- f"The docstring describes attributes which are not present - {tuple(attribute_documentation.keys())}",
- str(typ),
- )
-
- return schema
-
-
-def _describe_type(typ: Type[Any]) -> Dict[Any, Any]: # noqa: C901, PLR0911, PLR0912
- # pylint: disable=too-many-branches
-
- if inspect.isclass(typ) and issubclass(typ, BaseSchema):
- return typ.json_schema(include_schema_definition=False)
-
- if inspect.isclass(typ) and issubclass(typ, BaseValueType):
- return typ.json_schema()
-
- if is_generic_type_wrapper(typ):
- wrapped = get_generic_type_wrapper_argument(typ)
- return _describe_type(wrapped)
-
- if is_none_type(typ):
- return {"type": "null"}
-
- if typ is int:
- return {"type": "integer"}
-
- if typ is bool:
- return {"type": "boolean"}
-
- if typ is str:
- return {"type": "string"}
-
- if is_literal(typ):
- lit: List[Any] = []
- args = get_generic_type_arguments(typ)
- for arg in args:
- if is_literal(arg):
- lit += get_generic_type_arguments(arg)
- else:
- lit.append(arg)
- return {"type": "string", "enum": lit}
-
- if is_optional(typ):
- desc = _describe_type(get_optional_inner_type(typ))
- if "type" in desc:
- desc["type"] = [desc["type"], "null"]
- return desc
- return {"anyOf": [{"type": "null"}, desc]}
-
- if is_union(typ):
- variants = get_generic_type_arguments(typ)
- return {"anyOf": [_describe_type(v) for v in variants]}
-
- if is_list(typ):
- return {"type": "array", "items": _describe_type(get_generic_type_argument(typ))}
-
- if is_dict(typ):
- key, val = get_generic_type_arguments(typ)
-
- if inspect.isclass(key) and issubclass(key, BaseValueType):
- assert (
- key.__str__ is not BaseValueType.__str__
- ), "To support derived 'BaseValueType', __str__ must be implemented."
- else:
- assert key is str, "We currently do not support any other keys then strings"
-
- return {"type": "object", "additionalProperties": _describe_type(val)}
-
- if inspect.isclass(typ) and issubclass(typ, enum.Enum): # same as our is_enum(typ), but inlined for type checker
- return {"type": "string", "enum": [str(v) for v in typ]}
-
- raise NotImplementedError(f"Trying to get JSON schema for type '{typ}', which is not implemented")
-
-
-TSource = Union[None, "BaseSchema", Dict[str, Any]]
-
-
-def _create_untouchable(name: str) -> object:
- class _Untouchable:
- def __getattribute__(self, item_name: str) -> Any:
- raise RuntimeError(f"You are not supposed to access object '{name}'.")
-
- def __setattr__(self, item_name: str, value: Any) -> None:
- raise RuntimeError(f"You are not supposed to access object '{name}'.")
-
- return _Untouchable()
-
-
-class ObjectMapper:
- def _create_tuple(self, tp: Type[Any], obj: Tuple[Any, ...], object_path: str) -> Tuple[Any, ...]:
- types = get_generic_type_arguments(tp)
- errs: List[DataValidationError] = []
- res: List[Any] = []
- for i, (t, val) in enumerate(zip(types, obj)):
- try:
- res.append(self.map_object(t, val, object_path=f"{object_path}[{i}]"))
- except DataValidationError as e:
- errs.append(e)
- if len(errs) == 1:
- raise errs[0]
- if len(errs) > 1:
- raise AggregateDataValidationError(object_path, child_exceptions=errs)
- return tuple(res)
-
- def _create_dict(self, tp: Type[Any], obj: Dict[Any, Any], object_path: str) -> Dict[Any, Any]:
- key_type, val_type = get_generic_type_arguments(tp)
- try:
- errs: List[DataValidationError] = []
- res: Dict[Any, Any] = {}
- for key, val in obj.items():
- try:
- nkey = self.map_object(key_type, key, object_path=f"{object_path}[{key}]")
- nval = self.map_object(val_type, val, object_path=f"{object_path}[{key}]")
- res[nkey] = nval
- except DataValidationError as e:
- errs.append(e)
- if len(errs) == 1:
- raise errs[0]
- if len(errs) > 1:
- raise AggregateDataValidationError(object_path, child_exceptions=errs)
- except AttributeError as e:
- raise DataValidationError(
- f"Expected dict-like object, but failed to access its .items() method. Value was {obj}", object_path
- ) from e
- else:
- return res
-
- def _create_list(self, tp: Type[Any], obj: List[Any], object_path: str) -> List[Any]:
- if isinstance(obj, str):
- raise DataValidationError("expected list, got string", object_path)
-
- inner_type = get_generic_type_argument(tp)
- errs: List[DataValidationError] = []
- res: List[Any] = []
-
- try:
- for i, val in enumerate(obj):
- res.append(self.map_object(inner_type, val, object_path=f"{object_path}[{i}]"))
- if len(res) == 0:
- raise DataValidationError("empty list is not allowed", object_path)
- except DataValidationError as e:
- errs.append(e)
- except TypeError as e:
- errs.append(DataValidationError(str(e), object_path))
-
- if len(errs) == 1:
- raise errs[0]
- if len(errs) > 1:
- raise AggregateDataValidationError(object_path, child_exceptions=errs)
- return res
-
- def _create_str(self, obj: Any, object_path: str) -> str:
- # we are willing to cast any primitive value to string, but no compound values are allowed
- if is_obj_type(obj, (str, float, int)) or isinstance(obj, BaseValueType):
- return str(obj)
- if is_obj_type(obj, bool):
- raise DataValidationError(
- "Expected str, found bool. Be careful, that YAML parsers consider even"
- ' "no" and "yes" as a bool. Search for the Norway Problem for more'
- " details. And please use quotes explicitly.",
- object_path,
- )
- raise DataValidationError(
- f"expected str (or number that would be cast to string), but found type {type(obj)}", object_path
- )
-
- def _create_int(self, obj: Any, object_path: str) -> int:
- # we don't want to make an int out of anything else than other int
- # except for BaseValueType class instances
- if is_obj_type(obj, int) or isinstance(obj, BaseValueType):
- return int(obj)
- raise DataValidationError(f"expected int, found {type(obj)}", object_path)
-
- def _create_union(self, tp: Type[T], obj: Any, object_path: str) -> T:
- variants = get_generic_type_arguments(tp)
- errs: List[DataValidationError] = []
- for v in variants:
- try:
- return self.map_object(v, obj, object_path=object_path)
- except DataValidationError as e:
- errs.append(e)
-
- raise DataValidationError("could not parse any of the possible variants", object_path, child_exceptions=errs)
-
- def _create_optional(self, tp: Type[Optional[T]], obj: Any, object_path: str) -> Optional[T]:
- inner: Type[Any] = get_optional_inner_type(tp)
- if obj is None:
- return None
- return self.map_object(inner, obj, object_path=object_path)
-
- def _create_bool(self, obj: Any, object_path: str) -> bool:
- if is_obj_type(obj, bool):
- return obj
- raise DataValidationError(f"expected bool, found {type(obj)}", object_path)
-
- def _create_literal(self, tp: Type[Any], obj: Any, object_path: str) -> Any:
- args = get_generic_type_arguments(tp)
-
- expected = []
- if sys.version_info < (3, 9):
- for arg in args:
- if is_literal(arg):
- expected += get_generic_type_arguments(arg)
- else:
- expected.append(arg)
- else:
- expected = args
-
- if obj in expected:
- return obj
- raise DataValidationError(f"'{obj}' does not match any of the expected values {expected}", object_path)
-
- def _create_base_schema_object(self, tp: Type[Any], obj: Any, object_path: str) -> "BaseSchema":
- if isinstance(obj, (dict, BaseSchema)):
- return tp(obj, object_path=object_path)
- raise DataValidationError(f"expected 'dict' or 'NoRenameBaseSchema' object, found '{type(obj)}'", object_path)
-
- def create_value_type_object(self, tp: Type[Any], obj: Any, object_path: str) -> "BaseValueType":
- if isinstance(obj, tp):
- # if we already have a custom value type, just pass it through
- return obj
- # no validation performed, the implementation does it in the constuctor
- try:
- return tp(obj, object_path=object_path)
- except ValueError as e:
- if len(e.args) > 0 and isinstance(e.args[0], str):
- msg = e.args[0]
- else:
- msg = f"Failed to validate value against {tp} type"
- raise DataValidationError(msg, object_path) from e
-
- def _create_default(self, obj: Any) -> Any:
- if isinstance(obj, _LazyDefault):
- return obj.instantiate()
- return obj
-
- def map_object( # noqa: C901, PLR0911, PLR0912
- self,
- tp: Type[Any],
- obj: Any,
- default: Any = ...,
- use_default: bool = False,
- object_path: str = "/",
- ) -> Any:
- """
- Given an expected type `cls` and a value object `obj`.
-
- Return a new object of the given type and map fields of `obj` into it.
- During the mapping procedure, runtime type checking is performed.
- """
- # Disabling these checks, because I think it's much more readable as a single function
- # and it's not that large at this point. If it got larger, then we should definitely split it
- # pylint: disable=too-many-branches,too-many-locals,too-many-statements
-
- # default values
- if obj is None and use_default:
- return self._create_default(default)
-
- # NoneType
- if is_none_type(tp):
- if obj is None:
- return None
- raise DataValidationError(f"expected None, found '{obj}'.", object_path)
-
- # Optional[T] (could be technically handled by Union[*variants], but this way we have better error reporting)
- if is_optional(tp):
- return self._create_optional(tp, obj, object_path)
-
- # Union[*variants]
- if is_union(tp):
- return self._create_union(tp, obj, object_path)
-
- # after this, there is no place for a None object
- if obj is None:
- raise DataValidationError(f"unexpected value 'None' for type {tp}", object_path)
-
- # int
- if tp is int:
- return self._create_int(obj, object_path)
-
- # str
- if tp is str:
- return self._create_str(obj, object_path)
-
- # bool
- if tp is bool:
- return self._create_bool(obj, object_path)
-
- # float
- if tp is float:
- raise NotImplementedError(
- "Floating point values are not supported in the object mapper."
- " Please implement them and be careful with type coercions"
- )
-
- # Literal[T]
- if is_literal(tp):
- return self._create_literal(tp, obj, object_path)
-
- # Dict[K,V]
- if is_dict(tp):
- return self._create_dict(tp, obj, object_path)
-
- # any Enums (probably used only internally in DataValidator)
- if is_enum(tp):
- if isinstance(obj, tp):
- return obj
- raise DataValidationError(f"unexpected value '{obj}' for enum '{tp}'", object_path)
-
- # List[T]
- if is_list(tp):
- return self._create_list(tp, obj, object_path)
-
- # Tuple[A,B,C,D,...]
- if is_tuple(tp):
- return self._create_tuple(tp, obj, object_path)
-
- # type of obj and cls type match
- if is_obj_type(obj, tp):
- return obj
-
- # when the specified type is Any, just return the given value
- # on mypy version 1.11.0 comparison-overlap error started popping up
- # https://github.com/python/mypy/issues/17665
- if tp == Any: # type: ignore[comparison-overlap]
- return obj
-
- # BaseValueType subclasses
- if inspect.isclass(tp) and issubclass(tp, BaseValueType):
- return self.create_value_type_object(tp, obj, object_path)
-
- # BaseGenericTypeWrapper subclasses
- if is_generic_type_wrapper(tp):
- inner_type = get_generic_type_wrapper_argument(tp)
- obj_valid = self.map_object(inner_type, obj, object_path)
- return tp(obj_valid, object_path=object_path)
-
- # nested BaseSchema subclasses
- if inspect.isclass(tp) and issubclass(tp, BaseSchema):
- return self._create_base_schema_object(tp, obj, object_path)
-
- # if the object matches, just pass it through
- if inspect.isclass(tp) and isinstance(obj, tp):
- return obj
-
- # default error handler
- raise DataValidationError(
- f"Type {tp} cannot be parsed. This is a implementation error. "
- "Please fix your types in the class or improve the parser/validator.",
- object_path,
- )
-
- def is_obj_type_valid(self, obj: Any, tp: Type[Any]) -> bool:
- """Runtime type checking. Validate, that a given object is of a given type."""
- try:
- self.map_object(tp, obj)
- except (DataValidationError, ValueError):
- return False
- else:
- return True
-
- def _assign_default(self, obj: Any, name: str, python_type: Any, object_path: str) -> None:
- cls = obj.__class__
-
- try:
- default = self._create_default(getattr(cls, name, None))
- except ValueError as e:
- raise DataValidationError(str(e), f"{object_path}/{name}") from e
-
- value = self.map_object(python_type, default, object_path=f"{object_path}/{name}")
- setattr(obj, name, value)
-
- def _assign_field(self, obj: Any, name: str, python_type: Any, value: Any, object_path: str) -> None:
- value = self.map_object(python_type, value, object_path=f"{object_path}/{name}")
- setattr(obj, name, value)
-
- def _assign_fields(self, obj: Any, source: Union[Dict[str, Any], "BaseSchema", None], object_path: str) -> Set[str]: # noqa: C901
- """
- Assign fields and values.
-
- Order of assignment:
- 1. all direct assignments
- 2. assignments with conversion method
- """
- cls = obj.__class__
- annot = get_annotations(cls)
- errs: List[DataValidationError] = []
-
- used_keys: Set[str] = set()
- for name, python_type in annot.items():
- try:
- if is_internal_field_name(name):
- continue
-
- # populate field
- if source is None:
- self._assign_default(obj, name, python_type, object_path)
-
- # check for invalid configuration with both transformation function and default value
- elif hasattr(obj, f"_{name}") and hasattr(obj, name):
- raise RuntimeError(
- f"Field '{obj.__class__.__name__}.{name}' has default value and transformation function at"
- " the same time. That is now allowed. Store the default in the transformation function."
- )
-
- # there is a transformation function to create the value
- elif hasattr(obj, f"_{name}") and callable(getattr(obj, f"_{name}")):
- val = self._get_converted_value(obj, name, source, object_path)
- self._assign_field(obj, name, python_type, val, object_path)
- used_keys.add(name)
-
- # source just contains the value
- elif name in source:
- val = source[name]
- self._assign_field(obj, name, python_type, val, object_path)
- used_keys.add(name)
-
- # there is a default value, or the type is optional => store the default or null
- elif hasattr(obj, name) or is_optional(python_type):
- self._assign_default(obj, name, python_type, object_path)
-
- # we expected a value but it was not there
- else:
- errs.append(DataValidationError(f"missing attribute '{name}'.", object_path))
- except DataValidationError as e:
- errs.append(e)
-
- if len(errs) == 1:
- raise errs[0]
- if len(errs) > 1:
- raise AggregateDataValidationError(object_path, errs)
- return used_keys
-
- def _get_converted_value(self, obj: Any, key: str, source: TSource, object_path: str) -> Any:
- """Get a value of a field by invoking appropriate transformation function."""
- try:
- func = getattr(obj.__class__, f"_{key}")
- argc = len(inspect.signature(func).parameters)
- if argc == 1:
- # it is a static method
- return func(source)
- if argc == 2:
- # it is a instance method
- return func(_create_untouchable("obj"), source)
- raise RuntimeError("Transformation function has wrong number of arguments")
- except ValueError as e:
- msg = e.args[0] if len(e.args) > 0 and isinstance(e.args[0], str) else "Failed to validate value type"
- raise DataValidationError(msg, object_path) from e
-
- def object_constructor(self, obj: Any, source: Union["BaseSchema", Dict[Any, Any]], object_path: str) -> None:
- """
- Construct object. Delegated constructor for the NoRenameBaseSchema class.
-
- The reason this method is delegated to the mapper is due to renaming. Like this, we don't have to
- worry about a different BaseSchema class, when we want to have dynamically renamed fields.
- """
- # As this is a delegated constructor, we must ignore protected access warnings
-
- # sanity check
- if not isinstance(source, (BaseSchema, dict)):
- raise DataValidationError(f"expected dict-like object, found '{type(source)}'", object_path)
-
- # construct lower level schema first if configured to do so
- if obj._LAYER is not None: # noqa: SLF001
- source = obj._LAYER(source, object_path=object_path) # pylint: disable=not-callable # noqa: SLF001
-
- # assign fields
- used_keys = self._assign_fields(obj, source, object_path)
-
- # check for unused keys in the source object
- if source and not isinstance(source, BaseSchema):
- unused = source.keys() - used_keys
- if len(unused) > 0:
- keys = ", ".join(f"'{u}'" for u in unused)
- raise DataValidationError(
- f"unexpected extra key(s) {keys}",
- object_path,
- )
-
- # validate the constructed value
- try:
- obj._validate() # noqa: SLF001
- except ValueError as e:
- raise DataValidationError(e.args[0] if len(e.args) > 0 else "Validation error", object_path or "/") from e
-
-
-class BaseSchema(Serializable):
- """
- Base class for modeling configuration schema.
-
- It somewhat resembles standard dataclasses with additional functionality:
-
- * type validation
- * data conversion
-
- To create an instance of this class, you have to provide source data in the form of dict-like object.
- Generally, raw dict or another `BaseSchema` instance. The provided data object is traversed, transformed
- and validated before assigned to the appropriate fields (attributes).
-
- Fields (attributes)
- ===================
-
- The fields (or attributes) of the class are defined the same way as in a dataclass by creating a class-level
- type-annotated fields. An example of that is:
-
- class A(BaseSchema):
- awesome_number: int
-
- If your `BaseSchema` instance has a field with type of a BaseSchema, its value is recursively created
- from the nested input data. This way, you can specify a complex tree of BaseSchema's and use the root
- BaseSchema to create instance of everything.
-
- Transformation
- ==============
-
- You can provide the BaseSchema class with a field and a function with the same name, but starting with
- underscore ('_'). For example, you could have field called `awesome_number` and function called
- `_awesome_number(self, source)`. The function takes one argument - the source data (optionally with self,
- but you are not supposed to touch that). It can read any data from the source object and return a value of
- an appropriate type, which will be assigned to the field `awesome_number`. If you want to report an error
- during validation, raise a `ValueError` exception.
-
- Using this, you can convert any input values into any type and field you want. To make the conversion easier
- to write, you could also specify a special class variable called `_LAYER` pointing to another
- BaseSchema class. This causes the source object to be first parsed as the specified additional layer of BaseSchema
- and after that used a source for this class. This therefore allows nesting of transformation functions.
-
- Validation
- ==========
-
- All assignments to fields during object construction are checked at runtime for proper types. This means,
- you are free to use an untrusted source object and turn it into a data structure, where you are sure what
- is what.
-
- You can also define a `_validate` method, which will be called once the whole data structure is built. You
- can validate the data in there and raise a `ValueError`, if they are invalid.
-
- Default values
- ==============
-
- If you create a field with a value, it will be used as a default value whenever the data in source object
- are not present. As a special case, default value for Optional type is None if not specified otherwise. You
- are not allowed to have a field with a default value and a transformation function at once.
-
- Example:
- =======
-
- See tests/utils/test_modelling.py for example usage.
-
- """
-
- _LAYER: Optional[Type["BaseSchema"]] = None
- _MAPPER: ObjectMapper = ObjectMapper()
-
- def __init__(self, source: TSource = None, object_path: str = "") -> None: # pylint: disable=[super-init-not-called]
- # save source data (and drop information about nullness)
- source = source or {}
- self.__source: Union[Dict[str, Any], BaseSchema] = source
-
- # delegate the rest of the constructor
- self._MAPPER.object_constructor(self, source, object_path)
-
- def get_unparsed_data(self) -> Dict[str, Any]:
- if isinstance(self.__source, BaseSchema):
- return self.__source.get_unparsed_data()
- if isinstance(self.__source, Renamed):
- return self.__source.original()
- return self.__source
-
- def __getitem__(self, key: str) -> Any:
- if not hasattr(self, key):
- raise RuntimeError(f"Object '{self}' of type '{type(self)}' does not have field named '{key}'")
- return getattr(self, key)
-
- def __contains__(self, item: Any) -> bool:
- return hasattr(self, item)
-
- def _validate(self) -> None:
- """
- Additional validation procedure called after all field are assigned.
-
- Should throw a ValueError in case of failure.
- """
-
- def __eq__(self, o: object) -> bool:
- cls = self.__class__
- if not isinstance(o, cls):
- return False
-
- annot = get_annotations(cls)
- return all(getattr(self, name) == getattr(o, name) for name in annot)
-
- @classmethod
- def json_schema(
- cls: Type["BaseSchema"],
- schema_id: Optional[str] = None,
- title: Optional[str] = None,
- description: Optional[str] = None,
- include_schema_definition: bool = True,
- ) -> Dict[Any, Any]:
- if cls._LAYER is not None:
- return cls._LAYER.json_schema(
- schema_id=schema_id,
- title=title,
- description=description,
- include_schema_definition=include_schema_definition,
- )
-
- schema: Dict[Any, Any] = {}
- if include_schema_definition:
- schema["$schema"] = "https://json-schema.org/draft/2020-12/schema"
- if schema_id is not None:
- schema["$id"] = schema_id
- if title is not None:
- schema["title"] = title
- if description is not None:
- schema["description"] = description
- elif cls.__doc__ is not None:
- schema["description"] = _split_docstring(cls.__doc__)[0]
- schema["type"] = "object"
- schema["properties"] = _get_properties_schema(cls)
-
- return schema
-
- def to_dict(self) -> Dict[Any, Any]:
- res: Dict[Any, Any] = {}
- cls = self.__class__
- annot = get_annotations(cls)
-
- for name in annot:
- res[name] = Serializable.serialize(getattr(self, name))
- return res
-
-
-class RenamingObjectMapper(ObjectMapper):
- """
- Same as object mapper, but it uses collection wrappers from the module `renamed` to perform dynamic field renaming.
-
- More specifically:
- - it renames all properties in (nested) objects
- - it does not rename keys in dictionaries
- """
-
- def _create_dict(self, tp: Type[Any], obj: Dict[Any, Any], object_path: str) -> Dict[Any, Any]:
- if isinstance(obj, Renamed):
- obj = obj.original()
- return super()._create_dict(tp, obj, object_path)
-
- def _create_base_schema_object(self, tp: Type[Any], obj: Any, object_path: str) -> "BaseSchema":
- if isinstance(obj, dict):
- obj = renamed(obj)
- return super()._create_base_schema_object(tp, obj, object_path)
-
- def object_constructor(self, obj: Any, source: Union["BaseSchema", Dict[Any, Any]], object_path: str) -> None:
- if isinstance(source, dict):
- source = renamed(source)
- return super().object_constructor(obj, source, object_path)
-
-
-# export as a standalone functions for simplicity compatibility
-is_obj_type_valid = ObjectMapper().is_obj_type_valid
-map_object = ObjectMapper().map_object
-
-
-class ConfigSchema(BaseSchema):
- """Same as BaseSchema, but maps with RenamingObjectMapper."""
-
- _MAPPER: ObjectMapper = RenamingObjectMapper()
+++ /dev/null
-from abc import ABC, abstractmethod # pylint: disable=[no-name-in-module]
-from typing import Any, Dict, Type
-
-
-class BaseTypeABC(ABC):
- @abstractmethod
- def __init__(self, source_value: Any, object_path: str = "/") -> None:
- pass
-
- @abstractmethod
- def __int__(self) -> int:
- raise NotImplementedError(f" return 'int()' value for {type(self).__name__} is not implemented.")
-
- @abstractmethod
- def __str__(self) -> str:
- raise NotImplementedError(f"return 'str()' value for {type(self).__name__} is not implemented.")
-
- @abstractmethod
- def serialize(self) -> Any:
- """
- Dump configuration to JSON-serializable object.
-
- Returns a JSON-serializable object from which the object
- can be recreated again using the constructor.
-
- It's not necessary to return the same structure that was given as an input. It only has
- to be the same semantically.
- """
- raise NotImplementedError(f"{type(self).__name__}'s' 'serialize()' not implemented.")
-
-
-class BaseValueType(BaseTypeABC):
- """
- Subclasses of this class can be used as type annotations in 'DataParser'.
-
- When a value is being parsed from a serialized format (e.g. JSON/YAML),
- an object will be created by calling the constructor of the appropriate
- type on the field value. The only limitation is that the value MUST NOT be `None`.
-
- There is no validation done on the wrapped value. The only condition is that
- it can't be `None`. If you want to perform any validation during creation,
- raise a `ValueError` in case of errors.
- """
-
- @classmethod
- @abstractmethod
- def json_schema(cls: Type["BaseValueType"]) -> Dict[Any, Any]:
- raise NotImplementedError()
+++ /dev/null
-from typing import Iterable, Iterator
-
-from knot_resolver import KresBaseError
-
-
-class ModelingBaseError(KresBaseError):
- """Base class for all errors used in data modeling."""
-
-
-class DataDescriptionError(ModelingBaseError):
- """Class for errors that are raised when checking data description."""
-
-
-class DataParsingError(ModelingBaseError):
- """Class for errors that are raised when parsing data."""
-
-
-class DataValidationError(ModelingBaseError):
- """Class for errors that are raised when validating data."""
-
- def __init__(self, msg: str, tree_path: str, child_exceptions: Iterable["DataValidationError"] = ()) -> None:
- super().__init__(msg)
- self._tree_path = tree_path.replace("_", "-")
- self._child_exceptions = child_exceptions
-
- def where(self) -> str:
- return self._tree_path
-
- def msg(self) -> str:
- return f"[{self.where()}] {super().__str__()}"
-
- def recursive_msg(self, indentation_level: int = 0) -> str:
- def indented_lines(level: int) -> Iterator[str]:
- if level == 0:
- yield "Configuration validation error detected:"
- level += 1
-
- indent = "\t" * level
- yield f"{indent}{self.msg()}"
-
- for child in self._child_exceptions:
- yield from child.recursive_msg(level + 1).split("\n")
-
- return "\n".join(indented_lines(indentation_level))
-
- def __str__(self) -> str:
- return self.recursive_msg()
-
-
-class AggregateDataValidationError(DataValidationError):
- """Aggregation class for errors (DataValidationError) raised during data validation."""
-
- def __init__(self, object_path: str, child_exceptions: Iterable[DataValidationError]) -> None:
- super().__init__("error due to lower level exceptions", object_path, child_exceptions)
-
- def recursive_msg(self, indentation_level: int = 0) -> str:
- def indented_lines(level: int) -> Iterator[str]:
- inc = 0
- if level == 0:
- yield "Configuration validation errors detected:"
- inc = 1
-
- for child in self._child_exceptions:
- yield from child.recursive_msg(level + inc).split("\n")
-
- return "\n".join(indented_lines(indentation_level))
+++ /dev/null
-"""Implements JSON pointer resolution based on RFC 6901: https://www.rfc-editor.org/rfc/rfc6901."""
-
-from typing import Any, Optional, Tuple, Union
-
-JSONPtrAddressable = Any
-
-
-class _JSONPtr:
- @staticmethod
- def _decode_token(token: str) -> str:
- """Resolve escaped characters ~ and /."""
- # the order of the replace statements is important, do not change without
- # consulting the RFC
- return token.replace("~1", "/").replace("~0", "~")
-
- @staticmethod
- def _encode_token(token: str) -> str:
- return token.replace("~", "~0").replace("/", "~1")
-
- def __init__(self, ptr: str) -> None:
- if ptr == "":
- # pointer to the root
- self.tokens = []
-
- else:
- if ptr[0] != "/":
- raise SyntaxError(
- f"JSON pointer '{ptr}' invalid: the first character MUST be '/' or the pointer must be empty"
- )
-
- ptr = ptr[1:]
- self.tokens = [_JSONPtr._decode_token(tok) for tok in ptr.split("/")]
-
- def resolve(
- self, obj: JSONPtrAddressable
- ) -> Tuple[Optional[JSONPtrAddressable], JSONPtrAddressable, Union[str, int, None]]:
- parent: Optional[JSONPtrAddressable] = None
- current = obj
- current_ptr = ""
- token: Union[int, str, None] = None
-
- for token in self.tokens:
- if current is None:
- raise ValueError(
- f"JSON pointer cannot reference nested non-existent object: object at ptr '{current_ptr}'"
- f" already points to None, cannot nest deeper with token '{token}'"
- )
-
- if isinstance(current, (bool, int, float, str)):
- raise ValueError(f"object at '{current_ptr}' is a scalar, JSON pointer cannot point into it")
-
- parent = current
- if isinstance(current, list):
- if token == "-":
- current = None
- else:
- try:
- token_num = int(token)
- current = current[token_num]
- except ValueError as e:
- raise ValueError(
- f"invalid JSON pointer: list '{current_ptr}' require numbers as keys, instead got '{token}'"
- ) from e
-
- elif isinstance(current, dict):
- current = current.get(token, None)
-
- current_ptr += f"/{token}"
-
- return parent, current, token
-
-
-def json_ptr_resolve(
- obj: JSONPtrAddressable,
- ptr: str,
-) -> Tuple[Optional[JSONPtrAddressable], Optional[JSONPtrAddressable], Union[str, int, None]]:
- return _JSONPtr(ptr).resolve(obj)
+++ /dev/null
-import json
-from enum import Enum, auto
-from typing import Any, Dict, List, Optional, Tuple, Union
-
-import yaml
-from yaml.constructor import ConstructorError
-from yaml.nodes import MappingNode
-
-from .exceptions import DataParsingError, DataValidationError
-from .renaming import Renamed, renamed
-
-
-# custom hook for 'json.loads()' to detect duplicate keys in data
-# source: https://stackoverflow.com/q/14902299/12858520
-def _json_raise_duplicates(pairs: List[Tuple[Any, Any]]) -> Optional[Any]:
- dict_out: Dict[Any, Any] = {}
- for key, val in pairs:
- if key in dict_out:
- raise DataParsingError(f"Duplicate attribute key detected: {key}")
- dict_out[key] = val
- return dict_out
-
-
-# custom loader for 'yaml.load()' to detect duplicate keys in data
-# source: https://gist.github.com/pypt/94d747fe5180851196eb
-class _RaiseDuplicatesLoader(yaml.SafeLoader):
- def construct_mapping(self, node: Union[MappingNode, Any], deep: bool = False) -> Dict[Any, Any]:
- if not isinstance(node, MappingNode):
- raise ConstructorError(None, None, f"expected a mapping node, but found {node.id}", node.start_mark)
- mapping: Dict[Any, Any] = {}
- for key_node, value_node in node.value:
- key = self.construct_object(key_node, deep=deep)
- # we need to check, that the key object can be used in a hash table
- try:
- _ = hash(key)
- except TypeError as exc:
- raise ConstructorError(
- "while constructing a mapping",
- node.start_mark,
- f"found unacceptable key ({exc})",
- key_node.start_mark,
- ) from exc
-
- # check for duplicate keys
- if key in mapping:
- raise DataParsingError(f"duplicate key detected: {key_node.start_mark}")
- value = self.construct_object(value_node, deep=deep)
- mapping[key] = value
- return mapping
-
-
-class DataFormat(Enum):
- YAML = auto()
- JSON = auto()
-
- def parse_to_dict(self, text: str) -> Any:
- if self is DataFormat.YAML:
- # RaiseDuplicatesLoader extends yaml.SafeLoader, so this should be safe
- # https://python.land/data-processing/python-yaml#PyYAML_safe_load_vs_load
- return renamed(yaml.load(text, Loader=_RaiseDuplicatesLoader)) # noqa: S506
- if self is DataFormat.JSON:
- return renamed(json.loads(text, object_pairs_hook=_json_raise_duplicates))
- raise NotImplementedError(f"Parsing of format '{self}' is not implemented")
-
- def dict_dump(self, data: Union[Dict[str, Any], Renamed], indent: Optional[int] = None) -> str:
- if isinstance(data, Renamed):
- data = data.original()
-
- if self is DataFormat.YAML:
- return yaml.safe_dump(data, indent=indent)
- if self is DataFormat.JSON:
- return json.dumps(data, indent=indent)
- raise NotImplementedError(f"Exporting to '{self}' format is not implemented")
-
-
-def parse_yaml(data: str) -> Any:
- return DataFormat.YAML.parse_to_dict(data)
-
-
-def parse_json(data: str) -> Any:
- return DataFormat.JSON.parse_to_dict(data)
-
-
-def try_to_parse(data: str) -> Any:
- """Attempt to parse the data as a JSON or YAML string."""
- try:
- return parse_json(data)
- except json.JSONDecodeError as je:
- try:
- return parse_yaml(data)
- except yaml.YAMLError as ye:
- # We do not raise-from here because there are two possible causes
- # and we may not know which one is the actual one.
- raise DataParsingError( # pylint: disable=raise-missing-from
- f"failed to parse data, JSON: {je}, YAML: {ye}"
- ) from ye
-
-
-def data_combine(data: Dict[Any, Any], additional_data: Dict[Any, Any], object_path: str = "") -> Dict[Any, Any]:
- """Combine dictionaries data."""
- for key in additional_data:
- if key in data:
- # if both are dictionaries we can try to combine them deeper
- if isinstance(data[key], (Dict, dict)) and isinstance(additional_data[key], (Dict, dict)):
- data[key] = data_combine(data[key], additional_data[key], f"{object_path}/{key}").copy()
- continue
- # otherwise we cannot combine them
- raise DataValidationError(f"duplicity key '{key}' with value in data", object_path)
- val = additional_data[key]
- data[key] = val.copy() if hasattr(val, "copy") else val
- return data
+++ /dev/null
-import copy
-from abc import ABC, abstractmethod # pylint: disable=[no-name-in-module]
-from typing import Any, List, Literal, Optional, Tuple, Union
-
-from knot_resolver.utils.modeling.base_schema import BaseSchema, map_object
-from knot_resolver.utils.modeling.json_pointer import json_ptr_resolve
-
-
-class PatchError(Exception):
- pass
-
-
-class Op(BaseSchema, ABC):
- @abstractmethod
- def eval(self, fakeroot: Any) -> Any:
- """Modify the given fakeroot, returns a new one."""
-
- def _resolve_ptr(self, fakeroot: Any, ptr: str) -> Tuple[Any, Any, Union[str, int, None]]:
- # Lookup tree part based on the given JSON pointer
- parent, obj, token = json_ptr_resolve(fakeroot["root"], ptr)
-
- # the lookup was on pure data, wrap the results in QueryTree
- if parent is None:
- parent = fakeroot
- token = "root"
-
- assert token is not None
-
- return parent, obj, token
-
-
-class AddOp(Op):
- op: Literal["add"]
- path: str
- value: Any
-
- def eval(self, fakeroot: Any) -> Any:
- parent, _obj, token = self._resolve_ptr(fakeroot, self.path)
-
- if isinstance(parent, dict):
- parent[token] = self.value
- elif isinstance(parent, list):
- if token == "-":
- parent.append(self.value)
- else:
- assert isinstance(token, int)
- parent.insert(token, self.value)
- else:
- raise AssertionError("never happens")
-
- return fakeroot
-
-
-class RemoveOp(Op):
- op: Literal["remove"]
- path: str
-
- def eval(self, fakeroot: Any) -> Any:
- parent, _obj, token = self._resolve_ptr(fakeroot, self.path)
- del parent[token]
- return fakeroot
-
-
-class ReplaceOp(Op):
- op: Literal["replace"]
- path: str
- value: str
-
- def eval(self, fakeroot: Any) -> Any:
- parent, obj, token = self._resolve_ptr(fakeroot, self.path)
-
- if obj is None:
- raise PatchError("the value you are trying to replace is null")
- parent[token] = self.value
- return fakeroot
-
-
-class MoveOp(Op):
- op: Literal["move"]
- source: str
- path: str
-
- def _source(self, source: Any) -> Any:
- if "from" not in source:
- raise ValueError("missing property 'from' in 'move' JSON patch operation")
- return str(source["from"])
-
- def eval(self, fakeroot: Any) -> Any:
- if self.path.startswith(self.source):
- raise PatchError("can't move value into itself")
-
- _parent, obj, _token = self._resolve_ptr(fakeroot, self.source)
- newobj = copy.deepcopy(obj)
-
- fakeroot = RemoveOp({"op": "remove", "path": self.source}).eval(fakeroot)
- return AddOp({"path": self.path, "value": newobj, "op": "add"}).eval(fakeroot)
-
-
-class CopyOp(Op):
- op: Literal["copy"]
- source: str
- path: str
-
- def _source(self, source: Any) -> Any:
- if "from" not in source:
- raise ValueError("missing property 'from' in 'copy' JSON patch operation")
- return str(source["from"])
-
- def eval(self, fakeroot: Any) -> Any:
- _parent, obj, _token = self._resolve_ptr(fakeroot, self.source)
- newobj = copy.deepcopy(obj)
-
- return AddOp({"path": self.path, "value": newobj, "op": "add"}).eval(fakeroot)
-
-
-class TestOp(Op):
- op: Literal["test"]
- path: str
- value: Any
-
- def eval(self, fakeroot: Any) -> Any:
- _parent, obj, _token = self._resolve_ptr(fakeroot, self.path)
-
- if obj != self.value:
- raise PatchError("test failed")
-
- return fakeroot
-
-
-def query(
- original: Any, method: Literal["get", "delete", "put", "patch"], ptr: str, payload: Any
-) -> Tuple[Any, Optional[Any]]:
- ########################################
- # Prepare data we will be working on
-
- # First of all, we consider the original data to be immutable. So we need to make a copy
- # in order to freely mutate them
- dataroot = copy.deepcopy(original)
-
- # To simplify referencing the root, create a fake root node
- fakeroot = {"root": dataroot}
-
- #########################################
- # Handle the actual requested operation
-
- # get = return what the path selector picks
- if method == "get":
- parent, obj, token = json_ptr_resolve(fakeroot, f"/root{ptr}")
- return fakeroot["root"], obj
-
- if method == "delete":
- fakeroot = RemoveOp({"op": "remove", "path": ptr}).eval(fakeroot)
- return fakeroot["root"], None
-
- if method == "put":
- parent, obj, token = json_ptr_resolve(fakeroot, f"/root{ptr}")
- assert parent is not None # we know this due to the fakeroot
- if isinstance(parent, list) and token == "-":
- parent.append(payload)
- else:
- parent[token] = payload
- return fakeroot["root"], None
-
- if method == "patch":
- tp = List[Union[AddOp, RemoveOp, MoveOp, CopyOp, TestOp, ReplaceOp]]
- transaction: tp = map_object(tp, payload)
-
- for i, op in enumerate(transaction):
- try:
- fakeroot = op.eval(fakeroot)
- except PatchError as e:
- raise ValueError(f"json patch transaction failed on step {i}") from e
-
- return fakeroot["root"], None
-
- raise AssertionError("invalid operation, never happens")
+++ /dev/null
-"""
-Standard dict and list alternatives, which can dynamically rename its keys replacing `-` with `_`.
-
-They persist in nested data structures, meaning that if you try to obtain a dict from Renamed variant, you will actually
-get RenamedDict back instead.
-
-Usage:
-
-d = dict()
-l = list()
-
-rd = renamed(d)
-rl = renamed(l)
-
-assert isinstance(rd, Renamed) == True
-assert l = rl.original()
-"""
-
-from abc import ABC, abstractmethod # pylint: disable=[no-name-in-module]
-from typing import Any, Dict, List, TypeVar
-
-
-class Renamed(ABC):
- @abstractmethod
- def original(self) -> Any:
- """Return a data structure, which is the source without dynamic renaming."""
-
- @staticmethod
- def map_public_to_private(name: Any) -> Any:
- if isinstance(name, str):
- return name.replace("_", "-")
- return name
-
- @staticmethod
- def map_private_to_public(name: Any) -> Any:
- if isinstance(name, str):
- return name.replace("-", "_")
- return name
-
-
-K = TypeVar("K")
-V = TypeVar("V")
-
-
-class RenamedDict(Dict[K, V], Renamed):
- def keys(self) -> Any:
- keys = super().keys()
- return {Renamed.map_private_to_public(key) for key in keys}
-
- def __getitem__(self, key: K) -> V:
- key = Renamed.map_public_to_private(key)
- res = super().__getitem__(key)
- return renamed(res)
-
- def __setitem__(self, key: K, value: V) -> None:
- key = Renamed.map_public_to_private(key)
- return super().__setitem__(key, value)
-
- def __contains__(self, key: object) -> bool:
- key = Renamed.map_public_to_private(key)
- return super().__contains__(key)
-
- def items(self) -> Any:
- for k, v in super().items():
- yield Renamed.map_private_to_public(k), renamed(v)
-
- def original(self) -> Dict[K, V]:
- return dict(super().items())
-
-
-class RenamedList(List[V], Renamed):
- def __getitem__(self, key: Any) -> Any:
- res = super().__getitem__(key)
- return renamed(res)
-
- def original(self) -> Any:
- return list(super().__iter__())
-
-
-def renamed(obj: Any) -> Any:
- if isinstance(obj, dict):
- return RenamedDict(**obj)
- if isinstance(obj, list):
- return RenamedList(obj)
- return obj
-
-
-__all__ = ["renamed", "Renamed"]
+++ /dev/null
-# pylint: disable=comparison-with-callable
-
-
-import enum
-import inspect
-import sys
-from typing import Any, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union
-
-from .base_generic_type_wrapper import BaseGenericTypeWrapper
-
-NoneType = type(None)
-
-
-def get_annotations(obj: Any) -> Dict[str, Any]:
- if hasattr(inspect, "get_annotations"):
- return inspect.get_annotations(obj)
- # TODO(bump to py3.10): Safe to remove. This fallback exists for older versions
- return obj.__dict__.get("__annotations__", {})
-
-
-def is_optional(tp: Any) -> bool:
- origin = getattr(tp, "__origin__", None)
- args = get_generic_type_arguments(tp)
-
- return origin == Union and len(args) == 2 and args[1] == NoneType
-
-
-def is_dict(tp: Any) -> bool:
- return getattr(tp, "__origin__", None) in (Dict, dict)
-
-
-def is_enum(tp: Any) -> bool:
- return inspect.isclass(tp) and issubclass(tp, enum.Enum)
-
-
-def is_list(tp: Any) -> bool:
- return getattr(tp, "__origin__", None) in (List, list)
-
-
-def is_tuple(tp: Any) -> bool:
- return getattr(tp, "__origin__", None) in (Tuple, tuple)
-
-
-def is_union(tp: Any) -> bool:
- """Return true even for optional types, because they are just a Union[T, NoneType]."""
- return getattr(tp, "__origin__", None) == Union
-
-
-def is_literal(tp: Any) -> bool:
- if sys.version_info.minor == 6:
- return isinstance(tp, type(Literal))
- return getattr(tp, "__origin__", None) == Literal
-
-
-def is_generic_type_wrapper(tp: Any) -> bool:
- orig = getattr(tp, "__origin__", None)
- return inspect.isclass(orig) and issubclass(orig, BaseGenericTypeWrapper)
-
-
-def get_generic_type_arguments(tp: Any) -> List[Any]:
- default: List[Any] = []
- if sys.version_info.minor == 6 and is_literal(tp):
- return getattr(tp, "__values__")
- return getattr(tp, "__args__", default)
-
-
-def get_generic_type_argument(tp: Any) -> Any:
- """Same as function get_generic_type_arguments, but expects just one type argument.""" # noqa: D401
- args = get_generic_type_arguments(tp)
- assert len(args) == 1
- return args[0]
-
-
-def get_generic_type_wrapper_argument(tp: Type["BaseGenericTypeWrapper[Any]"]) -> Any:
- assert hasattr(tp, "__origin__")
- origin = getattr(tp, "__origin__")
-
- assert hasattr(origin, "__orig_bases__")
- orig_base: List[Any] = getattr(origin, "__orig_bases__", [])[0]
-
- arg = get_generic_type_argument(tp)
- return get_generic_type_argument(orig_base[arg])
-
-
-def is_none_type(tp: Any) -> bool:
- return tp is None or tp == NoneType
-
-
-def get_attr_type(obj: Any, attr_name: str) -> Any:
- assert hasattr(obj, attr_name)
- assert hasattr(obj, "__annotations__")
- annot = get_annotations(type(obj))
- assert attr_name in annot
- return annot[attr_name]
-
-
-T = TypeVar("T")
-
-
-def get_optional_inner_type(optional: Type[Optional[T]]) -> Type[T]:
- assert is_optional(optional)
- t: Type[T] = get_generic_type_arguments(optional)[0]
- return t
-
-
-def is_internal_field_name(field_name: str) -> bool:
- return field_name.startswith("_")
+++ /dev/null
-import errno
-import socket
-import sys
-from http.client import HTTPConnection
-from typing import Any, Literal, Optional
-from urllib.error import HTTPError, URLError
-from urllib.parse import quote, unquote, urlparse
-from urllib.request import AbstractHTTPHandler, Request, build_opener, install_opener, urlopen
-
-
-class SocketDesc:
- def __init__(self, socket_def: str, source: str) -> None:
- self.source = source
- if ":" in socket_def:
- # `socket_def` contains a schema, probably already URI-formatted, use directly
- self.uri = socket_def
- else:
- # `socket_def` is probably a path, convert to URI
- self.uri = f'http+unix://{quote(socket_def, safe="")}'
-
- while self.uri.endswith("/"):
- self.uri = self.uri[:-1]
-
-
-class Response:
- def __init__(self, status: int, body: str) -> None:
- self.status = status
- self.body = body
-
- def __repr__(self) -> str:
- return f"status: {self.status}\nbody:\n{self.body}"
-
-
-def _print_conn_error(error_desc: str, url: str, socket_source: str) -> None:
- host: str
- try:
- parsed_url = urlparse(url)
- host = unquote(parsed_url.hostname or "(Unknown)")
- except Exception as e:
- host = f"(Invalid URL: {e})"
- msg = f"""
-{error_desc}
-\tURL: {url}
-\tHost/Path: {host}
-\tSourced from: {socket_source}
-Is the URL correct?
-\tUnix socket would start with http+unix:// and URL encoded path.
-\tInet sockets would start with http:// and domain or ip
- """
- print(msg, file=sys.stderr)
-
-
-def request(
- socket_desc: SocketDesc,
- method: Literal["GET", "POST", "HEAD", "PUT", "DELETE"],
- path: str,
- body: Optional[str] = None,
- content_type: str = "application/json",
-) -> Response:
- while path.startswith("/"):
- path = path[1:]
- url = f"{socket_desc.uri}/{path}"
- req = Request(
- url,
- method=method,
- data=body.encode("utf8") if body is not None else None,
- headers={"Content-Type": content_type},
- )
- # req.add_header("Authorization", _authorization_header)
-
- timeout_m = 5 # minutes
- try:
- with urlopen(req, timeout=timeout_m * 60) as response:
- return Response(response.status, response.read().decode("utf8"))
- except HTTPError as err:
- return Response(err.code, err.read().decode("utf8"))
- except URLError as err:
- if err.errno == errno.ECONNREFUSED or isinstance(err.reason, ConnectionRefusedError):
- _print_conn_error("Connection refused.", url, socket_desc.source)
- elif err.errno == errno.ENOENT or isinstance(err.reason, FileNotFoundError):
- _print_conn_error("No such file or directory.", url, socket_desc.source)
- else:
- _print_conn_error(str(err), url, socket_desc.source)
- sys.exit(1)
- except (TimeoutError, socket.timeout):
- _print_conn_error(
- f"Connection timed out after {timeout_m} minutes."
- "\nIt does not mean that the operation necessarily failed."
- "\nSee Knot Resolver's log for more information.",
- url,
- socket_desc.source,
- )
- sys.exit(1)
-
-
-# Code heavily inspired by requests-unixsocket
-# https://github.com/msabramo/requests-unixsocket/blob/master/requests_unixsocket/adapters.py
-class UnixHTTPConnection(HTTPConnection):
- def __init__(self, unix_socket_url: str, timeout: float = 60) -> None:
- """
- Create an HTTP connection to a unix domain socket.
-
- :param unix_socket_url: A URL with a scheme of 'http+unix' and the
- netloc is a percent-encoded path to a unix domain socket. E.g.:
- 'http+unix://%2Ftmp%2Fprofilesvc.sock/status/pid'
- """
- super().__init__("localhost", timeout=timeout)
- self.unix_socket_path = unix_socket_url
- self.timeout = timeout
- self.sock: Optional[socket.socket] = None
-
- def __del__(self) -> None: # base class does not have d'tor
- if self.sock:
- self.sock.close()
-
- def connect(self) -> None:
- sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
- sock.settimeout(self.timeout)
- sock.connect(self.unix_socket_path)
- self.sock = sock
-
-
-class UnixHTTPHandler(AbstractHTTPHandler):
- def __init__(self) -> None:
- super().__init__()
-
- def open_(self: UnixHTTPHandler, req: Any) -> Any:
- return self.do_open(UnixHTTPConnection, req) # type: ignore[arg-type]
-
- setattr(UnixHTTPHandler, "http+unix_open", open_)
- setattr(UnixHTTPHandler, "http+unix_request", AbstractHTTPHandler.do_request_)
-
-
-opener = build_opener(UnixHTTPHandler())
-install_opener(opener)
+++ /dev/null
-import enum
-import logging
-import os
-import socket
-
-logger = logging.getLogger(__name__)
-
-
-class _Status(enum.Enum):
- NOT_INITIALIZED = 1
- FUNCTIONAL = 2
- FAILED = 3
-
-
-_status = _Status.NOT_INITIALIZED
-_socket = None
-
-
-def systemd_notify(**values: str) -> None:
- global _status
- global _socket
-
- if _status is _Status.NOT_INITIALIZED:
- socket_addr = os.getenv("NOTIFY_SOCKET")
- os.unsetenv("NOTIFY_SOCKET")
- if socket_addr is None:
- _status = _Status.FAILED
- return
- if socket_addr.startswith("@"):
- socket_addr = socket_addr.replace("@", "\0", 1)
-
- try:
- _socket = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
- _socket.connect(socket_addr)
- _status = _Status.FUNCTIONAL
- except Exception:
- _socket = None
- _status = _Status.FAILED
- logger.warning(f"Failed to connect to $NOTIFY_SOCKET at '{socket_addr}'", exc_info=True)
- return
-
- elif _status is _Status.FAILED:
- return
-
- if _status is _Status.FUNCTIONAL:
- assert _socket is not None
- payload = "\n".join((f"{key}={value}" for key, value in values.items()))
- try:
- _socket.send(payload.encode("utf8"))
- except Exception:
- logger.warning("Failed to send notification to systemd", exc_info=True)
- _status = _Status.FAILED
- _socket.close()
- _socket = None
+++ /dev/null
-import functools
-import os
-from pathlib import Path
-
-
-@functools.lru_cache(maxsize=16)
-def which(binary_name: str) -> Path:
- """
- Search $PATH and return the absolute path of that executable.
-
- The results of this function are LRU cached.
-
- If not found, throws an RuntimeError.
- """
- possible_directories = os.get_exec_path()
- for dr in possible_directories:
- p = Path(dr, binary_name)
- if p.exists():
- return p.absolute()
-
- raise RuntimeError(f"Executable {binary_name} was not found in $PATH")
+++ /dev/null
-from pathlib import Path
-
-from knot_resolver.datamodel.globals import Context, set_global_validation_context
-
-set_global_validation_context(Context(Path("."), False))
+++ /dev/null
-from typing import Any
-
-import pytest
-
-from knot_resolver.datamodel.cache_schema import CacheClearRPCSchema
-from knot_resolver.datamodel.templates import template_from_str
-
-
-@pytest.mark.parametrize(
- "val,res",
- [
- ({}, "cache.clear(nil,false,nil,100)"),
- ({"chunk-size": 200}, "cache.clear(nil,false,nil,200)"),
- ({"name": "example.com.", "exact-name": True}, "cache.clear('example.com.',true,nil,nil)"),
- (
- {"name": "example.com.", "exact-name": True, "rr-type": "AAAA"},
- "cache.clear('example.com.',true,kres.type.AAAA,nil)",
- ),
- ],
-)
-def test_cache_clear(val: Any, res: Any):
- tmpl_str = "{% from 'macros/cache_macros.lua.j2' import cache_clear %}{{ cache_clear(x) }}"
-
- tmpl = template_from_str(tmpl_str)
- assert tmpl.render(x=CacheClearRPCSchema(val)) == res
+++ /dev/null
-from knot_resolver.datamodel.forward_schema import ForwardServerSchema
-from knot_resolver.datamodel.templates import template_from_str
-
-
-def test_boolean():
- tmpl_str = """{% from 'macros/common_macros.lua.j2' import boolean %}
-{{ boolean(x) }}"""
-
- tmpl = template_from_str(tmpl_str)
- assert tmpl.render(x=True) == "true"
- assert tmpl.render(x=False) == "false"
-
-
-def test_boolean_neg():
- tmpl_str = """{% from 'macros/common_macros.lua.j2' import boolean %}
-{{ boolean(x,true) }}"""
-
- tmpl = template_from_str(tmpl_str)
- assert tmpl.render(x=True) == "false"
- assert tmpl.render(x=False) == "true"
-
-
-def test_string_table():
- s = "any string"
- t = [s, "other string"]
- tmpl_str = """{% from 'macros/common_macros.lua.j2' import string_table %}
-{{ string_table(x) }}"""
-
- tmpl = template_from_str(tmpl_str)
- assert tmpl.render(x=s) == f"'{s}'"
- assert tmpl.render(x=t) == f"{{'{s}','{t[1]}',}}"
-
-
-def test_str2ip_table():
- s = "2001:DB8::d0c"
- t = [s, "192.0.2.1"]
- tmpl_str = """{% from 'macros/common_macros.lua.j2' import str2ip_table %}
-{{ str2ip_table(x) }}"""
-
- tmpl = template_from_str(tmpl_str)
- assert tmpl.render(x=s) == f"kres.str2ip('{s}')"
- assert tmpl.render(x=t) == f"{{kres.str2ip('{s}'),kres.str2ip('{t[1]}'),}}"
-
-
-def test_qtype_table():
- s = "AAAA"
- t = [s, "TXT"]
- tmpl_str = """{% from 'macros/common_macros.lua.j2' import qtype_table %}
-{{ qtype_table(x) }}"""
-
- tmpl = template_from_str(tmpl_str)
- assert tmpl.render(x=s) == f"kres.type.{s}"
- assert tmpl.render(x=t) == f"{{kres.type.{s},kres.type.{t[1]},}}"
-
-
-def test_servers_table():
- s = "2001:DB8::d0c"
- t = [s, "192.0.2.1"]
- tmpl_str = """{% from 'macros/common_macros.lua.j2' import servers_table %}
-{{ servers_table(x) }}"""
-
- tmpl = template_from_str(tmpl_str)
- assert tmpl.render(x=s) == f"'{s}'"
- assert tmpl.render(x=t) == f"{{'{s}','{t[1]}',}}"
- assert tmpl.render(x=[{"address": s}, {"address": t[1]}]) == f"{{'{s}','{t[1]}',}}"
-
-
-def test_tls_servers_table():
- d = ForwardServerSchema(
- # the ca-file is a dummy, because it's existence is checked
- {"address": ["2001:DB8::d0c"], "hostname": "res.example.com", "ca-file": "/etc/passwd"}
- )
- t = [
- d,
- ForwardServerSchema(
- {
- "address": "192.0.2.1",
- "pin-sha256": "E9CZ9INDbd+2eRQozYqqbQ2yXLVKB9+xcprMF+44U1g=",
- }
- ),
- ]
- tmpl_str = """{% from 'macros/common_macros.lua.j2' import tls_servers_table %}
-{{ tls_servers_table(x) }}"""
-
- tmpl = template_from_str(tmpl_str)
- assert tmpl.render(x=[d.address, t[1].address]) == f"{{'{d.address}','{t[1].address}',}}"
- assert (
- tmpl.render(x=t)
- == f"{{{{'{d.address}',hostname='{d.hostname}',ca_file='{d.ca_file}',}},{{'{t[1].address}',pin_sha256={{'{t[1].pin_sha256}',}}}},}}"
- )
+++ /dev/null
-from knot_resolver.datamodel.forward_schema import ForwardSchema
-from knot_resolver.datamodel.templates import template_from_str
-from knot_resolver.datamodel.types import IPAddressOptionalPort
-
-
-def test_policy_rule_forward_add():
- tmpl_str = """{% from 'macros/forward_macros.lua.j2' import policy_rule_forward_add %}
-{{ policy_rule_forward_add(rule.subtree[0],rule.options,rule.servers) }}"""
-
- rule = ForwardSchema(
- {
- "subtree": ".",
- "servers": [{"address": ["2001:148f:fffe::1", "185.43.135.1"], "hostname": "odvr.nic.cz"}],
- "options": {
- "authoritative": False,
- "dnssec": True,
- },
- }
- )
- result = "policy.rule_forward_add('.',{dnssec=true,auth=false},{{'2001:148f:fffe::1',tls=false,hostname='odvr.nic.cz',},{'185.43.135.1',tls=false,hostname='odvr.nic.cz',},})"
-
- tmpl = template_from_str(tmpl_str)
- assert tmpl.render(rule=rule) == result
-
- rule.servers = [IPAddressOptionalPort("2001:148f:fffe::1"), IPAddressOptionalPort("185.43.135.1")]
- result = "policy.rule_forward_add('.',{dnssec=true,auth=false},{{'2001:148f:fffe::1'},{'185.43.135.1'},})"
- assert tmpl.render(rule=rule) == result
+++ /dev/null
-from knot_resolver.datamodel.network_schema import ListenSchema
-from knot_resolver.datamodel.templates import template_from_str
-
-
-def test_network_listen():
- tmpl_str = """{% from 'macros/network_macros.lua.j2' import network_listen %}
-{{ network_listen(listen) }}"""
- tmpl = template_from_str(tmpl_str)
-
- soc = ListenSchema({"unix-socket": "/tmp/kresd-socket", "kind": "dot"})
- assert tmpl.render(listen=soc) == "net.listen('/tmp/kresd-socket',nil,{kind='tls',freebind=false})\n"
- soc_list = ListenSchema({"unix-socket": [soc.unix_socket.to_std()[0], "/tmp/kresd-socket2"], "kind": "dot"})
- assert (
- tmpl.render(listen=soc_list)
- == "net.listen('/tmp/kresd-socket',nil,{kind='tls',freebind=false})\n"
- + "net.listen('/tmp/kresd-socket2',nil,{kind='tls',freebind=false})\n"
- )
-
- ip = ListenSchema({"interface": "::1@55", "freebind": True})
- assert tmpl.render(listen=ip) == "net.listen('::1',55,{kind='dns',freebind=true})\n"
- ip_list = ListenSchema({"interface": [ip.interface.to_std()[0], "127.0.0.1@5335"]})
- assert (
- tmpl.render(listen=ip_list)
- == "net.listen('::1',55,{kind='dns',freebind=false})\n"
- + "net.listen('127.0.0.1',5335,{kind='dns',freebind=false})\n"
- )
-
- intrfc = ListenSchema({"interface": "eth0", "kind": "doh2"})
- assert tmpl.render(listen=intrfc) == "net.listen(net['eth0'],443,{kind='doh2',freebind=false})\n"
- intrfc_list = ListenSchema({"interface": [intrfc.interface.to_std()[0], "lo"], "port": 5555, "kind": "doh2"})
- assert (
- tmpl.render(listen=intrfc_list)
- == "net.listen(net['eth0'],5555,{kind='doh2',freebind=false})\n"
- + "net.listen(net['lo'],5555,{kind='doh2',freebind=false})\n"
- )
+++ /dev/null
-from typing import List
-
-from knot_resolver.datamodel.network_schema import AddressRenumberingSchema
-from knot_resolver.datamodel.templates import template_from_str
-
-
-def test_policy_add():
- rule = "policy.all(policy.DENY)"
- tmpl_str = """{% from 'macros/policy_macros.lua.j2' import policy_add %}
-{{ policy_add(rule, postrule) }}"""
-
- tmpl = template_from_str(tmpl_str)
- assert tmpl.render(rule=rule, postrule=False) == f"policy.add({rule})"
- assert tmpl.render(rule=rule, postrule=True) == f"policy.add({rule},true)"
-
-
-def test_policy_tags_assign():
- tags: List[str] = ["t01", "t02", "t03"]
- tmpl_str = """{% from 'macros/policy_macros.lua.j2' import policy_tags_assign %}
-{{ policy_tags_assign(tags) }}"""
-
- tmpl = template_from_str(tmpl_str)
- assert tmpl.render(tags=tags[1]) == f"policy.TAGS_ASSIGN('{tags[1]}')"
- assert tmpl.render(tags=tags) == "policy.TAGS_ASSIGN({" + ",".join([f"'{x}'" for x in tags]) + ",})"
-
-
-def test_policy_get_tagset():
- tags: List[str] = ["t01", "t02", "t03"]
- tmpl_str = """{% from 'macros/policy_macros.lua.j2' import policy_get_tagset %}
-{{ policy_get_tagset(tags) }}"""
-
- tmpl = template_from_str(tmpl_str)
- assert tmpl.render(tags=tags[1]) == f"policy.get_tagset('{tags[1]}')"
- assert tmpl.render(tags=tags) == "policy.get_tagset({" + ",".join([f"'{x}'" for x in tags]) + ",})"
-
-
-# Filters
-
-
-def test_policy_all():
- action = "policy.DENY"
- tmpl_str = """{% from 'macros/policy_macros.lua.j2' import policy_all %}
-{{ policy_all(action) }}"""
-
- tmpl = template_from_str(tmpl_str)
- assert tmpl.render(action=action) == f"policy.all({action})"
-
-
-def test_policy_suffix():
- action = "policy.DROP"
- suffix = "policy.todnames({'example.com'})"
- tmpl_str = """{% from 'macros/policy_macros.lua.j2' import policy_suffix %}
-{{ policy_suffix(action, suffix) }}"""
-
- tmpl = template_from_str(tmpl_str)
- assert tmpl.render(action=action, suffix=suffix) == f"policy.suffix({action},{suffix})"
-
-
-def test_policy_suffix_common():
- action = "policy.DROP"
- suffix = "policy.todnames({'first.example.com','second.example.com'})"
- common = "policy.todnames({'example.com'})"
- tmpl_str = """{% from 'macros/policy_macros.lua.j2' import policy_suffix_common %}
-{{ policy_suffix_common(action, suffix, common) }}"""
-
- tmpl = template_from_str(tmpl_str)
- assert tmpl.render(action=action, suffix=suffix, common=None) == f"policy.suffix_common({action},{suffix})"
- assert (
- tmpl.render(action=action, suffix=suffix, common=common) == f"policy.suffix_common({action},{suffix},{common})"
- )
-
-
-def test_policy_pattern():
- action = "policy.DENY"
- pattern = "[0-9]+\2cz"
- tmpl_str = """{% from 'macros/policy_macros.lua.j2' import policy_pattern %}
-{{ policy_pattern(action, pattern) }}"""
-
- tmpl = template_from_str(tmpl_str)
- assert tmpl.render(action=action, pattern=pattern) == f"policy.pattern({action},'{pattern}')"
-
-
-def test_policy_rpz():
- action = "policy.DENY"
- path = "/etc/knot-resolver/blocklist.rpz"
- tmpl_str = """{% from 'macros/policy_macros.lua.j2' import policy_rpz %}
-{{ policy_rpz(action, path, watch) }}"""
-
- tmpl = template_from_str(tmpl_str)
- assert tmpl.render(action=action, path=path, watch=False) == f"policy.rpz({action},'{path}',false)"
- assert tmpl.render(action=action, path=path, watch=True) == f"policy.rpz({action},'{path}',true)"
-
-
-# Non-chain actions
-
-
-def test_policy_deny_msg():
- msg = "this is deny message"
- tmpl_str = """{% from 'macros/policy_macros.lua.j2' import policy_deny_msg %}
-{{ policy_deny_msg(msg) }}"""
-
- tmpl = template_from_str(tmpl_str)
- assert tmpl.render(msg=msg) == f"policy.DENY_MSG('{msg}')"
-
-
-def test_policy_reroute():
- r: List[AddressRenumberingSchema] = [
- AddressRenumberingSchema({"source": "192.0.2.0/24", "destination": "127.0.0.0"}),
- AddressRenumberingSchema({"source": "10.10.10.0/24", "destination": "192.168.1.0"}),
- ]
- tmpl_str = """{% from 'macros/policy_macros.lua.j2' import policy_reroute %}
-{{ policy_reroute(reroute) }}"""
-
- tmpl = template_from_str(tmpl_str)
- assert (
- tmpl.render(reroute=r)
- == f"policy.REROUTE({{['{r[0].source}']='{r[0].destination}'}},{{['{r[1].source}']='{r[1].destination}'}},)"
- )
+++ /dev/null
-from typing import Any
-
-import pytest
-from jinja2 import Template
-
-from knot_resolver.datamodel.types import EscapedStr
-from knot_resolver.utils.modeling import ConfigSchema
-
-str_template = Template("'{{ string }}'")
-
-
-str_multiline_template = Template(
- """[[
-{{ string.multiline() }}
-]]"""
-)
-
-
-@pytest.mark.parametrize(
- "val,exp",
- [
- ("\a\b\f\n\r\t\v\\", "\a\b\f\n\r\t\v\\"),
- ("[[ test ]]", r"\[\[ test \]\]"),
- ("[ [ test ] ]", r"[ [ test ] ]"),
- ],
-)
-def test_escaped_str_multiline(val: Any, exp: str):
- class TestSchema(ConfigSchema):
- pattern: EscapedStr
-
- d = TestSchema({"pattern": val})
- assert (
- str_multiline_template.render(string=d.pattern)
- == f"""[[
-{exp}
-]]"""
- )
-
-
-@pytest.mark.parametrize(
- "val,exp",
- [
- ("", ""),
- ("string", "string"),
- (2000, "2000"),
- ('"\a\b\f\n\r\t\v\\"', r"\"\a\b\f\n\r\t\v\\\""),
- ('""', r"\"\""),
- ("''", r"\'\'"),
- # fmt: off
- ('""', r"\"\""),
- ("''", r"\'\'"),
- # fmt: on
- ],
-)
-def test_escaped_str(val: Any, exp: str):
- class TestSchema(ConfigSchema):
- pattern: EscapedStr
-
- d = TestSchema({"pattern": val})
- assert str_template.render(string=d.pattern) == f"'{exp}'"
+++ /dev/null
-from typing import Any
-
-import pytest
-
-from knot_resolver.datamodel.templates import template_from_str
-from knot_resolver.datamodel.view_schema import ViewOptionsSchema, ViewSchema
-
-
-def test_view_flags():
- tmpl_str = """{% from 'macros/view_macros.lua.j2' import view_flags %}
-{{ view_flags(options) }}"""
-
- tmpl = template_from_str(tmpl_str)
- options = ViewOptionsSchema({"dns64": False, "minimize": False})
- assert tmpl.render(options=options) == '"NO_MINIMIZE","DNS64_DISABLE",'
- assert tmpl.render(options=ViewOptionsSchema()) == ""
-
-
-def test_view_answer():
- tmpl_str = """{% from 'macros/view_macros.lua.j2' import view_options_flags %}
-{{ view_options_flags(options) }}"""
-
- tmpl = template_from_str(tmpl_str)
- options = ViewOptionsSchema({"dns64": False, "minimize": False})
- assert tmpl.render(options=options) == "policy.FLAGS({'NO_MINIMIZE','DNS64_DISABLE',})"
- assert tmpl.render(options=ViewOptionsSchema()) == "policy.FLAGS({})"
-
-
-@pytest.mark.parametrize(
- "val,res",
- [
- ("allow", "policy.TAGS_ASSIGN({})"),
- ("refused", "'policy.REFUSE'"),
- ("noanswer", "'policy.NO_ANSWER'"),
- ],
-)
-def test_view_answer(val: Any, res: Any):
- tmpl_str = """{% from 'macros/view_macros.lua.j2' import view_answer %}
-{{ view_answer(view.answer) }}"""
-
- tmpl = template_from_str(tmpl_str)
- view = ViewSchema({"subnets": ["10.0.0.0/8"], "answer": val})
- assert tmpl.render(view=view) == res
+++ /dev/null
-import inspect
-import json
-from typing import Any, Dict, Type, cast
-
-from knot_resolver.datamodel import KresConfig
-from knot_resolver.datamodel.lua_schema import LuaSchema
-from knot_resolver.utils.modeling import BaseSchema
-from knot_resolver.utils.modeling.types import (
- get_annotations,
- get_generic_type_argument,
- get_generic_type_arguments,
- get_optional_inner_type,
- is_dict,
- is_list,
- is_optional,
- is_union,
-)
-
-
-def test_config_check_str_type():
- # check that there is no 'str' type in datamodel schema (except for LuaSchema
- def _check_str_type(cls: Type[Any], object_path: str = ""):
- if cls == str:
- raise TypeError(f"{object_path}: 'str' type not allowed")
- elif is_optional(cls):
- inner: Type[Any] = get_optional_inner_type(cls)
- _check_str_type(inner, object_path)
- elif is_union(cls):
- variants = get_generic_type_arguments(cls)
- for v in variants:
- _check_str_type(v, object_path)
- elif is_dict(cls):
- key_type, val_type = get_generic_type_arguments(cls)
- _check_str_type(key_type, object_path)
- _check_str_type(val_type, object_path)
- elif is_list(cls):
- inner_type = get_generic_type_argument(cls)
- _check_str_type(inner_type, object_path)
-
- elif inspect.isclass(cls) and issubclass(cls, BaseSchema):
- annot = get_annotations(cls)
- for name, python_type in annot.items():
- # ignore lua section
- if python_type != LuaSchema:
- _check_str_type(python_type, f"{object_path}/{name}")
-
- _check_str_type(KresConfig)
-
-
-def test_config_defaults():
- config = KresConfig()
-
- # DNS64 default
- assert config.dns64.enable == False
-
-
-def test_dnssec_false():
- config = KresConfig({"dnssec": {"enable": False}})
-
- assert config.dnssec.enable == False
-
-
-def test_dnssec_default_true():
- config = KresConfig()
-
- # DNSSEC defaults
- assert config.dnssec.enable == True
- assert config.dnssec.sentinel == True
- assert config.dnssec.signal_query == True
- assert config.dnssec.trust_anchors == None
- assert config.dnssec.trust_anchors_files == None
- assert config.dnssec.negative_trust_anchors == None
-
-
-def test_dns64_prefix_default():
- config = KresConfig({"dns64": {"enable": True}})
-
- assert config.dns64.enable == True
- assert str(config.dns64.prefix) == "64:ff9b::/96"
-
-
-def test_config_json_schema():
- dct = KresConfig.json_schema()
-
- def recser(obj: Any, path: str = "") -> None:
- if not isinstance(obj, dict):
- return
- else:
- obj = cast(Dict[Any, Any], obj)
- for key in obj:
- recser(obj[key], path=f"{path}/{key}")
- try:
- _ = json.dumps(obj)
- except BaseException as e:
- raise Exception(f"failed to serialize '{path}': {e}") from e
-
- recser(dct)
+++ /dev/null
-import pytest
-from pytest import raises
-
-from knot_resolver.datamodel.forward_schema import ForwardSchema
-from knot_resolver.utils.modeling.exceptions import DataValidationError
-
-
-@pytest.mark.parametrize("port,auth", [(5335, False), (53, True)])
-def test_forward_valid(port: int, auth: bool):
- assert ForwardSchema(
- {"subtree": ".", "options": {"authoritative": auth, "dnssec": True}, "servers": [f"127.0.0.1", "::1"]}
- )
- assert ForwardSchema(
- {"subtree": ".", "options": {"authoritative": auth, "dnssec": False}, "servers": [f"127.0.0.1@{port}", "::1"]}
- )
-
- assert ForwardSchema(
- {
- "subtree": ".",
- "options": {"authoritative": auth, "dnssec": False},
- "servers": [{"address": [f"127.0.0.1@{port}", "::1"]}],
- }
- )
-
- assert ForwardSchema(
- {
- "subtree": ".",
- "options": {"authoritative": auth, "dnssec": False},
- "servers": [{"address": [f"127.0.0.1", "::1"]}],
- }
- )
-
-
-@pytest.mark.parametrize(
- "port,auth,tls",
- [(5335, True, False), (53, True, True)],
-)
-def test_forward_invalid(port: int, auth: bool, tls: bool):
- if not tls:
- with raises(DataValidationError):
- ForwardSchema(
- {
- "subtree": ".",
- "options": {"authoritative": auth, "dnssec": False},
- "servers": [f"127.0.0.1@{port}", "::1"],
- }
- )
-
- with raises(DataValidationError):
- ForwardSchema(
- {
- "subtree": ".",
- "options": {"authoritative": auth, "dnssec": False},
- "servers": [{"address": [f"127.0.0.1{port}", f"::1{port}"], "transport": "tls" if tls else None}],
- }
- )
+++ /dev/null
-from typing import Any
-
-import pytest
-from pytest import raises
-
-from knot_resolver.datamodel.local_data_schema import RuleSchema
-from knot_resolver.utils.modeling.exceptions import DataValidationError
-
-
-@pytest.mark.parametrize(
- "val",
- [
- {"name": ["sub2.example.org"], "subtree": "empty", "tags": ["t01"]},
- {"name": ["sub3.example.org", "sub5.example.net."], "subtree": "nxdomain", "ttl": "1h"},
- {"name": ["sub4.example.org"], "subtree": "redirect"},
- {"name": ["sub5.example.org"], "address": ["127.0.0.1"]},
- {"name": ["sub6.example.org"], "subtree": "redirect", "address": ["127.0.0.1"]},
- {"file": "/etc/hosts", "ttl": "20m", "nodata": True},
- {"records": "", "ttl": "20m", "nodata": True},
- ],
-)
-def test_subtree_valid(val: Any):
- RuleSchema(val)
-
-
-@pytest.mark.parametrize(
- "val",
- [
- {"subtree": "empty"},
- {"name": ["sub2.example.org"], "file": "/etc/hosts"},
- {"name": ["sub4.example.org"], "address": ["127.0.0.1"], "subtree": "nxdomain"},
- {"name": ["sub4.example.org"], "subtree": "redirect", "file": "/etc/hosts"},
- ],
-)
-def test_subtree_invalid(val: Any):
- with raises(DataValidationError):
- RuleSchema(val)
+++ /dev/null
-from pytest import raises
-
-from knot_resolver.datamodel.lua_schema import LuaSchema
-from knot_resolver.utils.modeling.exceptions import DataValidationError
-
-
-def test_invalid():
- with raises(DataValidationError):
- LuaSchema({"script": "-- lua script", "script-file": "path/to/file"})
+++ /dev/null
-from typing import Any, Dict, Optional
-
-import pytest
-
-from knot_resolver.datamodel.management_schema import ManagementSchema
-from knot_resolver.utils.modeling.exceptions import DataValidationError
-
-
-@pytest.mark.parametrize("val", [{"interface": "::1@53"}, {"unix-socket": "/tmp/socket"}])
-def test_management_valid(val: Dict[str, Any]):
- o = ManagementSchema(val)
- if o.interface:
- assert str(o.interface) == val["interface"]
- if o.unix_socket:
- assert str(o.unix_socket) == val["unix-socket"]
-
-
-@pytest.mark.parametrize("val", [None, {"interface": "::1@53", "unix-socket": "/tmp/socket"}])
-def test_management_invalid(val: Optional[Dict[str, Any]]):
- with pytest.raises(DataValidationError):
- ManagementSchema(val)
+++ /dev/null
-from typing import Any, Dict, Optional
-
-import pytest
-from pytest import raises
-
-from knot_resolver.constants import WATCHDOG_LIB
-from knot_resolver.datamodel.network_schema import ListenSchema, NetworkSchema, TLSSchema
-from knot_resolver.datamodel.types import InterfaceOptionalPort, PortNumber
-from knot_resolver.utils.modeling.exceptions import DataValidationError
-
-
-def test_listen_defaults():
- o = NetworkSchema()
-
- assert len(o.listen) == 2
- # {"ip-address": "127.0.0.1"}
- assert o.listen[0].interface.to_std() == [InterfaceOptionalPort("127.0.0.1")]
- assert o.listen[0].port == PortNumber(53)
- assert o.listen[0].kind == "dns"
- assert o.listen[0].freebind == False
- # {"ip-address": "::1", "freebind": True}
- assert o.listen[1].interface.to_std() == [InterfaceOptionalPort("::1")]
- assert o.listen[1].port == PortNumber(53)
- assert o.listen[1].kind == "dns"
- assert o.listen[1].freebind == True
-
-
-@pytest.mark.parametrize(
- "listen,port",
- [
- ({"unix-socket": ["/tmp/kresd-socket"]}, None),
- ({"interface": ["::1"]}, 53),
- ({"interface": ["::1"], "kind": "dot"}, 853),
- ({"interface": ["::1"], "kind": "doh-legacy"}, 443),
- ({"interface": ["::1"], "kind": "doh2"}, 443),
- ({"interface": ["::1"], "kind": "doq"}, 853),
- ],
-)
-def test_listen_port_defaults(listen: Dict[str, Any], port: Optional[int]):
- assert ListenSchema(listen).port == (PortNumber(port) if port else None)
-
-
-@pytest.mark.parametrize(
- "listen",
- [
- {"unix-socket": "/tmp/kresd-socket"},
- {"unix-socket": ["/tmp/kresd-socket", "/tmp/kresd-socket2"]},
- {"interface": "::1"},
- {"interface": "::1@5335"},
- {"interface": "::1", "port": 5335},
- {"interface": ["127.0.0.1", "::1"]},
- {"interface": ["127.0.0.1@5335", "::1@5335"]},
- {"interface": ["127.0.0.1", "::1"], "port": 5335},
- {"interface": "lo"},
- {"interface": "lo@5335"},
- {"interface": "lo", "port": 5335},
- {"interface": ["lo", "eth0"]},
- {"interface": ["lo@5335", "eth0@5335"]},
- {"interface": ["lo", "eth0"], "port": 5335},
- ],
-)
-def test_listen_valid(listen: Dict[str, Any]):
- assert ListenSchema(listen)
-
-
-@pytest.mark.parametrize(
- "listen",
- [
- {"unix-socket": "/tmp/kresd-socket", "port": "53"},
- {"interface": "::1", "unix-socket": "/tmp/kresd-socket"},
- {"interface": "::1@5335", "port": 5335},
- {"interface": ["127.0.0.1", "::1@5335"]},
- {"interface": ["127.0.0.1@5335", "::1@5335"], "port": 5335},
- {"interface": "lo@5335", "port": 5335},
- {"interface": ["lo", "eth0@5335"]},
- {"interface": ["lo@5335", "eth0@5335"], "port": 5335},
- ],
-)
-def test_listen_invalid(listen: Dict[str, Any]):
- with raises(DataValidationError):
- ListenSchema(listen)
-
-
-@pytest.mark.parametrize(
- "tls",
- [
- {"watchdog": "auto"},
- {"watchdog": True},
- {"watchdog": False},
- ],
-)
-def test_tls_watchdog(tls: Dict[str, Any]):
- expected: bool = WATCHDOG_LIB if tls["watchdog"] == "auto" else tls["watchdog"]
- assert TLSSchema(tls).watchdog == expected
+++ /dev/null
-import random
-import sys
-from typing import List, Optional
-
-import pytest
-from pytest import raises
-
-from knot_resolver import KresBaseError
-from knot_resolver.datamodel.types.base_types import FloatRangeBase, IntRangeBase, StringLengthBase
-
-
-@pytest.mark.parametrize("min,max", [(0, None), (None, 0), (1, 65535), (-65535, -1)])
-def test_int_range_base(min: Optional[int], max: Optional[int]):
- class Test(IntRangeBase):
- if min:
- _min = min
- if max:
- _max = max
-
- if min:
- assert int(Test(min)) == min
- if max:
- assert int(Test(max)) == max
-
- rmin = min if min else -sys.maxsize - 1
- rmax = max if max else sys.maxsize
-
- n = 100
- vals: List[int] = [random.randint(rmin, rmax) for _ in range(n)]
- assert [str(Test(val)) == f"{val}" for val in vals]
-
- invals: List[int] = []
- invals.extend([random.randint(rmax + 1, sys.maxsize) for _ in range(n % 2)] if max else [])
- invals.extend([random.randint(-sys.maxsize - 1, rmin - 1) for _ in range(n % 2)] if max else [])
-
- for inval in invals:
- with raises(KresBaseError):
- Test(inval)
-
-
-@pytest.mark.parametrize("min,max", [(0.0, None), (None, 0.0), (1.0, 65535.0), (-65535.0, -1.0)])
-def test_float_range_base(min: Optional[float], max: Optional[float]):
- class Test(FloatRangeBase):
- if min:
- _min = min
- if max:
- _max = max
-
- if min:
- assert float(Test(min)) == min
- if max:
- assert float(Test(max)) == max
-
- rmin = min if min else sys.float_info.min - 1.0
- rmax = max if max else sys.float_info.max
-
- n = 100
- vals: List[float] = [random.uniform(rmin, rmax) for _ in range(n)]
- assert [str(Test(val)) == f"{val}" for val in vals]
-
- invals: List[float] = []
- invals.extend([random.uniform(rmax + 1.0, sys.float_info.max) for _ in range(n % 2)] if max else [])
- invals.extend([random.uniform(sys.float_info.min - 1.0, rmin - 1.0) for _ in range(n % 2)] if max else [])
-
- for inval in invals:
- with raises(KresBaseError):
- Test(inval)
-
-
-@pytest.mark.parametrize("min,max", [(10, None), (None, 10), (2, 32)])
-def test_str_bytes_length_base(min: Optional[int], max: Optional[int]):
- class Test(StringLengthBase):
- if min:
- _min_bytes = min
- if max:
- _max_bytes = max
-
- if min:
- assert len(str(Test("x" * min)).encode("utf-8")) == min
- if max:
- assert len(str(Test("x" * max)).encode("utf-8")) == max
-
- n = 100
- rmin = 1 if not min else min
- rmax = 1024 if not max else max
- vals: List[str] = ["x" * random.randint(rmin, rmax) for _ in range(n)]
- assert [str(Test(val)) == f"{val}" for val in vals]
-
- invals: List[str] = []
- invals.extend(["x" * random.randint(rmax + 1, 2048) for _ in range(n % 2)] if max else [])
- invals.extend(["x" * random.randint(1, rmin - 1) for _ in range(n % 2)] if max else [])
-
- for inval in invals:
- with raises(KresBaseError):
- Test(inval)
+++ /dev/null
-import ipaddress
-import random
-import string
-from typing import Any
-
-import pytest
-from pytest import raises
-
-from knot_resolver.datamodel.types import (
- Dir,
- DomainName,
- EscapedStr,
- InterfaceName,
- InterfaceOptionalPort,
- InterfacePort,
- IPAddressEM,
- IPAddressOptionalPort,
- IPAddressPort,
- IPNetwork,
- IPv4Address,
- IPv6Address,
- IPv6Network96,
- PinSha256,
- PortNumber,
- SizeUnit,
- TimeUnit,
-)
-from knot_resolver.utils.modeling import BaseSchema
-
-
-def _rand_domain(label_chars: int, levels: int = 1) -> str:
- return "".join(
- ["".join(random.choices(string.ascii_letters + string.digits, k=label_chars)) + "." for i in range(levels)]
- )
-
-
-@pytest.mark.parametrize("val", [1, 65_535, 5335, 5000])
-def test_port_number_valid(val: int):
- assert int(PortNumber(val)) == val
-
-
-@pytest.mark.parametrize("val", [0, 65_636, -1, "53"])
-def test_port_number_invalid(val: Any):
- with raises(ValueError):
- PortNumber(val)
-
-
-@pytest.mark.parametrize("val", ["5368709120B", "5242880K", "5120M", "5G"])
-def test_size_unit_valid(val: str):
- o = SizeUnit(val)
- assert int(o) == 5368709120
- assert str(o) == val
- assert o.bytes() == 5368709120
-
-
-@pytest.mark.parametrize("val", ["-5B", 5, -5242880, "45745mB"])
-def test_size_unit_invalid(val: Any):
- with raises(ValueError):
- SizeUnit(val)
-
-
-@pytest.mark.parametrize("val", ["1d", "24h", "1440m", "86400s", "86400000ms"])
-def test_time_unit_valid(val: str):
- o = TimeUnit(val)
- assert int(o) == 86400000000
- assert str(o) == val
- assert o.seconds() == 86400
- assert o.millis() == 86400000
- assert o.micros() == 86400000000
-
-
-@pytest.mark.parametrize("val", ["-1", "-24h", "1440mm", 6575, -1440])
-def test_time_unit_invalid(val: Any):
- with raises(ValueError):
- TimeUnit("-1")
-
-
-def test_parsing_units():
- class TestSchema(BaseSchema):
- size: SizeUnit
- time: TimeUnit
-
- o = TestSchema({"size": "3K", "time": "10m"})
- assert int(o.size) == int(SizeUnit("3072B"))
- assert int(o.time) == int(TimeUnit("600s"))
- assert o.size.bytes() == 3072
- assert o.time.seconds() == 10 * 60
-
-
-def test_checked_path():
- class TestSchema(BaseSchema):
- p: Dir
-
- assert str(TestSchema({"p": "/tmp"}).p) == "/tmp"
-
-
-@pytest.mark.parametrize(
- "val",
- [
- "d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM=",
- "E9CZ9INDbd+2eRQozYqqbQ2yXLVKB9+xcprMF+44U1g=",
- ],
-)
-def test_pin_sha256_valid(val: str):
- o = PinSha256(val)
- assert str(o) == val
-
-
-@pytest.mark.parametrize(
- "val",
- [
- "d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM==",
- "E9CZ9INDbd+2eRQozYqqbQ2yXLVKB9+xcprMF+44U1g",
- "!E9CZ9INDbd+2eRQozYqqbQ2yXLVKB9+xcprMF+44U1g=",
- "d6qzRu9zOE",
- ],
-)
-def test_pin_sha256_invalid(val: str):
- with raises(ValueError):
- PinSha256(val)
-
-
-@pytest.mark.parametrize(
- "val,exp",
- [
- ("", r""),
- (2000, "2000"),
- ("string", r"string"),
- ("\t\n\v", r"\t\n\v"),
- ("\a\b\f\n\r\t\v\\", r"\a\b\f\n\r\t\v\\"),
- # fmt: off
- ("''", r"\'\'"),
- ('""', r"\"\""),
- ("''", r"\'\'"),
- ('""', r"\"\""),
- ('\\"\\"', r"\\\"\\\""),
- ("\\'\\'", r"\\\'\\\'"),
- # fmt: on
- ],
-)
-def test_escaped_str_valid(val: Any, exp: str):
- assert str(EscapedStr(val)) == exp
-
-
-@pytest.mark.parametrize("val", [1.1, False])
-def test_escaped_str_invalid(val: Any):
- with raises(ValueError):
- EscapedStr(val)
-
-
-@pytest.mark.parametrize(
- "val",
- [
- ".",
- "example.com",
- "_8443._https.example.com.",
- "this.is.example.com.",
- "test.example.com",
- "test-example.com",
- "bücher.com.",
- "příklad.cz",
- _rand_domain(63),
- _rand_domain(1, 127),
- ],
-)
-def test_domain_name_valid(val: str):
- o = DomainName(val)
- assert str(o) == val
- assert o == DomainName(val)
- assert o.punycode() == val.encode("idna").decode("utf-8") if val != "." else "."
-
-
-@pytest.mark.parametrize(
- "val",
- [
- "test.example..com.",
- "-example.com",
- "-test.example.net",
- "test-.example.net",
- "test.-example.net",
- ".example.net",
- _rand_domain(64),
- _rand_domain(1, 128),
- ],
-)
-def test_domain_name_invalid(val: str):
- with raises(ValueError):
- DomainName(val)
-
-
-@pytest.mark.parametrize("val", ["lo", "eth0", "wlo1", "web_ifgrp", "e8-2"])
-def test_interface_name_valid(val: str):
- assert str(InterfaceName(val)) == val
-
-
-@pytest.mark.parametrize("val", ["_lo", "-wlo1", "lo_", "wlo1-", "e8--2", "web__ifgrp"])
-def test_interface_name_invalid(val: Any):
- with raises(ValueError):
- InterfaceName(val)
-
-
-@pytest.mark.parametrize("val", ["lo@5335", "2001:db8::1000@5001"])
-def test_interface_port_valid(val: str):
- o = InterfacePort(val)
- assert str(o) == val
- assert o == InterfacePort(val)
- assert str(o.if_name if o.if_name else o.addr) == val.split("@", 1)[0]
- assert o.port == PortNumber(int(val.split("@", 1)[1]))
-
-
-@pytest.mark.parametrize("val", ["lo", "2001:db8::1000", "53"])
-def test_interface_port_invalid(val: Any):
- with raises(ValueError):
- InterfacePort(val)
-
-
-@pytest.mark.parametrize("val", ["lo", "123.4.5.6", "lo@5335", "2001:db8::1000@5001"])
-def test_interface_optional_port_valid(val: str):
- o = InterfaceOptionalPort(val)
- assert str(o) == val
- assert o == InterfaceOptionalPort(val)
- assert str(o.if_name if o.if_name else o.addr) == (val.split("@", 1)[0] if "@" in val else val)
- assert o.port == (PortNumber(int(val.split("@", 1)[1])) if "@" in val else None)
-
-
-@pytest.mark.parametrize("val", ["lo@", "@53"])
-def test_interface_optional_port_invalid(val: Any):
- with raises(ValueError):
- InterfaceOptionalPort(val)
-
-
-@pytest.mark.parametrize("val", ["123.4.5.6@5335", "2001:db8::1000@53"])
-def test_ip_address_port_valid(val: str):
- o = IPAddressPort(val)
- assert str(o) == val
- assert o == IPAddressPort(val)
- assert str(o.addr) == val.split("@", 1)[0]
- assert o.port == PortNumber(int(val.split("@", 1)[1]))
-
-
-@pytest.mark.parametrize(
- "val", ["123.4.5.6", "2001:db8::1000", "123.4.5.6.7@5000", "2001:db8::10000@5001", "123.4.5.6@"]
-)
-def test_ip_address_port_invalid(val: Any):
- with raises(ValueError):
- IPAddressPort(val)
-
-
-@pytest.mark.parametrize("val", ["123.4.5.6", "123.4.5.6@5335", "2001:db8::1000", "2001:db8::1000@53"])
-def test_ip_address_optional_port_valid(val: str):
- o = IPAddressOptionalPort(val)
- assert str(o) == val
- assert o == IPAddressOptionalPort(val)
- assert str(o.addr) == (val.split("@", 1)[0] if "@" in val else val)
- assert o.port == (PortNumber(int(val.split("@", 1)[1])) if "@" in val else None)
-
-
-@pytest.mark.parametrize("val", ["123.4.5.6.7", "2001:db8::10000", "123.4.5.6@", "@55"])
-def test_ip_address_optional_port_invalid(val: Any):
- with raises(ValueError):
- IPAddressOptionalPort(val)
-
-
-@pytest.mark.parametrize("val", ["123.4.5.6", "192.168.0.1"])
-def test_ipv4_address_valid(val: str):
- o = IPv4Address(val)
- assert str(o) == val
- assert o == IPv4Address(val)
-
-
-@pytest.mark.parametrize("val", ["123456", "2001:db8::1000"])
-def test_ipv4_address_invalid(val: Any):
- with raises(ValueError):
- IPv4Address(val)
-
-
-@pytest.mark.parametrize("val", ["2001:db8::1000", "2001:db8:85a3::8a2e:370:7334"])
-def test_ipv6_address_valid(val: str):
- o = IPv6Address(val)
- assert str(o) == val
- assert o == IPv6Address(val)
-
-
-@pytest.mark.parametrize("val", ["123.4.5.6", "2001::db8::1000"])
-def test_ipv6_address_invalid(val: Any):
- with raises(ValueError):
- IPv6Address(val)
-
-
-@pytest.mark.parametrize("val", ["10.11.12.0/24", "64:ff9b::/96"])
-def test_ip_network_valid(val: str):
- o = IPNetwork(val)
- assert str(o) == val
- assert o.to_std().prefixlen == int(val.split("/", 1)[1])
- assert o.to_std() == ipaddress.ip_network(val)
-
-
-@pytest.mark.parametrize("val", ["10.11.12.13/8", "10.11.12.5/128"])
-def test_ip_network_invalid(val: str):
- with raises(ValueError):
- IPNetwork(val)
-
-
-@pytest.mark.parametrize("val", ["fe80::/96", "64:ff9b::/96"])
-def test_ipv6_96_network_valid(val: str):
- assert str(IPv6Network96(val)) == val
-
-
-@pytest.mark.parametrize("val", ["fe80::/95", "10.11.12.3/96", "64:ff9b::1/96"])
-def test_ipv6_96_network_invalid(val: Any):
- with raises(ValueError):
- IPv6Network96(val)
-
-
-@pytest.mark.parametrize("val", ["10.10.10.5!", "::1!"])
-def test_ip_address_em_valid(val: str):
- assert str(IPAddressEM(val)) == val
-
-
-@pytest.mark.parametrize("val", ["10.10.10.5", "::1", "10.10.10.5!!", "::1!!"])
-def test_ip_address_em_invalid(val: Any):
- with raises(ValueError):
- IPAddressEM(val)
+++ /dev/null
-from typing import Any, List, Optional, Union
-
-import pytest
-from pytest import raises
-
-from knot_resolver.datamodel.types import ListOrItem
-from knot_resolver.utils.modeling import BaseSchema
-from knot_resolver.utils.modeling.exceptions import DataValidationError
-from knot_resolver.utils.modeling.types import get_generic_type_wrapper_argument
-
-
-@pytest.mark.parametrize("val", [str, int])
-def test_list_or_item_inner_type(val: Any):
- assert get_generic_type_wrapper_argument(ListOrItem[val]) == Union[List[val], val]
-
-
-@pytest.mark.parametrize(
- "typ,val",
- [
- (int, [1, 65_535, 5335, 5000]),
- (int, 65_535),
- (str, ["string1", "string2"]),
- (str, "string1"),
- ],
-)
-def test_list_or_item_valid(typ: Any, val: Any):
- class ListOrItemSchema(BaseSchema):
- test: ListOrItem[typ]
-
- o = ListOrItemSchema({"test": val})
- assert o.test.serialize() == val
- assert o.test.to_std() == val if isinstance(val, list) else [val]
-
- i = 0
- for item in o.test:
- assert item == val[i] if isinstance(val, list) else val
- i += 1
-
-
-@pytest.mark.parametrize(
- "typ,val",
- [
- (str, [True, False, True, False]),
- (str, False),
- (bool, [1, 65_535, 5335, 5000]),
- (bool, 65_535),
- (int, "string1"),
- (int, ["string1", "string2"]),
- ],
-)
-def test_list_or_item_invalid(typ: Any, val: Any):
- class ListOrItemSchema(BaseSchema):
- test: ListOrItem[typ]
-
- with raises(DataValidationError):
- ListOrItemSchema({"test": val})
-
-
-def test_list_or_item_empty():
- with raises(ValueError):
- ListOrItem([])
+++ /dev/null
-import pytest
-
-from knot_resolver.datamodel.config_schema import KresConfig
-from knot_resolver.manager.config_store import ConfigStore, only_on_real_changes_update
-
-
-@pytest.mark.asyncio # type: ignore
-async def test_only_once():
- count = 0
-
- @only_on_real_changes_update(lambda config: config.logging.level)
- async def change_callback(config: KresConfig, force: bool = False) -> None:
- nonlocal count
- count += 1
-
- config = KresConfig()
- store = ConfigStore(config)
-
- await store.register_on_change_callback(change_callback)
- assert count == 1
-
- config = KresConfig()
- config.logging.level = "crit"
- await store.update(config)
- assert count == 2
-
- config = KresConfig()
- config.lua.script_only = True
- config.lua.script = "meaningless value"
- await store.update(config)
- assert count == 2
+++ /dev/null
-import toml
-
-from knot_resolver import __version__
-
-
-def test_version():
- with open("pyproject.toml", "r") as f:
- pyproject = toml.load(f)
-
- version = pyproject["tool"]["poetry"]["version"]
- assert __version__ == version
+++ /dev/null
-from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
-
-import pytest
-from pytest import raises
-
-from knot_resolver.utils.modeling import ConfigSchema, parse_json, parse_yaml
-from knot_resolver.utils.modeling.exceptions import DataDescriptionError, DataValidationError
-
-
-class _TestBool(ConfigSchema):
- v: bool
-
-
-class _TestInt(ConfigSchema):
- v: int
-
-
-class _TestStr(ConfigSchema):
- v: str
-
-
-class _TestLiteral(ConfigSchema):
- v: Literal[Literal["lit1"], Literal["lit2"]]
-
-
-@pytest.mark.parametrize("val", ["lit1", "lit2"])
-def test_parsing_literal_valid(val: str):
- assert _TestLiteral(parse_yaml(f"v: {val}")).v == val
-
-
-@pytest.mark.parametrize("val", ["invalid", "false", 1, "null"])
-def test_parsing_literal_invalid(val: str):
- with raises(DataValidationError):
- _TestLiteral(parse_yaml(f"v: {val}"))
-
-
-@pytest.mark.parametrize("val,exp", [("false", False), ("true", True), ("False", False), ("True", True)])
-def test_parsing_bool_valid(val: str, exp: bool):
- assert _TestBool(parse_yaml(f"v: {val}")).v == exp
-
-
-@pytest.mark.parametrize("val", ["0", "1", "5", "'true'", "'false'", "5.5"]) # int, str, float
-def test_parsing_bool_invalid(val: str):
- with raises(DataValidationError):
- _TestBool(parse_yaml(f"v: {val}"))
-
-
-@pytest.mark.parametrize("val,exp", [("0", 0), ("5335", 5335), ("-5001", -5001)])
-def test_parsing_int_valid(val: str, exp: int):
- assert _TestInt(parse_yaml(f"v: {val}")).v == exp
-
-
-@pytest.mark.parametrize("val", ["false", "'5'", "5.5"]) # bool, str, float
-def test_parsing_int_invalid(val: str):
- with raises(DataValidationError):
- _TestInt(parse_yaml(f"v: {val}"))
-
-
-# int and float are allowed inputs for string
-@pytest.mark.parametrize("val,exp", [("test", "test"), (65, "65"), (5.5, "5.5")])
-def test_parsing_str_valid(val: Any, exp: str):
- assert _TestStr(parse_yaml(f"v: {val}")).v == exp
-
-
-def test_parsing_str_invalid():
- with raises(DataValidationError):
- _TestStr(parse_yaml("v: false")) # bool
-
-
-def test_parsing_list_empty():
- class ListSchema(ConfigSchema):
- empty: List[Any]
-
- with raises(DataValidationError):
- ListSchema(parse_yaml("empty: []"))
-
-
-@pytest.mark.parametrize("typ,val", [(_TestInt, 5), (_TestBool, False), (_TestStr, "test")])
-def test_parsing_nested(typ: Type[ConfigSchema], val: Any):
- class UpperSchema(ConfigSchema):
- l: typ
-
- yaml = f"""
-l:
- v: {val}
-"""
-
- o = UpperSchema(parse_yaml(yaml))
- assert o.l.v == val
-
-
-def test_parsing_simple_compound_types():
- class TestSchema(ConfigSchema):
- l: List[int]
- d: Dict[str, str]
- t: Tuple[str, int]
- o: Optional[int]
-
- yaml = """
-l:
- - 1
- - 2
- - 3
- - 4
- - 5
-d:
- something: else
- w: all
-t:
- - test
- - 5
-"""
-
- o = TestSchema(parse_yaml(yaml))
- assert o.l == [1, 2, 3, 4, 5]
- assert o.d == {"something": "else", "w": "all"}
- assert o.t == ("test", 5)
- assert o.o is None
-
-
-def test_parsing_nested_compound_types():
- class TestSchema(ConfigSchema):
- i: int
- o: Optional[Dict[str, str]]
-
- yaml1 = "i: 5"
- yaml2 = f"""
-{yaml1}
-o:
- key1: str1
- key2: str2
- """
-
- o = TestSchema(parse_yaml(yaml1))
- assert o.i == 5
- assert o.o is None
-
- o = TestSchema(parse_yaml(yaml2))
- assert o.i == 5
- assert o.o == {"key1": "str1", "key2": "str2"}
-
-
-def test_dash_conversion():
- class TestSchema(ConfigSchema):
- awesome_field: Dict[str, str]
-
- yaml = """
-awesome-field:
- awesome-key: awesome-value
-"""
-
- o = TestSchema(parse_yaml(yaml))
- assert o.awesome_field["awesome-key"] == "awesome-value"
-
-
-def test_eq():
- class B(ConfigSchema):
- a: _TestInt
- field: str
-
- b1 = B({"a": {"v": 6}, "field": "val"})
- b2 = B({"a": {"v": 6}, "field": "val"})
- b_diff = B({"a": {"v": 7}, "field": "val"})
-
- assert b1 == b2
- assert b2 != b_diff
- assert b1 != b_diff
- assert b_diff == b_diff
-
-
-def test_docstring_parsing_valid():
- class NormalDescription(ConfigSchema):
- """
- Does nothing special
- Really
- """
-
- desc = NormalDescription.json_schema()
- assert desc["description"] == "Does nothing special\nReally"
-
- class FieldsDescription(ConfigSchema):
- """
- This is an awesome test class
- ---
- field: This field does nothing interesting
- value: Neither does this
- """
-
- field: str
- value: int
-
- schema = FieldsDescription.json_schema()
- assert schema["description"] == "This is an awesome test class"
- assert schema["properties"]["field"]["description"] == "This field does nothing interesting"
- assert schema["properties"]["value"]["description"] == "Neither does this"
-
- class NoDescription(ConfigSchema):
- nothing: str
-
- _ = NoDescription.json_schema()
-
-
-def test_docstring_parsing_invalid():
- class AdditionalItem(ConfigSchema):
- """
- This class is wrong
- ---
- field: nope
- nothing: really nothing
- """
-
- nothing: str
-
- with raises(DataDescriptionError):
- _ = AdditionalItem.json_schema()
-
- class WrongDescription(ConfigSchema):
- """
- This class is wrong
- ---
- other: description
- """
-
- nothing: str
-
- with raises(DataDescriptionError):
- _ = WrongDescription.json_schema()
+++ /dev/null
-from pyparsing import empty
-
-from knot_resolver.utils.etag import structural_etag
-
-
-def test_etag():
- empty1 = {}
- empty2 = {}
-
- assert structural_etag(empty1) == structural_etag(empty2)
-
- something1 = {"something": 1}
- something2 = {"something": 2}
- assert structural_etag(empty1) != structural_etag(something1)
- assert structural_etag(something1) != structural_etag(something2)
+++ /dev/null
-from pytest import raises
-
-from knot_resolver.utils.modeling.json_pointer import json_ptr_resolve
-
-# example adopted from https://www.sitepoint.com/json-server-example/
-TEST = {
- "clients": [
- {
- "id": "59761c23b30d971669fb42ff",
- "isActive": True,
- "age": 36,
- "name": "Dunlap Hubbard",
- "gender": "male",
- "company": "CEDWARD",
- "email": "dunlaphubbard@cedward.com",
- "phone": "+1 (890) 543-2508",
- "address": "169 Rutledge Street, Konterra, Northern Mariana Islands, 8551",
- },
- {
- "id": "59761c233d8d0f92a6b0570d",
- "isActive": True,
- "age": 24,
- "name": "Kirsten Sellers",
- "gender": "female",
- "company": "EMERGENT",
- "email": "kirstensellers@emergent.com",
- "phone": "+1 (831) 564-2190",
- "address": "886 Gallatin Place, Fannett, Arkansas, 4656",
- },
- {
- "id": "59761c23fcb6254b1a06dad5",
- "isActive": True,
- "age": 30,
- "name": "Acosta Robbins",
- "gender": "male",
- "company": "ORGANICA",
- "email": "acostarobbins@organica.com",
- "phone": "+1 (882) 441-3367",
- "address": "697 Linden Boulevard, Sattley, Idaho, 1035",
- },
- ]
-}
-
-
-def test_json_ptr():
- parent, res, token = json_ptr_resolve(TEST, "")
- assert parent is None
- assert res is TEST
-
- parent, res, token = json_ptr_resolve(TEST, "/")
- assert parent is TEST
- assert res is None
- assert token == ""
-
- parent, res, token = json_ptr_resolve(TEST, "/clients/2/gender")
- assert parent is TEST["clients"][2]
- assert res == "male"
- assert token == "gender"
-
- with raises(ValueError):
- _ = json_ptr_resolve(TEST, "//")
-
- with raises(SyntaxError):
- _ = json_ptr_resolve(TEST, "invalid/ptr")
-
- with raises(ValueError):
- _ = json_ptr_resolve(TEST, "/clients/2/gender/invalid")
-
- parent, res, token = json_ptr_resolve(TEST, "/~01")
- assert parent is TEST
- assert res is None
- assert token == "~1"
+++ /dev/null
-import copy
-from typing import Any, Dict
-
-import pytest
-from pytest import raises
-
-from knot_resolver.utils.modeling.exceptions import DataValidationError
-from knot_resolver.utils.modeling.parsing import data_combine
-
-# default data
-data_default = {"key1": {"inner11": False}}
-
-
-@pytest.mark.parametrize(
- "val,res",
- [
- ({"key2": "value"}, {"key1": {"inner11": False}, "key2": "value"}),
- ({"key2": {"inner21": True}}, {"key1": {"inner11": False}, "key2": {"inner21": True}}),
- ({"key1": {"inner12": 5}}, {"key1": {"inner11": False, "inner12": 5}}),
- ],
-)
-def test_data_combine_valid(val: Dict[Any, Any], res: Dict[Any, Any]) -> None:
- data = copy.deepcopy(data_default)
- assert data_combine(data, val) == res
-
-
-@pytest.mark.parametrize("val", [{"key1": "value"}, {"key1": {"inner11": False}}])
-def test_data_combine_invalid(val: Dict[Any, Any]) -> None:
- data = copy.deepcopy(data_default)
- with raises(DataValidationError):
- data_combine(data, val)
+++ /dev/null
-from pytest import raises
-
-from knot_resolver.utils.modeling.query import query
-
-
-def test_example_from_spec():
- # source of the example: https://jsonpatch.com/
- original = {"baz": "qux", "foo": "bar"}
- patch = [
- {"op": "replace", "path": "/baz", "value": "boo"},
- {"op": "add", "path": "/hello", "value": ["world"]},
- {"op": "remove", "path": "/foo"},
- ]
- expected = {"baz": "boo", "hello": ["world"]}
-
- result, _ = query(original, "patch", "", patch)
-
- assert result == expected
+++ /dev/null
-from knot_resolver.utils.modeling.renaming import renamed
-
-
-def test_all():
- ref = {
- "awesome-customers": [{"name": "John", "home-address": "London"}, {"name": "Bob", "home-address": "Prague"}],
- "storage": {"bobby-pin": 5, "can-opener": 0, "laptop": 1},
- }
-
- rnm = renamed(ref)
- assert rnm["awesome_customers"][0]["home_address"] == "London"
- assert rnm["awesome_customers"][1:][0]["home_address"] == "Prague"
- assert set(rnm["storage"].items()) == set((("can_opener", 0), ("bobby_pin", 5), ("laptop", 1)))
- assert set(rnm["storage"].keys()) == set(("bobby_pin", "can_opener", "laptop"))
-
-
-def test_nested_init():
- val = renamed(renamed(({"ke-y": "val-ue"})))
- assert val["ke_y"] == "val-ue"
-
-
-def test_original():
- obj = renamed(({"ke-y": "val-ue"})).original()
- assert obj["ke-y"] == "val-ue"
+++ /dev/null
-from typing import Any, Dict, List, Literal, Tuple, Union
-
-import pytest
-
-from knot_resolver.utils.modeling import BaseSchema
-from knot_resolver.utils.modeling.types import is_list, is_literal
-
-types = [
- bool,
- int,
- str,
- Dict[Any, Any],
- Tuple[Any, Any],
- Union[str, int],
- BaseSchema,
-]
-literal_types = [Literal[5], Literal["test"], Literal[False]]
-
-
-@pytest.mark.parametrize("val", types)
-def test_is_list_true(val: Any):
- assert is_list(List[val])
-
-
-@pytest.mark.parametrize("val", types)
-def test_is_list_false(val: Any):
- assert not is_list(val)
-
-
-@pytest.mark.parametrize("val", literal_types)
-def test_is_literal_true(val: Any):
- assert is_literal(Literal[val])
-
-
-@pytest.mark.parametrize("val", types)
-def test_is_literal_false(val: Any):
- assert not is_literal(val)
+++ /dev/null
-from knot_resolver.utils.functional import all_matches, contains_element_matching, foldl
-
-
-def test_foldl():
- lst = list(range(10))
-
- assert foldl(lambda x, y: x + y, 0, lst) == sum(range(10))
- assert foldl(lambda x, y: x + y, 55, lst) == sum(range(10)) + 55
-
-
-def test_containsElementMatching():
- lst = list(range(10))
-
- assert contains_element_matching(lambda e: e == 5, lst)
- assert not contains_element_matching(lambda e: e == 11, lst)
-
-
-def test_matches_all():
- lst = list(range(10))
-
- assert all_matches(lambda x: x >= 0, lst)
- assert not all_matches(lambda x: x % 2 == 0, lst)