# Copyright (C) 2020 The Psycopg Team
+from __future__ import annotations
+
import os
import re
import socket
import asyncio
-from typing import Any, Dict, List, Optional
+from typing import Any, Iterator, AsyncIterator
from pathlib import Path
from datetime import tzinfo
from functools import lru_cache
from ipaddress import ip_address
+from typing_extensions import TypeAlias
from . import pq
from . import errors as e
from ._tz import get_tzinfo
+from ._compat import cache
from ._encodings import pgconn_encoding
+ConnDict: TypeAlias = "dict[str, Any]"
+
def make_conninfo(conninfo: str = "", **kwargs: Any) -> str:
"""
return conninfo
-def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> Dict[str, Any]:
+def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> ConnDict:
"""
Convert the `!conninfo` string into a dictionary of parameters.
return rv
-def _parse_conninfo(conninfo: str) -> List[pq.ConninfoOption]:
+def _parse_conninfo(conninfo: str) -> list[pq.ConninfoOption]:
"""
Verify that `!conninfo` is a valid connection string.
"""
return self._get_pgconn_attr("options")
- def get_parameters(self) -> Dict[str, str]:
+ def get_parameters(self) -> dict[str, str]:
"""Return the connection parameters values.
Return all the parameters set to a non-default value, which might come
"""
return pq.PipelineStatus(self.pgconn.pipeline_status)
- def parameter_status(self, param_name: str) -> Optional[str]:
+ def parameter_status(self, param_name: str) -> str | None:
"""
Return a parameter setting of the connection.
return value.decode(self.encoding)
-async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
+async def resolve_hostaddr_async(params: ConnDict) -> ConnDict:
"""
Perform async DNS lookup of the hosts and return a new params dict.
Raise `~psycopg.OperationalError` if connection is not possible (e.g. no
host resolve, inconsistent lists length).
"""
- hostaddr_arg = params.get("hostaddr", os.environ.get("PGHOSTADDR", ""))
- if hostaddr_arg:
- # Already resolved
- return params
-
- host_arg: str = params.get("host", os.environ.get("PGHOST", ""))
- if not host_arg:
- # Nothing to resolve
- return params
-
- hosts_in = host_arg.split(",")
- port_arg: str = str(params.get("port", os.environ.get("PGPORT", "")))
- ports_in = port_arg.split(",") if port_arg else []
- default_port = "5432"
-
- if len(ports_in) == 1:
- # If only one port is specified, the libpq will apply it to all
- # the hosts, so don't mangle it.
- default_port = ports_in.pop()
-
- elif len(ports_in) > 1:
- if len(ports_in) != len(hosts_in):
- # ProgrammingError would have been more appropriate, but this is
- # what the raise if the libpq fails connect in the same case.
- raise e.OperationalError(
- f"cannot match {len(hosts_in)} hosts with {len(ports_in)} port numbers"
- )
- ports_out = []
-
- hosts_out = []
- hostaddr_out = []
- loop = asyncio.get_running_loop()
- for i, host in enumerate(hosts_in):
- if not host or host.startswith("/") or host[1:2] == ":":
- # Local path
- hosts_out.append(host)
- hostaddr_out.append("")
- if ports_in:
- ports_out.append(ports_in[i])
- continue
-
- # If the host is already an ip address don't try to resolve it
- if is_ip_address(host):
- hosts_out.append(host)
- hostaddr_out.append(host)
- if ports_in:
- ports_out.append(ports_in[i])
- continue
+ hosts: list[str] = []
+ hostaddrs: list[str] = []
+ ports: list[str] = []
+ for attempt in _split_attempts(_inject_defaults(params)):
try:
- port = ports_in[i] if ports_in else default_port
- ans = await loop.getaddrinfo(
- host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
- )
+ async for a2 in _split_attempts_and_resolve(attempt):
+ hosts.append(a2["host"])
+ hostaddrs.append(a2["hostaddr"])
+ if "port" in params:
+ ports.append(a2["port"])
except OSError as ex:
last_exc = ex
- else:
- for item in ans:
- hosts_out.append(host)
- hostaddr_out.append(item[4][0])
- if ports_in:
- ports_out.append(ports_in[i])
-
- # Throw an exception if no host could be resolved
- if not hosts_out:
+
+ if params.get("host") and not hosts:
+ # We couldn't resolve anything
raise e.OperationalError(str(last_exc))
out = params.copy()
- out["host"] = ",".join(hosts_out)
- out["hostaddr"] = ",".join(hostaddr_out)
- if ports_in:
- out["port"] = ",".join(ports_out)
+ shosts = ",".join(hosts)
+ if shosts:
+ out["host"] = shosts
+ shostaddrs = ",".join(hostaddrs)
+ if shostaddrs:
+ out["hostaddr"] = shostaddrs
+ sports = ",".join(ports)
+ if ports:
+ out["port"] = sports
return out
+def _inject_defaults(params: ConnDict) -> ConnDict:
+ """
+ Add defaults to a dictionary of parameters.
+
+ This avoids the need to look up for env vars at various stages during
+ processing.
+
+ Note that a port is always specified. 5432 likely comes from here.
+
+ The `host`, `hostaddr`, `port` will be always set to a string.
+ """
+ defaults = _conn_defaults()
+ out = params.copy()
+
+ def inject(name: str, envvar: str) -> None:
+ value = out.get(name)
+ if not value:
+ out[name] = os.environ.get(envvar, defaults[name])
+ else:
+ out[name] = str(value)
+
+ inject("host", "PGHOST")
+ inject("hostaddr", "PGHOSTADDR")
+ inject("port", "PGPORT")
+
+ return out
+
+
+def _split_attempts(params: ConnDict) -> Iterator[ConnDict]:
+ """
+ Split connection parameters with a sequence of hosts into separate attempts.
+
+ Assume that `host`, `hostaddr`, `port` are always present and a string (as
+ emitted from `_inject_defaults()`).
+ """
+
+ def split_val(key: str) -> list[str]:
+ # Assume all keys are present and strings.
+ val: str = params[key]
+ return val.split(",") if val else []
+
+ hosts = split_val("host")
+ hostaddrs = split_val("hostaddr")
+ ports = split_val("port")
+
+ if hosts and hostaddrs and len(hosts) != len(hostaddrs):
+ raise e.OperationalError(
+ f"could not match {len(hosts)} host names"
+ f" with {len(hostaddrs)} hostaddr values"
+ )
+
+ nhosts = max(len(hosts), len(hostaddrs))
+
+ if 1 < len(ports) != nhosts:
+ raise e.OperationalError(
+ f"could not match {len(ports)} port numbers to {len(hosts)} hosts"
+ )
+ elif len(ports) == 1:
+ ports *= nhosts
+
+ # A single attempt to make
+ if nhosts <= 1:
+ yield params
+ return
+
+ # Now all lists are either empty or have the same length
+ for i in range(nhosts):
+ attempt = params.copy()
+ if hosts:
+ attempt["host"] = hosts[i]
+ if hostaddrs:
+ attempt["hostaddr"] = hostaddrs[i]
+ if ports:
+ attempt["port"] = ports[i]
+ yield attempt
+
+
+async def _split_attempts_and_resolve(params: ConnDict) -> AsyncIterator[ConnDict]:
+ """
+ Perform async DNS lookup of the hosts and return a new params dict.
+
+ :param params: The input parameters, for instance as returned by
+ `~psycopg.conninfo.conninfo_to_dict()`. The function expects at most
+ a single entry for host, hostaddr, port and doesn't check for env vars
+ because it is designed to further process the input of _split_attempts()
+
+ If a ``host`` param is present but not ``hostname``, resolve the host
+ addresses dynamically.
+
+ The function may change the input ``host``, ``hostname``, ``port`` to allow
+ connecting without further DNS lookups.
+
+ Raise `~psycopg.OperationalError` if resolution fails.
+ """
+ host = params["host"]
+ if not host or host.startswith("/") or host[1:2] == ":":
+ # Local path, or no host to resolve
+ yield params
+ return
+
+ hostaddr = params["hostaddr"]
+ if hostaddr:
+ # Already resolved
+ yield params
+ return
+
+ if is_ip_address(host):
+ # If the host is already an ip address don't try to resolve it
+ params["hostaddr"] = host
+ yield params
+ return
+
+ loop = asyncio.get_running_loop()
+
+ port = params["port"]
+ ans = await loop.getaddrinfo(
+ host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
+ )
+
+ attempt = params.copy()
+ for item in ans:
+ attempt["hostaddr"] = item[4][0]
+ yield attempt
+
+
+@cache
+def _conn_defaults() -> dict[str, str]:
+ """
+ Return a dictionary of defaults for connection strings parameters.
+ """
+ defs = pq.Conninfo.get_defaults()
+ return {
+ d.keyword.decode(): d.compiled.decode() if d.compiled is not None else ""
+ for d in defs
+ }
+
+
@lru_cache()
def is_ip_address(s: str) -> bool:
"""Return True if the string represent a valid ip address."""