]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
kresctl: improve default connection behaviour
authorOto Šťáva <oto.stava@nic.cz>
Fri, 15 Sep 2023 13:42:35 +0000 (15:42 +0200)
committerOto Šťáva <oto.stava@nic.cz>
Mon, 25 Sep 2023 10:58:17 +0000 (12:58 +0200)
It now searches `/etc/knot-resolver/config.yml` for `management`
configuration first, when no `--config` or `--socket` is specified.

doc/manager-client.rst
manager/knot_resolver_manager/cli/command.py

index 704c1667bc1a8c3e8ea2a3e232eeb04475e8e368..ddb1f695bb610780f412aae5bfea576f36d8e9d7 100644 (file)
@@ -30,9 +30,13 @@ the :option:`--config <-c <config>, --config <config>>` or
 :option:`--socket <-s <socket>, --socket <socket>>` option to tell
 ``kresctl`` where to look for the API.
 
-By default, ``kresctl`` connects to the ``/var/run/knot-resolver/manager.sock``
-Unix-domain socket, or, when specified, reads the ``KRES_MANAGER_CONFIG``
-environment variable to retrieve a path to a configuration file.
+If the ``management`` key is not present in the configuration file, ``kresctl``
+attempts to connect to the ``/var/run/knot-resolver/manager.sock`` Unix-domain
+socket, which is the Manager's default communication channel.
+
+By default, ``kresctl`` tries to find the correct communication channel in
+``/etc/knot-resolver/config.yaml``, or, if present, the file specified by the
+``KRES_MANAGER_CONFIG`` environment variable.
 
 .. option:: -s <socket>, --socket <socket>
 
index 2c532e8db03712c6b568c6d0e5c6c1c4b1ea7533..c8752d6737fb93fea7a7a28632f16374f4436b1f 100644 (file)
@@ -5,7 +5,7 @@ from pathlib import Path
 from typing import Dict, List, Optional, Tuple, Type, TypeVar
 from urllib.parse import quote
 
-from knot_resolver_manager.constants import CONFIG_FILE_ENV_VAR
+from knot_resolver_manager.constants import DEFAULT_MANAGER_CONFIG_FILE, CONFIG_FILE_ENV_VAR
 from knot_resolver_manager.utils.modeling import parsing
 
 T = TypeVar("T", bound=Type["Command"])
@@ -37,6 +37,49 @@ def install_commands_parsers(parser: argparse.ArgumentParser) -> None:
         subparser.set_defaults(command=typ, subparser=subparser)
 
 
+def get_socket_from_config(config: Path, optional_file: bool) -> Optional[str]:
+    try:
+        with open(config, "r") as f:
+            data = parsing.try_to_parse(f.read())
+        if "management" in data:
+            management = data["management"]
+            if "unix_socket" in management:
+                return management["unix_socket"]
+            elif "interface" in management:
+                split = management["interface"].split("@")
+                host = split[0]
+                port = split[1] if len(split) >= 2 else 80
+                return f"http://{host}:{port}"
+        return None
+    except OSError as e:
+        if not optional_file:
+            raise e
+        return None
+
+
+def determine_socket(namespace: argparse.Namespace) -> str:
+    if len(namespace.socket) > 0:
+        return namespace.socket[0]
+
+    socket: Optional[str] = None
+    if len(namespace.config) > 0:
+        socket = get_socket_from_config(namespace.config[0], False)
+        if socket is not None:
+            return socket
+    else:
+        config_env = os.getenv(CONFIG_FILE_ENV_VAR)
+        if config_env is not None:
+            socket = get_socket_from_config(Path(config_env), False)
+            if socket is not None:
+                return socket
+        else:
+            socket = get_socket_from_config(DEFAULT_MANAGER_CONFIG_FILE, True)
+            if socket is not None:
+                return socket
+
+    return DEFAULT_SOCKET
+
+
 class CommandArgs:
     def __init__(self, namespace: argparse.Namespace, parser: argparse.ArgumentParser) -> None:
         self.namespace = namespace
@@ -44,25 +87,7 @@ class CommandArgs:
         self.subparser: argparse.ArgumentParser = namespace.subparser
         self.command: Type["Command"] = namespace.command
 
-        config_env = os.getenv(CONFIG_FILE_ENV_VAR)
-        if len(namespace.socket) == 0 and len(namespace.config) == 0 and config_env is not None:
-            namespace.config = [config_env]
-
-        self.socket: str = DEFAULT_SOCKET
-        if len(namespace.socket) > 0:
-            self.socket = namespace.socket[0]
-        elif len(namespace.config) > 0:
-            with open(namespace.config[0], "r") as f:
-                config = parsing.try_to_parse(f.read())
-            if "management" in config:
-                management = config["management"]
-                if "unix_socket" in management:
-                    self.socket = management["unix_socket"]
-                elif "interface" in management:
-                    split = management["interface"].split("@")
-                    host = split[0]
-                    port = split[1] if len(split) >= 2 else 80
-                    self.socket = f"http://{host}:{port}"
+        self.socket: str = determine_socket(namespace)
 
         if Path(self.socket).exists():
             self.socket = f'http+unix://{quote(self.socket, safe="")}/'