From e8daf4f1b9aa8507be9f70a7c7b3f73bb13e7a8e Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Thu, 26 Oct 2023 23:16:01 +0200 Subject: [PATCH] feat: explicitly iterate on multiple hosts on connections The libpq async connection path doesn't iterate on the attempts, so we need to do it ourselves. --- docs/news.rst | 1 + psycopg/psycopg/connection.py | 34 +++++++++++++++------- psycopg/psycopg/connection_async.py | 45 +++++++++++++---------------- psycopg/psycopg/conninfo.py | 43 +++++++++++++++++++++++++++ tests/fix_proxy.py | 2 +- tests/test_connection.py | 42 +++++++++++++++++++-------- tests/test_connection_async.py | 19 ++++++------ tests/test_dns.py | 23 +++++++++++++-- tests/test_module.py | 9 ++++-- tests/test_psycopg_dbapi20.py | 13 +++++---- 10 files changed, 163 insertions(+), 68 deletions(-) diff --git a/docs/news.rst b/docs/news.rst index 7ce52eff1..cc456e5b1 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -18,6 +18,7 @@ Psycopg 3.1.13 (unreleased) `~zoneinfo.ZoneInfo` (ambiguous offset, see :ticket:`#652`). - Handle gracefully EINTR on signals instead of raising `InterruptedError`, consistently with :pep:`475` guideline (:ticket:`#667`). +- Fix support for connection strings with multiple hosts (:ticket:`#674`). Current release diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index 5f3437321..e90e30adb 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -7,7 +7,7 @@ psycopg connection objects import logging import threading from types import TracebackType -from typing import Any, Callable, cast, Dict, Generator, Generic, Iterator +from typing import Any, Callable, cast, Generator, Generic, Iterator from typing import List, NamedTuple, Optional, Type, TypeVar, Tuple, Union from typing import overload, TYPE_CHECKING from weakref import ref, ReferenceType @@ -31,6 +31,7 @@ from .cursor import Cursor from ._compat import LiteralString from .pq.misc import connection_summary from .conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo +from .conninfo import conninfo_attempts, ConnDict from ._pipeline import BasePipeline, Pipeline from .generators import notifies, connect, execute from ._encodings import pgconn_encoding @@ -106,6 +107,11 @@ class BaseConnection(Generic[Row]): ConnStatus = pq.ConnStatus TransactionStatus = pq.TransactionStatus + # Default timeout for connection a attempt. + # Arbitrary timeout, what applied by the libpq on my computer. + # Your mileage won't vary. + _DEFAULT_CONNECT_TIMEOUT = 130 + def __init__(self, pgconn: "PGconn"): self.pgconn = pgconn self._autocommit = False @@ -724,14 +730,19 @@ class Connection(BaseConnection[Row]): Connect to a database server and return a new `Connection` instance. """ params = cls._get_connection_params(conninfo, **kwargs) - conninfo = make_conninfo(**params) + timeout = int(params["connect_timeout"]) + rv = None + for attempt in conninfo_attempts(params): + try: + conninfo = make_conninfo(**attempt) + rv = cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout) + break + except e._NO_TRACEBACK as ex: + last_ex = ex - try: - rv = cls._wait_conn( - cls._connect_gen(conninfo), timeout=params["connect_timeout"] - ) - except e._NO_TRACEBACK as ex: - raise ex.with_traceback(None) + if not rv: + assert last_ex + raise last_ex.with_traceback(None) rv._autocommit = bool(autocommit) if row_factory: @@ -774,7 +785,7 @@ class Connection(BaseConnection[Row]): self.close() @classmethod - def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> Dict[str, Any]: + def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> ConnDict: """Manipulate connection parameters before connecting. :param conninfo: Connection string as received by `~Connection.connect()`. @@ -788,7 +799,10 @@ class Connection(BaseConnection[Row]): if "connect_timeout" in params: params["connect_timeout"] = int(params["connect_timeout"]) else: - params["connect_timeout"] = None + # The sync connect function will stop on the default socket timeout + # Because in async connection mode we need to enforce the timeout + # ourselves, we need a finite value. + params["connect_timeout"] = cls._DEFAULT_CONNECT_TIMEOUT return params diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index 5ab0522b0..6766a2245 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -8,7 +8,7 @@ import sys import asyncio import logging from types import TracebackType -from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional +from typing import Any, AsyncGenerator, AsyncIterator, List, Optional from typing import Type, TypeVar, Union, cast, overload, TYPE_CHECKING from contextlib import asynccontextmanager @@ -20,7 +20,7 @@ from ._tpc import Xid from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row from .adapt import AdaptersMap from ._enums import IsolationLevel -from .conninfo import make_conninfo, conninfo_to_dict, resolve_hostaddr_async +from .conninfo import ConnDict, make_conninfo, conninfo_to_dict, conninfo_attempts_async from ._pipeline import AsyncPipeline from ._encodings import pgconn_encoding from .connection import BaseConnection, CursorRow, Notify @@ -118,14 +118,19 @@ class AsyncConnection(BaseConnection[Row]): ) params = await cls._get_connection_params(conninfo, **kwargs) - conninfo = make_conninfo(**params) + timeout = int(params["connect_timeout"]) + rv = None + async for attempt in conninfo_attempts_async(params): + try: + conninfo = make_conninfo(**attempt) + rv = await cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout) + break + except e._NO_TRACEBACK as ex: + last_ex = ex - try: - rv = await cls._wait_conn( - cls._connect_gen(conninfo), timeout=params["connect_timeout"] - ) - except e._NO_TRACEBACK as ex: - raise ex.with_traceback(None) + if not rv: + assert last_ex + raise last_ex.with_traceback(None) rv._autocommit = bool(autocommit) if row_factory: @@ -168,28 +173,18 @@ class AsyncConnection(BaseConnection[Row]): await self.close() @classmethod - async def _get_connection_params( - cls, conninfo: str, **kwargs: Any - ) -> Dict[str, Any]: - """Manipulate connection parameters before connecting. - - .. versionchanged:: 3.1 - Unlike the sync counterpart, perform non-blocking address - resolution and populate the ``hostaddr`` connection parameter, - unless the user has provided one themselves. See - `~psycopg._dns.resolve_hostaddr_async()` for details. - - """ + async def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> ConnDict: + """Manipulate connection parameters before connecting.""" params = conninfo_to_dict(conninfo, **kwargs) # Make sure there is an usable connect_timeout if "connect_timeout" in params: params["connect_timeout"] = int(params["connect_timeout"]) else: - params["connect_timeout"] = None - - # Resolve host addresses in non-blocking way - params = await resolve_hostaddr_async(params) + # The sync connect function will stop on the default socket timeout + # Because in async connection mode we need to enforce the timeout + # ourselves, we need a finite value. + params["connect_timeout"] = cls._DEFAULT_CONNECT_TIMEOUT return params diff --git a/psycopg/psycopg/conninfo.py b/psycopg/psycopg/conninfo.py index efbb2be31..4f633fff4 100644 --- a/psycopg/psycopg/conninfo.py +++ b/psycopg/psycopg/conninfo.py @@ -330,6 +330,49 @@ async def resolve_hostaddr_async(params: ConnDict) -> ConnDict: return out +def conninfo_attempts(params: ConnDict) -> Iterator[ConnDict]: + """Split a set of connection params on the single attempts to perforn. + + A connection param can perform more than one attempt more than one ``host`` + is provided. + + Because the libpq async function doesn't honour the timeout, we need to + reimplement the repeated attempts. + """ + for attempt in _split_attempts(_inject_defaults(params)): + yield attempt + + +async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]: + """Split a set of connection params on the single attempts to perforn. + + A connection param can perform more than one attempt more than one ``host`` + is provided. + + Also perform async resolution of the hostname into hostaddr in order to + avoid blocking. Because a host can resolve to more than one address, this + can lead to yield more attempts too. Raise `OperationalError` if no host + could be resolved. + + Because the libpq async function doesn't honour the timeout, we need to + reimplement the repeated attempts. + """ + yielded = False + last_exc = None + for attempt in _split_attempts(_inject_defaults(params)): + try: + async for a2 in _split_attempts_and_resolve(attempt): + yielded = True + yield a2 + except OSError as ex: + last_exc = ex + + if not yielded: + assert last_exc + # We couldn't resolve anything + raise e.OperationalError(str(last_exc)) + + def _inject_defaults(params: ConnDict) -> ConnDict: """ Add defaults to a dictionary of parameters. diff --git a/tests/fix_proxy.py b/tests/fix_proxy.py index e50f5ec05..1d566b5e5 100644 --- a/tests/fix_proxy.py +++ b/tests/fix_proxy.py @@ -60,7 +60,7 @@ class Proxy: # Get server params host = cdict.get("host") or os.environ.get("PGHOST") self.server_host = host if host and not host.startswith("/") else "localhost" - self.server_port = cdict.get("port", "5432") + self.server_port = cdict.get("port") or os.environ.get("PGPORT", "5432") # Get client params self.client_host = "localhost" diff --git a/tests/test_connection.py b/tests/test_connection.py index 7314f6f31..ddfff5311 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -15,6 +15,8 @@ from .utils import gc_collect from .test_cursor import my_row_factory from .test_adapt import make_bin_dumper, make_dumper +DEFAULT_TIMEOUT = psycopg.Connection._DEFAULT_CONNECT_TIMEOUT + def test_connect(conn_cls, dsn): conn = conn_cls.connect(dsn) @@ -359,25 +361,26 @@ def test_autocommit_unknown(conn): (("host=foo user=bar",), {}, "host=foo user=bar"), (("host=foo",), {"user": "baz"}, "host=foo user=baz"), ( - ("host=foo port=5432",), - {"host": "qux", "user": "joe"}, - "host=qux user=joe port=5432", + ("dbname=foo port=5433",), + {"dbname": "qux", "user": "joe"}, + "dbname=qux user=joe port=5433", ), (("host=foo",), {"user": None}, "host=foo"), ], ) -def test_connect_args(conn_cls, monkeypatch, pgconn, args, kwargs, want): - the_conninfo: str +def test_connect_args(conn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, want): + got_conninfo: str def fake_connect(conninfo): - nonlocal the_conninfo - the_conninfo = conninfo + nonlocal got_conninfo + got_conninfo = conninfo return pgconn yield monkeypatch.setattr(psycopg.connection, "connect", fake_connect) conn = conn_cls.connect(*args, **kwargs) - assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want) + got_params = drop_default_args_from_conninfo(got_conninfo) + assert got_params == conninfo_to_dict(want) conn.close() @@ -761,7 +764,7 @@ conninfo_params_timeout = [ ( "", {"dbname": "mydb", "connect_timeout": None}, - ({"dbname": "mydb"}, None), + ({"dbname": "mydb"}, DEFAULT_TIMEOUT), ), ( "", @@ -771,7 +774,7 @@ conninfo_params_timeout = [ ( "dbname=postgres", {}, - ({"dbname": "postgres"}, None), + ({"dbname": "postgres"}, DEFAULT_TIMEOUT), ), ( "dbname=postgres connect_timeout=2", @@ -790,8 +793,8 @@ conninfo_params_timeout = [ def test_get_connection_params(conn_cls, dsn, kwargs, exp): params = conn_cls._get_connection_params(dsn, **kwargs) conninfo = make_conninfo(**params) - assert conninfo_to_dict(conninfo) == exp[0] - assert params.get("connect_timeout") == exp[1] + assert drop_default_args_from_conninfo(conninfo) == exp[0] + assert params["connect_timeout"] == exp[1] def test_connect_context(conn_cls, dsn): @@ -824,3 +827,18 @@ def test_connect_context_copy(conn_cls, dsn, conn): def test_cancel_closed(conn): conn.close() conn.cancel() + + +def drop_default_args_from_conninfo(conninfo): + params = conninfo_to_dict(conninfo) + + def removeif(key, value): + if params.get(key) == value: + params.pop(key) + + removeif("host", "") + removeif("hostaddr", "") + removeif("port", "5432") + removeif("connect_timeout", str(DEFAULT_TIMEOUT)) + + return params diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index 61277872f..87d8a4ee6 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -12,7 +12,7 @@ from psycopg.conninfo import conninfo_to_dict, make_conninfo from .utils import gc_collect from .test_cursor import my_row_factory from .test_connection import tx_params, tx_params_isolation, tx_values_map -from .test_connection import conninfo_params_timeout +from .test_connection import conninfo_params_timeout, drop_default_args_from_conninfo from .test_connection import testctx # noqa: F401 # fixture from .test_adapt import make_bin_dumper, make_dumper from .test_conninfo import fake_resolve # noqa: F401 @@ -357,9 +357,9 @@ async def test_autocommit_unknown(aconn): (("dbname=foo user=bar",), {}, "dbname=foo user=bar"), (("dbname=foo",), {"user": "baz"}, "dbname=foo user=baz"), ( - ("dbname=foo port=5432",), + ("dbname=foo port=5433",), {"dbname": "qux", "user": "joe"}, - "dbname=qux user=joe port=5432", + "dbname=qux user=joe port=5433", ), (("dbname=foo",), {"user": None}, "dbname=foo"), ], @@ -367,18 +367,19 @@ async def test_autocommit_unknown(aconn): async def test_connect_args( aconn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, want ): - the_conninfo: str + got_conninfo: str def fake_connect(conninfo): - nonlocal the_conninfo - the_conninfo = conninfo + nonlocal got_conninfo + got_conninfo = conninfo return pgconn yield setpgenv({}) monkeypatch.setattr(psycopg.connection, "connect", fake_connect) conn = await aconn_cls.connect(*args, **kwargs) - assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want) + got_params = drop_default_args_from_conninfo(got_conninfo) + assert got_params == conninfo_to_dict(want) await conn.close() @@ -724,7 +725,7 @@ async def test_get_connection_params(aconn_cls, dsn, kwargs, exp, setpgenv): setpgenv({}) params = await aconn_cls._get_connection_params(dsn, **kwargs) conninfo = make_conninfo(**params) - assert conninfo_to_dict(conninfo) == exp[0] + assert drop_default_args_from_conninfo(conninfo) == exp[0] assert params["connect_timeout"] == exp[1] @@ -774,4 +775,4 @@ async def test_resolve_hostaddr_conn(monkeypatch, fake_resolve): # noqa: F811 assert len(got) == 1 want = {"host": "foo.com", "hostaddr": "1.1.1.1"} - assert conninfo_to_dict(got[0]) == want + assert drop_default_args_from_conninfo(got[0]) == want diff --git a/tests/test_dns.py b/tests/test_dns.py index 2eb5569df..efbf6f503 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -3,9 +3,29 @@ import pytest import psycopg from psycopg.conninfo import conninfo_to_dict -pytestmark = [pytest.mark.dns] +from .test_conninfo import fake_resolve # noqa: F401 # fixture +from .test_connection import drop_default_args_from_conninfo +@pytest.mark.usefixtures("fake_resolve") +async def test_resolve_hostaddr_conn(aconn_cls, monkeypatch): + got = [] + + def fake_connect_gen(conninfo, **kwargs): + got.append(conninfo) + 1 / 0 + + monkeypatch.setattr(aconn_cls, "_connect_gen", fake_connect_gen) + + with pytest.raises(ZeroDivisionError): + await aconn_cls.connect("host=foo.com") + + assert len(got) == 1 + want = {"host": "foo.com", "hostaddr": "1.1.1.1"} + assert drop_default_args_from_conninfo(got[0]) == want + + +@pytest.mark.dns @pytest.mark.anyio async def test_resolve_hostaddr_async_warning(recwarn): import_dnspython() @@ -14,7 +34,6 @@ async def test_resolve_hostaddr_async_warning(recwarn): params = await psycopg._dns.resolve_hostaddr_async( # type: ignore[attr-defined] params ) - assert conninfo_to_dict(conninfo) == params assert "resolve_hostaddr_async" in str(recwarn.pop(DeprecationWarning).message) diff --git a/tests/test_module.py b/tests/test_module.py index 794ef0f89..030b75808 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -1,10 +1,13 @@ import pytest from psycopg._cmodule import _psycopg +from psycopg.conninfo import conninfo_to_dict + +from .test_connection import drop_default_args_from_conninfo @pytest.mark.parametrize( - "args, kwargs, want_conninfo", + "args, kwargs, want", [ ((), {}, ""), (("dbname=foo",), {"user": "bar"}, "dbname=foo user=bar"), @@ -12,7 +15,7 @@ from psycopg._cmodule import _psycopg ((), {"user": "foo", "dbname": None}, "user=foo"), ], ) -def test_connect(monkeypatch, dsn, args, kwargs, want_conninfo): +def test_connect(monkeypatch, dsn, args, kwargs, want): # Check the main args passing from psycopg.connect to the conn generator # Details of the params manipulation are in test_conninfo. import psycopg.connection @@ -29,7 +32,7 @@ def test_connect(monkeypatch, dsn, args, kwargs, want_conninfo): monkeypatch.setattr(psycopg.connection, "connect", mock_connect) conn = psycopg.connect(*args, **kwargs) - assert got_conninfo == want_conninfo + assert drop_default_args_from_conninfo(got_conninfo) == conninfo_to_dict(want) conn.close() diff --git a/tests/test_psycopg_dbapi20.py b/tests/test_psycopg_dbapi20.py index 82a5d730c..bc8b1cc95 100644 --- a/tests/test_psycopg_dbapi20.py +++ b/tests/test_psycopg_dbapi20.py @@ -7,6 +7,7 @@ from psycopg.conninfo import conninfo_to_dict from . import dbapi20 from . import dbapi20_tpc +from .test_connection import drop_default_args_from_conninfo @pytest.fixture(scope="class") @@ -125,25 +126,25 @@ def test_time_from_ticks(ticks, want): (("host=foo user=bar",), {}, "host=foo user=bar"), (("host=foo",), {"user": "baz"}, "host=foo user=baz"), ( - ("host=foo port=5432",), + ("host=foo port=5433",), {"host": "qux", "user": "joe"}, - "host=qux user=joe port=5432", + "host=qux user=joe port=5433", ), (("host=foo",), {"user": None}, "host=foo"), ], ) def test_connect_args(monkeypatch, pgconn, args, kwargs, want): - the_conninfo: str + got_conninfo: str def fake_connect(conninfo): - nonlocal the_conninfo - the_conninfo = conninfo + nonlocal got_conninfo + got_conninfo = conninfo return pgconn yield monkeypatch.setattr(psycopg.connection, "connect", fake_connect) conn = psycopg.connect(*args, **kwargs) - assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want) + assert drop_default_args_from_conninfo(got_conninfo) == conninfo_to_dict(want) conn.close() -- 2.47.2