]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: set minimum timeout to 2s
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 13 Dec 2023 03:35:48 +0000 (04:35 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 13 Dec 2023 04:00:28 +0000 (05:00 +0100)
This is consistent with what the libpq does.

Move timeout calculation to a function in conninfo module and don't
change the connect_timeout parameter explicitly in the connection
string.

Drop awful drop_default_args_from_conninfo() from tests.

psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
psycopg/psycopg/conninfo.py
tests/test_connection.py
tests/test_connection_async.py
tests/test_dns.py
tests/test_module.py
tests/test_psycopg_dbapi20.py

index e04a994039a0cfb92c2c3ea081073da18e28e37b..a38a98fc39b39b246f41d44cdb7714a149b4a975 100644 (file)
@@ -31,7 +31,7 @@ from .cursor import Cursor
 from ._compat import LiteralString
 from .pq.misc import connection_summary
 from .conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo
-from .conninfo import conninfo_attempts, ConnDict
+from .conninfo import conninfo_attempts, ConnDict, timeout_from_conninfo
 from ._pipeline import BasePipeline, Pipeline
 from .generators import notifies, connect, execute
 from ._encodings import pgconn_encoding
@@ -107,11 +107,6 @@ class BaseConnection(Generic[Row]):
     ConnStatus = pq.ConnStatus
     TransactionStatus = pq.TransactionStatus
 
-    # Default timeout for connection a attempt.
-    # Arbitrary timeout, what applied by the libpq on my computer.
-    # Your mileage won't vary.
-    _DEFAULT_CONNECT_TIMEOUT = 130
-
     def __init__(self, pgconn: "PGconn"):
         self.pgconn = pgconn
         self._autocommit = False
@@ -730,7 +725,7 @@ class Connection(BaseConnection[Row]):
         Connect to a database server and return a new `Connection` instance.
         """
         params = cls._get_connection_params(conninfo, **kwargs)
-        timeout = int(params["connect_timeout"])
+        timeout = timeout_from_conninfo(params)
         rv = None
         attempts = conninfo_attempts(params)
         for attempt in attempts:
@@ -803,18 +798,7 @@ class Connection(BaseConnection[Row]):
         :return: Connection arguments merged and eventually modified, in a
             format similar to `~conninfo.conninfo_to_dict()`.
         """
-        params = conninfo_to_dict(conninfo, **kwargs)
-
-        # Make sure there is an usable connect_timeout
-        if "connect_timeout" in params:
-            params["connect_timeout"] = int(params["connect_timeout"])
-        else:
-            # The sync connect function will stop on the default socket timeout
-            # Because in async connection mode we need to enforce the timeout
-            # ourselves, we need a finite value.
-            params["connect_timeout"] = cls._DEFAULT_CONNECT_TIMEOUT
-
-        return params
+        return conninfo_to_dict(conninfo, **kwargs)
 
     def close(self) -> None:
         """Close the database connection."""
index bea077fe86801f7ca9d9a2355f46d442dd71d63e..88544db997e7dc57e60ecafa31763737b9c47ffa 100644 (file)
@@ -20,7 +20,8 @@ from ._tpc import Xid
 from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
-from .conninfo import ConnDict, make_conninfo, conninfo_to_dict, conninfo_attempts_async
+from .conninfo import ConnDict, make_conninfo, conninfo_to_dict
+from .conninfo import conninfo_attempts_async, timeout_from_conninfo
 from ._pipeline import AsyncPipeline
 from ._encodings import pgconn_encoding
 from .connection import BaseConnection, CursorRow, Notify
@@ -118,7 +119,7 @@ class AsyncConnection(BaseConnection[Row]):
                 )
 
         params = await cls._get_connection_params(conninfo, **kwargs)
-        timeout = int(params["connect_timeout"])
+        timeout = timeout_from_conninfo(params)
         rv = None
         attempts = await conninfo_attempts_async(params)
         for attempt in attempts:
@@ -185,18 +186,7 @@ class AsyncConnection(BaseConnection[Row]):
     @classmethod
     async def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> ConnDict:
         """Manipulate connection parameters before connecting."""
-        params = conninfo_to_dict(conninfo, **kwargs)
-
-        # Make sure there is an usable connect_timeout
-        if "connect_timeout" in params:
-            params["connect_timeout"] = int(params["connect_timeout"])
-        else:
-            # The sync connect function will stop on the default socket timeout
-            # Because in async connection mode we need to enforce the timeout
-            # ourselves, we need a finite value.
-            params["connect_timeout"] = cls._DEFAULT_CONNECT_TIMEOUT
-
-        return params
+        return conninfo_to_dict(conninfo, **kwargs)
 
     async def close(self) -> None:
         if self.closed:
index c38d7af7a0b56c2f932573ae16d29600ee69fdcc..ec36ff09ebf7c80e77511e553eb88422b375547c 100644 (file)
@@ -27,6 +27,11 @@ from ._encodings import pgconn_encoding
 
 ConnDict: TypeAlias = "dict[str, Any]"
 
+# Default timeout for connection a attempt.
+# Arbitrary timeout, what applied by the libpq on my computer.
+# Your mileage won't vary.
+_DEFAULT_CONNECT_TIMEOUT = 130
+
 logger = logging.getLogger("psycopg")
 
 
@@ -430,6 +435,35 @@ async def _resolve_hostnames(params: ConnDict) -> list[ConnDict]:
     return [{**params, "hostaddr": item[4][0]} for item in ans]
 
 
+def timeout_from_conninfo(params: ConnDict) -> int:
+    """
+    Return the timeout in seconds from the connection parameters.
+    """
+    # Follow the libpq convention:
+    #
+    # - 0 or less means no timeout (but we will use a default to simulate
+    #   the socket timeout)
+    # - at least 2 seconds.
+    #
+    # See connectDBComplete in fe-connect.c
+    value = params.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
+    try:
+        timeout = int(value)
+    except ValueError:
+        raise e.ProgrammingError(f"bad value for connect_timeout: {value!r}")
+
+    if timeout <= 0:
+        # The sync connect function will stop on the default socket timeout
+        # Because in async connection mode we need to enforce the timeout
+        # ourselves, we need a finite value.
+        timeout = _DEFAULT_CONNECT_TIMEOUT
+    elif timeout < 2:
+        # Enforce a 2s min
+        timeout = 2
+
+    return timeout
+
+
 def _get_param(params: ConnDict, name: str) -> str | None:
     """
     Return a value from a connection string.
index 754acec3adb228a7326b434b7fa3a1a948b99ac3..0ee8aea3136e7c737616092d6c5ccc3cbd3e8329 100644 (file)
@@ -9,13 +9,12 @@ from dataclasses import dataclass
 import psycopg
 from psycopg import Notify, pq, errors as e
 from psycopg.rows import tuple_row
-from psycopg.conninfo import conninfo_to_dict, make_conninfo
+from psycopg.conninfo import conninfo_to_dict
+from psycopg.conninfo import timeout_from_conninfo, _DEFAULT_CONNECT_TIMEOUT
 
 from .test_cursor import my_row_factory
 from .test_adapt import make_bin_dumper, make_dumper
 
-DEFAULT_TIMEOUT = psycopg.Connection._DEFAULT_CONNECT_TIMEOUT
-
 
 def test_connect(conn_cls, dsn):
     conn = conn_cls.connect(dsn)
@@ -44,9 +43,9 @@ def test_connect_bad(conn_cls):
 def test_connect_timeout(conn_cls, deaf_port):
     t0 = time.time()
     with pytest.raises(psycopg.OperationalError, match="timeout expired"):
-        conn_cls.connect(host="localhost", port=deaf_port, connect_timeout=1)
+        conn_cls.connect(host="localhost", port=deaf_port, connect_timeout=2)
     elapsed = time.time() - t0
-    assert elapsed == pytest.approx(1.0, abs=0.05)
+    assert elapsed == pytest.approx(2.0, abs=0.05)
 
 
 @pytest.mark.slow
@@ -56,7 +55,7 @@ def test_multi_hosts(conn_cls, proxy, dsn, deaf_port, monkeypatch):
     args["host"] = f"{proxy.client_host},{proxy.server_host}"
     args["port"] = f"{deaf_port},{proxy.server_port}"
     args.pop("hostaddr", None)
-    monkeypatch.setattr(conn_cls, "_DEFAULT_CONNECT_TIMEOUT", 2)
+    monkeypatch.setattr(psycopg.conninfo, "_DEFAULT_CONNECT_TIMEOUT", 2)
     t0 = time.time()
     with conn_cls.connect(**args) as conn:
         elapsed = time.time() - t0
@@ -72,11 +71,11 @@ def test_multi_hosts_timeout(conn_cls, proxy, dsn, deaf_port):
     args["host"] = f"{proxy.client_host},{proxy.server_host}"
     args["port"] = f"{deaf_port},{proxy.server_port}"
     args.pop("hostaddr", None)
-    args["connect_timeout"] = "1"
+    args["connect_timeout"] = "2"
     t0 = time.time()
     with conn_cls.connect(**args) as conn:
         elapsed = time.time() - t0
-        assert 1.0 < elapsed < 1.5
+        assert 2.0 < elapsed < 2.5
         assert conn.info.port == int(proxy.server_port)
         assert conn.info.host == proxy.server_host
 
@@ -416,8 +415,7 @@ def test_connect_args(conn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, wan
     setpgenv({})
     monkeypatch.setattr(psycopg.connection, "connect", fake_connect)
     conn = conn_cls.connect(*args, **kwargs)
-    got_params = drop_default_args_from_conninfo(got_conninfo)
-    assert got_params == conninfo_to_dict(want)
+    assert conninfo_to_dict(got_conninfo) == conninfo_to_dict(want)
     conn.close()
 
 
@@ -801,17 +799,17 @@ conninfo_params_timeout = [
     (
         "",
         {"dbname": "mydb", "connect_timeout": None},
-        ({"dbname": "mydb"}, DEFAULT_TIMEOUT),
+        ({"dbname": "mydb"}, _DEFAULT_CONNECT_TIMEOUT),
     ),
     (
         "",
         {"dbname": "mydb", "connect_timeout": 1},
-        ({"dbname": "mydb", "connect_timeout": "1"}, 1),
+        ({"dbname": "mydb", "connect_timeout": 1}, 2),
     ),
     (
         "dbname=postgres",
         {},
-        ({"dbname": "postgres"}, DEFAULT_TIMEOUT),
+        ({"dbname": "postgres"}, _DEFAULT_CONNECT_TIMEOUT),
     ),
     (
         "dbname=postgres connect_timeout=2",
@@ -821,7 +819,7 @@ conninfo_params_timeout = [
     (
         "postgresql:///postgres?connect_timeout=2",
         {"connect_timeout": 10},
-        ({"dbname": "postgres", "connect_timeout": "10"}, 10),
+        ({"dbname": "postgres", "connect_timeout": 10}, 10),
     ),
 ]
 
@@ -829,9 +827,8 @@ conninfo_params_timeout = [
 @pytest.mark.parametrize("dsn, kwargs, exp", conninfo_params_timeout)
 def test_get_connection_params(conn_cls, dsn, kwargs, exp):
     params = conn_cls._get_connection_params(dsn, **kwargs)
-    conninfo = make_conninfo(**params)
-    assert drop_default_args_from_conninfo(conninfo) == exp[0]
-    assert params["connect_timeout"] == exp[1]
+    assert params == exp[0]
+    assert timeout_from_conninfo(params) == exp[1]
 
 
 def test_connect_context(conn_cls, dsn):
@@ -864,18 +861,3 @@ def test_connect_context_copy(conn_cls, dsn, conn):
 def test_cancel_closed(conn):
     conn.close()
     conn.cancel()
-
-
-def drop_default_args_from_conninfo(conninfo):
-    if isinstance(conninfo, str):
-        params = conninfo_to_dict(conninfo)
-    else:
-        params = conninfo.copy()
-
-    def removeif(key, value):
-        if params.get(key) == value:
-            params.pop(key)
-
-    removeif("connect_timeout", str(DEFAULT_TIMEOUT))
-
-    return params
index a35973c69042ed1ed4425aaaf8beca83f12bb141..86ccfe10bf07d55b25315575e6447698c048be57 100644 (file)
@@ -8,11 +8,11 @@ from typing import List, Any
 import psycopg
 from psycopg import Notify, errors as e
 from psycopg.rows import tuple_row
-from psycopg.conninfo import conninfo_to_dict, make_conninfo
+from psycopg.conninfo import conninfo_to_dict, timeout_from_conninfo
 
 from .test_cursor import my_row_factory
 from .test_connection import tx_params, tx_params_isolation, tx_values_map
-from .test_connection import conninfo_params_timeout, drop_default_args_from_conninfo
+from .test_connection import conninfo_params_timeout
 from .test_connection import testctx  # noqa: F401  # fixture
 from .test_adapt import make_bin_dumper, make_dumper
 from .test_conninfo import fake_resolve  # noqa: F401
@@ -52,9 +52,9 @@ async def test_connect_str_subclass(aconn_cls, dsn):
 async def test_connect_timeout(aconn_cls, deaf_port):
     t0 = time.time()
     with pytest.raises(psycopg.OperationalError, match="timeout expired"):
-        await aconn_cls.connect(host="localhost", port=deaf_port, connect_timeout=1)
+        await aconn_cls.connect(host="localhost", port=deaf_port, connect_timeout=2)
     elapsed = time.time() - t0
-    assert elapsed == pytest.approx(1.0, abs=0.05)
+    assert elapsed == pytest.approx(2.0, abs=0.05)
 
 
 @pytest.mark.slow
@@ -64,7 +64,7 @@ async def test_multi_hosts(aconn_cls, proxy, dsn, deaf_port, monkeypatch):
     args["host"] = f"{proxy.client_host},{proxy.server_host}"
     args["port"] = f"{deaf_port},{proxy.server_port}"
     args.pop("hostaddr", None)
-    monkeypatch.setattr(aconn_cls, "_DEFAULT_CONNECT_TIMEOUT", 2)
+    monkeypatch.setattr(psycopg.conninfo, "_DEFAULT_CONNECT_TIMEOUT", 2)
     t0 = time.time()
     async with await aconn_cls.connect(**args) as conn:
         elapsed = time.time() - t0
@@ -80,11 +80,11 @@ async def test_multi_hosts_timeout(aconn_cls, proxy, dsn, deaf_port):
     args["host"] = f"{proxy.client_host},{proxy.server_host}"
     args["port"] = f"{deaf_port},{proxy.server_port}"
     args.pop("hostaddr", None)
-    args["connect_timeout"] = "1"
+    args["connect_timeout"] = "2"
     t0 = time.time()
     async with await aconn_cls.connect(**args) as conn:
         elapsed = time.time() - t0
-        assert 1.0 < elapsed < 1.5
+        assert 2.0 < elapsed < 2.5
         assert conn.info.port == int(proxy.server_port)
         assert conn.info.host == proxy.server_host
 
@@ -423,8 +423,7 @@ async def test_connect_args(
     setpgenv({})
     monkeypatch.setattr(psycopg.connection, "connect", fake_connect)
     conn = await aconn_cls.connect(*args, **kwargs)
-    got_params = drop_default_args_from_conninfo(got_conninfo)
-    assert got_params == conninfo_to_dict(want)
+    assert conninfo_to_dict(got_conninfo) == conninfo_to_dict(want)
     await conn.close()
 
 
@@ -769,9 +768,8 @@ async def test_set_transaction_param_strange(aconn):
 async def test_get_connection_params(aconn_cls, dsn, kwargs, exp, setpgenv):
     setpgenv({})
     params = await aconn_cls._get_connection_params(dsn, **kwargs)
-    conninfo = make_conninfo(**params)
-    assert drop_default_args_from_conninfo(conninfo) == exp[0]
-    assert params["connect_timeout"] == exp[1]
+    assert params == exp[0]
+    assert timeout_from_conninfo(params) == exp[1]
 
 
 async def test_connect_context_adapters(aconn_cls, dsn):
@@ -820,4 +818,4 @@ async def test_resolve_hostaddr_conn(monkeypatch, fake_resolve):  # noqa: F811
 
     assert len(got) == 1
     want = {"host": "foo.com", "hostaddr": "1.1.1.1"}
-    assert drop_default_args_from_conninfo(got[0]) == want
+    assert conninfo_to_dict(got[0]) == want
index efbf6f5030825258133d6333a62a092861429dec..a83aaeb6694165ebd0c257c650ad7361063185f0 100644 (file)
@@ -4,7 +4,6 @@ import psycopg
 from psycopg.conninfo import conninfo_to_dict
 
 from .test_conninfo import fake_resolve  # noqa: F401  # fixture
-from .test_connection import drop_default_args_from_conninfo
 
 
 @pytest.mark.usefixtures("fake_resolve")
@@ -22,7 +21,7 @@ async def test_resolve_hostaddr_conn(aconn_cls, monkeypatch):
 
     assert len(got) == 1
     want = {"host": "foo.com", "hostaddr": "1.1.1.1"}
-    assert drop_default_args_from_conninfo(got[0]) == want
+    assert conninfo_to_dict(got[0]) == want
 
 
 @pytest.mark.dns
index 8abc3ebeab87815de2cce83fd755e3b591947006..49757eb46e547fec8732cacaff76b77a73771c0e 100644 (file)
@@ -3,8 +3,6 @@ import pytest
 from psycopg._cmodule import _psycopg
 from psycopg.conninfo import conninfo_to_dict
 
-from .test_connection import drop_default_args_from_conninfo
-
 
 @pytest.mark.parametrize(
     "args, kwargs, want",
@@ -22,7 +20,7 @@ def test_connect(monkeypatch, dsn_env, args, kwargs, want, setpgenv):
 
     orig_connect = psycopg.connection.connect  # type: ignore
 
-    got_conninfo = None
+    got_conninfo: str
 
     def mock_connect(conninfo):
         nonlocal got_conninfo
@@ -33,7 +31,7 @@ def test_connect(monkeypatch, dsn_env, args, kwargs, want, setpgenv):
     monkeypatch.setattr(psycopg.connection, "connect", mock_connect)
 
     conn = psycopg.connect(*args, **kwargs)
-    assert drop_default_args_from_conninfo(got_conninfo) == conninfo_to_dict(want)
+    assert conninfo_to_dict(got_conninfo) == conninfo_to_dict(want)
     conn.close()
 
 
index 69c2fa756708c680b60b22e033cbe4de3ba69add..3c4ae3ac5fccdf3a3bd19a868782b4d737a76ca2 100644 (file)
@@ -7,7 +7,6 @@ from psycopg.conninfo import conninfo_to_dict
 
 from . import dbapi20
 from . import dbapi20_tpc
-from .test_connection import drop_default_args_from_conninfo
 
 
 @pytest.fixture(scope="class")
@@ -145,7 +144,7 @@ def test_connect_args(monkeypatch, pgconn, args, kwargs, want, setpgenv):
     setpgenv({})
     monkeypatch.setattr(psycopg.connection, "connect", fake_connect)
     conn = psycopg.connect(*args, **kwargs)
-    assert drop_default_args_from_conninfo(got_conninfo) == conninfo_to_dict(want)
+    assert conninfo_to_dict(got_conninfo) == conninfo_to_dict(want)
     conn.close()