]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: don't add defaults to connection strings
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 12 Dec 2023 19:33:20 +0000 (20:33 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 13 Dec 2023 00:00:58 +0000 (01:00 +0100)
A default such as empty string for host may may shadow values defined in
a service file.

Fix #694.

docs/news.rst
psycopg/psycopg/_dns.py
psycopg/psycopg/conninfo.py
tests/test_connection.py
tests/test_conninfo.py

index 17d5638ebe97e9872b0b167bb348914296c2c7bd..55f0ea199dec74709720747e3d8711c680adabc9 100644 (file)
@@ -13,8 +13,10 @@ Future releases
 Psycopg 3.1.15 (unreleased)
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
-- Fix async connection to hosts resolving to  multiple IP addresses
-  (:ticket:`#695`).
+- Fix use of ``service`` in connection string (regression in 3.1.13,
+  :ticket:`#694`).
+- Fix async connection to hosts resolving to multiple IP addresses (regression
+  in 3.1.13, :ticket:`#695`).
 
 
 Current release
index ae0a71429a6ec13c03c0ebda378791502e79104b..eb06d1cd8f938f951016884f19a62e5180ba940b 100644 (file)
@@ -52,13 +52,15 @@ async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
     hostaddrs: list[str] = []
     ports: list[str] = []
 
-    for attempt in conninfo._split_attempts(conninfo._inject_defaults(params)):
+    for attempt in conninfo._split_attempts(params):
         try:
             async for a2 in conninfo._split_attempts_and_resolve(attempt):
-                hosts.append(a2["host"])
-                hostaddrs.append(a2["hostaddr"])
-                if "port" in params:
-                    ports.append(a2["port"])
+                if a2.get("host") is not None:
+                    hosts.append(a2["host"])
+                if a2.get("hostaddr") is not None:
+                    hostaddrs.append(a2["hostaddr"])
+                if a2.get("port") is not None:
+                    ports.append(str(a2["port"]))
         except OSError as ex:
             last_exc = ex
 
index 6c48da734de127126e49faa61444d6ccc7086914..5f56eb3871fe679fec1ab8ed38efb267c7f95bb4 100644 (file)
@@ -16,12 +16,12 @@ from pathlib import Path
 from datetime import tzinfo
 from functools import lru_cache
 from ipaddress import ip_address
+from dataclasses import dataclass
 from typing_extensions import TypeAlias
 
 from . import pq
 from . import errors as e
 from ._tz import get_tzinfo
-from ._compat import cache
 from ._encodings import pgconn_encoding
 
 ConnDict: TypeAlias = "dict[str, Any]"
@@ -283,7 +283,7 @@ class ConnectionInfo:
 
 
 def conninfo_attempts(params: ConnDict) -> Iterator[ConnDict]:
-    """Split a set of connection params on the single attempts to perforn.
+    """Split a set of connection params on the single attempts to perform.
 
     A connection param can perform more than one attempt more than one ``host``
     is provided.
@@ -291,16 +291,20 @@ def conninfo_attempts(params: ConnDict) -> Iterator[ConnDict]:
     Because the libpq async function doesn't honour the timeout, we need to
     reimplement the repeated attempts.
     """
+    # TODO: we should actually resolve the hosts ourselves.
+    # If an host resolves to more than one ip, the libpq will make more than
+    # one attempt and wouldn't get to try the following ones, as before
+    # fixing #674.
     if params.get("load_balance_hosts", "disable") == "random":
-        attempts = list(_split_attempts(_inject_defaults(params)))
+        attempts = list(_split_attempts(params))
         shuffle(attempts)
         yield from attempts
     else:
-        yield from _split_attempts(_inject_defaults(params))
+        yield from _split_attempts(params)
 
 
 async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]:
-    """Split a set of connection params on the single attempts to perforn.
+    """Split a set of connection params on the single attempts to perform.
 
     A connection param can perform more than one attempt more than one ``host``
     is provided.
@@ -313,9 +317,11 @@ async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]:
     Because the libpq async function doesn't honour the timeout, we need to
     reimplement the repeated attempts.
     """
+    # TODO: the function should resolve all hosts and shuffle the results
+    # to replicate the same libpq algorithm.
     yielded = False
     last_exc = None
-    for attempt in _split_attempts(_inject_defaults(params)):
+    for attempt in _split_attempts(params):
         try:
             async for a2 in _split_attempts_and_resolve(attempt):
                 yielded = True
@@ -329,45 +335,13 @@ async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]:
         raise e.OperationalError(str(last_exc))
 
 
-def _inject_defaults(params: ConnDict) -> ConnDict:
-    """
-    Add defaults to a dictionary of parameters.
-
-    This avoids the need to look up for env vars at various stages during
-    processing.
-
-    Note that a port is always specified. 5432 likely comes from here.
-
-    The `host`, `hostaddr`, `port` will be always set to a string.
-    """
-    defaults = _conn_defaults()
-    out = params.copy()
-
-    def inject(name: str, envvar: str) -> None:
-        value = out.get(name)
-        if not value:
-            out[name] = os.environ.get(envvar, defaults[name])
-        else:
-            out[name] = str(value)
-
-    inject("host", "PGHOST")
-    inject("hostaddr", "PGHOSTADDR")
-    inject("port", "PGPORT")
-
-    return out
-
-
 def _split_attempts(params: ConnDict) -> Iterator[ConnDict]:
     """
     Split connection parameters with a sequence of hosts into separate attempts.
-
-    Assume that `host`, `hostaddr`, `port` are always present and a string (as
-    emitted from `_inject_defaults()`).
     """
 
     def split_val(key: str) -> list[str]:
-        # Assume all keys are present and strings.
-        val: str = params[key]
+        val = _get_param(params, key)
         return val.split(",") if val else []
 
     hosts = split_val("host")
@@ -386,14 +360,15 @@ def _split_attempts(params: ConnDict) -> Iterator[ConnDict]:
         raise e.OperationalError(
             f"could not match {len(ports)} port numbers to {len(hosts)} hosts"
         )
-    elif len(ports) == 1:
-        ports *= nhosts
 
-    # A single attempt to make
+    # A single attempt to make. Don't mangle the conninfo string.
     if nhosts <= 1:
         yield params
         return
 
+    if len(ports) == 1:
+        ports *= nhosts
+
     # Now all lists are either empty or have the same length
     for i in range(nhosts):
         attempt = params.copy()
@@ -412,24 +387,22 @@ async def _split_attempts_and_resolve(params: ConnDict) -> AsyncIterator[ConnDic
 
     :param params: The input parameters, for instance as returned by
         `~psycopg.conninfo.conninfo_to_dict()`. The function expects at most
-        a single entry for host, hostaddr, port and doesn't check for env vars
-        because it is designed to further process the input of _split_attempts()
+        a single entry for host, hostaddr because it is designed to further
+        process the input of _split_attempts().
 
     If a ``host`` param is present but not ``hostname``, resolve the host
-    addresses dynamically.
+    addresses asynchronously.
 
     The function may change the input ``host``, ``hostname``, ``port`` to allow
     connecting without further DNS lookups.
-
-    Raise `~psycopg.OperationalError` if resolution fails.
     """
-    host = params["host"]
+    host = _get_param(params, "host")
     if not host or host.startswith("/") or host[1:2] == ":":
         # Local path, or no host to resolve
         yield params
         return
 
-    hostaddr = params["hostaddr"]
+    hostaddr = _get_param(params, "hostaddr")
     if hostaddr:
         # Already resolved
         yield params
@@ -443,25 +416,69 @@ async def _split_attempts_and_resolve(params: ConnDict) -> AsyncIterator[ConnDic
 
     loop = asyncio.get_running_loop()
 
-    port = params["port"]
+    port = _get_param(params, "port")
+    if not port:
+        portdef = _get_param_def("port")
+        if portdef:
+            port = portdef.compiled
+
+    assert port and "," not in port  # assume a libpq default and no multi
     ans = await loop.getaddrinfo(
-        host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
+        host, int(port), proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
     )
 
     for item in ans:
         yield {**params, "hostaddr": item[4][0]}
 
 
-@cache
-def _conn_defaults() -> dict[str, str]:
+def _get_param(params: ConnDict, name: str) -> str | None:
+    """
+    Return a value from a connection string.
+
+    The value may be also specified in a PG* env var.
+    """
+    if name in params:
+        return str(params[name])
+
+    # TODO: check if in service
+
+    paramdef = _get_param_def(name)
+    if not paramdef:
+        return None
+
+    env = os.environ.get(paramdef.envvar)
+    if env is not None:
+        return env
+
+    return None
+
+
+@dataclass
+class ParamDef:
+    """
+    Information about defaults and env vars for connection params
+    """
+
+    keyword: str
+    envvar: str
+    compiled: str | None
+
+
+def _get_param_def(keyword: str, _cache: dict[str, ParamDef] = {}) -> ParamDef | None:
     """
-    Return a dictionary of defaults for connection strings parameters.
+    Return the ParamDef of a connection string parameter.
     """
-    defs = pq.Conninfo.get_defaults()
-    return {
-        d.keyword.decode(): d.compiled.decode() if d.compiled is not None else ""
-        for d in defs
-    }
+    if not _cache:
+        defs = pq.Conninfo.get_defaults()
+        for d in defs:
+            cd = ParamDef(
+                keyword=d.keyword.decode(),
+                envvar=d.envvar.decode() if d.envvar else "",
+                compiled=d.compiled.decode() if d.compiled is not None else None,
+            )
+            _cache[cd.keyword] = cd
+
+    return _cache.get(keyword)
 
 
 @lru_cache()
index 7b823beaa8c9d0fe7ac9be8052d17fca68e4d576..754acec3adb228a7326b434b7fa3a1a948b99ac3 100644 (file)
@@ -876,13 +876,6 @@ def drop_default_args_from_conninfo(conninfo):
         if params.get(key) == value:
             params.pop(key)
 
-    removeif("host", "")
-    removeif("hostaddr", "")
-    removeif("port", "5432")
-    if "," in params.get("host", ""):
-        nhosts = len(params["host"].split(","))
-        removeif("port", ",".join(["5432"] * nhosts))
-        removeif("hostaddr", "," * (nhosts - 1))
     removeif("connect_timeout", str(DEFAULT_TIMEOUT))
 
     return params
index 83174026f0004cf6a198699192248774477d95a6..8254199261c7fa1baf4234420b227b9dfb8a8013 100644 (file)
@@ -1,7 +1,6 @@
 import socket
 import asyncio
 import datetime as dt
-from functools import reduce
 
 import pytest
 
@@ -13,7 +12,6 @@ from psycopg._encodings import pg2pyenc
 
 from .utils import alist
 from .fix_crdb import crdb_encoding
-from .test_connection import drop_default_args_from_conninfo
 
 snowman = "\u2603"
 
@@ -322,26 +320,27 @@ class TestConnectionInfo:
 @pytest.mark.parametrize(
     "conninfo, want, env",
     [
-        ("", "", None),
-        ("host='' user=bar", "host='' user=bar", None),
+        ("", [""], None),
+        ("service=foo", ["service=foo"], None),
+        ("host='' user=bar", ["host='' user=bar"], None),
         (
             "host=127.0.0.1 user=bar",
-            "host=127.0.0.1 user=bar",
+            ["host=127.0.0.1 user=bar"],
             None,
         ),
         (
             "host=1.1.1.1,2.2.2.2 user=bar",
-            "host=1.1.1.1,2.2.2.2 user=bar",
+            ["host=1.1.1.1 user=bar", "host=2.2.2.2 user=bar"],
             None,
         ),
         (
             "host=1.1.1.1,2.2.2.2 port=5432",
-            "host=1.1.1.1,2.2.2.2 port=5432,5432",
+            ["host=1.1.1.1 port=5432", "host=2.2.2.2 port=5432"],
             None,
         ),
         (
             "host=foo.com port=5432",
-            "host=foo.com port=5432 hostaddr=1.2.3.4",
+            ["host=foo.com port=5432"],
             {"PGHOSTADDR": "1.2.3.4"},
         ),
     ],
@@ -351,38 +350,47 @@ def test_conninfo_attempts(setpgenv, conninfo, want, env):
     setpgenv(env)
     params = conninfo_to_dict(conninfo)
     attempts = list(conninfo_attempts(params))
-    params = drop_default_args_from_conninfo(reduce(merge_conninfos, attempts))
-    assert drop_default_args_from_conninfo(conninfo_to_dict(want)) == params
+    want = list(map(conninfo_to_dict, want))
+    assert want == attempts
 
 
 @pytest.mark.parametrize(
     "conninfo, want, env",
     [
-        ("", "", None),
-        ("host='' user=bar", "host='' user=bar", None),
+        ("", [""], None),
+        ("host='' user=bar", ["host='' user=bar"], None),
         (
             "host=127.0.0.1 user=bar",
-            "host=127.0.0.1 user=bar hostaddr=127.0.0.1",
+            ["host=127.0.0.1 user=bar hostaddr=127.0.0.1"],
             None,
         ),
         (
             "host=1.1.1.1,2.2.2.2 user=bar",
-            "host=1.1.1.1,2.2.2.2 user=bar hostaddr=1.1.1.1,2.2.2.2",
+            [
+                "host=1.1.1.1 user=bar hostaddr=1.1.1.1",
+                "host=2.2.2.2 user=bar hostaddr=2.2.2.2",
+            ],
             None,
         ),
         (
             "host=1.1.1.1,2.2.2.2 port=5432",
-            "host=1.1.1.1,2.2.2.2 port=5432,5432 hostaddr=1.1.1.1,2.2.2.2",
+            [
+                "host=1.1.1.1 port=5432 hostaddr=1.1.1.1",
+                "host=2.2.2.2 port=5432 hostaddr=2.2.2.2",
+            ],
             None,
         ),
         (
             "port=5432",
-            "host=1.1.1.1,2.2.2.2 port=5432,5432 hostaddr=1.1.1.1,2.2.2.2",
+            [
+                "host=1.1.1.1 port=5432 hostaddr=1.1.1.1",
+                "host=2.2.2.2 port=5432 hostaddr=2.2.2.2",
+            ],
             {"PGHOST": "1.1.1.1,2.2.2.2"},
         ),
         (
             "host=foo.com port=5432",
-            "host=foo.com port=5432 hostaddr=1.2.3.4",
+            ["host=foo.com port=5432"],
             {"PGHOSTADDR": "1.2.3.4"},
         ),
     ],
@@ -394,8 +402,8 @@ async def test_conninfo_attempts_async_no_resolve(
     setpgenv(env)
     params = conninfo_to_dict(conninfo)
     attempts = await alist(conninfo_attempts_async(params))
-    params = drop_default_args_from_conninfo(reduce(merge_conninfos, attempts))
-    assert drop_default_args_from_conninfo(conninfo_to_dict(want)) == params
+    want = list(map(conninfo_to_dict, want))
+    assert want == attempts
 
 
 @pytest.mark.parametrize(
@@ -403,42 +411,48 @@ async def test_conninfo_attempts_async_no_resolve(
     [
         (
             "host=foo.com,qux.com",
-            "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2",
+            ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"],
             None,
         ),
         (
             "host=foo.com,qux.com port=5433",
-            "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5433,5433",
+            [
+                "host=foo.com hostaddr=1.1.1.1 port=5433",
+                "host=qux.com hostaddr=2.2.2.2 port=5433",
+            ],
             None,
         ),
         (
             "host=foo.com,qux.com port=5432,5433",
-            "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5432,5433",
+            [
+                "host=foo.com hostaddr=1.1.1.1 port=5432",
+                "host=qux.com hostaddr=2.2.2.2 port=5433",
+            ],
             None,
         ),
         (
             "host=foo.com,nosuchhost.com",
-            "host=foo.com hostaddr=1.1.1.1",
+            ["host=foo.com hostaddr=1.1.1.1"],
             None,
         ),
         (
             "host=foo.com, port=5432,5433",
-            "host=foo.com, hostaddr=1.1.1.1, port=5432,5433",
+            ["host=foo.com hostaddr=1.1.1.1 port=5432", "host='' port=5433"],
             None,
         ),
         (
             "host=nosuchhost.com,foo.com",
-            "host=foo.com hostaddr=1.1.1.1",
+            ["host=foo.com hostaddr=1.1.1.1"],
             None,
         ),
         (
             "host=foo.com,qux.com",
-            "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2",
+            ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"],
             {},
         ),
         (
             "host=dup.com",
-            "host=dup.com,dup.com hostaddr=3.3.3.3,3.3.3.4",
+            ["host=dup.com hostaddr=3.3.3.3", "host=dup.com hostaddr=3.3.3.4"],
             None,
         ),
     ],
@@ -447,8 +461,8 @@ async def test_conninfo_attempts_async_no_resolve(
 async def test_conninfo_attempts_async(conninfo, want, env, fake_resolve):
     params = conninfo_to_dict(conninfo)
     attempts = await alist(conninfo_attempts_async(params))
-    params = drop_default_args_from_conninfo(reduce(merge_conninfos, attempts))
-    assert drop_default_args_from_conninfo(conninfo_to_dict(want)) == params
+    want = list(map(conninfo_to_dict, want))
+    assert want == attempts
 
 
 @pytest.mark.parametrize(
@@ -534,18 +548,3 @@ async def fail_resolve(monkeypatch):
         pytest.fail(f"shouldn't try to resolve {host}")
 
     monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fail_getaddrinfo)
-
-
-def merge_conninfos(a1, a2):
-    """
-    merge conninfo attempts into a multi-host conninfo.
-    """
-    assert set(a1) == set(a2)
-    rv = {}
-    for k in a1:
-        if k in ("host", "hostaddr", "port"):
-            rv[k] = f"{a1[k]},{a2[k]}"
-        else:
-            assert a1[k] == a2[k]
-            rv[k] = a1[k]
-    return rv