From: Daniele Varrazzo Date: Wed, 13 Dec 2023 03:35:48 +0000 (+0100) Subject: fix: set minimum timeout to 2s X-Git-Tag: 3.1.15~1^2~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5d18a15772cbe9e74de36addc6bc46c9b4454274;p=thirdparty%2Fpsycopg.git fix: set minimum timeout to 2s This is consistent with what the libpq does. Move timeout calculation to a function in conninfo module and don't change the connect_timeout parameter explicitly in the connection string. Drop awful drop_default_args_from_conninfo() from tests. --- diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index e04a99403..a38a98fc3 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -31,7 +31,7 @@ from .cursor import Cursor from ._compat import LiteralString from .pq.misc import connection_summary from .conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo -from .conninfo import conninfo_attempts, ConnDict +from .conninfo import conninfo_attempts, ConnDict, timeout_from_conninfo from ._pipeline import BasePipeline, Pipeline from .generators import notifies, connect, execute from ._encodings import pgconn_encoding @@ -107,11 +107,6 @@ class BaseConnection(Generic[Row]): ConnStatus = pq.ConnStatus TransactionStatus = pq.TransactionStatus - # Default timeout for connection a attempt. - # Arbitrary timeout, what applied by the libpq on my computer. - # Your mileage won't vary. - _DEFAULT_CONNECT_TIMEOUT = 130 - def __init__(self, pgconn: "PGconn"): self.pgconn = pgconn self._autocommit = False @@ -730,7 +725,7 @@ class Connection(BaseConnection[Row]): Connect to a database server and return a new `Connection` instance. """ params = cls._get_connection_params(conninfo, **kwargs) - timeout = int(params["connect_timeout"]) + timeout = timeout_from_conninfo(params) rv = None attempts = conninfo_attempts(params) for attempt in attempts: @@ -803,18 +798,7 @@ class Connection(BaseConnection[Row]): :return: Connection arguments merged and eventually modified, in a format similar to `~conninfo.conninfo_to_dict()`. """ - params = conninfo_to_dict(conninfo, **kwargs) - - # Make sure there is an usable connect_timeout - if "connect_timeout" in params: - params["connect_timeout"] = int(params["connect_timeout"]) - else: - # The sync connect function will stop on the default socket timeout - # Because in async connection mode we need to enforce the timeout - # ourselves, we need a finite value. - params["connect_timeout"] = cls._DEFAULT_CONNECT_TIMEOUT - - return params + return conninfo_to_dict(conninfo, **kwargs) def close(self) -> None: """Close the database connection.""" diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index bea077fe8..88544db99 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -20,7 +20,8 @@ from ._tpc import Xid from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row from .adapt import AdaptersMap from ._enums import IsolationLevel -from .conninfo import ConnDict, make_conninfo, conninfo_to_dict, conninfo_attempts_async +from .conninfo import ConnDict, make_conninfo, conninfo_to_dict +from .conninfo import conninfo_attempts_async, timeout_from_conninfo from ._pipeline import AsyncPipeline from ._encodings import pgconn_encoding from .connection import BaseConnection, CursorRow, Notify @@ -118,7 +119,7 @@ class AsyncConnection(BaseConnection[Row]): ) params = await cls._get_connection_params(conninfo, **kwargs) - timeout = int(params["connect_timeout"]) + timeout = timeout_from_conninfo(params) rv = None attempts = await conninfo_attempts_async(params) for attempt in attempts: @@ -185,18 +186,7 @@ class AsyncConnection(BaseConnection[Row]): @classmethod async def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> ConnDict: """Manipulate connection parameters before connecting.""" - params = conninfo_to_dict(conninfo, **kwargs) - - # Make sure there is an usable connect_timeout - if "connect_timeout" in params: - params["connect_timeout"] = int(params["connect_timeout"]) - else: - # The sync connect function will stop on the default socket timeout - # Because in async connection mode we need to enforce the timeout - # ourselves, we need a finite value. - params["connect_timeout"] = cls._DEFAULT_CONNECT_TIMEOUT - - return params + return conninfo_to_dict(conninfo, **kwargs) async def close(self) -> None: if self.closed: diff --git a/psycopg/psycopg/conninfo.py b/psycopg/psycopg/conninfo.py index c38d7af7a..ec36ff09e 100644 --- a/psycopg/psycopg/conninfo.py +++ b/psycopg/psycopg/conninfo.py @@ -27,6 +27,11 @@ from ._encodings import pgconn_encoding ConnDict: TypeAlias = "dict[str, Any]" +# Default timeout for connection a attempt. +# Arbitrary timeout, what applied by the libpq on my computer. +# Your mileage won't vary. +_DEFAULT_CONNECT_TIMEOUT = 130 + logger = logging.getLogger("psycopg") @@ -430,6 +435,35 @@ async def _resolve_hostnames(params: ConnDict) -> list[ConnDict]: return [{**params, "hostaddr": item[4][0]} for item in ans] +def timeout_from_conninfo(params: ConnDict) -> int: + """ + Return the timeout in seconds from the connection parameters. + """ + # Follow the libpq convention: + # + # - 0 or less means no timeout (but we will use a default to simulate + # the socket timeout) + # - at least 2 seconds. + # + # See connectDBComplete in fe-connect.c + value = params.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT) + try: + timeout = int(value) + except ValueError: + raise e.ProgrammingError(f"bad value for connect_timeout: {value!r}") + + if timeout <= 0: + # The sync connect function will stop on the default socket timeout + # Because in async connection mode we need to enforce the timeout + # ourselves, we need a finite value. + timeout = _DEFAULT_CONNECT_TIMEOUT + elif timeout < 2: + # Enforce a 2s min + timeout = 2 + + return timeout + + def _get_param(params: ConnDict, name: str) -> str | None: """ Return a value from a connection string. diff --git a/tests/test_connection.py b/tests/test_connection.py index 754acec3a..0ee8aea31 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -9,13 +9,12 @@ from dataclasses import dataclass import psycopg from psycopg import Notify, pq, errors as e from psycopg.rows import tuple_row -from psycopg.conninfo import conninfo_to_dict, make_conninfo +from psycopg.conninfo import conninfo_to_dict +from psycopg.conninfo import timeout_from_conninfo, _DEFAULT_CONNECT_TIMEOUT from .test_cursor import my_row_factory from .test_adapt import make_bin_dumper, make_dumper -DEFAULT_TIMEOUT = psycopg.Connection._DEFAULT_CONNECT_TIMEOUT - def test_connect(conn_cls, dsn): conn = conn_cls.connect(dsn) @@ -44,9 +43,9 @@ def test_connect_bad(conn_cls): def test_connect_timeout(conn_cls, deaf_port): t0 = time.time() with pytest.raises(psycopg.OperationalError, match="timeout expired"): - conn_cls.connect(host="localhost", port=deaf_port, connect_timeout=1) + conn_cls.connect(host="localhost", port=deaf_port, connect_timeout=2) elapsed = time.time() - t0 - assert elapsed == pytest.approx(1.0, abs=0.05) + assert elapsed == pytest.approx(2.0, abs=0.05) @pytest.mark.slow @@ -56,7 +55,7 @@ def test_multi_hosts(conn_cls, proxy, dsn, deaf_port, monkeypatch): args["host"] = f"{proxy.client_host},{proxy.server_host}" args["port"] = f"{deaf_port},{proxy.server_port}" args.pop("hostaddr", None) - monkeypatch.setattr(conn_cls, "_DEFAULT_CONNECT_TIMEOUT", 2) + monkeypatch.setattr(psycopg.conninfo, "_DEFAULT_CONNECT_TIMEOUT", 2) t0 = time.time() with conn_cls.connect(**args) as conn: elapsed = time.time() - t0 @@ -72,11 +71,11 @@ def test_multi_hosts_timeout(conn_cls, proxy, dsn, deaf_port): args["host"] = f"{proxy.client_host},{proxy.server_host}" args["port"] = f"{deaf_port},{proxy.server_port}" args.pop("hostaddr", None) - args["connect_timeout"] = "1" + args["connect_timeout"] = "2" t0 = time.time() with conn_cls.connect(**args) as conn: elapsed = time.time() - t0 - assert 1.0 < elapsed < 1.5 + assert 2.0 < elapsed < 2.5 assert conn.info.port == int(proxy.server_port) assert conn.info.host == proxy.server_host @@ -416,8 +415,7 @@ def test_connect_args(conn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, wan setpgenv({}) monkeypatch.setattr(psycopg.connection, "connect", fake_connect) conn = conn_cls.connect(*args, **kwargs) - got_params = drop_default_args_from_conninfo(got_conninfo) - assert got_params == conninfo_to_dict(want) + assert conninfo_to_dict(got_conninfo) == conninfo_to_dict(want) conn.close() @@ -801,17 +799,17 @@ conninfo_params_timeout = [ ( "", {"dbname": "mydb", "connect_timeout": None}, - ({"dbname": "mydb"}, DEFAULT_TIMEOUT), + ({"dbname": "mydb"}, _DEFAULT_CONNECT_TIMEOUT), ), ( "", {"dbname": "mydb", "connect_timeout": 1}, - ({"dbname": "mydb", "connect_timeout": "1"}, 1), + ({"dbname": "mydb", "connect_timeout": 1}, 2), ), ( "dbname=postgres", {}, - ({"dbname": "postgres"}, DEFAULT_TIMEOUT), + ({"dbname": "postgres"}, _DEFAULT_CONNECT_TIMEOUT), ), ( "dbname=postgres connect_timeout=2", @@ -821,7 +819,7 @@ conninfo_params_timeout = [ ( "postgresql:///postgres?connect_timeout=2", {"connect_timeout": 10}, - ({"dbname": "postgres", "connect_timeout": "10"}, 10), + ({"dbname": "postgres", "connect_timeout": 10}, 10), ), ] @@ -829,9 +827,8 @@ conninfo_params_timeout = [ @pytest.mark.parametrize("dsn, kwargs, exp", conninfo_params_timeout) def test_get_connection_params(conn_cls, dsn, kwargs, exp): params = conn_cls._get_connection_params(dsn, **kwargs) - conninfo = make_conninfo(**params) - assert drop_default_args_from_conninfo(conninfo) == exp[0] - assert params["connect_timeout"] == exp[1] + assert params == exp[0] + assert timeout_from_conninfo(params) == exp[1] def test_connect_context(conn_cls, dsn): @@ -864,18 +861,3 @@ def test_connect_context_copy(conn_cls, dsn, conn): def test_cancel_closed(conn): conn.close() conn.cancel() - - -def drop_default_args_from_conninfo(conninfo): - if isinstance(conninfo, str): - params = conninfo_to_dict(conninfo) - else: - params = conninfo.copy() - - def removeif(key, value): - if params.get(key) == value: - params.pop(key) - - removeif("connect_timeout", str(DEFAULT_TIMEOUT)) - - return params diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index a35973c69..86ccfe10b 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -8,11 +8,11 @@ from typing import List, Any import psycopg from psycopg import Notify, errors as e from psycopg.rows import tuple_row -from psycopg.conninfo import conninfo_to_dict, make_conninfo +from psycopg.conninfo import conninfo_to_dict, timeout_from_conninfo from .test_cursor import my_row_factory from .test_connection import tx_params, tx_params_isolation, tx_values_map -from .test_connection import conninfo_params_timeout, drop_default_args_from_conninfo +from .test_connection import conninfo_params_timeout from .test_connection import testctx # noqa: F401 # fixture from .test_adapt import make_bin_dumper, make_dumper from .test_conninfo import fake_resolve # noqa: F401 @@ -52,9 +52,9 @@ async def test_connect_str_subclass(aconn_cls, dsn): async def test_connect_timeout(aconn_cls, deaf_port): t0 = time.time() with pytest.raises(psycopg.OperationalError, match="timeout expired"): - await aconn_cls.connect(host="localhost", port=deaf_port, connect_timeout=1) + await aconn_cls.connect(host="localhost", port=deaf_port, connect_timeout=2) elapsed = time.time() - t0 - assert elapsed == pytest.approx(1.0, abs=0.05) + assert elapsed == pytest.approx(2.0, abs=0.05) @pytest.mark.slow @@ -64,7 +64,7 @@ async def test_multi_hosts(aconn_cls, proxy, dsn, deaf_port, monkeypatch): args["host"] = f"{proxy.client_host},{proxy.server_host}" args["port"] = f"{deaf_port},{proxy.server_port}" args.pop("hostaddr", None) - monkeypatch.setattr(aconn_cls, "_DEFAULT_CONNECT_TIMEOUT", 2) + monkeypatch.setattr(psycopg.conninfo, "_DEFAULT_CONNECT_TIMEOUT", 2) t0 = time.time() async with await aconn_cls.connect(**args) as conn: elapsed = time.time() - t0 @@ -80,11 +80,11 @@ async def test_multi_hosts_timeout(aconn_cls, proxy, dsn, deaf_port): args["host"] = f"{proxy.client_host},{proxy.server_host}" args["port"] = f"{deaf_port},{proxy.server_port}" args.pop("hostaddr", None) - args["connect_timeout"] = "1" + args["connect_timeout"] = "2" t0 = time.time() async with await aconn_cls.connect(**args) as conn: elapsed = time.time() - t0 - assert 1.0 < elapsed < 1.5 + assert 2.0 < elapsed < 2.5 assert conn.info.port == int(proxy.server_port) assert conn.info.host == proxy.server_host @@ -423,8 +423,7 @@ async def test_connect_args( setpgenv({}) monkeypatch.setattr(psycopg.connection, "connect", fake_connect) conn = await aconn_cls.connect(*args, **kwargs) - got_params = drop_default_args_from_conninfo(got_conninfo) - assert got_params == conninfo_to_dict(want) + assert conninfo_to_dict(got_conninfo) == conninfo_to_dict(want) await conn.close() @@ -769,9 +768,8 @@ async def test_set_transaction_param_strange(aconn): async def test_get_connection_params(aconn_cls, dsn, kwargs, exp, setpgenv): setpgenv({}) params = await aconn_cls._get_connection_params(dsn, **kwargs) - conninfo = make_conninfo(**params) - assert drop_default_args_from_conninfo(conninfo) == exp[0] - assert params["connect_timeout"] == exp[1] + assert params == exp[0] + assert timeout_from_conninfo(params) == exp[1] async def test_connect_context_adapters(aconn_cls, dsn): @@ -820,4 +818,4 @@ async def test_resolve_hostaddr_conn(monkeypatch, fake_resolve): # noqa: F811 assert len(got) == 1 want = {"host": "foo.com", "hostaddr": "1.1.1.1"} - assert drop_default_args_from_conninfo(got[0]) == want + assert conninfo_to_dict(got[0]) == want diff --git a/tests/test_dns.py b/tests/test_dns.py index efbf6f503..a83aaeb66 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -4,7 +4,6 @@ import psycopg from psycopg.conninfo import conninfo_to_dict from .test_conninfo import fake_resolve # noqa: F401 # fixture -from .test_connection import drop_default_args_from_conninfo @pytest.mark.usefixtures("fake_resolve") @@ -22,7 +21,7 @@ async def test_resolve_hostaddr_conn(aconn_cls, monkeypatch): assert len(got) == 1 want = {"host": "foo.com", "hostaddr": "1.1.1.1"} - assert drop_default_args_from_conninfo(got[0]) == want + assert conninfo_to_dict(got[0]) == want @pytest.mark.dns diff --git a/tests/test_module.py b/tests/test_module.py index 8abc3ebea..49757eb46 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -3,8 +3,6 @@ import pytest from psycopg._cmodule import _psycopg from psycopg.conninfo import conninfo_to_dict -from .test_connection import drop_default_args_from_conninfo - @pytest.mark.parametrize( "args, kwargs, want", @@ -22,7 +20,7 @@ def test_connect(monkeypatch, dsn_env, args, kwargs, want, setpgenv): orig_connect = psycopg.connection.connect # type: ignore - got_conninfo = None + got_conninfo: str def mock_connect(conninfo): nonlocal got_conninfo @@ -33,7 +31,7 @@ def test_connect(monkeypatch, dsn_env, args, kwargs, want, setpgenv): monkeypatch.setattr(psycopg.connection, "connect", mock_connect) conn = psycopg.connect(*args, **kwargs) - assert drop_default_args_from_conninfo(got_conninfo) == conninfo_to_dict(want) + assert conninfo_to_dict(got_conninfo) == conninfo_to_dict(want) conn.close() diff --git a/tests/test_psycopg_dbapi20.py b/tests/test_psycopg_dbapi20.py index 69c2fa756..3c4ae3ac5 100644 --- a/tests/test_psycopg_dbapi20.py +++ b/tests/test_psycopg_dbapi20.py @@ -7,7 +7,6 @@ from psycopg.conninfo import conninfo_to_dict from . import dbapi20 from . import dbapi20_tpc -from .test_connection import drop_default_args_from_conninfo @pytest.fixture(scope="class") @@ -145,7 +144,7 @@ def test_connect_args(monkeypatch, pgconn, args, kwargs, want, setpgenv): setpgenv({}) monkeypatch.setattr(psycopg.connection, "connect", fake_connect) conn = psycopg.connect(*args, **kwargs) - assert drop_default_args_from_conninfo(got_conninfo) == conninfo_to_dict(want) + assert conninfo_to_dict(got_conninfo) == conninfo_to_dict(want) conn.close()