]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat(pool): allow conninfo/kwargs to be callable
authordsuhinin <suhinin.dmitriy@gmail.com>
Wed, 10 Sep 2025 15:08:04 +0000 (17:08 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 30 Oct 2025 13:18:54 +0000 (13:18 +0000)
Allow to change connection pool credentials dynamically without
requiring pool re-creation.

docs/news_pool.rst
psycopg_pool/psycopg_pool/_acompat.py
psycopg_pool/psycopg_pool/abc.py
psycopg_pool/psycopg_pool/base.py
psycopg_pool/psycopg_pool/null_pool.py
psycopg_pool/psycopg_pool/null_pool_async.py
psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py
tests/pool/test_pool.py
tests/pool/test_pool_async.py
tools/async_to_sync.py

index 79bc73bd52cf51257a6d1b5b1bfc15a2aed36af4..dd65a2d98878dadffeca19fa1778b96c505f83e3 100644 (file)
@@ -15,6 +15,8 @@ psycopg_pool 3.3.0 (unreleased)
 
 - Add `!close_returns` for :ref:`integration with SQLAlchemy <pool-sqlalchemy>`
   (:ticket:`#1046`).
+- Allow `!conninfo` and `!kwargs` to be callable to allow connection
+  parameters# update (:ticket:`#851`).
 
 
 Current release
index 136c948dd7ff2bf6dc431694f70a3a8c46b56a12..0d1330a09856abdecb04db8f48dfaeabeb803107 100644 (file)
@@ -15,7 +15,7 @@ import queue
 import asyncio
 import logging
 import threading
-from typing import Any, ParamSpec, TypeAlias, overload
+from typing import Any, Awaitable, ParamSpec, TypeAlias, overload
 from inspect import isawaitable
 from collections.abc import Callable, Coroutine
 
@@ -166,7 +166,7 @@ def asleep(seconds: float) -> Coroutine[Any, Any, None]:
 
 @overload
 async def ensure_async(
-    f: Callable[P, Coroutine[Any, Any, T]], *args: P.args, **kwargs: P.kwargs
+    f: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs
 ) -> T: ...
 
 
index 16d6d745cbd5b8d46191b1df2e5620d73b8ff3a0..12c1d5d32541ab5e0e4f6ddf4e9390ccf280d1ca 100644 (file)
@@ -6,14 +6,12 @@ Types used in the psycopg_pool package
 
 from __future__ import annotations
 
-from typing import TYPE_CHECKING, TypeAlias, Union
+from typing import TYPE_CHECKING, Any, TypeAlias, Union
 from collections.abc import Awaitable, Callable
 
 from ._compat import TypeVar
 
 if TYPE_CHECKING:
-    from typing import Any
-
     from psycopg import AsyncConnection, Connection  # noqa: F401
     from psycopg.rows import TupleRow  # noqa: F401
 
@@ -34,3 +32,17 @@ AsyncConnectFailedCB: TypeAlias = Union[
     Callable[["AsyncConnectionPool[Any]"], None],
     Callable[["AsyncConnectionPool[Any]"], Awaitable[None]],
 ]
+
+# Types of the connection parameters
+ConninfoParam: TypeAlias = Union[str, Callable[[], str]]
+AsyncConninfoParam: TypeAlias = Union[
+    str,
+    Callable[[], str],
+    Callable[[], Awaitable[str]],
+]
+KwargsParam: TypeAlias = Union[dict[str, Any], Callable[[], dict[str, Any]]]
+AsyncKwargsParam: TypeAlias = Union[
+    dict[str, Any],
+    Callable[[], dict[str, Any]],
+    Callable[[], Awaitable[dict[str, Any]]],
+]
index ad109a0d4f1ff1f8be1b4a84f540f10ad2490ef7..745c995c78136c57b9c7c32115d4538fae737835 100644 (file)
@@ -44,9 +44,7 @@ class BasePool:
 
     def __init__(
         self,
-        conninfo: str = "",
         *,
-        kwargs: dict[str, Any] | None,
         min_size: int,
         max_size: int | None,
         name: str | None,
@@ -67,8 +65,6 @@ class BasePool:
         if num_workers < 1:
             raise ValueError("num_workers must be at least 1")
 
-        self.conninfo = conninfo
-        self.kwargs: dict[str, Any] = kwargs or {}
         self.name = name
         self.close_returns = close_returns
         self._min_size = min_size
index bf027e9dd4783204db1ebf0a54d9e41c27084a39..59a34d2421c0849e4d90894092cc1f9f4f33fa18 100644 (file)
@@ -10,12 +10,12 @@ Psycopg null connection pool module (sync version).
 from __future__ import annotations
 
 import logging
-from typing import Any, cast
+from typing import cast
 
 from psycopg import Connection
 from psycopg.pq import TransactionStatus
 
-from .abc import CT, ConnectFailedCB, ConnectionCB
+from .abc import CT, ConnectFailedCB, ConnectionCB, ConninfoParam, KwargsParam
 from .pool import AddConnection, ConnectionPool
 from .errors import PoolTimeout, TooManyRequests
 from ._compat import ConnectionTimeout
@@ -29,10 +29,10 @@ class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool[CT]):
 
     def __init__(
         self,
-        conninfo: str = "",
+        conninfo: ConninfoParam = "",
         *,
         connection_class: type[CT] = cast(type[CT], Connection),
-        kwargs: dict[str, Any] | None = None,
+        kwargs: KwargsParam | None = None,
         min_size: int = 0,
         max_size: int | None = None,
         open: bool | None = None,
index a18c59c3fa8867cc1494c499556b9f350c6111c3..8329c202d99cdc0ce4f632fee7b8a1aee191761c 100644 (file)
@@ -7,12 +7,13 @@ Psycopg null connection pool module (async version).
 from __future__ import annotations
 
 import logging
-from typing import Any, cast
+from typing import cast
 
 from psycopg import AsyncConnection
 from psycopg.pq import TransactionStatus
 
-from .abc import ACT, AsyncConnectFailedCB, AsyncConnectionCB
+from .abc import ACT, AsyncConnectFailedCB, AsyncConnectionCB, AsyncConninfoParam
+from .abc import AsyncKwargsParam
 from .errors import PoolTimeout, TooManyRequests
 from ._compat import ConnectionTimeout
 from ._acompat import AEvent
@@ -23,12 +24,13 @@ logger = logging.getLogger("psycopg.pool")
 
 
 class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool[ACT]):
+
     def __init__(
         self,
-        conninfo: str = "",
+        conninfo: AsyncConninfoParam = "",
         *,
         connection_class: type[ACT] = cast(type[ACT], AsyncConnection),
-        kwargs: dict[str, Any] | None = None,
+        kwargs: AsyncKwargsParam | None = None,
         min_size: int = 0,  # Note: min_size default value changed to 0.
         max_size: int | None = None,
         open: bool | None = None,
index 68cb8b899754c971fd26a2a11b6d5484e592ac77..619dad34ebf59616a0db831c7f08f4de2376f9aa 100644 (file)
@@ -24,7 +24,7 @@ from psycopg import Connection
 from psycopg import errors as e
 from psycopg.pq import TransactionStatus
 
-from .abc import CT, ConnectFailedCB, ConnectionCB
+from .abc import CT, ConnectFailedCB, ConnectionCB, ConninfoParam, KwargsParam
 from .base import AttemptWithBackoff, BasePool
 from .sched import Scheduler
 from .errors import PoolClosed, PoolTimeout, TooManyRequests
@@ -40,10 +40,10 @@ class ConnectionPool(Generic[CT], BasePool):
 
     def __init__(
         self,
-        conninfo: str = "",
+        conninfo: ConninfoParam = "",
         *,
         connection_class: type[CT] = cast(type[CT], Connection),
-        kwargs: dict[str, Any] | None = None,
+        kwargs: KwargsParam | None = None,
         min_size: int = 4,
         max_size: int | None = None,
         open: bool | None = None,
@@ -67,7 +67,8 @@ class ConnectionPool(Generic[CT], BasePool):
                 raise TypeError(
                     "Using 'close_returns=True' and a non-standard 'connection_class' requires psycopg 3.3 or newer. Please check the docs at https://www.psycopg.org/psycopg3/docs/advanced/pool.html#pool-sqlalchemy for a workaround."
                 )
-
+        self.conninfo = conninfo
+        self.kwargs = kwargs
         self.connection_class = connection_class
         self._check = check
         self._configure = configure
@@ -90,8 +91,6 @@ class ConnectionPool(Generic[CT], BasePool):
         self._workers: list[Worker] = []
 
         super().__init__(
-            conninfo,
-            kwargs=kwargs,
             min_size=min_size,
             max_size=max_size,
             name=name,
@@ -599,13 +598,14 @@ class ConnectionPool(Generic[CT], BasePool):
     def _connect(self, timeout: float | None = None) -> CT:
         """Return a new connection configured for the pool."""
         self._stats[self._CONNECTIONS_NUM] += 1
-        kwargs = self.kwargs
+        conninfo = self._resolve_conninfo()
+        kwargs = self._resolve_kwargs()
         if timeout:
             kwargs = kwargs.copy()
             kwargs["connect_timeout"] = max(round(timeout), 1)
         t0 = monotonic()
         try:
-            conn = self.connection_class.connect(self.conninfo, **kwargs)
+            conn = self.connection_class.connect(conninfo, **kwargs)
         except Exception:
             self._stats[self._CONNECTIONS_ERRORS] += 1
             raise
@@ -627,6 +627,23 @@ class ConnectionPool(Generic[CT], BasePool):
         self._set_connection_expiry_date(conn)
         return conn
 
+    def _resolve_conninfo(self) -> str:
+        """Resolve conninfo (static string, sync callable, or async callable)."""
+        if callable(self.conninfo):
+            return self.conninfo()
+
+        return self.conninfo or ""
+
+    def _resolve_kwargs(self) -> dict[str, Any]:
+        """Resolve kwargs (static dict, sync callable, or async callable)."""
+        if not self.kwargs:
+            return {}
+
+        if callable(self.kwargs):
+            return self.kwargs()
+
+        return self.kwargs
+
     def _add_connection(
         self, attempt: AttemptWithBackoff | None, growing: bool = False
     ) -> None:
index 7441c3a70bfe1e2728603c8c1f968a4e26d55a79..58a943034f6bd607d377df215a586b730d3cbe75 100644 (file)
@@ -21,7 +21,8 @@ from psycopg import AsyncConnection
 from psycopg import errors as e
 from psycopg.pq import TransactionStatus
 
-from .abc import ACT, AsyncConnectFailedCB, AsyncConnectionCB
+from .abc import ACT, AsyncConnectFailedCB, AsyncConnectionCB, AsyncConninfoParam
+from .abc import AsyncKwargsParam
 from .base import AttemptWithBackoff, BasePool
 from .errors import PoolClosed, PoolTimeout, TooManyRequests
 from ._compat import PSYCOPG_VERSION, AsyncPoolConnection, Self
@@ -40,10 +41,10 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
 
     def __init__(
         self,
-        conninfo: str = "",
+        conninfo: AsyncConninfoParam = "",
         *,
         connection_class: type[ACT] = cast(type[ACT], AsyncConnection),
-        kwargs: dict[str, Any] | None = None,
+        kwargs: AsyncKwargsParam | None = None,
         min_size: int = 4,
         max_size: int | None = None,
         open: bool | None = None,
@@ -70,7 +71,8 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
                     " https://www.psycopg.org/psycopg3/docs/advanced/pool.html"
                     "#pool-sqlalchemy for a workaround."
                 )
-
+        self.conninfo = conninfo
+        self.kwargs = kwargs
         self.connection_class = connection_class
         self._check = check
         self._configure = configure
@@ -93,8 +95,6 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         self._workers: list[AWorker] = []
 
         super().__init__(
-            conninfo,
-            kwargs=kwargs,
             min_size=min_size,
             max_size=max_size,
             name=name,
@@ -648,13 +648,14 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
     async def _connect(self, timeout: float | None = None) -> ACT:
         """Return a new connection configured for the pool."""
         self._stats[self._CONNECTIONS_NUM] += 1
-        kwargs = self.kwargs
+        conninfo = await self._resolve_conninfo()
+        kwargs = await self._resolve_kwargs()
         if timeout:
             kwargs = kwargs.copy()
             kwargs["connect_timeout"] = max(round(timeout), 1)
         t0 = monotonic()
         try:
-            conn = await self.connection_class.connect(self.conninfo, **kwargs)
+            conn = await self.connection_class.connect(conninfo, **kwargs)
         except Exception:
             self._stats[self._CONNECTIONS_ERRORS] += 1
             raise
@@ -677,6 +678,23 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         self._set_connection_expiry_date(conn)
         return conn
 
+    async def _resolve_conninfo(self) -> str:
+        """Resolve conninfo (static string, sync callable, or async callable)."""
+        if callable(self.conninfo):
+            return await ensure_async(self.conninfo)
+
+        return self.conninfo or ""
+
+    async def _resolve_kwargs(self) -> dict[str, Any]:
+        """Resolve kwargs (static dict, sync callable, or async callable)."""
+        if not self.kwargs:
+            return {}
+
+        if callable(self.kwargs):
+            return await ensure_async(self.kwargs)
+
+        return self.kwargs
+
     async def _add_connection(
         self, attempt: AttemptWithBackoff | None, growing: bool = False
     ) -> None:
index ebbd4bfe13756ec4711a91a537b93648b55bbcb2..965e5e34d0bdba9f8d0f18bef6fab95666112b30 100644 (file)
@@ -1170,3 +1170,94 @@ def test_override_close_no_loop_subclass(dsn):
         conn.close()
         sleep(0.1)
         assert len(p._pool) == 1
+
+
+def test_get_config_rotates_connections(dsn):
+    config_rotation_counter = 0
+
+    def rotating_config():
+        nonlocal config_rotation_counter
+        config_rotation_counter += 1
+        return dsn
+
+    app_names = ["app-1", "app-2"]
+    kwargs_counter = 0
+
+    def rotating_kwargs():
+        # Return a different application_name for each new connection
+        nonlocal kwargs_counter
+        kwargs_counter += 1
+        return {"application_name": app_names[kwargs_counter % len(app_names)]}
+
+    p = pool.ConnectionPool(
+        conninfo=rotating_config,
+        kwargs=rotating_kwargs,
+        min_size=2,
+        max_lifetime=0.2,
+        open=False,
+    )
+
+    try:
+        p.open()
+        p.wait()
+
+        # Make sure we created two connections (rotating_config called twice)
+        assert config_rotation_counter == 2
+
+        # Acquire both connections and check application_name
+        with p.connection() as conn1, p.connection() as conn2:
+            row1 = conn1.execute("SHOW application_name")
+            row2 = conn2.execute("SHOW application_name")
+
+            name1 = row1.fetchone()
+            assert (
+                name1 is not None
+            ), "first call to SHOW application_name returned no rows"
+            assert name1[0] in app_names
+
+            name2 = row2.fetchone()
+            assert (
+                name2 is not None
+            ), "second call to SHOW application_name returned no rows"
+            assert name2[0] in app_names
+
+            # Make sure that names are different.
+            assert name1 != name2
+    finally:
+        p.close()
+
+
+@pytest.mark.slow
+def test_get_config_raise_exception(dsn, caplog):
+
+    def failing_conninfo():
+        raise RuntimeError("cannot build conninfo")
+
+    def failing_kwargs():
+        raise RuntimeError("cannot build kwargs")
+
+    p = pool.ConnectionPool(
+        conninfo=failing_conninfo,
+        kwargs=failing_kwargs,
+        min_size=1,
+        max_lifetime=0.1,
+        open=False,
+        reconnect_timeout=1.0,
+    )
+
+    with caplog.at_level("WARNING"):
+        try:
+            p.open()
+            with pytest.raises(pool.PoolTimeout):
+                p.wait(timeout=2.0)
+        finally:
+            p.close()
+
+    # Ensure log contains "reconnection attempt" warning`
+    reconnection_warnings = [
+        rec for rec in caplog.records if "reconnection attempt" in rec.message.lower()
+    ]
+
+    assert reconnection_warnings, "Expected reconnection attempt logs"
+    # Make sure that we saw not too many (backoff works)
+    assert len(reconnection_warnings) < 5, "Too many attempts (likely busyloop)"
index dfda392b663e52aed41898c5495fbeaa254ade64..e1c52ec273720231a2a78295b987a5243ea69bec 100644 (file)
@@ -1170,3 +1170,96 @@ async def test_override_close_no_loop_subclass(dsn):
         await conn.close()
         await asleep(0.1)
         assert len(p._pool) == 1
+
+
+async def test_get_config_rotates_connections(dsn):
+
+    config_rotation_counter = 0
+
+    async def rotating_config():
+        nonlocal config_rotation_counter
+        config_rotation_counter += 1
+        return dsn
+
+    app_names = ["app-1", "app-2"]
+    kwargs_counter = 0
+
+    async def rotating_kwargs():
+        # Return a different application_name for each new connection
+        nonlocal kwargs_counter
+        kwargs_counter += 1
+        return {"application_name": app_names[kwargs_counter % len(app_names)]}
+
+    p = pool.AsyncConnectionPool(
+        conninfo=rotating_config,
+        kwargs=rotating_kwargs,
+        min_size=2,
+        max_lifetime=0.2,
+        open=False,
+    )
+
+    try:
+        await p.open()
+        await p.wait()
+
+        # Make sure we created two connections (rotating_config called twice)
+        assert config_rotation_counter == 2
+
+        # Acquire both connections and check application_name
+        async with p.connection() as conn1, p.connection() as conn2:
+            row1 = await conn1.execute("SHOW application_name")
+            row2 = await conn2.execute("SHOW application_name")
+
+            name1 = await row1.fetchone()
+            assert (
+                name1 is not None
+            ), "first call to SHOW application_name returned no rows"
+            assert name1[0] in app_names
+
+            name2 = await row2.fetchone()
+            assert (
+                name2 is not None
+            ), "second call to SHOW application_name returned no rows"
+            assert name2[0] in app_names
+
+            # Make sure that names are different.
+            assert name1 != name2
+
+    finally:
+        await p.close()
+
+
+@pytest.mark.slow
+async def test_get_config_raise_exception(dsn, caplog):
+
+    async def failing_conninfo():
+        raise RuntimeError("cannot build conninfo")
+
+    async def failing_kwargs():
+        raise RuntimeError("cannot build kwargs")
+
+    p = pool.AsyncConnectionPool(
+        conninfo=failing_conninfo,
+        kwargs=failing_kwargs,
+        min_size=1,
+        max_lifetime=0.1,
+        open=False,
+        reconnect_timeout=1.0,
+    )
+
+    with caplog.at_level("WARNING"):
+        try:
+            await p.open()
+            with pytest.raises(pool.PoolTimeout):
+                await p.wait(timeout=2.0)
+        finally:
+            await p.close()
+
+    # Ensure log contains "reconnection attempt" warning`
+    reconnection_warnings = [
+        rec for rec in caplog.records if "reconnection attempt" in rec.message.lower()
+    ]
+
+    assert reconnection_warnings, "Expected reconnection attempt logs"
+    # Make sure that we saw not too many (backoff works)
+    assert len(reconnection_warnings) < 5, "Too many attempts (likely busyloop)"
index 7aee1a81e40873d85cf8d683bf9d539648eb7030..b23b5d02a6d2f8b5a1c1f2b77c34de9fda72c60c 100755 (executable)
@@ -309,6 +309,8 @@ class RenameAsyncToSync(ast.NodeTransformer):  # type: ignore
         "AsyncServerCursor": "ServerCursor",
         "AsyncTransaction": "Transaction",
         "AsyncWriter": "Writer",
+        "AsyncKwargsParam": "KwargsParam",
+        "AsyncConninfoParam": "ConninfoParam",
         "StopAsyncIteration": "StopIteration",
         "__aenter__": "__enter__",
         "__aexit__": "__exit__",