]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: set minimum timeout to 2s
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 13 Dec 2023 03:35:48 +0000 (04:35 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 14 Dec 2023 11:51:12 +0000 (12:51 +0100)
This is consistent with what the libpq does.

Move timeout calculation to a function in conninfo module and don't
change the connect_timeout parameter explicitly in the connection
string.

Drop awful drop_default_args_from_conninfo() from tests.

psycopg/psycopg/_connection_base.py
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
psycopg/psycopg/conninfo.py
tests/_test_connection.py
tests/test_connection.py
tests/test_connection_async.py
tests/test_dns.py
tests/test_module.py
tests/test_psycopg_dbapi20.py

index 19e3d1b848f3299cd5fdf72861ec442465064928..cf391998bd39ae16bb0900a67e65c7c9f2f5d869 100644 (file)
@@ -97,11 +97,6 @@ class BaseConnection(Generic[Row]):
     ConnStatus = pq.ConnStatus
     TransactionStatus = pq.TransactionStatus
 
-    # Default timeout for connection a attempt.
-    # Arbitrary timeout, what applied by the libpq on my computer.
-    # Your mileage won't vary.
-    _DEFAULT_CONNECT_TIMEOUT = 130
-
     def __init__(self, pgconn: "PGconn"):
         self.pgconn = pgconn
         self._autocommit = False
index 475617c96a012e3ca25841195953ee3b5a9cbfed..2f3f89595f07b662f7db734f0707202957fc913b 100644 (file)
@@ -23,7 +23,8 @@ from ._tpc import Xid
 from .rows import Row, RowFactory, tuple_row, TupleRow, args_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
-from .conninfo import ConnDict, make_conninfo, conninfo_to_dict, conninfo_attempts
+from .conninfo import ConnDict, make_conninfo, conninfo_to_dict
+from .conninfo import conninfo_attempts, timeout_from_conninfo
 from ._pipeline import Pipeline
 from ._encodings import pgconn_encoding
 from .generators import notifies
@@ -119,7 +120,7 @@ class Connection(BaseConnection[Row]):
         """
 
         params = cls._get_connection_params(conninfo, **kwargs)
-        timeout = int(params["connect_timeout"])
+        timeout = timeout_from_conninfo(params)
         rv = None
         attempts = conninfo_attempts(params)
         for attempt in attempts:
@@ -180,25 +181,8 @@ class Connection(BaseConnection[Row]):
 
     @classmethod
     def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> ConnDict:
-        """Manipulate connection parameters before connecting.
-
-        :param conninfo: Connection string as received by `~Connection.connect()`.
-        :param kwargs: Overriding connection arguments as received by `!connect()`.
-        :return: Connection arguments merged and eventually modified, in a
-            format similar to `~conninfo.conninfo_to_dict()`.
-        """
-        params = conninfo_to_dict(conninfo, **kwargs)
-
-        # Make sure there is an usable connect_timeout
-        if "connect_timeout" in params:
-            params["connect_timeout"] = int(params["connect_timeout"])
-        else:
-            # The sync connect function will stop on the default socket timeout
-            # Because in async connection mode we need to enforce the timeout
-            # ourselves, we need a finite value.
-            params["connect_timeout"] = cls._DEFAULT_CONNECT_TIMEOUT
-
-        return params
+        """Manipulate connection parameters before connecting."""
+        return conninfo_to_dict(conninfo, **kwargs)
 
     def close(self) -> None:
         """Close the database connection."""
index 032b175ccf08380bb058b4c1f1866e6074a55299..c69c483ea7160fa7d8ce3d593ac26d90a1eb2700 100644 (file)
@@ -20,7 +20,8 @@ from ._tpc import Xid
 from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
-from .conninfo import ConnDict, make_conninfo, conninfo_to_dict, conninfo_attempts_async
+from .conninfo import ConnDict, make_conninfo, conninfo_to_dict
+from .conninfo import conninfo_attempts_async, timeout_from_conninfo
 from ._pipeline import AsyncPipeline
 from ._encodings import pgconn_encoding
 from .generators import notifies
@@ -134,7 +135,7 @@ class AsyncConnection(BaseConnection[Row]):
                     )
 
         params = await cls._get_connection_params(conninfo, **kwargs)
-        timeout = int(params["connect_timeout"])
+        timeout = timeout_from_conninfo(params)
         rv = None
         attempts = await conninfo_attempts_async(params)
         for attempt in attempts:
@@ -195,25 +196,8 @@ class AsyncConnection(BaseConnection[Row]):
 
     @classmethod
     async def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> ConnDict:
-        """Manipulate connection parameters before connecting.
-
-        :param conninfo: Connection string as received by `~Connection.connect()`.
-        :param kwargs: Overriding connection arguments as received by `!connect()`.
-        :return: Connection arguments merged and eventually modified, in a
-            format similar to `~conninfo.conninfo_to_dict()`.
-        """
-        params = conninfo_to_dict(conninfo, **kwargs)
-
-        # Make sure there is an usable connect_timeout
-        if "connect_timeout" in params:
-            params["connect_timeout"] = int(params["connect_timeout"])
-        else:
-            # The sync connect function will stop on the default socket timeout
-            # Because in async connection mode we need to enforce the timeout
-            # ourselves, we need a finite value.
-            params["connect_timeout"] = cls._DEFAULT_CONNECT_TIMEOUT
-
-        return params
+        """Manipulate connection parameters before connecting."""
+        return conninfo_to_dict(conninfo, **kwargs)
 
     async def close(self) -> None:
         """Close the database connection."""
index c38d7af7a0b56c2f932573ae16d29600ee69fdcc..ec36ff09ebf7c80e77511e553eb88422b375547c 100644 (file)
@@ -27,6 +27,11 @@ from ._encodings import pgconn_encoding
 
 ConnDict: TypeAlias = "dict[str, Any]"
 
+# Default timeout for connection a attempt.
+# Arbitrary timeout, what applied by the libpq on my computer.
+# Your mileage won't vary.
+_DEFAULT_CONNECT_TIMEOUT = 130
+
 logger = logging.getLogger("psycopg")
 
 
@@ -430,6 +435,35 @@ async def _resolve_hostnames(params: ConnDict) -> list[ConnDict]:
     return [{**params, "hostaddr": item[4][0]} for item in ans]
 
 
+def timeout_from_conninfo(params: ConnDict) -> int:
+    """
+    Return the timeout in seconds from the connection parameters.
+    """
+    # Follow the libpq convention:
+    #
+    # - 0 or less means no timeout (but we will use a default to simulate
+    #   the socket timeout)
+    # - at least 2 seconds.
+    #
+    # See connectDBComplete in fe-connect.c
+    value = params.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
+    try:
+        timeout = int(value)
+    except ValueError:
+        raise e.ProgrammingError(f"bad value for connect_timeout: {value!r}")
+
+    if timeout <= 0:
+        # The sync connect function will stop on the default socket timeout
+        # Because in async connection mode we need to enforce the timeout
+        # ourselves, we need a finite value.
+        timeout = _DEFAULT_CONNECT_TIMEOUT
+    elif timeout < 2:
+        # Enforce a 2s min
+        timeout = 2
+
+    return timeout
+
+
 def _get_param(params: ConnDict, name: str) -> str | None:
     """
     Return a value from a connection string.
index dc5db6972d8618e2070d8ea037b428e56495f53b..56d3859266f333432b4b985ef5b15cf8a4ede422 100644 (file)
@@ -9,11 +9,13 @@ import pytest
 import psycopg
 from psycopg.conninfo import conninfo_to_dict
 
-# Don't import this to allow tests to import (not necessarily to pass all)
-# if the psycopg module imported is not the one expected (e.g. running
-# psycopg pool tests on the master branch with psycopg 3.1.x imported).
-# psycopg._connection_base.BaseConnection._DEFAULT_CONNECT_TIMEOUT
-DEFAULT_TIMEOUT = 130
+try:
+    from psycopg.conninfo import _DEFAULT_CONNECT_TIMEOUT as DEFAULT_TIMEOUT
+except ImportError:
+    # Allow tests to import (not necessarily to pass all) if the psycopg module
+    # imported is not the one expected (e.g. running psycopg pool tests on the
+    # master branch with psycopg 3.1.x imported).
+    DEFAULT_TIMEOUT = 130
 
 
 @pytest.fixture
@@ -87,7 +89,7 @@ conninfo_params_timeout = [
     (
         "",
         {"dbname": "mydb", "connect_timeout": 1},
-        ({"dbname": "mydb", "connect_timeout": "1"}, 1),
+        ({"dbname": "mydb", "connect_timeout": 1}, 2),
     ),
     (
         "dbname=postgres",
@@ -102,7 +104,7 @@ conninfo_params_timeout = [
     (
         "postgresql:///postgres?connect_timeout=2",
         {"connect_timeout": 10},
-        ({"dbname": "postgres", "connect_timeout": "10"}, 10),
+        ({"dbname": "postgres", "connect_timeout": 10}, 10),
     ),
 ]
 
index 5afb4e91e1407bd06f2d6ea13de1448056ff1c50..9cfa7459f53084a7f6ff49501808364c7ab0ecfc 100644 (file)
@@ -11,12 +11,12 @@ from typing import Any, List
 import psycopg
 from psycopg import Notify, pq, errors as e
 from psycopg.rows import tuple_row
-from psycopg.conninfo import conninfo_to_dict, make_conninfo
+from psycopg.conninfo import conninfo_to_dict, timeout_from_conninfo
 
 from .acompat import is_async, skip_sync, skip_async
 from ._test_cursor import my_row_factory
 from ._test_connection import tx_params, tx_params_isolation, tx_values_map
-from ._test_connection import conninfo_params_timeout, drop_default_args_from_conninfo
+from ._test_connection import conninfo_params_timeout
 from ._test_connection import testctx  # noqa: F401  # fixture
 from .test_adapt import make_bin_dumper, make_dumper
 
@@ -48,9 +48,9 @@ def test_connect_str_subclass(conn_cls, dsn):
 def test_connect_timeout(conn_cls, deaf_port):
     t0 = time.time()
     with pytest.raises(psycopg.OperationalError, match="timeout expired"):
-        conn_cls.connect(host="localhost", port=deaf_port, connect_timeout=1)
+        conn_cls.connect(host="localhost", port=deaf_port, connect_timeout=2)
     elapsed = time.time() - t0
-    assert elapsed == pytest.approx(1.0, abs=0.05)
+    assert elapsed == pytest.approx(2.0, abs=0.05)
 
 
 @pytest.mark.slow
@@ -60,7 +60,7 @@ def test_multi_hosts(conn_cls, proxy, dsn, deaf_port, monkeypatch):
     args["host"] = f"{proxy.client_host},{proxy.server_host}"
     args["port"] = f"{deaf_port},{proxy.server_port}"
     args.pop("hostaddr", None)
-    monkeypatch.setattr(conn_cls, "_DEFAULT_CONNECT_TIMEOUT", 2)
+    monkeypatch.setattr(psycopg.conninfo, "_DEFAULT_CONNECT_TIMEOUT", 2)
     t0 = time.time()
     with conn_cls.connect(**args) as conn:
         elapsed = time.time() - t0
@@ -76,11 +76,11 @@ def test_multi_hosts_timeout(conn_cls, proxy, dsn, deaf_port):
     args["host"] = f"{proxy.client_host},{proxy.server_host}"
     args["port"] = f"{deaf_port},{proxy.server_port}"
     args.pop("hostaddr", None)
-    args["connect_timeout"] = "1"
+    args["connect_timeout"] = "2"
     t0 = time.time()
     with conn_cls.connect(**args) as conn:
         elapsed = time.time() - t0
-        assert 1.0 < elapsed < 1.5
+        assert 2.0 < elapsed < 2.5
         assert conn.info.port == int(proxy.server_port)
         assert conn.info.host == proxy.server_host
 
@@ -452,8 +452,7 @@ def test_connect_args(conn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, wan
     setpgenv({})
     monkeypatch.setattr(psycopg.generators, "connect", fake_connect)
     conn = conn_cls.connect(*args, **kwargs)
-    got_params = drop_default_args_from_conninfo(got_conninfo)
-    assert got_params == conninfo_to_dict(want)
+    assert conninfo_to_dict(got_conninfo) == conninfo_to_dict(want)
     conn.close()
 
 
@@ -824,9 +823,8 @@ def test_set_transaction_param_strange_property(conn):
 def test_get_connection_params(conn_cls, dsn, kwargs, exp, setpgenv):
     setpgenv({})
     params = conn_cls._get_connection_params(dsn, **kwargs)
-    conninfo = make_conninfo(**params)
-    assert drop_default_args_from_conninfo(conninfo) == exp[0]
-    assert params["connect_timeout"] == exp[1]
+    assert params == exp[0]
+    assert timeout_from_conninfo(params) == exp[1]
 
 
 def test_connect_context_adapters(conn_cls, dsn):
index 5b4a66dfb6dbe6187fb0c58ebf27cce864cde319..754e57b386e7bd31c3714417774c5b8895153f31 100644 (file)
@@ -8,12 +8,12 @@ from typing import Any, List
 import psycopg
 from psycopg import Notify, pq, errors as e
 from psycopg.rows import tuple_row
-from psycopg.conninfo import conninfo_to_dict, make_conninfo
+from psycopg.conninfo import conninfo_to_dict, timeout_from_conninfo
 
 from .acompat import is_async, skip_sync, skip_async
 from ._test_cursor import my_row_factory
 from ._test_connection import tx_params, tx_params_isolation, tx_values_map
-from ._test_connection import conninfo_params_timeout, drop_default_args_from_conninfo
+from ._test_connection import conninfo_params_timeout
 from ._test_connection import testctx  # noqa: F401  # fixture
 from .test_adapt import make_bin_dumper, make_dumper
 
@@ -45,9 +45,9 @@ async def test_connect_str_subclass(aconn_cls, dsn):
 async def test_connect_timeout(aconn_cls, deaf_port):
     t0 = time.time()
     with pytest.raises(psycopg.OperationalError, match="timeout expired"):
-        await aconn_cls.connect(host="localhost", port=deaf_port, connect_timeout=1)
+        await aconn_cls.connect(host="localhost", port=deaf_port, connect_timeout=2)
     elapsed = time.time() - t0
-    assert elapsed == pytest.approx(1.0, abs=0.05)
+    assert elapsed == pytest.approx(2.0, abs=0.05)
 
 
 @pytest.mark.slow
@@ -57,7 +57,7 @@ async def test_multi_hosts(aconn_cls, proxy, dsn, deaf_port, monkeypatch):
     args["host"] = f"{proxy.client_host},{proxy.server_host}"
     args["port"] = f"{deaf_port},{proxy.server_port}"
     args.pop("hostaddr", None)
-    monkeypatch.setattr(aconn_cls, "_DEFAULT_CONNECT_TIMEOUT", 2)
+    monkeypatch.setattr(psycopg.conninfo, "_DEFAULT_CONNECT_TIMEOUT", 2)
     t0 = time.time()
     async with await aconn_cls.connect(**args) as conn:
         elapsed = time.time() - t0
@@ -73,11 +73,11 @@ async def test_multi_hosts_timeout(aconn_cls, proxy, dsn, deaf_port):
     args["host"] = f"{proxy.client_host},{proxy.server_host}"
     args["port"] = f"{deaf_port},{proxy.server_port}"
     args.pop("hostaddr", None)
-    args["connect_timeout"] = "1"
+    args["connect_timeout"] = "2"
     t0 = time.time()
     async with await aconn_cls.connect(**args) as conn:
         elapsed = time.time() - t0
-        assert 1.0 < elapsed < 1.5
+        assert 2.0 < elapsed < 2.5
         assert conn.info.port == int(proxy.server_port)
         assert conn.info.host == proxy.server_host
 
@@ -452,8 +452,7 @@ async def test_connect_args(
     setpgenv({})
     monkeypatch.setattr(psycopg.generators, "connect", fake_connect)
     conn = await aconn_cls.connect(*args, **kwargs)
-    got_params = drop_default_args_from_conninfo(got_conninfo)
-    assert got_params == conninfo_to_dict(want)
+    assert conninfo_to_dict(got_conninfo) == conninfo_to_dict(want)
     await conn.close()
 
 
@@ -832,9 +831,8 @@ def test_set_transaction_param_strange_property(conn):
 async def test_get_connection_params(aconn_cls, dsn, kwargs, exp, setpgenv):
     setpgenv({})
     params = await aconn_cls._get_connection_params(dsn, **kwargs)
-    conninfo = make_conninfo(**params)
-    assert drop_default_args_from_conninfo(conninfo) == exp[0]
-    assert params["connect_timeout"] == exp[1]
+    assert params == exp[0]
+    assert timeout_from_conninfo(params) == exp[1]
 
 
 async def test_connect_context_adapters(aconn_cls, dsn):
index a118786e527be88fa4cae4eba2b20513e82ccecd..a83aaeb6694165ebd0c257c650ad7361063185f0 100644 (file)
@@ -4,7 +4,6 @@ import psycopg
 from psycopg.conninfo import conninfo_to_dict
 
 from .test_conninfo import fake_resolve  # noqa: F401  # fixture
-from ._test_connection import drop_default_args_from_conninfo
 
 
 @pytest.mark.usefixtures("fake_resolve")
@@ -22,7 +21,7 @@ async def test_resolve_hostaddr_conn(aconn_cls, monkeypatch):
 
     assert len(got) == 1
     want = {"host": "foo.com", "hostaddr": "1.1.1.1"}
-    assert drop_default_args_from_conninfo(got[0]) == want
+    assert conninfo_to_dict(got[0]) == want
 
 
 @pytest.mark.dns
index 2e31c6feadb0200845b73568e0f1f8194b639baf..9b144d7d6be31bc0edb32884db1c89e39b69737c 100644 (file)
@@ -3,8 +3,6 @@ import pytest
 from psycopg._cmodule import _psycopg
 from psycopg.conninfo import conninfo_to_dict
 
-from ._test_connection import drop_default_args_from_conninfo
-
 
 @pytest.mark.parametrize(
     "args, kwargs, want",
@@ -22,7 +20,7 @@ def test_connect(monkeypatch, dsn_env, args, kwargs, want, setpgenv):
 
     orig_connect = psycopg.generators.connect
 
-    got_conninfo = None
+    got_conninfo: str
 
     def mock_connect(conninfo):
         nonlocal got_conninfo
@@ -33,7 +31,7 @@ def test_connect(monkeypatch, dsn_env, args, kwargs, want, setpgenv):
     monkeypatch.setattr(psycopg.generators, "connect", mock_connect)
 
     conn = psycopg.connect(*args, **kwargs)
-    assert drop_default_args_from_conninfo(got_conninfo) == conninfo_to_dict(want)
+    assert conninfo_to_dict(got_conninfo) == conninfo_to_dict(want)
     conn.close()
 
 
index e87937dafcd9c2b16397ae2dcb589dd7d2cb1342..2e429eac95c62a5d5ff93448651ffae5e75c8de2 100644 (file)
@@ -7,7 +7,6 @@ from psycopg.conninfo import conninfo_to_dict
 
 from . import dbapi20
 from . import dbapi20_tpc
-from ._test_connection import drop_default_args_from_conninfo
 
 
 @pytest.fixture(scope="class")
@@ -145,7 +144,7 @@ def test_connect_args(monkeypatch, pgconn, args, kwargs, want, setpgenv):
     setpgenv({})
     monkeypatch.setattr(psycopg.generators, "connect", fake_connect)
     conn = psycopg.connect(*args, **kwargs)
-    assert drop_default_args_from_conninfo(got_conninfo) == conninfo_to_dict(want)
+    assert conninfo_to_dict(got_conninfo) == conninfo_to_dict(want)
     conn.close()