From: dsuhinin Date: Wed, 10 Sep 2025 15:08:04 +0000 (+0200) Subject: feat(pool): allow conninfo/kwargs to be callable X-Git-Tag: 3.3.0~27^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ca4d36ecf2397073bfdde895bc40d09b2e92188d;p=thirdparty%2Fpsycopg.git feat(pool): allow conninfo/kwargs to be callable Allow to change connection pool credentials dynamically without requiring pool re-creation. --- diff --git a/docs/news_pool.rst b/docs/news_pool.rst index 79bc73bd5..dd65a2d98 100644 --- a/docs/news_pool.rst +++ b/docs/news_pool.rst @@ -15,6 +15,8 @@ psycopg_pool 3.3.0 (unreleased) - Add `!close_returns` for :ref:`integration with SQLAlchemy ` (:ticket:`#1046`). +- Allow `!conninfo` and `!kwargs` to be callable to allow connection + parameters# update (:ticket:`#851`). Current release diff --git a/psycopg_pool/psycopg_pool/_acompat.py b/psycopg_pool/psycopg_pool/_acompat.py index 136c948dd..0d1330a09 100644 --- a/psycopg_pool/psycopg_pool/_acompat.py +++ b/psycopg_pool/psycopg_pool/_acompat.py @@ -15,7 +15,7 @@ import queue import asyncio import logging import threading -from typing import Any, ParamSpec, TypeAlias, overload +from typing import Any, Awaitable, ParamSpec, TypeAlias, overload from inspect import isawaitable from collections.abc import Callable, Coroutine @@ -166,7 +166,7 @@ def asleep(seconds: float) -> Coroutine[Any, Any, None]: @overload async def ensure_async( - f: Callable[P, Coroutine[Any, Any, T]], *args: P.args, **kwargs: P.kwargs + f: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs ) -> T: ... diff --git a/psycopg_pool/psycopg_pool/abc.py b/psycopg_pool/psycopg_pool/abc.py index 16d6d745c..12c1d5d32 100644 --- a/psycopg_pool/psycopg_pool/abc.py +++ b/psycopg_pool/psycopg_pool/abc.py @@ -6,14 +6,12 @@ Types used in the psycopg_pool package from __future__ import annotations -from typing import TYPE_CHECKING, TypeAlias, Union +from typing import TYPE_CHECKING, Any, TypeAlias, Union from collections.abc import Awaitable, Callable from ._compat import TypeVar if TYPE_CHECKING: - from typing import Any - from psycopg import AsyncConnection, Connection # noqa: F401 from psycopg.rows import TupleRow # noqa: F401 @@ -34,3 +32,17 @@ AsyncConnectFailedCB: TypeAlias = Union[ Callable[["AsyncConnectionPool[Any]"], None], Callable[["AsyncConnectionPool[Any]"], Awaitable[None]], ] + +# Types of the connection parameters +ConninfoParam: TypeAlias = Union[str, Callable[[], str]] +AsyncConninfoParam: TypeAlias = Union[ + str, + Callable[[], str], + Callable[[], Awaitable[str]], +] +KwargsParam: TypeAlias = Union[dict[str, Any], Callable[[], dict[str, Any]]] +AsyncKwargsParam: TypeAlias = Union[ + dict[str, Any], + Callable[[], dict[str, Any]], + Callable[[], Awaitable[dict[str, Any]]], +] diff --git a/psycopg_pool/psycopg_pool/base.py b/psycopg_pool/psycopg_pool/base.py index ad109a0d4..745c995c7 100644 --- a/psycopg_pool/psycopg_pool/base.py +++ b/psycopg_pool/psycopg_pool/base.py @@ -44,9 +44,7 @@ class BasePool: def __init__( self, - conninfo: str = "", *, - kwargs: dict[str, Any] | None, min_size: int, max_size: int | None, name: str | None, @@ -67,8 +65,6 @@ class BasePool: if num_workers < 1: raise ValueError("num_workers must be at least 1") - self.conninfo = conninfo - self.kwargs: dict[str, Any] = kwargs or {} self.name = name self.close_returns = close_returns self._min_size = min_size diff --git a/psycopg_pool/psycopg_pool/null_pool.py b/psycopg_pool/psycopg_pool/null_pool.py index bf027e9dd..59a34d242 100644 --- a/psycopg_pool/psycopg_pool/null_pool.py +++ b/psycopg_pool/psycopg_pool/null_pool.py @@ -10,12 +10,12 @@ Psycopg null connection pool module (sync version). from __future__ import annotations import logging -from typing import Any, cast +from typing import cast from psycopg import Connection from psycopg.pq import TransactionStatus -from .abc import CT, ConnectFailedCB, ConnectionCB +from .abc import CT, ConnectFailedCB, ConnectionCB, ConninfoParam, KwargsParam from .pool import AddConnection, ConnectionPool from .errors import PoolTimeout, TooManyRequests from ._compat import ConnectionTimeout @@ -29,10 +29,10 @@ class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool[CT]): def __init__( self, - conninfo: str = "", + conninfo: ConninfoParam = "", *, connection_class: type[CT] = cast(type[CT], Connection), - kwargs: dict[str, Any] | None = None, + kwargs: KwargsParam | None = None, min_size: int = 0, max_size: int | None = None, open: bool | None = None, diff --git a/psycopg_pool/psycopg_pool/null_pool_async.py b/psycopg_pool/psycopg_pool/null_pool_async.py index a18c59c3f..8329c202d 100644 --- a/psycopg_pool/psycopg_pool/null_pool_async.py +++ b/psycopg_pool/psycopg_pool/null_pool_async.py @@ -7,12 +7,13 @@ Psycopg null connection pool module (async version). from __future__ import annotations import logging -from typing import Any, cast +from typing import cast from psycopg import AsyncConnection from psycopg.pq import TransactionStatus -from .abc import ACT, AsyncConnectFailedCB, AsyncConnectionCB +from .abc import ACT, AsyncConnectFailedCB, AsyncConnectionCB, AsyncConninfoParam +from .abc import AsyncKwargsParam from .errors import PoolTimeout, TooManyRequests from ._compat import ConnectionTimeout from ._acompat import AEvent @@ -23,12 +24,13 @@ logger = logging.getLogger("psycopg.pool") class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool[ACT]): + def __init__( self, - conninfo: str = "", + conninfo: AsyncConninfoParam = "", *, connection_class: type[ACT] = cast(type[ACT], AsyncConnection), - kwargs: dict[str, Any] | None = None, + kwargs: AsyncKwargsParam | None = None, min_size: int = 0, # Note: min_size default value changed to 0. max_size: int | None = None, open: bool | None = None, diff --git a/psycopg_pool/psycopg_pool/pool.py b/psycopg_pool/psycopg_pool/pool.py index 68cb8b899..619dad34e 100644 --- a/psycopg_pool/psycopg_pool/pool.py +++ b/psycopg_pool/psycopg_pool/pool.py @@ -24,7 +24,7 @@ from psycopg import Connection from psycopg import errors as e from psycopg.pq import TransactionStatus -from .abc import CT, ConnectFailedCB, ConnectionCB +from .abc import CT, ConnectFailedCB, ConnectionCB, ConninfoParam, KwargsParam from .base import AttemptWithBackoff, BasePool from .sched import Scheduler from .errors import PoolClosed, PoolTimeout, TooManyRequests @@ -40,10 +40,10 @@ class ConnectionPool(Generic[CT], BasePool): def __init__( self, - conninfo: str = "", + conninfo: ConninfoParam = "", *, connection_class: type[CT] = cast(type[CT], Connection), - kwargs: dict[str, Any] | None = None, + kwargs: KwargsParam | None = None, min_size: int = 4, max_size: int | None = None, open: bool | None = None, @@ -67,7 +67,8 @@ class ConnectionPool(Generic[CT], BasePool): raise TypeError( "Using 'close_returns=True' and a non-standard 'connection_class' requires psycopg 3.3 or newer. Please check the docs at https://www.psycopg.org/psycopg3/docs/advanced/pool.html#pool-sqlalchemy for a workaround." ) - + self.conninfo = conninfo + self.kwargs = kwargs self.connection_class = connection_class self._check = check self._configure = configure @@ -90,8 +91,6 @@ class ConnectionPool(Generic[CT], BasePool): self._workers: list[Worker] = [] super().__init__( - conninfo, - kwargs=kwargs, min_size=min_size, max_size=max_size, name=name, @@ -599,13 +598,14 @@ class ConnectionPool(Generic[CT], BasePool): def _connect(self, timeout: float | None = None) -> CT: """Return a new connection configured for the pool.""" self._stats[self._CONNECTIONS_NUM] += 1 - kwargs = self.kwargs + conninfo = self._resolve_conninfo() + kwargs = self._resolve_kwargs() if timeout: kwargs = kwargs.copy() kwargs["connect_timeout"] = max(round(timeout), 1) t0 = monotonic() try: - conn = self.connection_class.connect(self.conninfo, **kwargs) + conn = self.connection_class.connect(conninfo, **kwargs) except Exception: self._stats[self._CONNECTIONS_ERRORS] += 1 raise @@ -627,6 +627,23 @@ class ConnectionPool(Generic[CT], BasePool): self._set_connection_expiry_date(conn) return conn + def _resolve_conninfo(self) -> str: + """Resolve conninfo (static string, sync callable, or async callable).""" + if callable(self.conninfo): + return self.conninfo() + + return self.conninfo or "" + + def _resolve_kwargs(self) -> dict[str, Any]: + """Resolve kwargs (static dict, sync callable, or async callable).""" + if not self.kwargs: + return {} + + if callable(self.kwargs): + return self.kwargs() + + return self.kwargs + def _add_connection( self, attempt: AttemptWithBackoff | None, growing: bool = False ) -> None: diff --git a/psycopg_pool/psycopg_pool/pool_async.py b/psycopg_pool/psycopg_pool/pool_async.py index 7441c3a70..58a943034 100644 --- a/psycopg_pool/psycopg_pool/pool_async.py +++ b/psycopg_pool/psycopg_pool/pool_async.py @@ -21,7 +21,8 @@ from psycopg import AsyncConnection from psycopg import errors as e from psycopg.pq import TransactionStatus -from .abc import ACT, AsyncConnectFailedCB, AsyncConnectionCB +from .abc import ACT, AsyncConnectFailedCB, AsyncConnectionCB, AsyncConninfoParam +from .abc import AsyncKwargsParam from .base import AttemptWithBackoff, BasePool from .errors import PoolClosed, PoolTimeout, TooManyRequests from ._compat import PSYCOPG_VERSION, AsyncPoolConnection, Self @@ -40,10 +41,10 @@ class AsyncConnectionPool(Generic[ACT], BasePool): def __init__( self, - conninfo: str = "", + conninfo: AsyncConninfoParam = "", *, connection_class: type[ACT] = cast(type[ACT], AsyncConnection), - kwargs: dict[str, Any] | None = None, + kwargs: AsyncKwargsParam | None = None, min_size: int = 4, max_size: int | None = None, open: bool | None = None, @@ -70,7 +71,8 @@ class AsyncConnectionPool(Generic[ACT], BasePool): " https://www.psycopg.org/psycopg3/docs/advanced/pool.html" "#pool-sqlalchemy for a workaround." ) - + self.conninfo = conninfo + self.kwargs = kwargs self.connection_class = connection_class self._check = check self._configure = configure @@ -93,8 +95,6 @@ class AsyncConnectionPool(Generic[ACT], BasePool): self._workers: list[AWorker] = [] super().__init__( - conninfo, - kwargs=kwargs, min_size=min_size, max_size=max_size, name=name, @@ -648,13 +648,14 @@ class AsyncConnectionPool(Generic[ACT], BasePool): async def _connect(self, timeout: float | None = None) -> ACT: """Return a new connection configured for the pool.""" self._stats[self._CONNECTIONS_NUM] += 1 - kwargs = self.kwargs + conninfo = await self._resolve_conninfo() + kwargs = await self._resolve_kwargs() if timeout: kwargs = kwargs.copy() kwargs["connect_timeout"] = max(round(timeout), 1) t0 = monotonic() try: - conn = await self.connection_class.connect(self.conninfo, **kwargs) + conn = await self.connection_class.connect(conninfo, **kwargs) except Exception: self._stats[self._CONNECTIONS_ERRORS] += 1 raise @@ -677,6 +678,23 @@ class AsyncConnectionPool(Generic[ACT], BasePool): self._set_connection_expiry_date(conn) return conn + async def _resolve_conninfo(self) -> str: + """Resolve conninfo (static string, sync callable, or async callable).""" + if callable(self.conninfo): + return await ensure_async(self.conninfo) + + return self.conninfo or "" + + async def _resolve_kwargs(self) -> dict[str, Any]: + """Resolve kwargs (static dict, sync callable, or async callable).""" + if not self.kwargs: + return {} + + if callable(self.kwargs): + return await ensure_async(self.kwargs) + + return self.kwargs + async def _add_connection( self, attempt: AttemptWithBackoff | None, growing: bool = False ) -> None: diff --git a/tests/pool/test_pool.py b/tests/pool/test_pool.py index ebbd4bfe1..965e5e34d 100644 --- a/tests/pool/test_pool.py +++ b/tests/pool/test_pool.py @@ -1170,3 +1170,94 @@ def test_override_close_no_loop_subclass(dsn): conn.close() sleep(0.1) assert len(p._pool) == 1 + + +def test_get_config_rotates_connections(dsn): + config_rotation_counter = 0 + + def rotating_config(): + nonlocal config_rotation_counter + config_rotation_counter += 1 + return dsn + + app_names = ["app-1", "app-2"] + kwargs_counter = 0 + + def rotating_kwargs(): + # Return a different application_name for each new connection + nonlocal kwargs_counter + kwargs_counter += 1 + return {"application_name": app_names[kwargs_counter % len(app_names)]} + + p = pool.ConnectionPool( + conninfo=rotating_config, + kwargs=rotating_kwargs, + min_size=2, + max_lifetime=0.2, + open=False, + ) + + try: + p.open() + p.wait() + + # Make sure we created two connections (rotating_config called twice) + assert config_rotation_counter == 2 + + # Acquire both connections and check application_name + with p.connection() as conn1, p.connection() as conn2: + row1 = conn1.execute("SHOW application_name") + row2 = conn2.execute("SHOW application_name") + + name1 = row1.fetchone() + assert ( + name1 is not None + ), "first call to SHOW application_name returned no rows" + assert name1[0] in app_names + + name2 = row2.fetchone() + assert ( + name2 is not None + ), "second call to SHOW application_name returned no rows" + assert name2[0] in app_names + + # Make sure that names are different. + assert name1 != name2 + finally: + p.close() + + +@pytest.mark.slow +def test_get_config_raise_exception(dsn, caplog): + + def failing_conninfo(): + raise RuntimeError("cannot build conninfo") + + def failing_kwargs(): + raise RuntimeError("cannot build kwargs") + + p = pool.ConnectionPool( + conninfo=failing_conninfo, + kwargs=failing_kwargs, + min_size=1, + max_lifetime=0.1, + open=False, + reconnect_timeout=1.0, + ) + + with caplog.at_level("WARNING"): + try: + p.open() + with pytest.raises(pool.PoolTimeout): + p.wait(timeout=2.0) + finally: + p.close() + + # Ensure log contains "reconnection attempt" warning` + reconnection_warnings = [ + rec for rec in caplog.records if "reconnection attempt" in rec.message.lower() + ] + + assert reconnection_warnings, "Expected reconnection attempt logs" + # Make sure that we saw not too many (backoff works) + assert len(reconnection_warnings) < 5, "Too many attempts (likely busyloop)" diff --git a/tests/pool/test_pool_async.py b/tests/pool/test_pool_async.py index dfda392b6..e1c52ec27 100644 --- a/tests/pool/test_pool_async.py +++ b/tests/pool/test_pool_async.py @@ -1170,3 +1170,96 @@ async def test_override_close_no_loop_subclass(dsn): await conn.close() await asleep(0.1) assert len(p._pool) == 1 + + +async def test_get_config_rotates_connections(dsn): + + config_rotation_counter = 0 + + async def rotating_config(): + nonlocal config_rotation_counter + config_rotation_counter += 1 + return dsn + + app_names = ["app-1", "app-2"] + kwargs_counter = 0 + + async def rotating_kwargs(): + # Return a different application_name for each new connection + nonlocal kwargs_counter + kwargs_counter += 1 + return {"application_name": app_names[kwargs_counter % len(app_names)]} + + p = pool.AsyncConnectionPool( + conninfo=rotating_config, + kwargs=rotating_kwargs, + min_size=2, + max_lifetime=0.2, + open=False, + ) + + try: + await p.open() + await p.wait() + + # Make sure we created two connections (rotating_config called twice) + assert config_rotation_counter == 2 + + # Acquire both connections and check application_name + async with p.connection() as conn1, p.connection() as conn2: + row1 = await conn1.execute("SHOW application_name") + row2 = await conn2.execute("SHOW application_name") + + name1 = await row1.fetchone() + assert ( + name1 is not None + ), "first call to SHOW application_name returned no rows" + assert name1[0] in app_names + + name2 = await row2.fetchone() + assert ( + name2 is not None + ), "second call to SHOW application_name returned no rows" + assert name2[0] in app_names + + # Make sure that names are different. + assert name1 != name2 + + finally: + await p.close() + + +@pytest.mark.slow +async def test_get_config_raise_exception(dsn, caplog): + + async def failing_conninfo(): + raise RuntimeError("cannot build conninfo") + + async def failing_kwargs(): + raise RuntimeError("cannot build kwargs") + + p = pool.AsyncConnectionPool( + conninfo=failing_conninfo, + kwargs=failing_kwargs, + min_size=1, + max_lifetime=0.1, + open=False, + reconnect_timeout=1.0, + ) + + with caplog.at_level("WARNING"): + try: + await p.open() + with pytest.raises(pool.PoolTimeout): + await p.wait(timeout=2.0) + finally: + await p.close() + + # Ensure log contains "reconnection attempt" warning` + reconnection_warnings = [ + rec for rec in caplog.records if "reconnection attempt" in rec.message.lower() + ] + + assert reconnection_warnings, "Expected reconnection attempt logs" + # Make sure that we saw not too many (backoff works) + assert len(reconnection_warnings) < 5, "Too many attempts (likely busyloop)" diff --git a/tools/async_to_sync.py b/tools/async_to_sync.py index 7aee1a81e..b23b5d02a 100755 --- a/tools/async_to_sync.py +++ b/tools/async_to_sync.py @@ -309,6 +309,8 @@ class RenameAsyncToSync(ast.NodeTransformer): # type: ignore "AsyncServerCursor": "ServerCursor", "AsyncTransaction": "Transaction", "AsyncWriter": "Writer", + "AsyncKwargsParam": "KwargsParam", + "AsyncConninfoParam": "ConninfoParam", "StopAsyncIteration": "StopIteration", "__aenter__": "__enter__", "__aexit__": "__exit__",