]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat: explicitly iterate on multiple hosts on connections
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 26 Oct 2023 21:16:01 +0000 (23:16 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 13 Nov 2023 23:34:24 +0000 (00:34 +0100)
The libpq async connection path doesn't iterate on the attempts, so we
need to do it ourselves.

docs/news.rst
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
psycopg/psycopg/conninfo.py
tests/fix_proxy.py
tests/test_connection.py
tests/test_connection_async.py
tests/test_dns.py
tests/test_module.py
tests/test_psycopg_dbapi20.py

index 7ce52eff1edc30390c0eacab8c270afb0c4dc09a..cc456e5b1a72e6d65484c1ddfe256e88c9ff9e21 100644 (file)
@@ -18,6 +18,7 @@ Psycopg 3.1.13 (unreleased)
   `~zoneinfo.ZoneInfo` (ambiguous offset, see :ticket:`#652`).
 - Handle gracefully EINTR on signals instead of raising `InterruptedError`,
   consistently with :pep:`475` guideline (:ticket:`#667`).
+- Fix support for connection strings with multiple hosts (:ticket:`#674`).
 
 
 Current release
index 5f3437321cea0f3adaf6c619e5dd0b3331365db8..e90e30adb0d461f5f6fb30d81db9ee887afded66 100644 (file)
@@ -7,7 +7,7 @@ psycopg connection objects
 import logging
 import threading
 from types import TracebackType
-from typing import Any, Callable, cast, Dict, Generator, Generic, Iterator
+from typing import Any, Callable, cast, Generator, Generic, Iterator
 from typing import List, NamedTuple, Optional, Type, TypeVar, Tuple, Union
 from typing import overload, TYPE_CHECKING
 from weakref import ref, ReferenceType
@@ -31,6 +31,7 @@ from .cursor import Cursor
 from ._compat import LiteralString
 from .pq.misc import connection_summary
 from .conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo
+from .conninfo import conninfo_attempts, ConnDict
 from ._pipeline import BasePipeline, Pipeline
 from .generators import notifies, connect, execute
 from ._encodings import pgconn_encoding
@@ -106,6 +107,11 @@ class BaseConnection(Generic[Row]):
     ConnStatus = pq.ConnStatus
     TransactionStatus = pq.TransactionStatus
 
+    # Default timeout for connection a attempt.
+    # Arbitrary timeout, what applied by the libpq on my computer.
+    # Your mileage won't vary.
+    _DEFAULT_CONNECT_TIMEOUT = 130
+
     def __init__(self, pgconn: "PGconn"):
         self.pgconn = pgconn
         self._autocommit = False
@@ -724,14 +730,19 @@ class Connection(BaseConnection[Row]):
         Connect to a database server and return a new `Connection` instance.
         """
         params = cls._get_connection_params(conninfo, **kwargs)
-        conninfo = make_conninfo(**params)
+        timeout = int(params["connect_timeout"])
+        rv = None
+        for attempt in conninfo_attempts(params):
+            try:
+                conninfo = make_conninfo(**attempt)
+                rv = cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout)
+                break
+            except e._NO_TRACEBACK as ex:
+                last_ex = ex
 
-        try:
-            rv = cls._wait_conn(
-                cls._connect_gen(conninfo), timeout=params["connect_timeout"]
-            )
-        except e._NO_TRACEBACK as ex:
-            raise ex.with_traceback(None)
+        if not rv:
+            assert last_ex
+            raise last_ex.with_traceback(None)
 
         rv._autocommit = bool(autocommit)
         if row_factory:
@@ -774,7 +785,7 @@ class Connection(BaseConnection[Row]):
             self.close()
 
     @classmethod
-    def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> Dict[str, Any]:
+    def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> ConnDict:
         """Manipulate connection parameters before connecting.
 
         :param conninfo: Connection string as received by `~Connection.connect()`.
@@ -788,7 +799,10 @@ class Connection(BaseConnection[Row]):
         if "connect_timeout" in params:
             params["connect_timeout"] = int(params["connect_timeout"])
         else:
-            params["connect_timeout"] = None
+            # The sync connect function will stop on the default socket timeout
+            # Because in async connection mode we need to enforce the timeout
+            # ourselves, we need a finite value.
+            params["connect_timeout"] = cls._DEFAULT_CONNECT_TIMEOUT
 
         return params
 
index 5ab0522b026b22b01eef47e0b9821b945ba6159f..6766a224532fd00bf5d8ece7255760419fc56157 100644 (file)
@@ -8,7 +8,7 @@ import sys
 import asyncio
 import logging
 from types import TracebackType
-from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional
+from typing import Any, AsyncGenerator, AsyncIterator, List, Optional
 from typing import Type, TypeVar, Union, cast, overload, TYPE_CHECKING
 from contextlib import asynccontextmanager
 
@@ -20,7 +20,7 @@ from ._tpc import Xid
 from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
-from .conninfo import make_conninfo, conninfo_to_dict, resolve_hostaddr_async
+from .conninfo import ConnDict, make_conninfo, conninfo_to_dict, conninfo_attempts_async
 from ._pipeline import AsyncPipeline
 from ._encodings import pgconn_encoding
 from .connection import BaseConnection, CursorRow, Notify
@@ -118,14 +118,19 @@ class AsyncConnection(BaseConnection[Row]):
                 )
 
         params = await cls._get_connection_params(conninfo, **kwargs)
-        conninfo = make_conninfo(**params)
+        timeout = int(params["connect_timeout"])
+        rv = None
+        async for attempt in conninfo_attempts_async(params):
+            try:
+                conninfo = make_conninfo(**attempt)
+                rv = await cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout)
+                break
+            except e._NO_TRACEBACK as ex:
+                last_ex = ex
 
-        try:
-            rv = await cls._wait_conn(
-                cls._connect_gen(conninfo), timeout=params["connect_timeout"]
-            )
-        except e._NO_TRACEBACK as ex:
-            raise ex.with_traceback(None)
+        if not rv:
+            assert last_ex
+            raise last_ex.with_traceback(None)
 
         rv._autocommit = bool(autocommit)
         if row_factory:
@@ -168,28 +173,18 @@ class AsyncConnection(BaseConnection[Row]):
             await self.close()
 
     @classmethod
-    async def _get_connection_params(
-        cls, conninfo: str, **kwargs: Any
-    ) -> Dict[str, Any]:
-        """Manipulate connection parameters before connecting.
-
-        .. versionchanged:: 3.1
-            Unlike the sync counterpart, perform non-blocking address
-            resolution and populate the ``hostaddr`` connection parameter,
-            unless the user has provided one themselves. See
-            `~psycopg._dns.resolve_hostaddr_async()` for details.
-
-        """
+    async def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> ConnDict:
+        """Manipulate connection parameters before connecting."""
         params = conninfo_to_dict(conninfo, **kwargs)
 
         # Make sure there is an usable connect_timeout
         if "connect_timeout" in params:
             params["connect_timeout"] = int(params["connect_timeout"])
         else:
-            params["connect_timeout"] = None
-
-        # Resolve host addresses in non-blocking way
-        params = await resolve_hostaddr_async(params)
+            # The sync connect function will stop on the default socket timeout
+            # Because in async connection mode we need to enforce the timeout
+            # ourselves, we need a finite value.
+            params["connect_timeout"] = cls._DEFAULT_CONNECT_TIMEOUT
 
         return params
 
index efbb2be315e99a59ebbafb965192d92a19876bd2..4f633fff4a09c9c5bb72758af0c5d0ae7c637344 100644 (file)
@@ -330,6 +330,49 @@ async def resolve_hostaddr_async(params: ConnDict) -> ConnDict:
     return out
 
 
+def conninfo_attempts(params: ConnDict) -> Iterator[ConnDict]:
+    """Split a set of connection params on the single attempts to perforn.
+
+    A connection param can perform more than one attempt more than one ``host``
+    is provided.
+
+    Because the libpq async function doesn't honour the timeout, we need to
+    reimplement the repeated attempts.
+    """
+    for attempt in _split_attempts(_inject_defaults(params)):
+        yield attempt
+
+
+async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]:
+    """Split a set of connection params on the single attempts to perforn.
+
+    A connection param can perform more than one attempt more than one ``host``
+    is provided.
+
+    Also perform async resolution of the hostname into hostaddr in order to
+    avoid blocking. Because a host can resolve to more than one address, this
+    can lead to yield more attempts too. Raise `OperationalError` if no host
+    could be resolved.
+
+    Because the libpq async function doesn't honour the timeout, we need to
+    reimplement the repeated attempts.
+    """
+    yielded = False
+    last_exc = None
+    for attempt in _split_attempts(_inject_defaults(params)):
+        try:
+            async for a2 in _split_attempts_and_resolve(attempt):
+                yielded = True
+                yield a2
+        except OSError as ex:
+            last_exc = ex
+
+    if not yielded:
+        assert last_exc
+        # We couldn't resolve anything
+        raise e.OperationalError(str(last_exc))
+
+
 def _inject_defaults(params: ConnDict) -> ConnDict:
     """
     Add defaults to a dictionary of parameters.
index e50f5ec05f28b460cc7c9c349f055db763a64ad3..1d566b5e5dc178aefc104415e0e8932b90f27779 100644 (file)
@@ -60,7 +60,7 @@ class Proxy:
         # Get server params
         host = cdict.get("host") or os.environ.get("PGHOST")
         self.server_host = host if host and not host.startswith("/") else "localhost"
-        self.server_port = cdict.get("port", "5432")
+        self.server_port = cdict.get("port") or os.environ.get("PGPORT", "5432")
 
         # Get client params
         self.client_host = "localhost"
index 7314f6f31a93b6e676ac485728caf1d7eaabe03d..ddfff5311f82f14e3b7b5033c3c85dea9b058ee6 100644 (file)
@@ -15,6 +15,8 @@ from .utils import gc_collect
 from .test_cursor import my_row_factory
 from .test_adapt import make_bin_dumper, make_dumper
 
+DEFAULT_TIMEOUT = psycopg.Connection._DEFAULT_CONNECT_TIMEOUT
+
 
 def test_connect(conn_cls, dsn):
     conn = conn_cls.connect(dsn)
@@ -359,25 +361,26 @@ def test_autocommit_unknown(conn):
         (("host=foo user=bar",), {}, "host=foo user=bar"),
         (("host=foo",), {"user": "baz"}, "host=foo user=baz"),
         (
-            ("host=foo port=5432",),
-            {"host": "qux", "user": "joe"},
-            "host=qux user=joe port=5432",
+            ("dbname=foo port=5433",),
+            {"dbname": "qux", "user": "joe"},
+            "dbname=qux user=joe port=5433",
         ),
         (("host=foo",), {"user": None}, "host=foo"),
     ],
 )
-def test_connect_args(conn_cls, monkeypatch, pgconn, args, kwargs, want):
-    the_conninfo: str
+def test_connect_args(conn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, want):
+    got_conninfo: str
 
     def fake_connect(conninfo):
-        nonlocal the_conninfo
-        the_conninfo = conninfo
+        nonlocal got_conninfo
+        got_conninfo = conninfo
         return pgconn
         yield
 
     monkeypatch.setattr(psycopg.connection, "connect", fake_connect)
     conn = conn_cls.connect(*args, **kwargs)
-    assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want)
+    got_params = drop_default_args_from_conninfo(got_conninfo)
+    assert got_params == conninfo_to_dict(want)
     conn.close()
 
 
@@ -761,7 +764,7 @@ conninfo_params_timeout = [
     (
         "",
         {"dbname": "mydb", "connect_timeout": None},
-        ({"dbname": "mydb"}, None),
+        ({"dbname": "mydb"}, DEFAULT_TIMEOUT),
     ),
     (
         "",
@@ -771,7 +774,7 @@ conninfo_params_timeout = [
     (
         "dbname=postgres",
         {},
-        ({"dbname": "postgres"}, None),
+        ({"dbname": "postgres"}, DEFAULT_TIMEOUT),
     ),
     (
         "dbname=postgres connect_timeout=2",
@@ -790,8 +793,8 @@ conninfo_params_timeout = [
 def test_get_connection_params(conn_cls, dsn, kwargs, exp):
     params = conn_cls._get_connection_params(dsn, **kwargs)
     conninfo = make_conninfo(**params)
-    assert conninfo_to_dict(conninfo) == exp[0]
-    assert params.get("connect_timeout") == exp[1]
+    assert drop_default_args_from_conninfo(conninfo) == exp[0]
+    assert params["connect_timeout"] == exp[1]
 
 
 def test_connect_context(conn_cls, dsn):
@@ -824,3 +827,18 @@ def test_connect_context_copy(conn_cls, dsn, conn):
 def test_cancel_closed(conn):
     conn.close()
     conn.cancel()
+
+
+def drop_default_args_from_conninfo(conninfo):
+    params = conninfo_to_dict(conninfo)
+
+    def removeif(key, value):
+        if params.get(key) == value:
+            params.pop(key)
+
+    removeif("host", "")
+    removeif("hostaddr", "")
+    removeif("port", "5432")
+    removeif("connect_timeout", str(DEFAULT_TIMEOUT))
+
+    return params
index 61277872f0da48967dc68f727fdd86ba67fe3e6b..87d8a4ee6cf9cc17998dea21d50381375aa246b2 100644 (file)
@@ -12,7 +12,7 @@ from psycopg.conninfo import conninfo_to_dict, make_conninfo
 from .utils import gc_collect
 from .test_cursor import my_row_factory
 from .test_connection import tx_params, tx_params_isolation, tx_values_map
-from .test_connection import conninfo_params_timeout
+from .test_connection import conninfo_params_timeout, drop_default_args_from_conninfo
 from .test_connection import testctx  # noqa: F401  # fixture
 from .test_adapt import make_bin_dumper, make_dumper
 from .test_conninfo import fake_resolve  # noqa: F401
@@ -357,9 +357,9 @@ async def test_autocommit_unknown(aconn):
         (("dbname=foo user=bar",), {}, "dbname=foo user=bar"),
         (("dbname=foo",), {"user": "baz"}, "dbname=foo user=baz"),
         (
-            ("dbname=foo port=5432",),
+            ("dbname=foo port=5433",),
             {"dbname": "qux", "user": "joe"},
-            "dbname=qux user=joe port=5432",
+            "dbname=qux user=joe port=5433",
         ),
         (("dbname=foo",), {"user": None}, "dbname=foo"),
     ],
@@ -367,18 +367,19 @@ async def test_autocommit_unknown(aconn):
 async def test_connect_args(
     aconn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, want
 ):
-    the_conninfo: str
+    got_conninfo: str
 
     def fake_connect(conninfo):
-        nonlocal the_conninfo
-        the_conninfo = conninfo
+        nonlocal got_conninfo
+        got_conninfo = conninfo
         return pgconn
         yield
 
     setpgenv({})
     monkeypatch.setattr(psycopg.connection, "connect", fake_connect)
     conn = await aconn_cls.connect(*args, **kwargs)
-    assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want)
+    got_params = drop_default_args_from_conninfo(got_conninfo)
+    assert got_params == conninfo_to_dict(want)
     await conn.close()
 
 
@@ -724,7 +725,7 @@ async def test_get_connection_params(aconn_cls, dsn, kwargs, exp, setpgenv):
     setpgenv({})
     params = await aconn_cls._get_connection_params(dsn, **kwargs)
     conninfo = make_conninfo(**params)
-    assert conninfo_to_dict(conninfo) == exp[0]
+    assert drop_default_args_from_conninfo(conninfo) == exp[0]
     assert params["connect_timeout"] == exp[1]
 
 
@@ -774,4 +775,4 @@ async def test_resolve_hostaddr_conn(monkeypatch, fake_resolve):  # noqa: F811
 
     assert len(got) == 1
     want = {"host": "foo.com", "hostaddr": "1.1.1.1"}
-    assert conninfo_to_dict(got[0]) == want
+    assert drop_default_args_from_conninfo(got[0]) == want
index 2eb5569df944309ac2ae0814145365ab11adfa56..efbf6f5030825258133d6333a62a092861429dec 100644 (file)
@@ -3,9 +3,29 @@ import pytest
 import psycopg
 from psycopg.conninfo import conninfo_to_dict
 
-pytestmark = [pytest.mark.dns]
+from .test_conninfo import fake_resolve  # noqa: F401  # fixture
+from .test_connection import drop_default_args_from_conninfo
 
 
+@pytest.mark.usefixtures("fake_resolve")
+async def test_resolve_hostaddr_conn(aconn_cls, monkeypatch):
+    got = []
+
+    def fake_connect_gen(conninfo, **kwargs):
+        got.append(conninfo)
+        1 / 0
+
+    monkeypatch.setattr(aconn_cls, "_connect_gen", fake_connect_gen)
+
+    with pytest.raises(ZeroDivisionError):
+        await aconn_cls.connect("host=foo.com")
+
+    assert len(got) == 1
+    want = {"host": "foo.com", "hostaddr": "1.1.1.1"}
+    assert drop_default_args_from_conninfo(got[0]) == want
+
+
+@pytest.mark.dns
 @pytest.mark.anyio
 async def test_resolve_hostaddr_async_warning(recwarn):
     import_dnspython()
@@ -14,7 +34,6 @@ async def test_resolve_hostaddr_async_warning(recwarn):
     params = await psycopg._dns.resolve_hostaddr_async(  # type: ignore[attr-defined]
         params
     )
-    assert conninfo_to_dict(conninfo) == params
     assert "resolve_hostaddr_async" in str(recwarn.pop(DeprecationWarning).message)
 
 
index 794ef0f89ec6ca2db34de070ff23a6561b7a314b..030b75808aaf2e2613e3467e136d865f5ad9489f 100644 (file)
@@ -1,10 +1,13 @@
 import pytest
 
 from psycopg._cmodule import _psycopg
+from psycopg.conninfo import conninfo_to_dict
+
+from .test_connection import drop_default_args_from_conninfo
 
 
 @pytest.mark.parametrize(
-    "args, kwargs, want_conninfo",
+    "args, kwargs, want",
     [
         ((), {}, ""),
         (("dbname=foo",), {"user": "bar"}, "dbname=foo user=bar"),
@@ -12,7 +15,7 @@ from psycopg._cmodule import _psycopg
         ((), {"user": "foo", "dbname": None}, "user=foo"),
     ],
 )
-def test_connect(monkeypatch, dsn, args, kwargs, want_conninfo):
+def test_connect(monkeypatch, dsn, args, kwargs, want):
     # Check the main args passing from psycopg.connect to the conn generator
     # Details of the params manipulation are in test_conninfo.
     import psycopg.connection
@@ -29,7 +32,7 @@ def test_connect(monkeypatch, dsn, args, kwargs, want_conninfo):
     monkeypatch.setattr(psycopg.connection, "connect", mock_connect)
 
     conn = psycopg.connect(*args, **kwargs)
-    assert got_conninfo == want_conninfo
+    assert drop_default_args_from_conninfo(got_conninfo) == conninfo_to_dict(want)
     conn.close()
 
 
index 82a5d730c7cc21b2e0f205ff588bb153e3337ad3..bc8b1cc95bb244b319c50817557172f0512bcce8 100644 (file)
@@ -7,6 +7,7 @@ from psycopg.conninfo import conninfo_to_dict
 
 from . import dbapi20
 from . import dbapi20_tpc
+from .test_connection import drop_default_args_from_conninfo
 
 
 @pytest.fixture(scope="class")
@@ -125,25 +126,25 @@ def test_time_from_ticks(ticks, want):
         (("host=foo user=bar",), {}, "host=foo user=bar"),
         (("host=foo",), {"user": "baz"}, "host=foo user=baz"),
         (
-            ("host=foo port=5432",),
+            ("host=foo port=5433",),
             {"host": "qux", "user": "joe"},
-            "host=qux user=joe port=5432",
+            "host=qux user=joe port=5433",
         ),
         (("host=foo",), {"user": None}, "host=foo"),
     ],
 )
 def test_connect_args(monkeypatch, pgconn, args, kwargs, want):
-    the_conninfo: str
+    got_conninfo: str
 
     def fake_connect(conninfo):
-        nonlocal the_conninfo
-        the_conninfo = conninfo
+        nonlocal got_conninfo
+        got_conninfo = conninfo
         return pgconn
         yield
 
     monkeypatch.setattr(psycopg.connection, "connect", fake_connect)
     conn = psycopg.connect(*args, **kwargs)
-    assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want)
+    assert drop_default_args_from_conninfo(got_conninfo) == conninfo_to_dict(want)
     conn.close()