]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
kresctl: file-access-less URI detection for sockets and nicer messages
authorOto Šťáva <oto.stava@nic.cz>
Tue, 19 Sep 2023 14:43:18 +0000 (16:43 +0200)
committerOto Šťáva <oto.stava@nic.cz>
Mon, 25 Sep 2023 10:58:17 +0000 (12:58 +0200)
manager/knot_resolver_manager/cli/cmd/config.py
manager/knot_resolver_manager/cli/cmd/metrics.py
manager/knot_resolver_manager/cli/cmd/reload.py
manager/knot_resolver_manager/cli/cmd/schema.py
manager/knot_resolver_manager/cli/cmd/stop.py
manager/knot_resolver_manager/cli/command.py
manager/knot_resolver_manager/utils/requests.py

index a788f38d3863f5033895d5909981e828b6b4a386..8a1f4d4cfe443eb6bad2cc2c6b963a2095838706 100644 (file)
@@ -227,7 +227,7 @@ class ConfigCommand(Command):
             sys.exit()
 
         new_config = None
-        url = f"{args.socket}/v1/config{self.path}"
+        path = f"v1/config{self.path}"
         method = operation_to_method(self.operation)
 
         if self.operation == Operations.SET:
@@ -241,7 +241,7 @@ class ConfigCommand(Command):
                 # use STDIN also when file is not specified
                 new_config = input("Type new configuration: ")
 
-        response = request(method, url, json_dump(new_config) if new_config else None)
+        response = request(args.socket, method, path, json_dump(new_config) if new_config else None)
 
         if response.status != 200:
             print(response, file=sys.stderr)
index c924ef5e72bea1df98c862cc3d759eba335faca6..64237ce4989dbbcfe4b0579403fd02e0e7d78166 100644 (file)
@@ -31,8 +31,7 @@ class MetricsCommand(Command):
         return {}
 
     def run(self, args: CommandArgs) -> None:
-        url = f"{args.socket}/metrics"
-        response = request("GET", url)
+        response = request(args.socket, "GET", "metrics")
 
         if response.status == 200:
             if self.file:
index e0d288962c46e534a9c8623cb70569b537fad80c..89782f4ee19168b9beda69feae3425a95da19c07 100644 (file)
@@ -29,7 +29,7 @@ class ReloadCommand(Command):
         return {}
 
     def run(self, args: CommandArgs) -> None:
-        response = request("POST", f"{args.socket}/reload")
+        response = request(args.socket, "POST", "reload")
 
         if response.status != 200:
             print(response, file=sys.stderr)
index 790da7fdbb975ea0d7222c8085527005cc3531af..253699466b4eabc12be94f3ef3d156b980ae1175 100644 (file)
@@ -40,7 +40,7 @@ class SchemaCommand(Command):
 
     def run(self, args: CommandArgs) -> None:
         if self.live:
-            response = request("GET", f"{args.socket}/schema")
+            response = request(args.socket, "GET", "schema")
             if response.status != 200:
                 print(response, file=sys.stderr)
                 sys.exit(1)
index f3539def8947951002dd5ea94eefbfb07b7cc7df..a3f463542cfe73fdf344775339aad12e75787d75 100644 (file)
@@ -21,7 +21,7 @@ class StopCommand(Command):
         return stop, StopCommand
 
     def run(self, args: CommandArgs) -> None:
-        response = request("POST", f"{args.socket}/stop")
+        response = request(args.socket, "POST", "stop")
 
         if response.status != 200:
             print(response, file=sys.stderr)
index 0b10a2f1d8660e2c4dac4b84baf122c354926373..f8917417594336bcd2127b20174d3042b62912e4 100644 (file)
@@ -7,9 +7,10 @@ from urllib.parse import quote
 
 from knot_resolver_manager.constants import API_SOCK_ENV_VAR, CONFIG_FILE_ENV_VAR, DEFAULT_MANAGER_CONFIG_FILE
 from knot_resolver_manager.datamodel.config_schema import DEFAULT_MANAGER_API_SOCK
-from knot_resolver_manager.datamodel.types import FilePath, IPAddressPort
+from knot_resolver_manager.datamodel.types import IPAddressPort
 from knot_resolver_manager.utils.modeling import parsing
 from knot_resolver_manager.utils.modeling.exceptions import DataValidationError
+from knot_resolver_manager.utils.requests import SocketDesc
 
 T = TypeVar("T", bound=Type["Command"])
 
@@ -37,18 +38,24 @@ def install_commands_parsers(parser: argparse.ArgumentParser) -> None:
         subparser.set_defaults(command=typ, subparser=subparser)
 
 
-def get_socket_from_config(config: Path, optional_file: bool) -> Optional[str]:
+def get_socket_from_config(config: Path, optional_file: bool) -> Optional[SocketDesc]:
     try:
         with open(config, "r") as f:
             data = parsing.try_to_parse(f.read())
         mkey = "management"
         if mkey in data:
             management = data[mkey]
-            if "unix_socket" in management:
-                return str(FilePath(management["unix_socket"], object_path=f"/{mkey}/unix-socket"))
+            if "unix-socket" in management:
+                return SocketDesc(
+                    f'http+unix://{quote(management["unix-socket"], safe="")}/',
+                    f'Key "/management/unix-socket" in "{config}" file',
+                )
             elif "interface" in management:
                 ip = IPAddressPort(management["interface"], object_path=f"/{mkey}/interface")
-                return f"http://{ip.addr}:{ip.port}"
+                return SocketDesc(
+                    f"http://{ip.addr}:{ip.port}",
+                    f'Key "/management/interface" in "{config}" file',
+                )
         return None
     except ValueError as e:
         raise DataValidationError(*e.args)
@@ -58,15 +65,15 @@ def get_socket_from_config(config: Path, optional_file: bool) -> Optional[str]:
         return None
 
 
-def determine_socket(namespace: argparse.Namespace) -> str:
+def determine_socket(namespace: argparse.Namespace) -> SocketDesc:
     # 1) socket from 'kresctl --socket' argument
     if len(namespace.socket) > 0:
-        return namespace.socket[0]
+        return SocketDesc(namespace.socket[0], "--socket argument")
 
     config_path = os.getenv(CONFIG_FILE_ENV_VAR)
     socket_env = os.getenv(API_SOCK_ENV_VAR)
 
-    socket: Optional[str] = None
+    socket: Optional[SocketDesc] = None
     # 2) socket from config file ('kresctl --config' argument)
     if len(namespace.config) > 0:
         socket = get_socket_from_config(namespace.config[0], False)
@@ -75,7 +82,7 @@ def determine_socket(namespace: argparse.Namespace) -> str:
         socket = get_socket_from_config(Path(config_path), False)
     # 4) socket from environment variable
     elif socket_env:
-        socket = socket_env
+        socket = SocketDesc(socket_env, f'Environment variable "{API_SOCK_ENV_VAR}"')
     # 5) socket from config file (default config file constant)
     else:
         socket = get_socket_from_config(DEFAULT_MANAGER_CONFIG_FILE, True)
@@ -83,7 +90,7 @@ def determine_socket(namespace: argparse.Namespace) -> str:
     if socket:
         return socket
     # 6) socket default
-    return DEFAULT_MANAGER_API_SOCK
+    return SocketDesc(DEFAULT_MANAGER_API_SOCK, f'Default value "{DEFAULT_MANAGER_API_SOCK}"')
 
 
 class CommandArgs:
@@ -93,12 +100,7 @@ class CommandArgs:
         self.subparser: argparse.ArgumentParser = namespace.subparser
         self.command: Type["Command"] = namespace.command
 
-        self.socket: str = determine_socket(namespace)
-
-        if Path(self.socket).exists():
-            self.socket = f'http+unix://{quote(self.socket, safe="")}/'
-        if self.socket.endswith("/"):
-            self.socket = self.socket[:-1]
+        self.socket: SocketDesc = determine_socket(namespace)
 
 
 class Command(ABC):
index db29f85e024638e4f7aa378657a401fd62e406f2..edf2fef119c199a60a211b87a09eb7f64d1faa24 100644 (file)
@@ -3,11 +3,26 @@ import sys
 from http.client import HTTPConnection
 from typing import Any, Optional, Union
 from urllib.error import HTTPError, URLError
+from urllib.parse import quote
 from urllib.request import AbstractHTTPHandler, Request, build_opener, install_opener, urlopen
 
 from typing_extensions import Literal
 
 
+class SocketDesc:
+    def __init__(self, socket_def: str, source: str):
+        self.source = source
+        if ":" in socket_def:
+            # `socket_def` contains a schema, probably already URI-formatted, use directly
+            self.uri = socket_def
+        else:
+            # `socket_def` is probably a path, convert to URI
+            self.uri = f'http+unix://{quote(socket_def, safe="")}'
+
+        while self.uri.endswith("/"):
+            self.uri = self.uri[:-1]
+
+
 class Response:
     def __init__(self, status: int, body: str) -> None:
         self.status = status
@@ -17,12 +32,28 @@ class Response:
         return f"status: {self.status}\nbody:\n{self.body}"
 
 
+def _print_conn_error(error_desc: str, url: str, socket_source: str) -> None:
+    msg = f"""
+        {error_desc}
+        \tURL: {url}
+        \tSourced from: {socket_source}
+        Is the URL correct?
+        \tUnix socket would start with http+unix:// and URL encoded path.
+        \tInet sockets would start with http:// and domain or ip
+    """
+    print(msg, file=sys.stderr)
+
+
 def request(
+    socket_desc: SocketDesc,
     method: Literal["GET", "POST", "HEAD", "PUT", "DELETE"],
-    url: str,
+    path: str,
     body: Optional[str] = None,
     content_type: str = "application/json",
 ) -> Response:
+    while path.startswith("/"):
+        path = path[1:]
+    url = f"{socket_desc.uri}/{path}"
     req = Request(
         url,
         method=method,
@@ -38,14 +69,9 @@ def request(
         return Response(err.code, err.read().decode("utf8"))
     except URLError as err:
         if err.errno == 111 or isinstance(err.reason, ConnectionRefusedError):
-            msg = f"""
-                Connection refused.
-                \tURL: {url}
-                Is the URL correct?
-                \tUnix socket would start with http+unix:// and URL encoded path.
-                \tInet sockets would start with http:// and domain or ip
-            """
-            print(msg, file=sys.stderr)
+            _print_conn_error("Connection refused.", url, socket_desc.source)
+        elif err.errno == 2 or isinstance(err.reason, FileNotFoundError):
+            _print_conn_error("No such file or directory.", url, socket_desc.source)
         else:
             print(f"{err}: url={url}", file=sys.stderr)
         sys.exit(1)