From: Dan Shick Date: Tue, 14 Mar 2023 17:32:52 +0000 (-0400) Subject: fix: define reconnect_failed callback in each pool implementation, allow async functi... X-Git-Tag: pool-3.2.0~118^2~3 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=c15565f0b3236d676c7fc33dd5f241348c3c7067;p=thirdparty%2Fpsycopg.git fix: define reconnect_failed callback in each pool implementation, allow async functions in async pool reuse existing test with pytest parameterization --- diff --git a/psycopg_pool/psycopg_pool/base.py b/psycopg_pool/psycopg_pool/base.py index fe29408c3..6d66f5206 100644 --- a/psycopg_pool/psycopg_pool/base.py +++ b/psycopg_pool/psycopg_pool/base.py @@ -6,7 +6,7 @@ psycopg connection pool base class and functionalities. from time import monotonic from random import random -from typing import Any, Callable, Dict, Generic, Optional, Tuple +from typing import Any, Dict, Generic, Optional, Tuple from psycopg import errors as e from psycopg.abc import ConnectionType @@ -50,7 +50,6 @@ class BasePool(Generic[ConnectionType]): max_lifetime: float, max_idle: float, reconnect_timeout: float, - reconnect_failed: Optional[Callable[["BasePool[ConnectionType]"], None]], num_workers: int, ): min_size, max_size = self._check_size(min_size, max_size) @@ -64,8 +63,6 @@ class BasePool(Generic[ConnectionType]): self.conninfo = conninfo self.kwargs: Dict[str, Any] = kwargs or {} - self._reconnect_failed: Callable[["BasePool[ConnectionType]"], None] - self._reconnect_failed = reconnect_failed or (lambda pool: None) self.name = name self._min_size = min_size self._max_size = max_size diff --git a/psycopg_pool/psycopg_pool/null_pool.py b/psycopg_pool/psycopg_pool/null_pool.py index 20f9811b7..76d241458 100644 --- a/psycopg_pool/psycopg_pool/null_pool.py +++ b/psycopg_pool/psycopg_pool/null_pool.py @@ -11,7 +11,6 @@ from typing import Any, Callable, Dict, Optional, Tuple, Type from psycopg import Connection from psycopg.pq import TransactionStatus -from .base import BasePool from .pool import ConnectionPool, AddConnection from .errors import PoolTimeout, TooManyRequests from ._compat import ConnectionTimeout @@ -60,7 +59,7 @@ class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool): max_lifetime: float = 60 * 60.0, max_idle: float = 10 * 60.0, reconnect_timeout: float = 5 * 60.0, - reconnect_failed: Optional[Callable[[BasePool[Connection[Any]]], None]] = None, + reconnect_failed: Optional[Callable[["NullConnectionPool"], None]] = None, num_workers: int = 3, ): super().__init__( @@ -78,7 +77,6 @@ class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool): max_lifetime=max_lifetime, max_idle=max_idle, reconnect_timeout=reconnect_timeout, - reconnect_failed=reconnect_failed, num_workers=num_workers, ) diff --git a/psycopg_pool/psycopg_pool/null_pool_async.py b/psycopg_pool/psycopg_pool/null_pool_async.py index 9f566c663..cb1db2d7a 100644 --- a/psycopg_pool/psycopg_pool/null_pool_async.py +++ b/psycopg_pool/psycopg_pool/null_pool_async.py @@ -6,12 +6,11 @@ psycopg asynchronous null connection pool import asyncio import logging -from typing import Any, Awaitable, Callable, Dict, Optional, Type +from typing import Any, Awaitable, Callable, Dict, Optional, Type, Union from psycopg import AsyncConnection from psycopg.pq import TransactionStatus -from .base import BasePool from .errors import PoolTimeout, TooManyRequests from ._compat import ConnectionTimeout from .null_pool import _BaseNullConnectionPool @@ -40,7 +39,10 @@ class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool): max_idle: float = 10 * 60.0, reconnect_timeout: float = 5 * 60.0, reconnect_failed: Optional[ - Callable[[BasePool[AsyncConnection[None]]], None] + Union[ + Callable[["AsyncNullConnectionPool"], None], + Callable[["AsyncNullConnectionPool"], Awaitable[None]], + ] ] = None, num_workers: int = 3, ): @@ -59,7 +61,6 @@ class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool): max_lifetime=max_lifetime, max_idle=max_idle, reconnect_timeout=reconnect_timeout, - reconnect_failed=reconnect_failed, num_workers=num_workers, ) diff --git a/psycopg_pool/psycopg_pool/pool.py b/psycopg_pool/psycopg_pool/pool.py index dd50d73a6..327a2b6d7 100644 --- a/psycopg_pool/psycopg_pool/pool.py +++ b/psycopg_pool/psycopg_pool/pool.py @@ -45,13 +45,16 @@ class ConnectionPool(BasePool[Connection[Any]]): max_lifetime: float = 60 * 60.0, max_idle: float = 10 * 60.0, reconnect_timeout: float = 5 * 60.0, - reconnect_failed: Optional[Callable[[BasePool[Connection[Any]]], None]] = None, + reconnect_failed: Optional[Callable[["ConnectionPool"], None]] = None, num_workers: int = 3, ): self.connection_class = connection_class self._configure = configure self._reset = reset + self._reconnect_failed: Callable[["ConnectionPool"], None] + self._reconnect_failed = reconnect_failed or (lambda pool: None) + self._lock = threading.RLock() self._waiting = Deque["WaitingClient"]() @@ -75,7 +78,6 @@ class ConnectionPool(BasePool[Connection[Any]]): max_lifetime=max_lifetime, max_idle=max_idle, reconnect_timeout=reconnect_timeout, - reconnect_failed=reconnect_failed, num_workers=num_workers, ) diff --git a/psycopg_pool/psycopg_pool/pool_async.py b/psycopg_pool/psycopg_pool/pool_async.py index 1cffcce68..3150da02a 100644 --- a/psycopg_pool/psycopg_pool/pool_async.py +++ b/psycopg_pool/psycopg_pool/pool_async.py @@ -10,7 +10,7 @@ from abc import ABC, abstractmethod from time import monotonic from types import TracebackType from typing import Any, AsyncIterator, Awaitable, Callable -from typing import Dict, List, Optional, Sequence, Type +from typing import Dict, List, Optional, Sequence, Type, Union from weakref import ref from contextlib import asynccontextmanager @@ -45,7 +45,10 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): max_idle: float = 10 * 60.0, reconnect_timeout: float = 5 * 60.0, reconnect_failed: Optional[ - Callable[[BasePool[AsyncConnection[Any]]], None] + Union[ + Callable[["AsyncConnectionPool"], None], + Callable[["AsyncConnectionPool"], Awaitable[None]], + ] ] = None, num_workers: int = 3, ): @@ -53,6 +56,12 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): self._configure = configure self._reset = reset + self._reconnect_failed: Union[ + Callable[["AsyncConnectionPool"], None], + Callable[["AsyncConnectionPool"], Awaitable[None]], + ] + self._reconnect_failed = reconnect_failed or (lambda pool: None) + # asyncio objects, created on open to attach them to the right loop. self._lock: asyncio.Lock self._sched: AsyncScheduler @@ -78,7 +87,6 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): max_lifetime=max_lifetime, max_idle=max_idle, reconnect_timeout=reconnect_timeout, - reconnect_failed=reconnect_failed, num_workers=num_workers, ) @@ -381,11 +389,14 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): else: await self._add_to_pool(conn) - def reconnect_failed(self) -> None: + async def reconnect_failed(self) -> None: """ Called when reconnection failed for longer than `reconnect_timeout`. """ - self._reconnect_failed(self) + if asyncio.iscoroutinefunction(self._reconnect_failed): + await self._reconnect_failed(self) + else: + self._reconnect_failed(self) def run_task(self, task: "MaintenanceTask") -> None: """Run a maintenance task in a worker.""" @@ -484,7 +495,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): # If we have given up with a growing attempt, allow a new one. if growing and self._growing: self._growing = False - self.reconnect_failed() + await self.reconnect_failed() else: attempt.update_delay(now) await self.schedule_task( diff --git a/tests/pool/test_pool_async.py b/tests/pool/test_pool_async.py index 1f16ae2f3..490a05690 100644 --- a/tests/pool/test_pool_async.py +++ b/tests/pool/test_pool_async.py @@ -834,15 +834,25 @@ async def test_reconnect(proxy, caplog, monkeypatch): @pytest.mark.slow @pytest.mark.timing -async def test_reconnect_failure(proxy): +@pytest.mark.parametrize("async_cb", [True, False]) +async def test_reconnect_failure(proxy, async_cb): proxy.start() t1 = None - def failed(pool): - assert pool.name == "this-one" - nonlocal t1 - t1 = time() + if async_cb: + + async def failed(pool): + assert pool.name == "this-one" + nonlocal t1 + t1 = time() + + else: + + def failed(pool): + assert pool.name == "this-one" + nonlocal t1 + t1 = time() async with pool.AsyncConnectionPool( proxy.client_dsn,