- Add `!close_returns` for :ref:`integration with SQLAlchemy <pool-sqlalchemy>`
(:ticket:`#1046`).
+- Allow `!conninfo` and `!kwargs` to be callable to allow connection
+ parameters# update (:ticket:`#851`).
Current release
import asyncio
import logging
import threading
-from typing import Any, ParamSpec, TypeAlias, overload
+from typing import Any, Awaitable, ParamSpec, TypeAlias, overload
from inspect import isawaitable
from collections.abc import Callable, Coroutine
@overload
async def ensure_async(
- f: Callable[P, Coroutine[Any, Any, T]], *args: P.args, **kwargs: P.kwargs
+ f: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs
) -> T: ...
from __future__ import annotations
-from typing import TYPE_CHECKING, TypeAlias, Union
+from typing import TYPE_CHECKING, Any, TypeAlias, Union
from collections.abc import Awaitable, Callable
from ._compat import TypeVar
if TYPE_CHECKING:
- from typing import Any
-
from psycopg import AsyncConnection, Connection # noqa: F401
from psycopg.rows import TupleRow # noqa: F401
Callable[["AsyncConnectionPool[Any]"], None],
Callable[["AsyncConnectionPool[Any]"], Awaitable[None]],
]
+
+# Types of the connection parameters
+ConninfoParam: TypeAlias = Union[str, Callable[[], str]]
+AsyncConninfoParam: TypeAlias = Union[
+ str,
+ Callable[[], str],
+ Callable[[], Awaitable[str]],
+]
+KwargsParam: TypeAlias = Union[dict[str, Any], Callable[[], dict[str, Any]]]
+AsyncKwargsParam: TypeAlias = Union[
+ dict[str, Any],
+ Callable[[], dict[str, Any]],
+ Callable[[], Awaitable[dict[str, Any]]],
+]
def __init__(
self,
- conninfo: str = "",
*,
- kwargs: dict[str, Any] | None,
min_size: int,
max_size: int | None,
name: str | None,
if num_workers < 1:
raise ValueError("num_workers must be at least 1")
- self.conninfo = conninfo
- self.kwargs: dict[str, Any] = kwargs or {}
self.name = name
self.close_returns = close_returns
self._min_size = min_size
from __future__ import annotations
import logging
-from typing import Any, cast
+from typing import cast
from psycopg import Connection
from psycopg.pq import TransactionStatus
-from .abc import CT, ConnectFailedCB, ConnectionCB
+from .abc import CT, ConnectFailedCB, ConnectionCB, ConninfoParam, KwargsParam
from .pool import AddConnection, ConnectionPool
from .errors import PoolTimeout, TooManyRequests
from ._compat import ConnectionTimeout
def __init__(
self,
- conninfo: str = "",
+ conninfo: ConninfoParam = "",
*,
connection_class: type[CT] = cast(type[CT], Connection),
- kwargs: dict[str, Any] | None = None,
+ kwargs: KwargsParam | None = None,
min_size: int = 0,
max_size: int | None = None,
open: bool | None = None,
from __future__ import annotations
import logging
-from typing import Any, cast
+from typing import cast
from psycopg import AsyncConnection
from psycopg.pq import TransactionStatus
-from .abc import ACT, AsyncConnectFailedCB, AsyncConnectionCB
+from .abc import ACT, AsyncConnectFailedCB, AsyncConnectionCB, AsyncConninfoParam
+from .abc import AsyncKwargsParam
from .errors import PoolTimeout, TooManyRequests
from ._compat import ConnectionTimeout
from ._acompat import AEvent
class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool[ACT]):
+
def __init__(
self,
- conninfo: str = "",
+ conninfo: AsyncConninfoParam = "",
*,
connection_class: type[ACT] = cast(type[ACT], AsyncConnection),
- kwargs: dict[str, Any] | None = None,
+ kwargs: AsyncKwargsParam | None = None,
min_size: int = 0, # Note: min_size default value changed to 0.
max_size: int | None = None,
open: bool | None = None,
from psycopg import errors as e
from psycopg.pq import TransactionStatus
-from .abc import CT, ConnectFailedCB, ConnectionCB
+from .abc import CT, ConnectFailedCB, ConnectionCB, ConninfoParam, KwargsParam
from .base import AttemptWithBackoff, BasePool
from .sched import Scheduler
from .errors import PoolClosed, PoolTimeout, TooManyRequests
def __init__(
self,
- conninfo: str = "",
+ conninfo: ConninfoParam = "",
*,
connection_class: type[CT] = cast(type[CT], Connection),
- kwargs: dict[str, Any] | None = None,
+ kwargs: KwargsParam | None = None,
min_size: int = 4,
max_size: int | None = None,
open: bool | None = None,
raise TypeError(
"Using 'close_returns=True' and a non-standard 'connection_class' requires psycopg 3.3 or newer. Please check the docs at https://www.psycopg.org/psycopg3/docs/advanced/pool.html#pool-sqlalchemy for a workaround."
)
-
+ self.conninfo = conninfo
+ self.kwargs = kwargs
self.connection_class = connection_class
self._check = check
self._configure = configure
self._workers: list[Worker] = []
super().__init__(
- conninfo,
- kwargs=kwargs,
min_size=min_size,
max_size=max_size,
name=name,
def _connect(self, timeout: float | None = None) -> CT:
"""Return a new connection configured for the pool."""
self._stats[self._CONNECTIONS_NUM] += 1
- kwargs = self.kwargs
+ conninfo = self._resolve_conninfo()
+ kwargs = self._resolve_kwargs()
if timeout:
kwargs = kwargs.copy()
kwargs["connect_timeout"] = max(round(timeout), 1)
t0 = monotonic()
try:
- conn = self.connection_class.connect(self.conninfo, **kwargs)
+ conn = self.connection_class.connect(conninfo, **kwargs)
except Exception:
self._stats[self._CONNECTIONS_ERRORS] += 1
raise
self._set_connection_expiry_date(conn)
return conn
+ def _resolve_conninfo(self) -> str:
+ """Resolve conninfo (static string, sync callable, or async callable)."""
+ if callable(self.conninfo):
+ return self.conninfo()
+
+ return self.conninfo or ""
+
+ def _resolve_kwargs(self) -> dict[str, Any]:
+ """Resolve kwargs (static dict, sync callable, or async callable)."""
+ if not self.kwargs:
+ return {}
+
+ if callable(self.kwargs):
+ return self.kwargs()
+
+ return self.kwargs
+
def _add_connection(
self, attempt: AttemptWithBackoff | None, growing: bool = False
) -> None:
from psycopg import errors as e
from psycopg.pq import TransactionStatus
-from .abc import ACT, AsyncConnectFailedCB, AsyncConnectionCB
+from .abc import ACT, AsyncConnectFailedCB, AsyncConnectionCB, AsyncConninfoParam
+from .abc import AsyncKwargsParam
from .base import AttemptWithBackoff, BasePool
from .errors import PoolClosed, PoolTimeout, TooManyRequests
from ._compat import PSYCOPG_VERSION, AsyncPoolConnection, Self
def __init__(
self,
- conninfo: str = "",
+ conninfo: AsyncConninfoParam = "",
*,
connection_class: type[ACT] = cast(type[ACT], AsyncConnection),
- kwargs: dict[str, Any] | None = None,
+ kwargs: AsyncKwargsParam | None = None,
min_size: int = 4,
max_size: int | None = None,
open: bool | None = None,
" https://www.psycopg.org/psycopg3/docs/advanced/pool.html"
"#pool-sqlalchemy for a workaround."
)
-
+ self.conninfo = conninfo
+ self.kwargs = kwargs
self.connection_class = connection_class
self._check = check
self._configure = configure
self._workers: list[AWorker] = []
super().__init__(
- conninfo,
- kwargs=kwargs,
min_size=min_size,
max_size=max_size,
name=name,
async def _connect(self, timeout: float | None = None) -> ACT:
"""Return a new connection configured for the pool."""
self._stats[self._CONNECTIONS_NUM] += 1
- kwargs = self.kwargs
+ conninfo = await self._resolve_conninfo()
+ kwargs = await self._resolve_kwargs()
if timeout:
kwargs = kwargs.copy()
kwargs["connect_timeout"] = max(round(timeout), 1)
t0 = monotonic()
try:
- conn = await self.connection_class.connect(self.conninfo, **kwargs)
+ conn = await self.connection_class.connect(conninfo, **kwargs)
except Exception:
self._stats[self._CONNECTIONS_ERRORS] += 1
raise
self._set_connection_expiry_date(conn)
return conn
+ async def _resolve_conninfo(self) -> str:
+ """Resolve conninfo (static string, sync callable, or async callable)."""
+ if callable(self.conninfo):
+ return await ensure_async(self.conninfo)
+
+ return self.conninfo or ""
+
+ async def _resolve_kwargs(self) -> dict[str, Any]:
+ """Resolve kwargs (static dict, sync callable, or async callable)."""
+ if not self.kwargs:
+ return {}
+
+ if callable(self.kwargs):
+ return await ensure_async(self.kwargs)
+
+ return self.kwargs
+
async def _add_connection(
self, attempt: AttemptWithBackoff | None, growing: bool = False
) -> None:
conn.close()
sleep(0.1)
assert len(p._pool) == 1
+
+
+def test_get_config_rotates_connections(dsn):
+ config_rotation_counter = 0
+
+ def rotating_config():
+ nonlocal config_rotation_counter
+ config_rotation_counter += 1
+ return dsn
+
+ app_names = ["app-1", "app-2"]
+ kwargs_counter = 0
+
+ def rotating_kwargs():
+ # Return a different application_name for each new connection
+ nonlocal kwargs_counter
+ kwargs_counter += 1
+ return {"application_name": app_names[kwargs_counter % len(app_names)]}
+
+ p = pool.ConnectionPool(
+ conninfo=rotating_config,
+ kwargs=rotating_kwargs,
+ min_size=2,
+ max_lifetime=0.2,
+ open=False,
+ )
+
+ try:
+ p.open()
+ p.wait()
+
+ # Make sure we created two connections (rotating_config called twice)
+ assert config_rotation_counter == 2
+
+ # Acquire both connections and check application_name
+ with p.connection() as conn1, p.connection() as conn2:
+ row1 = conn1.execute("SHOW application_name")
+ row2 = conn2.execute("SHOW application_name")
+
+ name1 = row1.fetchone()
+ assert (
+ name1 is not None
+ ), "first call to SHOW application_name returned no rows"
+ assert name1[0] in app_names
+
+ name2 = row2.fetchone()
+ assert (
+ name2 is not None
+ ), "second call to SHOW application_name returned no rows"
+ assert name2[0] in app_names
+
+ # Make sure that names are different.
+ assert name1 != name2
+ finally:
+ p.close()
+
+
+@pytest.mark.slow
+def test_get_config_raise_exception(dsn, caplog):
+
+ def failing_conninfo():
+ raise RuntimeError("cannot build conninfo")
+
+ def failing_kwargs():
+ raise RuntimeError("cannot build kwargs")
+
+ p = pool.ConnectionPool(
+ conninfo=failing_conninfo,
+ kwargs=failing_kwargs,
+ min_size=1,
+ max_lifetime=0.1,
+ open=False,
+ reconnect_timeout=1.0,
+ )
+
+ with caplog.at_level("WARNING"):
+ try:
+ p.open()
+ with pytest.raises(pool.PoolTimeout):
+ p.wait(timeout=2.0)
+ finally:
+ p.close()
+
+ # Ensure log contains "reconnection attempt" warning`
+ reconnection_warnings = [
+ rec for rec in caplog.records if "reconnection attempt" in rec.message.lower()
+ ]
+
+ assert reconnection_warnings, "Expected reconnection attempt logs"
+ # Make sure that we saw not too many (backoff works)
+ assert len(reconnection_warnings) < 5, "Too many attempts (likely busyloop)"
await conn.close()
await asleep(0.1)
assert len(p._pool) == 1
+
+
+async def test_get_config_rotates_connections(dsn):
+
+ config_rotation_counter = 0
+
+ async def rotating_config():
+ nonlocal config_rotation_counter
+ config_rotation_counter += 1
+ return dsn
+
+ app_names = ["app-1", "app-2"]
+ kwargs_counter = 0
+
+ async def rotating_kwargs():
+ # Return a different application_name for each new connection
+ nonlocal kwargs_counter
+ kwargs_counter += 1
+ return {"application_name": app_names[kwargs_counter % len(app_names)]}
+
+ p = pool.AsyncConnectionPool(
+ conninfo=rotating_config,
+ kwargs=rotating_kwargs,
+ min_size=2,
+ max_lifetime=0.2,
+ open=False,
+ )
+
+ try:
+ await p.open()
+ await p.wait()
+
+ # Make sure we created two connections (rotating_config called twice)
+ assert config_rotation_counter == 2
+
+ # Acquire both connections and check application_name
+ async with p.connection() as conn1, p.connection() as conn2:
+ row1 = await conn1.execute("SHOW application_name")
+ row2 = await conn2.execute("SHOW application_name")
+
+ name1 = await row1.fetchone()
+ assert (
+ name1 is not None
+ ), "first call to SHOW application_name returned no rows"
+ assert name1[0] in app_names
+
+ name2 = await row2.fetchone()
+ assert (
+ name2 is not None
+ ), "second call to SHOW application_name returned no rows"
+ assert name2[0] in app_names
+
+ # Make sure that names are different.
+ assert name1 != name2
+
+ finally:
+ await p.close()
+
+
+@pytest.mark.slow
+async def test_get_config_raise_exception(dsn, caplog):
+
+ async def failing_conninfo():
+ raise RuntimeError("cannot build conninfo")
+
+ async def failing_kwargs():
+ raise RuntimeError("cannot build kwargs")
+
+ p = pool.AsyncConnectionPool(
+ conninfo=failing_conninfo,
+ kwargs=failing_kwargs,
+ min_size=1,
+ max_lifetime=0.1,
+ open=False,
+ reconnect_timeout=1.0,
+ )
+
+ with caplog.at_level("WARNING"):
+ try:
+ await p.open()
+ with pytest.raises(pool.PoolTimeout):
+ await p.wait(timeout=2.0)
+ finally:
+ await p.close()
+
+ # Ensure log contains "reconnection attempt" warning`
+ reconnection_warnings = [
+ rec for rec in caplog.records if "reconnection attempt" in rec.message.lower()
+ ]
+
+ assert reconnection_warnings, "Expected reconnection attempt logs"
+ # Make sure that we saw not too many (backoff works)
+ assert len(reconnection_warnings) < 5, "Too many attempts (likely busyloop)"
"AsyncServerCursor": "ServerCursor",
"AsyncTransaction": "Transaction",
"AsyncWriter": "Writer",
+ "AsyncKwargsParam": "KwargsParam",
+ "AsyncConninfoParam": "ConninfoParam",
"StopAsyncIteration": "StopIteration",
"__aenter__": "__enter__",
"__aexit__": "__exit__",