]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat: explicitly iterate on multiple hosts on connections
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 26 Oct 2023 21:16:01 +0000 (23:16 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 13 Nov 2023 19:54:13 +0000 (19:54 +0000)
The libpq async connection path doesn't iterate on the attempts, so we
need to do it ourselves.

14 files changed:
docs/api/dns.rst
docs/news.rst
psycopg/psycopg/_connection_base.py
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
psycopg/psycopg/conninfo.py
tests/_test_connection.py
tests/fix_proxy.py
tests/test_connection.py
tests/test_connection_async.py
tests/test_dns.py
tests/test_module.py
tests/test_psycopg_dbapi20.py
tools/async_to_sync.py

index e80f4d5943698470080194a4a87eff73466e0cd7..b109c2716a8c27421edace9eb5806120864b0830 100644 (file)
@@ -92,12 +92,6 @@ server before performing a connection.
     .. warning::
         This is an experimental method.
 
-    .. versionchanged:: 3.1
-        Unlike the sync counterpart, perform non-blocking address
-        resolution and populate the ``hostaddr`` connection parameter,
-        unless the user has provided one themselves. See
-        `resolve_hostaddr_async()` for details.
-
 
 .. function:: resolve_hostaddr_async(params)
     :async:
index b4e4fd97c76e5dce604dc56bd9ec31f828d05975..ded0549065206dc165e9528e264849c9da9d53cf 100644 (file)
@@ -38,6 +38,7 @@ Psycopg 3.1.13 (unreleased)
   `~zoneinfo.ZoneInfo` (ambiguous offset, see :ticket:`#652`).
 - Handle gracefully EINTR on signals instead of raising `InterruptedError`,
   consistently with :pep:`475` guideline (:ticket:`#667`).
+- Fix support for connection strings with multiple hosts (:ticket:`#674`).
 
 
 Current release
index cf391998bd39ae16bb0900a67e65c7c9f2f5d869..19e3d1b848f3299cd5fdf72861ec442465064928 100644 (file)
@@ -97,6 +97,11 @@ class BaseConnection(Generic[Row]):
     ConnStatus = pq.ConnStatus
     TransactionStatus = pq.TransactionStatus
 
+    # Default timeout for connection a attempt.
+    # Arbitrary timeout, what applied by the libpq on my computer.
+    # Your mileage won't vary.
+    _DEFAULT_CONNECT_TIMEOUT = 130
+
     def __init__(self, pgconn: "PGconn"):
         self.pgconn = pgconn
         self._autocommit = False
index 7d2418d8163b7f33cf1507a5e1e4acc759f1901c..c3bd099c2554d346aa1d00bc7bbcad84010059ef 100644 (file)
@@ -11,7 +11,7 @@ from __future__ import annotations
 
 import logging
 from types import TracebackType
-from typing import Any, Generator, Iterator, Dict, List, Optional
+from typing import Any, Generator, Iterator, List, Optional
 from typing import Type, TypeVar, Union, cast, overload, TYPE_CHECKING
 from contextlib import contextmanager
 
@@ -23,7 +23,7 @@ from ._tpc import Xid
 from .rows import Row, RowFactory, tuple_row, TupleRow, args_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
-from .conninfo import make_conninfo, conninfo_to_dict
+from .conninfo import ConnDict, make_conninfo, conninfo_to_dict, conninfo_attempts
 from ._pipeline import Pipeline
 from ._encodings import pgconn_encoding
 from .generators import notifies
@@ -119,14 +119,19 @@ class Connection(BaseConnection[Row]):
         """
 
         params = cls._get_connection_params(conninfo, **kwargs)
-        conninfo = make_conninfo(**params)
+        timeout = int(params["connect_timeout"])
+        rv = None
+        for attempt in conninfo_attempts(params):
+            try:
+                conninfo = make_conninfo(**attempt)
+                rv = cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout)
+                break
+            except e._NO_TRACEBACK as ex:
+                last_ex = ex
 
-        try:
-            rv = cls._wait_conn(
-                cls._connect_gen(conninfo), timeout=params["connect_timeout"]
-            )
-        except e._NO_TRACEBACK as ex:
-            raise ex.with_traceback(None)
+        if not rv:
+            assert last_ex
+            raise last_ex.with_traceback(None)
 
         rv._autocommit = bool(autocommit)
         if row_factory:
@@ -165,7 +170,7 @@ class Connection(BaseConnection[Row]):
             self.close()
 
     @classmethod
-    def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> Dict[str, Any]:
+    def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> ConnDict:
         """Manipulate connection parameters before connecting.
 
         :param conninfo: Connection string as received by `~Connection.connect()`.
@@ -179,7 +184,10 @@ class Connection(BaseConnection[Row]):
         if "connect_timeout" in params:
             params["connect_timeout"] = int(params["connect_timeout"])
         else:
-            params["connect_timeout"] = None
+            # The sync connect function will stop on the default socket timeout
+            # Because in async connection mode we need to enforce the timeout
+            # ourselves, we need a finite value.
+            params["connect_timeout"] = cls._DEFAULT_CONNECT_TIMEOUT
 
         return params
 
index 868c22ee75cd154113e72ef336f25d65fca97110..be7fca262cf4d35818b3f89f2dc185c9eb49c8ab 100644 (file)
@@ -8,7 +8,7 @@ from __future__ import annotations
 
 import logging
 from types import TracebackType
-from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional
+from typing import Any, AsyncGenerator, AsyncIterator, List, Optional
 from typing import Type, TypeVar, Union, cast, overload, TYPE_CHECKING
 from contextlib import asynccontextmanager
 
@@ -20,7 +20,7 @@ from ._tpc import Xid
 from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
-from .conninfo import make_conninfo, conninfo_to_dict
+from .conninfo import ConnDict, make_conninfo, conninfo_to_dict, conninfo_attempts_async
 from ._pipeline import AsyncPipeline
 from ._encodings import pgconn_encoding
 from .generators import notifies
@@ -33,7 +33,6 @@ if True:  # ASYNC
     import sys
     import asyncio
     from asyncio import Lock
-    from .conninfo import resolve_hostaddr_async
 else:
     from threading import Lock
 
@@ -135,14 +134,19 @@ class AsyncConnection(BaseConnection[Row]):
                     )
 
         params = await cls._get_connection_params(conninfo, **kwargs)
-        conninfo = make_conninfo(**params)
+        timeout = int(params["connect_timeout"])
+        rv = None
+        async for attempt in conninfo_attempts_async(params):
+            try:
+                conninfo = make_conninfo(**attempt)
+                rv = await cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout)
+                break
+            except e._NO_TRACEBACK as ex:
+                last_ex = ex
 
-        try:
-            rv = await cls._wait_conn(
-                cls._connect_gen(conninfo), timeout=params["connect_timeout"]
-            )
-        except e._NO_TRACEBACK as ex:
-            raise ex.with_traceback(None)
+        if not rv:
+            assert last_ex
+            raise last_ex.with_traceback(None)
 
         rv._autocommit = bool(autocommit)
         if row_factory:
@@ -181,9 +185,7 @@ class AsyncConnection(BaseConnection[Row]):
             await self.close()
 
     @classmethod
-    async def _get_connection_params(
-        cls, conninfo: str, **kwargs: Any
-    ) -> Dict[str, Any]:
+    async def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> ConnDict:
         """Manipulate connection parameters before connecting.
 
         :param conninfo: Connection string as received by `~Connection.connect()`.
@@ -197,11 +199,10 @@ class AsyncConnection(BaseConnection[Row]):
         if "connect_timeout" in params:
             params["connect_timeout"] = int(params["connect_timeout"])
         else:
-            params["connect_timeout"] = None
-
-        if True:  # ASYNC
-            # Resolve host addresses in non-blocking way
-            params = await resolve_hostaddr_async(params)
+            # The sync connect function will stop on the default socket timeout
+            # Because in async connection mode we need to enforce the timeout
+            # ourselves, we need a finite value.
+            params["connect_timeout"] = cls._DEFAULT_CONNECT_TIMEOUT
 
         return params
 
index efbb2be315e99a59ebbafb965192d92a19876bd2..4f633fff4a09c9c5bb72758af0c5d0ae7c637344 100644 (file)
@@ -330,6 +330,49 @@ async def resolve_hostaddr_async(params: ConnDict) -> ConnDict:
     return out
 
 
+def conninfo_attempts(params: ConnDict) -> Iterator[ConnDict]:
+    """Split a set of connection params on the single attempts to perforn.
+
+    A connection param can perform more than one attempt more than one ``host``
+    is provided.
+
+    Because the libpq async function doesn't honour the timeout, we need to
+    reimplement the repeated attempts.
+    """
+    for attempt in _split_attempts(_inject_defaults(params)):
+        yield attempt
+
+
+async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]:
+    """Split a set of connection params on the single attempts to perforn.
+
+    A connection param can perform more than one attempt more than one ``host``
+    is provided.
+
+    Also perform async resolution of the hostname into hostaddr in order to
+    avoid blocking. Because a host can resolve to more than one address, this
+    can lead to yield more attempts too. Raise `OperationalError` if no host
+    could be resolved.
+
+    Because the libpq async function doesn't honour the timeout, we need to
+    reimplement the repeated attempts.
+    """
+    yielded = False
+    last_exc = None
+    for attempt in _split_attempts(_inject_defaults(params)):
+        try:
+            async for a2 in _split_attempts_and_resolve(attempt):
+                yielded = True
+                yield a2
+        except OSError as ex:
+            last_exc = ex
+
+    if not yielded:
+        assert last_exc
+        # We couldn't resolve anything
+        raise e.OperationalError(str(last_exc))
+
+
 def _inject_defaults(params: ConnDict) -> ConnDict:
     """
     Add defaults to a dictionary of parameters.
index 296a7f7f4b95a58246d30fca4da4b8610e68c09f..16dd8d011a99e73f9683f345962611e2611529b3 100644 (file)
@@ -7,6 +7,10 @@ from dataclasses import dataclass
 
 import pytest
 import psycopg
+from psycopg.conninfo import conninfo_to_dict
+from psycopg._connection_base import BaseConnection
+
+DEFAULT_TIMEOUT = BaseConnection._DEFAULT_CONNECT_TIMEOUT
 
 
 @pytest.fixture
@@ -75,7 +79,7 @@ conninfo_params_timeout = [
     (
         "",
         {"dbname": "mydb", "connect_timeout": None},
-        ({"dbname": "mydb"}, None),
+        ({"dbname": "mydb"}, DEFAULT_TIMEOUT),
     ),
     (
         "",
@@ -85,7 +89,7 @@ conninfo_params_timeout = [
     (
         "dbname=postgres",
         {},
-        ({"dbname": "postgres"}, None),
+        ({"dbname": "postgres"}, DEFAULT_TIMEOUT),
     ),
     (
         "dbname=postgres connect_timeout=2",
@@ -98,3 +102,18 @@ conninfo_params_timeout = [
         ({"dbname": "postgres", "connect_timeout": "10"}, 10),
     ),
 ]
+
+
+def drop_default_args_from_conninfo(conninfo):
+    params = conninfo_to_dict(conninfo)
+
+    def removeif(key, value):
+        if params.get(key) == value:
+            params.pop(key)
+
+    removeif("host", "")
+    removeif("hostaddr", "")
+    removeif("port", "5432")
+    removeif("connect_timeout", str(DEFAULT_TIMEOUT))
+
+    return params
index e50f5ec05f28b460cc7c9c349f055db763a64ad3..1d566b5e5dc178aefc104415e0e8932b90f27779 100644 (file)
@@ -60,7 +60,7 @@ class Proxy:
         # Get server params
         host = cdict.get("host") or os.environ.get("PGHOST")
         self.server_host = host if host and not host.startswith("/") else "localhost"
-        self.server_port = cdict.get("port", "5432")
+        self.server_port = cdict.get("port") or os.environ.get("PGPORT", "5432")
 
         # Get client params
         self.client_host = "localhost"
index 1ead1ba9de39c6032ad5e5f1dca911aace36fc06..e381d692a6247db9d618cf85edcd19d9a29d955a 100644 (file)
@@ -17,7 +17,7 @@ from .utils import gc_collect
 from .acompat import is_async, skip_sync, skip_async
 from ._test_cursor import my_row_factory
 from ._test_connection import tx_params, tx_params_isolation, tx_values_map
-from ._test_connection import conninfo_params_timeout
+from ._test_connection import conninfo_params_timeout, drop_default_args_from_conninfo
 from ._test_connection import testctx  # noqa: F401  # fixture
 from .test_adapt import make_bin_dumper, make_dumper
 
@@ -399,26 +399,27 @@ def test_autocommit_unknown(conn):
         (("dbname=foo user=bar",), {}, "dbname=foo user=bar"),
         (("dbname=foo",), {"user": "baz"}, "dbname=foo user=baz"),
         (
-            ("dbname=foo port=5432",),
+            ("dbname=foo port=5433",),
             {"dbname": "qux", "user": "joe"},
-            "dbname=qux user=joe port=5432",
+            "dbname=qux user=joe port=5433",
         ),
         (("dbname=foo",), {"user": None}, "dbname=foo"),
     ],
 )
 def test_connect_args(conn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, want):
-    the_conninfo: str
+    got_conninfo: str
 
     def fake_connect(conninfo):
-        nonlocal the_conninfo
-        the_conninfo = conninfo
+        nonlocal got_conninfo
+        got_conninfo = conninfo
         return pgconn
         yield
 
     setpgenv({})
     monkeypatch.setattr(psycopg.generators, "connect", fake_connect)
     conn = conn_cls.connect(*args, **kwargs)
-    assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want)
+    got_params = drop_default_args_from_conninfo(got_conninfo)
+    assert got_params == conninfo_to_dict(want)
     conn.close()
 
 
@@ -790,7 +791,7 @@ def test_get_connection_params(conn_cls, dsn, kwargs, exp, setpgenv):
     setpgenv({})
     params = conn_cls._get_connection_params(dsn, **kwargs)
     conninfo = make_conninfo(**params)
-    assert conninfo_to_dict(conninfo) == exp[0]
+    assert drop_default_args_from_conninfo(conninfo) == exp[0]
     assert params["connect_timeout"] == exp[1]
 
 
index 36a6c9f9feb354985f9f0857bebf3a7ca6355389..77a7b62fd2bb0c3bdf848c24d76d66633b773ce6 100644 (file)
@@ -14,7 +14,7 @@ from .utils import gc_collect
 from .acompat import is_async, skip_sync, skip_async
 from ._test_cursor import my_row_factory
 from ._test_connection import tx_params, tx_params_isolation, tx_values_map
-from ._test_connection import conninfo_params_timeout
+from ._test_connection import conninfo_params_timeout, drop_default_args_from_conninfo
 from ._test_connection import testctx  # noqa: F401  # fixture
 from .test_adapt import make_bin_dumper, make_dumper
 
@@ -397,9 +397,9 @@ async def test_autocommit_unknown(aconn):
         (("dbname=foo user=bar",), {}, "dbname=foo user=bar"),
         (("dbname=foo",), {"user": "baz"}, "dbname=foo user=baz"),
         (
-            ("dbname=foo port=5432",),
+            ("dbname=foo port=5433",),
             {"dbname": "qux", "user": "joe"},
-            "dbname=qux user=joe port=5432",
+            "dbname=qux user=joe port=5433",
         ),
         (("dbname=foo",), {"user": None}, "dbname=foo"),
     ],
@@ -407,18 +407,19 @@ async def test_autocommit_unknown(aconn):
 async def test_connect_args(
     aconn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, want
 ):
-    the_conninfo: str
+    got_conninfo: str
 
     def fake_connect(conninfo):
-        nonlocal the_conninfo
-        the_conninfo = conninfo
+        nonlocal got_conninfo
+        got_conninfo = conninfo
         return pgconn
         yield
 
     setpgenv({})
     monkeypatch.setattr(psycopg.generators, "connect", fake_connect)
     conn = await aconn_cls.connect(*args, **kwargs)
-    assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want)
+    got_params = drop_default_args_from_conninfo(got_conninfo)
+    assert got_params == conninfo_to_dict(want)
     await conn.close()
 
 
@@ -798,7 +799,7 @@ async def test_get_connection_params(aconn_cls, dsn, kwargs, exp, setpgenv):
     setpgenv({})
     params = await aconn_cls._get_connection_params(dsn, **kwargs)
     conninfo = make_conninfo(**params)
-    assert conninfo_to_dict(conninfo) == exp[0]
+    assert drop_default_args_from_conninfo(conninfo) == exp[0]
     assert params["connect_timeout"] == exp[1]
 
 
index b1e8891155f523f3730d1078d2808c3e6a495544..a118786e527be88fa4cae4eba2b20513e82ccecd 100644 (file)
@@ -4,6 +4,7 @@ import psycopg
 from psycopg.conninfo import conninfo_to_dict
 
 from .test_conninfo import fake_resolve  # noqa: F401  # fixture
+from ._test_connection import drop_default_args_from_conninfo
 
 
 @pytest.mark.usefixtures("fake_resolve")
@@ -21,7 +22,7 @@ async def test_resolve_hostaddr_conn(aconn_cls, monkeypatch):
 
     assert len(got) == 1
     want = {"host": "foo.com", "hostaddr": "1.1.1.1"}
-    assert conninfo_to_dict(got[0]) == want
+    assert drop_default_args_from_conninfo(got[0]) == want
 
 
 @pytest.mark.dns
@@ -33,7 +34,6 @@ async def test_resolve_hostaddr_async_warning(recwarn):
     params = await psycopg._dns.resolve_hostaddr_async(  # type: ignore[attr-defined]
         params
     )
-    assert conninfo_to_dict(conninfo) == params
     assert "resolve_hostaddr_async" in str(recwarn.pop(DeprecationWarning).message)
 
 
index c6b3e08e312cfd16950166af255fb62943f087db..7f667d08a50d2345e4717aee7af5519dd908399d 100644 (file)
@@ -1,10 +1,13 @@
 import pytest
 
 from psycopg._cmodule import _psycopg
+from psycopg.conninfo import conninfo_to_dict
+
+from ._test_connection import drop_default_args_from_conninfo
 
 
 @pytest.mark.parametrize(
-    "args, kwargs, want_conninfo",
+    "args, kwargs, want",
     [
         ((), {}, ""),
         (("dbname=foo",), {"user": "bar"}, "dbname=foo user=bar"),
@@ -12,7 +15,7 @@ from psycopg._cmodule import _psycopg
         ((), {"user": "foo", "dbname": None}, "user=foo"),
     ],
 )
-def test_connect(monkeypatch, dsn, args, kwargs, want_conninfo):
+def test_connect(monkeypatch, dsn, args, kwargs, want):
     # Check the main args passing from psycopg.connect to the conn generator
     # Details of the params manipulation are in test_conninfo.
     import psycopg.connection
@@ -29,7 +32,7 @@ def test_connect(monkeypatch, dsn, args, kwargs, want_conninfo):
     monkeypatch.setattr(psycopg.generators, "connect", mock_connect)
 
     conn = psycopg.connect(*args, **kwargs)
-    assert got_conninfo == want_conninfo
+    assert drop_default_args_from_conninfo(got_conninfo) == conninfo_to_dict(want)
     conn.close()
 
 
index 69d4e8d8aa1004d0c50358fe905fd16895e6e027..8e51466d682580f805f7af6324731382934b9724 100644 (file)
@@ -7,6 +7,7 @@ from psycopg.conninfo import conninfo_to_dict
 
 from . import dbapi20
 from . import dbapi20_tpc
+from ._test_connection import drop_default_args_from_conninfo
 
 
 @pytest.fixture(scope="class")
@@ -125,25 +126,25 @@ def test_time_from_ticks(ticks, want):
         (("host=foo user=bar",), {}, "host=foo user=bar"),
         (("host=foo",), {"user": "baz"}, "host=foo user=baz"),
         (
-            ("host=foo port=5432",),
+            ("host=foo port=5433",),
             {"host": "qux", "user": "joe"},
-            "host=qux user=joe port=5432",
+            "host=qux user=joe port=5433",
         ),
         (("host=foo",), {"user": None}, "host=foo"),
     ],
 )
 def test_connect_args(monkeypatch, pgconn, args, kwargs, want):
-    the_conninfo: str
+    got_conninfo: str
 
     def fake_connect(conninfo):
-        nonlocal the_conninfo
-        the_conninfo = conninfo
+        nonlocal got_conninfo
+        got_conninfo = conninfo
         return pgconn
         yield
 
     monkeypatch.setattr(psycopg.generators, "connect", fake_connect)
     conn = psycopg.connect(*args, **kwargs)
-    assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want)
+    assert drop_default_args_from_conninfo(got_conninfo) == conninfo_to_dict(want)
     conn.close()
 
 
index 741c21861f216ca0935d9ba9d3af58390c3309d4..934df53a4588acd23fb2ce3ba62a7fec4c303027 100755 (executable)
@@ -308,6 +308,7 @@ class RenameAsyncToSync(ast.NodeTransformer):
         "aspawn": "spawn",
         "asynccontextmanager": "contextmanager",
         "connection_async": "connection",
+        "conninfo_attempts_async": "conninfo_attempts",
         "current_task_name": "current_thread_name",
         "cursor_async": "cursor",
         "ensure_table_async": "ensure_table",