From bfb68f10a27833e32928ae19fc90cadeda10a0e4 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Tue, 12 Dec 2023 20:33:20 +0100 Subject: [PATCH] fix: don't add defaults to connection strings A default such as empty string for host may may shadow values defined in a service file. Fix #694. --- docs/news.rst | 6 +- psycopg/psycopg/_dns.py | 12 ++-- psycopg/psycopg/conninfo.py | 135 ++++++++++++++++++++---------------- tests/test_connection.py | 7 -- tests/test_conninfo.py | 87 ++++++++++++----------- 5 files changed, 130 insertions(+), 117 deletions(-) diff --git a/docs/news.rst b/docs/news.rst index 17d5638eb..55f0ea199 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -13,8 +13,10 @@ Future releases Psycopg 3.1.15 (unreleased) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ -- Fix async connection to hosts resolving to multiple IP addresses - (:ticket:`#695`). +- Fix use of ``service`` in connection string (regression in 3.1.13, + :ticket:`#694`). +- Fix async connection to hosts resolving to multiple IP addresses (regression + in 3.1.13, :ticket:`#695`). Current release diff --git a/psycopg/psycopg/_dns.py b/psycopg/psycopg/_dns.py index ae0a71429..eb06d1cd8 100644 --- a/psycopg/psycopg/_dns.py +++ b/psycopg/psycopg/_dns.py @@ -52,13 +52,15 @@ async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]: hostaddrs: list[str] = [] ports: list[str] = [] - for attempt in conninfo._split_attempts(conninfo._inject_defaults(params)): + for attempt in conninfo._split_attempts(params): try: async for a2 in conninfo._split_attempts_and_resolve(attempt): - hosts.append(a2["host"]) - hostaddrs.append(a2["hostaddr"]) - if "port" in params: - ports.append(a2["port"]) + if a2.get("host") is not None: + hosts.append(a2["host"]) + if a2.get("hostaddr") is not None: + hostaddrs.append(a2["hostaddr"]) + if a2.get("port") is not None: + ports.append(str(a2["port"])) except OSError as ex: last_exc = ex diff --git a/psycopg/psycopg/conninfo.py b/psycopg/psycopg/conninfo.py index 6c48da734..5f56eb387 100644 --- a/psycopg/psycopg/conninfo.py +++ b/psycopg/psycopg/conninfo.py @@ -16,12 +16,12 @@ from pathlib import Path from datetime import tzinfo from functools import lru_cache from ipaddress import ip_address +from dataclasses import dataclass from typing_extensions import TypeAlias from . import pq from . import errors as e from ._tz import get_tzinfo -from ._compat import cache from ._encodings import pgconn_encoding ConnDict: TypeAlias = "dict[str, Any]" @@ -283,7 +283,7 @@ class ConnectionInfo: def conninfo_attempts(params: ConnDict) -> Iterator[ConnDict]: - """Split a set of connection params on the single attempts to perforn. + """Split a set of connection params on the single attempts to perform. A connection param can perform more than one attempt more than one ``host`` is provided. @@ -291,16 +291,20 @@ def conninfo_attempts(params: ConnDict) -> Iterator[ConnDict]: Because the libpq async function doesn't honour the timeout, we need to reimplement the repeated attempts. """ + # TODO: we should actually resolve the hosts ourselves. + # If an host resolves to more than one ip, the libpq will make more than + # one attempt and wouldn't get to try the following ones, as before + # fixing #674. if params.get("load_balance_hosts", "disable") == "random": - attempts = list(_split_attempts(_inject_defaults(params))) + attempts = list(_split_attempts(params)) shuffle(attempts) yield from attempts else: - yield from _split_attempts(_inject_defaults(params)) + yield from _split_attempts(params) async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]: - """Split a set of connection params on the single attempts to perforn. + """Split a set of connection params on the single attempts to perform. A connection param can perform more than one attempt more than one ``host`` is provided. @@ -313,9 +317,11 @@ async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]: Because the libpq async function doesn't honour the timeout, we need to reimplement the repeated attempts. """ + # TODO: the function should resolve all hosts and shuffle the results + # to replicate the same libpq algorithm. yielded = False last_exc = None - for attempt in _split_attempts(_inject_defaults(params)): + for attempt in _split_attempts(params): try: async for a2 in _split_attempts_and_resolve(attempt): yielded = True @@ -329,45 +335,13 @@ async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]: raise e.OperationalError(str(last_exc)) -def _inject_defaults(params: ConnDict) -> ConnDict: - """ - Add defaults to a dictionary of parameters. - - This avoids the need to look up for env vars at various stages during - processing. - - Note that a port is always specified. 5432 likely comes from here. - - The `host`, `hostaddr`, `port` will be always set to a string. - """ - defaults = _conn_defaults() - out = params.copy() - - def inject(name: str, envvar: str) -> None: - value = out.get(name) - if not value: - out[name] = os.environ.get(envvar, defaults[name]) - else: - out[name] = str(value) - - inject("host", "PGHOST") - inject("hostaddr", "PGHOSTADDR") - inject("port", "PGPORT") - - return out - - def _split_attempts(params: ConnDict) -> Iterator[ConnDict]: """ Split connection parameters with a sequence of hosts into separate attempts. - - Assume that `host`, `hostaddr`, `port` are always present and a string (as - emitted from `_inject_defaults()`). """ def split_val(key: str) -> list[str]: - # Assume all keys are present and strings. - val: str = params[key] + val = _get_param(params, key) return val.split(",") if val else [] hosts = split_val("host") @@ -386,14 +360,15 @@ def _split_attempts(params: ConnDict) -> Iterator[ConnDict]: raise e.OperationalError( f"could not match {len(ports)} port numbers to {len(hosts)} hosts" ) - elif len(ports) == 1: - ports *= nhosts - # A single attempt to make + # A single attempt to make. Don't mangle the conninfo string. if nhosts <= 1: yield params return + if len(ports) == 1: + ports *= nhosts + # Now all lists are either empty or have the same length for i in range(nhosts): attempt = params.copy() @@ -412,24 +387,22 @@ async def _split_attempts_and_resolve(params: ConnDict) -> AsyncIterator[ConnDic :param params: The input parameters, for instance as returned by `~psycopg.conninfo.conninfo_to_dict()`. The function expects at most - a single entry for host, hostaddr, port and doesn't check for env vars - because it is designed to further process the input of _split_attempts() + a single entry for host, hostaddr because it is designed to further + process the input of _split_attempts(). If a ``host`` param is present but not ``hostname``, resolve the host - addresses dynamically. + addresses asynchronously. The function may change the input ``host``, ``hostname``, ``port`` to allow connecting without further DNS lookups. - - Raise `~psycopg.OperationalError` if resolution fails. """ - host = params["host"] + host = _get_param(params, "host") if not host or host.startswith("/") or host[1:2] == ":": # Local path, or no host to resolve yield params return - hostaddr = params["hostaddr"] + hostaddr = _get_param(params, "hostaddr") if hostaddr: # Already resolved yield params @@ -443,25 +416,69 @@ async def _split_attempts_and_resolve(params: ConnDict) -> AsyncIterator[ConnDic loop = asyncio.get_running_loop() - port = params["port"] + port = _get_param(params, "port") + if not port: + portdef = _get_param_def("port") + if portdef: + port = portdef.compiled + + assert port and "," not in port # assume a libpq default and no multi ans = await loop.getaddrinfo( - host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM + host, int(port), proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM ) for item in ans: yield {**params, "hostaddr": item[4][0]} -@cache -def _conn_defaults() -> dict[str, str]: +def _get_param(params: ConnDict, name: str) -> str | None: + """ + Return a value from a connection string. + + The value may be also specified in a PG* env var. + """ + if name in params: + return str(params[name]) + + # TODO: check if in service + + paramdef = _get_param_def(name) + if not paramdef: + return None + + env = os.environ.get(paramdef.envvar) + if env is not None: + return env + + return None + + +@dataclass +class ParamDef: + """ + Information about defaults and env vars for connection params + """ + + keyword: str + envvar: str + compiled: str | None + + +def _get_param_def(keyword: str, _cache: dict[str, ParamDef] = {}) -> ParamDef | None: """ - Return a dictionary of defaults for connection strings parameters. + Return the ParamDef of a connection string parameter. """ - defs = pq.Conninfo.get_defaults() - return { - d.keyword.decode(): d.compiled.decode() if d.compiled is not None else "" - for d in defs - } + if not _cache: + defs = pq.Conninfo.get_defaults() + for d in defs: + cd = ParamDef( + keyword=d.keyword.decode(), + envvar=d.envvar.decode() if d.envvar else "", + compiled=d.compiled.decode() if d.compiled is not None else None, + ) + _cache[cd.keyword] = cd + + return _cache.get(keyword) @lru_cache() diff --git a/tests/test_connection.py b/tests/test_connection.py index 7b823beaa..754acec3a 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -876,13 +876,6 @@ def drop_default_args_from_conninfo(conninfo): if params.get(key) == value: params.pop(key) - removeif("host", "") - removeif("hostaddr", "") - removeif("port", "5432") - if "," in params.get("host", ""): - nhosts = len(params["host"].split(",")) - removeif("port", ",".join(["5432"] * nhosts)) - removeif("hostaddr", "," * (nhosts - 1)) removeif("connect_timeout", str(DEFAULT_TIMEOUT)) return params diff --git a/tests/test_conninfo.py b/tests/test_conninfo.py index 83174026f..825419926 100644 --- a/tests/test_conninfo.py +++ b/tests/test_conninfo.py @@ -1,7 +1,6 @@ import socket import asyncio import datetime as dt -from functools import reduce import pytest @@ -13,7 +12,6 @@ from psycopg._encodings import pg2pyenc from .utils import alist from .fix_crdb import crdb_encoding -from .test_connection import drop_default_args_from_conninfo snowman = "\u2603" @@ -322,26 +320,27 @@ class TestConnectionInfo: @pytest.mark.parametrize( "conninfo, want, env", [ - ("", "", None), - ("host='' user=bar", "host='' user=bar", None), + ("", [""], None), + ("service=foo", ["service=foo"], None), + ("host='' user=bar", ["host='' user=bar"], None), ( "host=127.0.0.1 user=bar", - "host=127.0.0.1 user=bar", + ["host=127.0.0.1 user=bar"], None, ), ( "host=1.1.1.1,2.2.2.2 user=bar", - "host=1.1.1.1,2.2.2.2 user=bar", + ["host=1.1.1.1 user=bar", "host=2.2.2.2 user=bar"], None, ), ( "host=1.1.1.1,2.2.2.2 port=5432", - "host=1.1.1.1,2.2.2.2 port=5432,5432", + ["host=1.1.1.1 port=5432", "host=2.2.2.2 port=5432"], None, ), ( "host=foo.com port=5432", - "host=foo.com port=5432 hostaddr=1.2.3.4", + ["host=foo.com port=5432"], {"PGHOSTADDR": "1.2.3.4"}, ), ], @@ -351,38 +350,47 @@ def test_conninfo_attempts(setpgenv, conninfo, want, env): setpgenv(env) params = conninfo_to_dict(conninfo) attempts = list(conninfo_attempts(params)) - params = drop_default_args_from_conninfo(reduce(merge_conninfos, attempts)) - assert drop_default_args_from_conninfo(conninfo_to_dict(want)) == params + want = list(map(conninfo_to_dict, want)) + assert want == attempts @pytest.mark.parametrize( "conninfo, want, env", [ - ("", "", None), - ("host='' user=bar", "host='' user=bar", None), + ("", [""], None), + ("host='' user=bar", ["host='' user=bar"], None), ( "host=127.0.0.1 user=bar", - "host=127.0.0.1 user=bar hostaddr=127.0.0.1", + ["host=127.0.0.1 user=bar hostaddr=127.0.0.1"], None, ), ( "host=1.1.1.1,2.2.2.2 user=bar", - "host=1.1.1.1,2.2.2.2 user=bar hostaddr=1.1.1.1,2.2.2.2", + [ + "host=1.1.1.1 user=bar hostaddr=1.1.1.1", + "host=2.2.2.2 user=bar hostaddr=2.2.2.2", + ], None, ), ( "host=1.1.1.1,2.2.2.2 port=5432", - "host=1.1.1.1,2.2.2.2 port=5432,5432 hostaddr=1.1.1.1,2.2.2.2", + [ + "host=1.1.1.1 port=5432 hostaddr=1.1.1.1", + "host=2.2.2.2 port=5432 hostaddr=2.2.2.2", + ], None, ), ( "port=5432", - "host=1.1.1.1,2.2.2.2 port=5432,5432 hostaddr=1.1.1.1,2.2.2.2", + [ + "host=1.1.1.1 port=5432 hostaddr=1.1.1.1", + "host=2.2.2.2 port=5432 hostaddr=2.2.2.2", + ], {"PGHOST": "1.1.1.1,2.2.2.2"}, ), ( "host=foo.com port=5432", - "host=foo.com port=5432 hostaddr=1.2.3.4", + ["host=foo.com port=5432"], {"PGHOSTADDR": "1.2.3.4"}, ), ], @@ -394,8 +402,8 @@ async def test_conninfo_attempts_async_no_resolve( setpgenv(env) params = conninfo_to_dict(conninfo) attempts = await alist(conninfo_attempts_async(params)) - params = drop_default_args_from_conninfo(reduce(merge_conninfos, attempts)) - assert drop_default_args_from_conninfo(conninfo_to_dict(want)) == params + want = list(map(conninfo_to_dict, want)) + assert want == attempts @pytest.mark.parametrize( @@ -403,42 +411,48 @@ async def test_conninfo_attempts_async_no_resolve( [ ( "host=foo.com,qux.com", - "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2", + ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"], None, ), ( "host=foo.com,qux.com port=5433", - "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5433,5433", + [ + "host=foo.com hostaddr=1.1.1.1 port=5433", + "host=qux.com hostaddr=2.2.2.2 port=5433", + ], None, ), ( "host=foo.com,qux.com port=5432,5433", - "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5432,5433", + [ + "host=foo.com hostaddr=1.1.1.1 port=5432", + "host=qux.com hostaddr=2.2.2.2 port=5433", + ], None, ), ( "host=foo.com,nosuchhost.com", - "host=foo.com hostaddr=1.1.1.1", + ["host=foo.com hostaddr=1.1.1.1"], None, ), ( "host=foo.com, port=5432,5433", - "host=foo.com, hostaddr=1.1.1.1, port=5432,5433", + ["host=foo.com hostaddr=1.1.1.1 port=5432", "host='' port=5433"], None, ), ( "host=nosuchhost.com,foo.com", - "host=foo.com hostaddr=1.1.1.1", + ["host=foo.com hostaddr=1.1.1.1"], None, ), ( "host=foo.com,qux.com", - "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2", + ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"], {}, ), ( "host=dup.com", - "host=dup.com,dup.com hostaddr=3.3.3.3,3.3.3.4", + ["host=dup.com hostaddr=3.3.3.3", "host=dup.com hostaddr=3.3.3.4"], None, ), ], @@ -447,8 +461,8 @@ async def test_conninfo_attempts_async_no_resolve( async def test_conninfo_attempts_async(conninfo, want, env, fake_resolve): params = conninfo_to_dict(conninfo) attempts = await alist(conninfo_attempts_async(params)) - params = drop_default_args_from_conninfo(reduce(merge_conninfos, attempts)) - assert drop_default_args_from_conninfo(conninfo_to_dict(want)) == params + want = list(map(conninfo_to_dict, want)) + assert want == attempts @pytest.mark.parametrize( @@ -534,18 +548,3 @@ async def fail_resolve(monkeypatch): pytest.fail(f"shouldn't try to resolve {host}") monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fail_getaddrinfo) - - -def merge_conninfos(a1, a2): - """ - merge conninfo attempts into a multi-host conninfo. - """ - assert set(a1) == set(a2) - rv = {} - for k in a1: - if k in ("host", "hostaddr", "port"): - rv[k] = f"{a1[k]},{a2[k]}" - else: - assert a1[k] == a2[k] - rv[k] = a1[k] - return rv -- 2.47.2