]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor: introduce support function to split connection attempts
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 26 Oct 2023 16:28:17 +0000 (18:28 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 13 Nov 2023 23:04:55 +0000 (00:04 +0100)
Refactor the `resolve_hostaddr_async()` function to make use of such
facilities.

psycopg/psycopg/conninfo.py
tests/test_conninfo.py

index 38d1f7dabc43d16a43e18b990156ca6b22f60c4c..efbb2be315e99a59ebbafb965192d92a19876bd2 100644 (file)
@@ -4,21 +4,27 @@ Functions to manipulate conninfo strings
 
 # Copyright (C) 2020 The Psycopg Team
 
+from __future__ import annotations
+
 import os
 import re
 import socket
 import asyncio
-from typing import Any, Dict, List, Optional
+from typing import Any, Iterator, AsyncIterator
 from pathlib import Path
 from datetime import tzinfo
 from functools import lru_cache
 from ipaddress import ip_address
+from typing_extensions import TypeAlias
 
 from . import pq
 from . import errors as e
 from ._tz import get_tzinfo
+from ._compat import cache
 from ._encodings import pgconn_encoding
 
+ConnDict: TypeAlias = "dict[str, Any]"
+
 
 def make_conninfo(conninfo: str = "", **kwargs: Any) -> str:
     """
@@ -61,7 +67,7 @@ def make_conninfo(conninfo: str = "", **kwargs: Any) -> str:
     return conninfo
 
 
-def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> Dict[str, Any]:
+def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> ConnDict:
     """
     Convert the `!conninfo` string into a dictionary of parameters.
 
@@ -84,7 +90,7 @@ def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> Dict[str, Any]:
     return rv
 
 
-def _parse_conninfo(conninfo: str) -> List[pq.ConninfoOption]:
+def _parse_conninfo(conninfo: str) -> list[pq.ConninfoOption]:
     """
     Verify that `!conninfo` is a valid connection string.
 
@@ -167,7 +173,7 @@ class ConnectionInfo:
         """
         return self._get_pgconn_attr("options")
 
-    def get_parameters(self) -> Dict[str, str]:
+    def get_parameters(self) -> dict[str, str]:
         """Return the connection parameters values.
 
         Return all the parameters set to a non-default value, which might come
@@ -228,7 +234,7 @@ class ConnectionInfo:
         """
         return pq.PipelineStatus(self.pgconn.pipeline_status)
 
-    def parameter_status(self, param_name: str) -> Optional[str]:
+    def parameter_status(self, param_name: str) -> str | None:
         """
         Return a parameter setting of the connection.
 
@@ -275,7 +281,7 @@ class ConnectionInfo:
         return value.decode(self.encoding)
 
 
-async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
+async def resolve_hostaddr_async(params: ConnDict) -> ConnDict:
     """
     Perform async DNS lookup of the hosts and return a new params dict.
 
@@ -292,82 +298,175 @@ async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
     Raise `~psycopg.OperationalError` if connection is not possible (e.g. no
     host resolve, inconsistent lists length).
     """
-    hostaddr_arg = params.get("hostaddr", os.environ.get("PGHOSTADDR", ""))
-    if hostaddr_arg:
-        # Already resolved
-        return params
-
-    host_arg: str = params.get("host", os.environ.get("PGHOST", ""))
-    if not host_arg:
-        # Nothing to resolve
-        return params
-
-    hosts_in = host_arg.split(",")
-    port_arg: str = str(params.get("port", os.environ.get("PGPORT", "")))
-    ports_in = port_arg.split(",") if port_arg else []
-    default_port = "5432"
-
-    if len(ports_in) == 1:
-        # If only one port is specified, the libpq will apply it to all
-        # the hosts, so don't mangle it.
-        default_port = ports_in.pop()
-
-    elif len(ports_in) > 1:
-        if len(ports_in) != len(hosts_in):
-            # ProgrammingError would have been more appropriate, but this is
-            # what the raise if the libpq fails connect in the same case.
-            raise e.OperationalError(
-                f"cannot match {len(hosts_in)} hosts with {len(ports_in)} port numbers"
-            )
-        ports_out = []
-
-    hosts_out = []
-    hostaddr_out = []
-    loop = asyncio.get_running_loop()
-    for i, host in enumerate(hosts_in):
-        if not host or host.startswith("/") or host[1:2] == ":":
-            # Local path
-            hosts_out.append(host)
-            hostaddr_out.append("")
-            if ports_in:
-                ports_out.append(ports_in[i])
-            continue
-
-        # If the host is already an ip address don't try to resolve it
-        if is_ip_address(host):
-            hosts_out.append(host)
-            hostaddr_out.append(host)
-            if ports_in:
-                ports_out.append(ports_in[i])
-            continue
+    hosts: list[str] = []
+    hostaddrs: list[str] = []
+    ports: list[str] = []
 
+    for attempt in _split_attempts(_inject_defaults(params)):
         try:
-            port = ports_in[i] if ports_in else default_port
-            ans = await loop.getaddrinfo(
-                host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
-            )
+            async for a2 in _split_attempts_and_resolve(attempt):
+                hosts.append(a2["host"])
+                hostaddrs.append(a2["hostaddr"])
+                if "port" in params:
+                    ports.append(a2["port"])
         except OSError as ex:
             last_exc = ex
-        else:
-            for item in ans:
-                hosts_out.append(host)
-                hostaddr_out.append(item[4][0])
-                if ports_in:
-                    ports_out.append(ports_in[i])
-
-    # Throw an exception if no host could be resolved
-    if not hosts_out:
+
+    if params.get("host") and not hosts:
+        # We couldn't resolve anything
         raise e.OperationalError(str(last_exc))
 
     out = params.copy()
-    out["host"] = ",".join(hosts_out)
-    out["hostaddr"] = ",".join(hostaddr_out)
-    if ports_in:
-        out["port"] = ",".join(ports_out)
+    shosts = ",".join(hosts)
+    if shosts:
+        out["host"] = shosts
+    shostaddrs = ",".join(hostaddrs)
+    if shostaddrs:
+        out["hostaddr"] = shostaddrs
+    sports = ",".join(ports)
+    if ports:
+        out["port"] = sports
 
     return out
 
 
+def _inject_defaults(params: ConnDict) -> ConnDict:
+    """
+    Add defaults to a dictionary of parameters.
+
+    This avoids the need to look up for env vars at various stages during
+    processing.
+
+    Note that a port is always specified. 5432 likely comes from here.
+
+    The `host`, `hostaddr`, `port` will be always set to a string.
+    """
+    defaults = _conn_defaults()
+    out = params.copy()
+
+    def inject(name: str, envvar: str) -> None:
+        value = out.get(name)
+        if not value:
+            out[name] = os.environ.get(envvar, defaults[name])
+        else:
+            out[name] = str(value)
+
+    inject("host", "PGHOST")
+    inject("hostaddr", "PGHOSTADDR")
+    inject("port", "PGPORT")
+
+    return out
+
+
+def _split_attempts(params: ConnDict) -> Iterator[ConnDict]:
+    """
+    Split connection parameters with a sequence of hosts into separate attempts.
+
+    Assume that `host`, `hostaddr`, `port` are always present and a string (as
+    emitted from `_inject_defaults()`).
+    """
+
+    def split_val(key: str) -> list[str]:
+        # Assume all keys are present and strings.
+        val: str = params[key]
+        return val.split(",") if val else []
+
+    hosts = split_val("host")
+    hostaddrs = split_val("hostaddr")
+    ports = split_val("port")
+
+    if hosts and hostaddrs and len(hosts) != len(hostaddrs):
+        raise e.OperationalError(
+            f"could not match {len(hosts)} host names"
+            f" with {len(hostaddrs)} hostaddr values"
+        )
+
+    nhosts = max(len(hosts), len(hostaddrs))
+
+    if 1 < len(ports) != nhosts:
+        raise e.OperationalError(
+            f"could not match {len(ports)} port numbers to {len(hosts)} hosts"
+        )
+    elif len(ports) == 1:
+        ports *= nhosts
+
+    # A single attempt to make
+    if nhosts <= 1:
+        yield params
+        return
+
+    # Now all lists are either empty or have the same length
+    for i in range(nhosts):
+        attempt = params.copy()
+        if hosts:
+            attempt["host"] = hosts[i]
+        if hostaddrs:
+            attempt["hostaddr"] = hostaddrs[i]
+        if ports:
+            attempt["port"] = ports[i]
+        yield attempt
+
+
+async def _split_attempts_and_resolve(params: ConnDict) -> AsyncIterator[ConnDict]:
+    """
+    Perform async DNS lookup of the hosts and return a new params dict.
+
+    :param params: The input parameters, for instance as returned by
+        `~psycopg.conninfo.conninfo_to_dict()`. The function expects at most
+        a single entry for host, hostaddr, port and doesn't check for env vars
+        because it is designed to further process the input of _split_attempts()
+
+    If a ``host`` param is present but not ``hostname``, resolve the host
+    addresses dynamically.
+
+    The function may change the input ``host``, ``hostname``, ``port`` to allow
+    connecting without further DNS lookups.
+
+    Raise `~psycopg.OperationalError` if resolution fails.
+    """
+    host = params["host"]
+    if not host or host.startswith("/") or host[1:2] == ":":
+        # Local path, or no host to resolve
+        yield params
+        return
+
+    hostaddr = params["hostaddr"]
+    if hostaddr:
+        # Already resolved
+        yield params
+        return
+
+    if is_ip_address(host):
+        # If the host is already an ip address don't try to resolve it
+        params["hostaddr"] = host
+        yield params
+        return
+
+    loop = asyncio.get_running_loop()
+
+    port = params["port"]
+    ans = await loop.getaddrinfo(
+        host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
+    )
+
+    attempt = params.copy()
+    for item in ans:
+        attempt["hostaddr"] = item[4][0]
+    yield attempt
+
+
+@cache
+def _conn_defaults() -> dict[str, str]:
+    """
+    Return a dictionary of defaults for connection strings parameters.
+    """
+    defs = pq.Conninfo.get_defaults()
+    return {
+        d.keyword.decode(): d.compiled.decode() if d.compiled is not None else ""
+        for d in defs
+    }
+
+
 @lru_cache()
 def is_ip_address(s: str) -> bool:
     """Return True if the string represent a valid ip address."""
index 56a944ff17b754160fd3eb383f2ce53c29144a53..e037b0539343a9119710d81f76b6e4a84d5b90e4 100644 (file)
@@ -333,17 +333,17 @@ class TestConnectionInfo:
         ),
         (
             "host=1.1.1.1,2.2.2.2 port=5432",
-            "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2",
+            "host=1.1.1.1,2.2.2.2 port=5432,5432 hostaddr=1.1.1.1,2.2.2.2",
             None,
         ),
         (
             "port=5432",
-            "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2",
+            "host=1.1.1.1,2.2.2.2 port=5432,5432 hostaddr=1.1.1.1,2.2.2.2",
             {"PGHOST": "1.1.1.1,2.2.2.2"},
         ),
         (
             "host=foo.com port=5432",
-            "host=foo.com port=5432",
+            "host=foo.com port=5432 hostaddr=1.2.3.4",
             {"PGHOSTADDR": "1.2.3.4"},
         ),
     ],
@@ -368,7 +368,7 @@ async def test_resolve_hostaddr_async_no_resolve(
         ),
         (
             "host=foo.com,qux.com port=5433",
-            "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5433",
+            "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5433,5433",
             None,
         ),
         (