From: Daniele Varrazzo Date: Thu, 26 Oct 2023 16:28:17 +0000 (+0200) Subject: refactor: introduce support function to split connection attempts X-Git-Tag: 3.1.13~2^2~4 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4598a296fbdf2456544056bdf605b6fe7c558e1a;p=thirdparty%2Fpsycopg.git refactor: introduce support function to split connection attempts Refactor the `resolve_hostaddr_async()` function to make use of such facilities. --- diff --git a/psycopg/psycopg/conninfo.py b/psycopg/psycopg/conninfo.py index 38d1f7dab..efbb2be31 100644 --- a/psycopg/psycopg/conninfo.py +++ b/psycopg/psycopg/conninfo.py @@ -4,21 +4,27 @@ Functions to manipulate conninfo strings # Copyright (C) 2020 The Psycopg Team +from __future__ import annotations + import os import re import socket import asyncio -from typing import Any, Dict, List, Optional +from typing import Any, Iterator, AsyncIterator from pathlib import Path from datetime import tzinfo from functools import lru_cache from ipaddress import ip_address +from typing_extensions import TypeAlias from . import pq from . import errors as e from ._tz import get_tzinfo +from ._compat import cache from ._encodings import pgconn_encoding +ConnDict: TypeAlias = "dict[str, Any]" + def make_conninfo(conninfo: str = "", **kwargs: Any) -> str: """ @@ -61,7 +67,7 @@ def make_conninfo(conninfo: str = "", **kwargs: Any) -> str: return conninfo -def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> Dict[str, Any]: +def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> ConnDict: """ Convert the `!conninfo` string into a dictionary of parameters. @@ -84,7 +90,7 @@ def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> Dict[str, Any]: return rv -def _parse_conninfo(conninfo: str) -> List[pq.ConninfoOption]: +def _parse_conninfo(conninfo: str) -> list[pq.ConninfoOption]: """ Verify that `!conninfo` is a valid connection string. @@ -167,7 +173,7 @@ class ConnectionInfo: """ return self._get_pgconn_attr("options") - def get_parameters(self) -> Dict[str, str]: + def get_parameters(self) -> dict[str, str]: """Return the connection parameters values. Return all the parameters set to a non-default value, which might come @@ -228,7 +234,7 @@ class ConnectionInfo: """ return pq.PipelineStatus(self.pgconn.pipeline_status) - def parameter_status(self, param_name: str) -> Optional[str]: + def parameter_status(self, param_name: str) -> str | None: """ Return a parameter setting of the connection. @@ -275,7 +281,7 @@ class ConnectionInfo: return value.decode(self.encoding) -async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]: +async def resolve_hostaddr_async(params: ConnDict) -> ConnDict: """ Perform async DNS lookup of the hosts and return a new params dict. @@ -292,82 +298,175 @@ async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]: Raise `~psycopg.OperationalError` if connection is not possible (e.g. no host resolve, inconsistent lists length). """ - hostaddr_arg = params.get("hostaddr", os.environ.get("PGHOSTADDR", "")) - if hostaddr_arg: - # Already resolved - return params - - host_arg: str = params.get("host", os.environ.get("PGHOST", "")) - if not host_arg: - # Nothing to resolve - return params - - hosts_in = host_arg.split(",") - port_arg: str = str(params.get("port", os.environ.get("PGPORT", ""))) - ports_in = port_arg.split(",") if port_arg else [] - default_port = "5432" - - if len(ports_in) == 1: - # If only one port is specified, the libpq will apply it to all - # the hosts, so don't mangle it. - default_port = ports_in.pop() - - elif len(ports_in) > 1: - if len(ports_in) != len(hosts_in): - # ProgrammingError would have been more appropriate, but this is - # what the raise if the libpq fails connect in the same case. - raise e.OperationalError( - f"cannot match {len(hosts_in)} hosts with {len(ports_in)} port numbers" - ) - ports_out = [] - - hosts_out = [] - hostaddr_out = [] - loop = asyncio.get_running_loop() - for i, host in enumerate(hosts_in): - if not host or host.startswith("/") or host[1:2] == ":": - # Local path - hosts_out.append(host) - hostaddr_out.append("") - if ports_in: - ports_out.append(ports_in[i]) - continue - - # If the host is already an ip address don't try to resolve it - if is_ip_address(host): - hosts_out.append(host) - hostaddr_out.append(host) - if ports_in: - ports_out.append(ports_in[i]) - continue + hosts: list[str] = [] + hostaddrs: list[str] = [] + ports: list[str] = [] + for attempt in _split_attempts(_inject_defaults(params)): try: - port = ports_in[i] if ports_in else default_port - ans = await loop.getaddrinfo( - host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM - ) + async for a2 in _split_attempts_and_resolve(attempt): + hosts.append(a2["host"]) + hostaddrs.append(a2["hostaddr"]) + if "port" in params: + ports.append(a2["port"]) except OSError as ex: last_exc = ex - else: - for item in ans: - hosts_out.append(host) - hostaddr_out.append(item[4][0]) - if ports_in: - ports_out.append(ports_in[i]) - - # Throw an exception if no host could be resolved - if not hosts_out: + + if params.get("host") and not hosts: + # We couldn't resolve anything raise e.OperationalError(str(last_exc)) out = params.copy() - out["host"] = ",".join(hosts_out) - out["hostaddr"] = ",".join(hostaddr_out) - if ports_in: - out["port"] = ",".join(ports_out) + shosts = ",".join(hosts) + if shosts: + out["host"] = shosts + shostaddrs = ",".join(hostaddrs) + if shostaddrs: + out["hostaddr"] = shostaddrs + sports = ",".join(ports) + if ports: + out["port"] = sports return out +def _inject_defaults(params: ConnDict) -> ConnDict: + """ + Add defaults to a dictionary of parameters. + + This avoids the need to look up for env vars at various stages during + processing. + + Note that a port is always specified. 5432 likely comes from here. + + The `host`, `hostaddr`, `port` will be always set to a string. + """ + defaults = _conn_defaults() + out = params.copy() + + def inject(name: str, envvar: str) -> None: + value = out.get(name) + if not value: + out[name] = os.environ.get(envvar, defaults[name]) + else: + out[name] = str(value) + + inject("host", "PGHOST") + inject("hostaddr", "PGHOSTADDR") + inject("port", "PGPORT") + + return out + + +def _split_attempts(params: ConnDict) -> Iterator[ConnDict]: + """ + Split connection parameters with a sequence of hosts into separate attempts. + + Assume that `host`, `hostaddr`, `port` are always present and a string (as + emitted from `_inject_defaults()`). + """ + + def split_val(key: str) -> list[str]: + # Assume all keys are present and strings. + val: str = params[key] + return val.split(",") if val else [] + + hosts = split_val("host") + hostaddrs = split_val("hostaddr") + ports = split_val("port") + + if hosts and hostaddrs and len(hosts) != len(hostaddrs): + raise e.OperationalError( + f"could not match {len(hosts)} host names" + f" with {len(hostaddrs)} hostaddr values" + ) + + nhosts = max(len(hosts), len(hostaddrs)) + + if 1 < len(ports) != nhosts: + raise e.OperationalError( + f"could not match {len(ports)} port numbers to {len(hosts)} hosts" + ) + elif len(ports) == 1: + ports *= nhosts + + # A single attempt to make + if nhosts <= 1: + yield params + return + + # Now all lists are either empty or have the same length + for i in range(nhosts): + attempt = params.copy() + if hosts: + attempt["host"] = hosts[i] + if hostaddrs: + attempt["hostaddr"] = hostaddrs[i] + if ports: + attempt["port"] = ports[i] + yield attempt + + +async def _split_attempts_and_resolve(params: ConnDict) -> AsyncIterator[ConnDict]: + """ + Perform async DNS lookup of the hosts and return a new params dict. + + :param params: The input parameters, for instance as returned by + `~psycopg.conninfo.conninfo_to_dict()`. The function expects at most + a single entry for host, hostaddr, port and doesn't check for env vars + because it is designed to further process the input of _split_attempts() + + If a ``host`` param is present but not ``hostname``, resolve the host + addresses dynamically. + + The function may change the input ``host``, ``hostname``, ``port`` to allow + connecting without further DNS lookups. + + Raise `~psycopg.OperationalError` if resolution fails. + """ + host = params["host"] + if not host or host.startswith("/") or host[1:2] == ":": + # Local path, or no host to resolve + yield params + return + + hostaddr = params["hostaddr"] + if hostaddr: + # Already resolved + yield params + return + + if is_ip_address(host): + # If the host is already an ip address don't try to resolve it + params["hostaddr"] = host + yield params + return + + loop = asyncio.get_running_loop() + + port = params["port"] + ans = await loop.getaddrinfo( + host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM + ) + + attempt = params.copy() + for item in ans: + attempt["hostaddr"] = item[4][0] + yield attempt + + +@cache +def _conn_defaults() -> dict[str, str]: + """ + Return a dictionary of defaults for connection strings parameters. + """ + defs = pq.Conninfo.get_defaults() + return { + d.keyword.decode(): d.compiled.decode() if d.compiled is not None else "" + for d in defs + } + + @lru_cache() def is_ip_address(s: str) -> bool: """Return True if the string represent a valid ip address.""" diff --git a/tests/test_conninfo.py b/tests/test_conninfo.py index 56a944ff1..e037b0539 100644 --- a/tests/test_conninfo.py +++ b/tests/test_conninfo.py @@ -333,17 +333,17 @@ class TestConnectionInfo: ), ( "host=1.1.1.1,2.2.2.2 port=5432", - "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2", + "host=1.1.1.1,2.2.2.2 port=5432,5432 hostaddr=1.1.1.1,2.2.2.2", None, ), ( "port=5432", - "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2", + "host=1.1.1.1,2.2.2.2 port=5432,5432 hostaddr=1.1.1.1,2.2.2.2", {"PGHOST": "1.1.1.1,2.2.2.2"}, ), ( "host=foo.com port=5432", - "host=foo.com port=5432", + "host=foo.com port=5432 hostaddr=1.2.3.4", {"PGHOSTADDR": "1.2.3.4"}, ), ], @@ -368,7 +368,7 @@ async def test_resolve_hostaddr_async_no_resolve( ), ( "host=foo.com,qux.com port=5433", - "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5433", + "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5433,5433", None, ), (