From: Daniele Varrazzo Date: Thu, 26 Oct 2023 21:16:01 +0000 (+0200) Subject: feat: explicitly iterate on multiple hosts on connections X-Git-Tag: 3.2.0~133^2~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c1f7a4aada68c646d94581584dbc8bb2b07a7ddc;p=thirdparty%2Fpsycopg.git feat: explicitly iterate on multiple hosts on connections The libpq async connection path doesn't iterate on the attempts, so we need to do it ourselves. --- diff --git a/docs/api/dns.rst b/docs/api/dns.rst index e80f4d594..b109c2716 100644 --- a/docs/api/dns.rst +++ b/docs/api/dns.rst @@ -92,12 +92,6 @@ server before performing a connection. .. warning:: This is an experimental method. - .. versionchanged:: 3.1 - Unlike the sync counterpart, perform non-blocking address - resolution and populate the ``hostaddr`` connection parameter, - unless the user has provided one themselves. See - `resolve_hostaddr_async()` for details. - .. function:: resolve_hostaddr_async(params) :async: diff --git a/docs/news.rst b/docs/news.rst index b4e4fd97c..ded054906 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -38,6 +38,7 @@ Psycopg 3.1.13 (unreleased) `~zoneinfo.ZoneInfo` (ambiguous offset, see :ticket:`#652`). - Handle gracefully EINTR on signals instead of raising `InterruptedError`, consistently with :pep:`475` guideline (:ticket:`#667`). +- Fix support for connection strings with multiple hosts (:ticket:`#674`). Current release diff --git a/psycopg/psycopg/_connection_base.py b/psycopg/psycopg/_connection_base.py index cf391998b..19e3d1b84 100644 --- a/psycopg/psycopg/_connection_base.py +++ b/psycopg/psycopg/_connection_base.py @@ -97,6 +97,11 @@ class BaseConnection(Generic[Row]): ConnStatus = pq.ConnStatus TransactionStatus = pq.TransactionStatus + # Default timeout for connection a attempt. + # Arbitrary timeout, what applied by the libpq on my computer. + # Your mileage won't vary. + _DEFAULT_CONNECT_TIMEOUT = 130 + def __init__(self, pgconn: "PGconn"): self.pgconn = pgconn self._autocommit = False diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index 7d2418d81..c3bd099c2 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -11,7 +11,7 @@ from __future__ import annotations import logging from types import TracebackType -from typing import Any, Generator, Iterator, Dict, List, Optional +from typing import Any, Generator, Iterator, List, Optional from typing import Type, TypeVar, Union, cast, overload, TYPE_CHECKING from contextlib import contextmanager @@ -23,7 +23,7 @@ from ._tpc import Xid from .rows import Row, RowFactory, tuple_row, TupleRow, args_row from .adapt import AdaptersMap from ._enums import IsolationLevel -from .conninfo import make_conninfo, conninfo_to_dict +from .conninfo import ConnDict, make_conninfo, conninfo_to_dict, conninfo_attempts from ._pipeline import Pipeline from ._encodings import pgconn_encoding from .generators import notifies @@ -119,14 +119,19 @@ class Connection(BaseConnection[Row]): """ params = cls._get_connection_params(conninfo, **kwargs) - conninfo = make_conninfo(**params) + timeout = int(params["connect_timeout"]) + rv = None + for attempt in conninfo_attempts(params): + try: + conninfo = make_conninfo(**attempt) + rv = cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout) + break + except e._NO_TRACEBACK as ex: + last_ex = ex - try: - rv = cls._wait_conn( - cls._connect_gen(conninfo), timeout=params["connect_timeout"] - ) - except e._NO_TRACEBACK as ex: - raise ex.with_traceback(None) + if not rv: + assert last_ex + raise last_ex.with_traceback(None) rv._autocommit = bool(autocommit) if row_factory: @@ -165,7 +170,7 @@ class Connection(BaseConnection[Row]): self.close() @classmethod - def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> Dict[str, Any]: + def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> ConnDict: """Manipulate connection parameters before connecting. :param conninfo: Connection string as received by `~Connection.connect()`. @@ -179,7 +184,10 @@ class Connection(BaseConnection[Row]): if "connect_timeout" in params: params["connect_timeout"] = int(params["connect_timeout"]) else: - params["connect_timeout"] = None + # The sync connect function will stop on the default socket timeout + # Because in async connection mode we need to enforce the timeout + # ourselves, we need a finite value. + params["connect_timeout"] = cls._DEFAULT_CONNECT_TIMEOUT return params diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index 868c22ee7..be7fca262 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -8,7 +8,7 @@ from __future__ import annotations import logging from types import TracebackType -from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional +from typing import Any, AsyncGenerator, AsyncIterator, List, Optional from typing import Type, TypeVar, Union, cast, overload, TYPE_CHECKING from contextlib import asynccontextmanager @@ -20,7 +20,7 @@ from ._tpc import Xid from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row from .adapt import AdaptersMap from ._enums import IsolationLevel -from .conninfo import make_conninfo, conninfo_to_dict +from .conninfo import ConnDict, make_conninfo, conninfo_to_dict, conninfo_attempts_async from ._pipeline import AsyncPipeline from ._encodings import pgconn_encoding from .generators import notifies @@ -33,7 +33,6 @@ if True: # ASYNC import sys import asyncio from asyncio import Lock - from .conninfo import resolve_hostaddr_async else: from threading import Lock @@ -135,14 +134,19 @@ class AsyncConnection(BaseConnection[Row]): ) params = await cls._get_connection_params(conninfo, **kwargs) - conninfo = make_conninfo(**params) + timeout = int(params["connect_timeout"]) + rv = None + async for attempt in conninfo_attempts_async(params): + try: + conninfo = make_conninfo(**attempt) + rv = await cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout) + break + except e._NO_TRACEBACK as ex: + last_ex = ex - try: - rv = await cls._wait_conn( - cls._connect_gen(conninfo), timeout=params["connect_timeout"] - ) - except e._NO_TRACEBACK as ex: - raise ex.with_traceback(None) + if not rv: + assert last_ex + raise last_ex.with_traceback(None) rv._autocommit = bool(autocommit) if row_factory: @@ -181,9 +185,7 @@ class AsyncConnection(BaseConnection[Row]): await self.close() @classmethod - async def _get_connection_params( - cls, conninfo: str, **kwargs: Any - ) -> Dict[str, Any]: + async def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> ConnDict: """Manipulate connection parameters before connecting. :param conninfo: Connection string as received by `~Connection.connect()`. @@ -197,11 +199,10 @@ class AsyncConnection(BaseConnection[Row]): if "connect_timeout" in params: params["connect_timeout"] = int(params["connect_timeout"]) else: - params["connect_timeout"] = None - - if True: # ASYNC - # Resolve host addresses in non-blocking way - params = await resolve_hostaddr_async(params) + # The sync connect function will stop on the default socket timeout + # Because in async connection mode we need to enforce the timeout + # ourselves, we need a finite value. + params["connect_timeout"] = cls._DEFAULT_CONNECT_TIMEOUT return params diff --git a/psycopg/psycopg/conninfo.py b/psycopg/psycopg/conninfo.py index efbb2be31..4f633fff4 100644 --- a/psycopg/psycopg/conninfo.py +++ b/psycopg/psycopg/conninfo.py @@ -330,6 +330,49 @@ async def resolve_hostaddr_async(params: ConnDict) -> ConnDict: return out +def conninfo_attempts(params: ConnDict) -> Iterator[ConnDict]: + """Split a set of connection params on the single attempts to perforn. + + A connection param can perform more than one attempt more than one ``host`` + is provided. + + Because the libpq async function doesn't honour the timeout, we need to + reimplement the repeated attempts. + """ + for attempt in _split_attempts(_inject_defaults(params)): + yield attempt + + +async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]: + """Split a set of connection params on the single attempts to perforn. + + A connection param can perform more than one attempt more than one ``host`` + is provided. + + Also perform async resolution of the hostname into hostaddr in order to + avoid blocking. Because a host can resolve to more than one address, this + can lead to yield more attempts too. Raise `OperationalError` if no host + could be resolved. + + Because the libpq async function doesn't honour the timeout, we need to + reimplement the repeated attempts. + """ + yielded = False + last_exc = None + for attempt in _split_attempts(_inject_defaults(params)): + try: + async for a2 in _split_attempts_and_resolve(attempt): + yielded = True + yield a2 + except OSError as ex: + last_exc = ex + + if not yielded: + assert last_exc + # We couldn't resolve anything + raise e.OperationalError(str(last_exc)) + + def _inject_defaults(params: ConnDict) -> ConnDict: """ Add defaults to a dictionary of parameters. diff --git a/tests/_test_connection.py b/tests/_test_connection.py index 296a7f7f4..16dd8d011 100644 --- a/tests/_test_connection.py +++ b/tests/_test_connection.py @@ -7,6 +7,10 @@ from dataclasses import dataclass import pytest import psycopg +from psycopg.conninfo import conninfo_to_dict +from psycopg._connection_base import BaseConnection + +DEFAULT_TIMEOUT = BaseConnection._DEFAULT_CONNECT_TIMEOUT @pytest.fixture @@ -75,7 +79,7 @@ conninfo_params_timeout = [ ( "", {"dbname": "mydb", "connect_timeout": None}, - ({"dbname": "mydb"}, None), + ({"dbname": "mydb"}, DEFAULT_TIMEOUT), ), ( "", @@ -85,7 +89,7 @@ conninfo_params_timeout = [ ( "dbname=postgres", {}, - ({"dbname": "postgres"}, None), + ({"dbname": "postgres"}, DEFAULT_TIMEOUT), ), ( "dbname=postgres connect_timeout=2", @@ -98,3 +102,18 @@ conninfo_params_timeout = [ ({"dbname": "postgres", "connect_timeout": "10"}, 10), ), ] + + +def drop_default_args_from_conninfo(conninfo): + params = conninfo_to_dict(conninfo) + + def removeif(key, value): + if params.get(key) == value: + params.pop(key) + + removeif("host", "") + removeif("hostaddr", "") + removeif("port", "5432") + removeif("connect_timeout", str(DEFAULT_TIMEOUT)) + + return params diff --git a/tests/fix_proxy.py b/tests/fix_proxy.py index e50f5ec05..1d566b5e5 100644 --- a/tests/fix_proxy.py +++ b/tests/fix_proxy.py @@ -60,7 +60,7 @@ class Proxy: # Get server params host = cdict.get("host") or os.environ.get("PGHOST") self.server_host = host if host and not host.startswith("/") else "localhost" - self.server_port = cdict.get("port", "5432") + self.server_port = cdict.get("port") or os.environ.get("PGPORT", "5432") # Get client params self.client_host = "localhost" diff --git a/tests/test_connection.py b/tests/test_connection.py index 1ead1ba9d..e381d692a 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -17,7 +17,7 @@ from .utils import gc_collect from .acompat import is_async, skip_sync, skip_async from ._test_cursor import my_row_factory from ._test_connection import tx_params, tx_params_isolation, tx_values_map -from ._test_connection import conninfo_params_timeout +from ._test_connection import conninfo_params_timeout, drop_default_args_from_conninfo from ._test_connection import testctx # noqa: F401 # fixture from .test_adapt import make_bin_dumper, make_dumper @@ -399,26 +399,27 @@ def test_autocommit_unknown(conn): (("dbname=foo user=bar",), {}, "dbname=foo user=bar"), (("dbname=foo",), {"user": "baz"}, "dbname=foo user=baz"), ( - ("dbname=foo port=5432",), + ("dbname=foo port=5433",), {"dbname": "qux", "user": "joe"}, - "dbname=qux user=joe port=5432", + "dbname=qux user=joe port=5433", ), (("dbname=foo",), {"user": None}, "dbname=foo"), ], ) def test_connect_args(conn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, want): - the_conninfo: str + got_conninfo: str def fake_connect(conninfo): - nonlocal the_conninfo - the_conninfo = conninfo + nonlocal got_conninfo + got_conninfo = conninfo return pgconn yield setpgenv({}) monkeypatch.setattr(psycopg.generators, "connect", fake_connect) conn = conn_cls.connect(*args, **kwargs) - assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want) + got_params = drop_default_args_from_conninfo(got_conninfo) + assert got_params == conninfo_to_dict(want) conn.close() @@ -790,7 +791,7 @@ def test_get_connection_params(conn_cls, dsn, kwargs, exp, setpgenv): setpgenv({}) params = conn_cls._get_connection_params(dsn, **kwargs) conninfo = make_conninfo(**params) - assert conninfo_to_dict(conninfo) == exp[0] + assert drop_default_args_from_conninfo(conninfo) == exp[0] assert params["connect_timeout"] == exp[1] diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index 36a6c9f9f..77a7b62fd 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -14,7 +14,7 @@ from .utils import gc_collect from .acompat import is_async, skip_sync, skip_async from ._test_cursor import my_row_factory from ._test_connection import tx_params, tx_params_isolation, tx_values_map -from ._test_connection import conninfo_params_timeout +from ._test_connection import conninfo_params_timeout, drop_default_args_from_conninfo from ._test_connection import testctx # noqa: F401 # fixture from .test_adapt import make_bin_dumper, make_dumper @@ -397,9 +397,9 @@ async def test_autocommit_unknown(aconn): (("dbname=foo user=bar",), {}, "dbname=foo user=bar"), (("dbname=foo",), {"user": "baz"}, "dbname=foo user=baz"), ( - ("dbname=foo port=5432",), + ("dbname=foo port=5433",), {"dbname": "qux", "user": "joe"}, - "dbname=qux user=joe port=5432", + "dbname=qux user=joe port=5433", ), (("dbname=foo",), {"user": None}, "dbname=foo"), ], @@ -407,18 +407,19 @@ async def test_autocommit_unknown(aconn): async def test_connect_args( aconn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, want ): - the_conninfo: str + got_conninfo: str def fake_connect(conninfo): - nonlocal the_conninfo - the_conninfo = conninfo + nonlocal got_conninfo + got_conninfo = conninfo return pgconn yield setpgenv({}) monkeypatch.setattr(psycopg.generators, "connect", fake_connect) conn = await aconn_cls.connect(*args, **kwargs) - assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want) + got_params = drop_default_args_from_conninfo(got_conninfo) + assert got_params == conninfo_to_dict(want) await conn.close() @@ -798,7 +799,7 @@ async def test_get_connection_params(aconn_cls, dsn, kwargs, exp, setpgenv): setpgenv({}) params = await aconn_cls._get_connection_params(dsn, **kwargs) conninfo = make_conninfo(**params) - assert conninfo_to_dict(conninfo) == exp[0] + assert drop_default_args_from_conninfo(conninfo) == exp[0] assert params["connect_timeout"] == exp[1] diff --git a/tests/test_dns.py b/tests/test_dns.py index b1e889115..a118786e5 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -4,6 +4,7 @@ import psycopg from psycopg.conninfo import conninfo_to_dict from .test_conninfo import fake_resolve # noqa: F401 # fixture +from ._test_connection import drop_default_args_from_conninfo @pytest.mark.usefixtures("fake_resolve") @@ -21,7 +22,7 @@ async def test_resolve_hostaddr_conn(aconn_cls, monkeypatch): assert len(got) == 1 want = {"host": "foo.com", "hostaddr": "1.1.1.1"} - assert conninfo_to_dict(got[0]) == want + assert drop_default_args_from_conninfo(got[0]) == want @pytest.mark.dns @@ -33,7 +34,6 @@ async def test_resolve_hostaddr_async_warning(recwarn): params = await psycopg._dns.resolve_hostaddr_async( # type: ignore[attr-defined] params ) - assert conninfo_to_dict(conninfo) == params assert "resolve_hostaddr_async" in str(recwarn.pop(DeprecationWarning).message) diff --git a/tests/test_module.py b/tests/test_module.py index c6b3e08e3..7f667d08a 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -1,10 +1,13 @@ import pytest from psycopg._cmodule import _psycopg +from psycopg.conninfo import conninfo_to_dict + +from ._test_connection import drop_default_args_from_conninfo @pytest.mark.parametrize( - "args, kwargs, want_conninfo", + "args, kwargs, want", [ ((), {}, ""), (("dbname=foo",), {"user": "bar"}, "dbname=foo user=bar"), @@ -12,7 +15,7 @@ from psycopg._cmodule import _psycopg ((), {"user": "foo", "dbname": None}, "user=foo"), ], ) -def test_connect(monkeypatch, dsn, args, kwargs, want_conninfo): +def test_connect(monkeypatch, dsn, args, kwargs, want): # Check the main args passing from psycopg.connect to the conn generator # Details of the params manipulation are in test_conninfo. import psycopg.connection @@ -29,7 +32,7 @@ def test_connect(monkeypatch, dsn, args, kwargs, want_conninfo): monkeypatch.setattr(psycopg.generators, "connect", mock_connect) conn = psycopg.connect(*args, **kwargs) - assert got_conninfo == want_conninfo + assert drop_default_args_from_conninfo(got_conninfo) == conninfo_to_dict(want) conn.close() diff --git a/tests/test_psycopg_dbapi20.py b/tests/test_psycopg_dbapi20.py index 69d4e8d8a..8e51466d6 100644 --- a/tests/test_psycopg_dbapi20.py +++ b/tests/test_psycopg_dbapi20.py @@ -7,6 +7,7 @@ from psycopg.conninfo import conninfo_to_dict from . import dbapi20 from . import dbapi20_tpc +from ._test_connection import drop_default_args_from_conninfo @pytest.fixture(scope="class") @@ -125,25 +126,25 @@ def test_time_from_ticks(ticks, want): (("host=foo user=bar",), {}, "host=foo user=bar"), (("host=foo",), {"user": "baz"}, "host=foo user=baz"), ( - ("host=foo port=5432",), + ("host=foo port=5433",), {"host": "qux", "user": "joe"}, - "host=qux user=joe port=5432", + "host=qux user=joe port=5433", ), (("host=foo",), {"user": None}, "host=foo"), ], ) def test_connect_args(monkeypatch, pgconn, args, kwargs, want): - the_conninfo: str + got_conninfo: str def fake_connect(conninfo): - nonlocal the_conninfo - the_conninfo = conninfo + nonlocal got_conninfo + got_conninfo = conninfo return pgconn yield monkeypatch.setattr(psycopg.generators, "connect", fake_connect) conn = psycopg.connect(*args, **kwargs) - assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want) + assert drop_default_args_from_conninfo(got_conninfo) == conninfo_to_dict(want) conn.close() diff --git a/tools/async_to_sync.py b/tools/async_to_sync.py index 741c21861..934df53a4 100755 --- a/tools/async_to_sync.py +++ b/tools/async_to_sync.py @@ -308,6 +308,7 @@ class RenameAsyncToSync(ast.NodeTransformer): "aspawn": "spawn", "asynccontextmanager": "contextmanager", "connection_async": "connection", + "conninfo_attempts_async": "conninfo_attempts", "current_task_name": "current_thread_name", "cursor_async": "cursor", "ensure_table_async": "ensure_table",