Psycopg 3.1.15 (unreleased)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
-- Fix async connection to hosts resolving to multiple IP addresses
- (:ticket:`#695`).
+- Fix use of ``service`` in connection string (regression in 3.1.13,
+ :ticket:`#694`).
+- Fix async connection to hosts resolving to multiple IP addresses (regression
+ in 3.1.13, :ticket:`#695`).
Current release
hostaddrs: list[str] = []
ports: list[str] = []
- for attempt in conninfo._split_attempts(conninfo._inject_defaults(params)):
+ for attempt in conninfo._split_attempts(params):
try:
async for a2 in conninfo._split_attempts_and_resolve(attempt):
- hosts.append(a2["host"])
- hostaddrs.append(a2["hostaddr"])
- if "port" in params:
- ports.append(a2["port"])
+ if a2.get("host") is not None:
+ hosts.append(a2["host"])
+ if a2.get("hostaddr") is not None:
+ hostaddrs.append(a2["hostaddr"])
+ if a2.get("port") is not None:
+ ports.append(str(a2["port"]))
except OSError as ex:
last_exc = ex
from datetime import tzinfo
from functools import lru_cache
from ipaddress import ip_address
+from dataclasses import dataclass
from typing_extensions import TypeAlias
from . import pq
from . import errors as e
from ._tz import get_tzinfo
-from ._compat import cache
from ._encodings import pgconn_encoding
ConnDict: TypeAlias = "dict[str, Any]"
def conninfo_attempts(params: ConnDict) -> Iterator[ConnDict]:
- """Split a set of connection params on the single attempts to perforn.
+ """Split a set of connection params on the single attempts to perform.
A connection param can perform more than one attempt more than one ``host``
is provided.
Because the libpq async function doesn't honour the timeout, we need to
reimplement the repeated attempts.
"""
+ # TODO: we should actually resolve the hosts ourselves.
+ # If an host resolves to more than one ip, the libpq will make more than
+ # one attempt and wouldn't get to try the following ones, as before
+ # fixing #674.
if params.get("load_balance_hosts", "disable") == "random":
- attempts = list(_split_attempts(_inject_defaults(params)))
+ attempts = list(_split_attempts(params))
shuffle(attempts)
yield from attempts
else:
- yield from _split_attempts(_inject_defaults(params))
+ yield from _split_attempts(params)
async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]:
- """Split a set of connection params on the single attempts to perforn.
+ """Split a set of connection params on the single attempts to perform.
A connection param can perform more than one attempt more than one ``host``
is provided.
Because the libpq async function doesn't honour the timeout, we need to
reimplement the repeated attempts.
"""
+ # TODO: the function should resolve all hosts and shuffle the results
+ # to replicate the same libpq algorithm.
yielded = False
last_exc = None
- for attempt in _split_attempts(_inject_defaults(params)):
+ for attempt in _split_attempts(params):
try:
async for a2 in _split_attempts_and_resolve(attempt):
yielded = True
raise e.OperationalError(str(last_exc))
-def _inject_defaults(params: ConnDict) -> ConnDict:
- """
- Add defaults to a dictionary of parameters.
-
- This avoids the need to look up for env vars at various stages during
- processing.
-
- Note that a port is always specified. 5432 likely comes from here.
-
- The `host`, `hostaddr`, `port` will be always set to a string.
- """
- defaults = _conn_defaults()
- out = params.copy()
-
- def inject(name: str, envvar: str) -> None:
- value = out.get(name)
- if not value:
- out[name] = os.environ.get(envvar, defaults[name])
- else:
- out[name] = str(value)
-
- inject("host", "PGHOST")
- inject("hostaddr", "PGHOSTADDR")
- inject("port", "PGPORT")
-
- return out
-
-
def _split_attempts(params: ConnDict) -> Iterator[ConnDict]:
"""
Split connection parameters with a sequence of hosts into separate attempts.
-
- Assume that `host`, `hostaddr`, `port` are always present and a string (as
- emitted from `_inject_defaults()`).
"""
def split_val(key: str) -> list[str]:
- # Assume all keys are present and strings.
- val: str = params[key]
+ val = _get_param(params, key)
return val.split(",") if val else []
hosts = split_val("host")
raise e.OperationalError(
f"could not match {len(ports)} port numbers to {len(hosts)} hosts"
)
- elif len(ports) == 1:
- ports *= nhosts
- # A single attempt to make
+ # A single attempt to make. Don't mangle the conninfo string.
if nhosts <= 1:
yield params
return
+ if len(ports) == 1:
+ ports *= nhosts
+
# Now all lists are either empty or have the same length
for i in range(nhosts):
attempt = params.copy()
:param params: The input parameters, for instance as returned by
`~psycopg.conninfo.conninfo_to_dict()`. The function expects at most
- a single entry for host, hostaddr, port and doesn't check for env vars
- because it is designed to further process the input of _split_attempts()
+ a single entry for host, hostaddr because it is designed to further
+ process the input of _split_attempts().
If a ``host`` param is present but not ``hostname``, resolve the host
- addresses dynamically.
+ addresses asynchronously.
The function may change the input ``host``, ``hostname``, ``port`` to allow
connecting without further DNS lookups.
-
- Raise `~psycopg.OperationalError` if resolution fails.
"""
- host = params["host"]
+ host = _get_param(params, "host")
if not host or host.startswith("/") or host[1:2] == ":":
# Local path, or no host to resolve
yield params
return
- hostaddr = params["hostaddr"]
+ hostaddr = _get_param(params, "hostaddr")
if hostaddr:
# Already resolved
yield params
loop = asyncio.get_running_loop()
- port = params["port"]
+ port = _get_param(params, "port")
+ if not port:
+ portdef = _get_param_def("port")
+ if portdef:
+ port = portdef.compiled
+
+ assert port and "," not in port # assume a libpq default and no multi
ans = await loop.getaddrinfo(
- host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
+ host, int(port), proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
)
for item in ans:
yield {**params, "hostaddr": item[4][0]}
-@cache
-def _conn_defaults() -> dict[str, str]:
+def _get_param(params: ConnDict, name: str) -> str | None:
+ """
+ Return a value from a connection string.
+
+ The value may be also specified in a PG* env var.
+ """
+ if name in params:
+ return str(params[name])
+
+ # TODO: check if in service
+
+ paramdef = _get_param_def(name)
+ if not paramdef:
+ return None
+
+ env = os.environ.get(paramdef.envvar)
+ if env is not None:
+ return env
+
+ return None
+
+
+@dataclass
+class ParamDef:
+ """
+ Information about defaults and env vars for connection params
+ """
+
+ keyword: str
+ envvar: str
+ compiled: str | None
+
+
+def _get_param_def(keyword: str, _cache: dict[str, ParamDef] = {}) -> ParamDef | None:
"""
- Return a dictionary of defaults for connection strings parameters.
+ Return the ParamDef of a connection string parameter.
"""
- defs = pq.Conninfo.get_defaults()
- return {
- d.keyword.decode(): d.compiled.decode() if d.compiled is not None else ""
- for d in defs
- }
+ if not _cache:
+ defs = pq.Conninfo.get_defaults()
+ for d in defs:
+ cd = ParamDef(
+ keyword=d.keyword.decode(),
+ envvar=d.envvar.decode() if d.envvar else "",
+ compiled=d.compiled.decode() if d.compiled is not None else None,
+ )
+ _cache[cd.keyword] = cd
+
+ return _cache.get(keyword)
@lru_cache()
if params.get(key) == value:
params.pop(key)
- removeif("host", "")
- removeif("hostaddr", "")
- removeif("port", "5432")
- if "," in params.get("host", ""):
- nhosts = len(params["host"].split(","))
- removeif("port", ",".join(["5432"] * nhosts))
- removeif("hostaddr", "," * (nhosts - 1))
removeif("connect_timeout", str(DEFAULT_TIMEOUT))
return params
import socket
import asyncio
import datetime as dt
-from functools import reduce
import pytest
from .utils import alist
from .fix_crdb import crdb_encoding
-from .test_connection import drop_default_args_from_conninfo
snowman = "\u2603"
@pytest.mark.parametrize(
"conninfo, want, env",
[
- ("", "", None),
- ("host='' user=bar", "host='' user=bar", None),
+ ("", [""], None),
+ ("service=foo", ["service=foo"], None),
+ ("host='' user=bar", ["host='' user=bar"], None),
(
"host=127.0.0.1 user=bar",
- "host=127.0.0.1 user=bar",
+ ["host=127.0.0.1 user=bar"],
None,
),
(
"host=1.1.1.1,2.2.2.2 user=bar",
- "host=1.1.1.1,2.2.2.2 user=bar",
+ ["host=1.1.1.1 user=bar", "host=2.2.2.2 user=bar"],
None,
),
(
"host=1.1.1.1,2.2.2.2 port=5432",
- "host=1.1.1.1,2.2.2.2 port=5432,5432",
+ ["host=1.1.1.1 port=5432", "host=2.2.2.2 port=5432"],
None,
),
(
"host=foo.com port=5432",
- "host=foo.com port=5432 hostaddr=1.2.3.4",
+ ["host=foo.com port=5432"],
{"PGHOSTADDR": "1.2.3.4"},
),
],
setpgenv(env)
params = conninfo_to_dict(conninfo)
attempts = list(conninfo_attempts(params))
- params = drop_default_args_from_conninfo(reduce(merge_conninfos, attempts))
- assert drop_default_args_from_conninfo(conninfo_to_dict(want)) == params
+ want = list(map(conninfo_to_dict, want))
+ assert want == attempts
@pytest.mark.parametrize(
"conninfo, want, env",
[
- ("", "", None),
- ("host='' user=bar", "host='' user=bar", None),
+ ("", [""], None),
+ ("host='' user=bar", ["host='' user=bar"], None),
(
"host=127.0.0.1 user=bar",
- "host=127.0.0.1 user=bar hostaddr=127.0.0.1",
+ ["host=127.0.0.1 user=bar hostaddr=127.0.0.1"],
None,
),
(
"host=1.1.1.1,2.2.2.2 user=bar",
- "host=1.1.1.1,2.2.2.2 user=bar hostaddr=1.1.1.1,2.2.2.2",
+ [
+ "host=1.1.1.1 user=bar hostaddr=1.1.1.1",
+ "host=2.2.2.2 user=bar hostaddr=2.2.2.2",
+ ],
None,
),
(
"host=1.1.1.1,2.2.2.2 port=5432",
- "host=1.1.1.1,2.2.2.2 port=5432,5432 hostaddr=1.1.1.1,2.2.2.2",
+ [
+ "host=1.1.1.1 port=5432 hostaddr=1.1.1.1",
+ "host=2.2.2.2 port=5432 hostaddr=2.2.2.2",
+ ],
None,
),
(
"port=5432",
- "host=1.1.1.1,2.2.2.2 port=5432,5432 hostaddr=1.1.1.1,2.2.2.2",
+ [
+ "host=1.1.1.1 port=5432 hostaddr=1.1.1.1",
+ "host=2.2.2.2 port=5432 hostaddr=2.2.2.2",
+ ],
{"PGHOST": "1.1.1.1,2.2.2.2"},
),
(
"host=foo.com port=5432",
- "host=foo.com port=5432 hostaddr=1.2.3.4",
+ ["host=foo.com port=5432"],
{"PGHOSTADDR": "1.2.3.4"},
),
],
setpgenv(env)
params = conninfo_to_dict(conninfo)
attempts = await alist(conninfo_attempts_async(params))
- params = drop_default_args_from_conninfo(reduce(merge_conninfos, attempts))
- assert drop_default_args_from_conninfo(conninfo_to_dict(want)) == params
+ want = list(map(conninfo_to_dict, want))
+ assert want == attempts
@pytest.mark.parametrize(
[
(
"host=foo.com,qux.com",
- "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2",
+ ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"],
None,
),
(
"host=foo.com,qux.com port=5433",
- "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5433,5433",
+ [
+ "host=foo.com hostaddr=1.1.1.1 port=5433",
+ "host=qux.com hostaddr=2.2.2.2 port=5433",
+ ],
None,
),
(
"host=foo.com,qux.com port=5432,5433",
- "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5432,5433",
+ [
+ "host=foo.com hostaddr=1.1.1.1 port=5432",
+ "host=qux.com hostaddr=2.2.2.2 port=5433",
+ ],
None,
),
(
"host=foo.com,nosuchhost.com",
- "host=foo.com hostaddr=1.1.1.1",
+ ["host=foo.com hostaddr=1.1.1.1"],
None,
),
(
"host=foo.com, port=5432,5433",
- "host=foo.com, hostaddr=1.1.1.1, port=5432,5433",
+ ["host=foo.com hostaddr=1.1.1.1 port=5432", "host='' port=5433"],
None,
),
(
"host=nosuchhost.com,foo.com",
- "host=foo.com hostaddr=1.1.1.1",
+ ["host=foo.com hostaddr=1.1.1.1"],
None,
),
(
"host=foo.com,qux.com",
- "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2",
+ ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"],
{},
),
(
"host=dup.com",
- "host=dup.com,dup.com hostaddr=3.3.3.3,3.3.3.4",
+ ["host=dup.com hostaddr=3.3.3.3", "host=dup.com hostaddr=3.3.3.4"],
None,
),
],
async def test_conninfo_attempts_async(conninfo, want, env, fake_resolve):
params = conninfo_to_dict(conninfo)
attempts = await alist(conninfo_attempts_async(params))
- params = drop_default_args_from_conninfo(reduce(merge_conninfos, attempts))
- assert drop_default_args_from_conninfo(conninfo_to_dict(want)) == params
+ want = list(map(conninfo_to_dict, want))
+ assert want == attempts
@pytest.mark.parametrize(
pytest.fail(f"shouldn't try to resolve {host}")
monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fail_getaddrinfo)
-
-
-def merge_conninfos(a1, a2):
- """
- merge conninfo attempts into a multi-host conninfo.
- """
- assert set(a1) == set(a2)
- rv = {}
- for k in a1:
- if k in ("host", "hostaddr", "port"):
- rv[k] = f"{a1[k]},{a2[k]}"
- else:
- assert a1[k] == a2[k]
- rv[k] = a1[k]
- return rv