]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: define reconnect_failed callback in each pool implementation, allow async functi...
authorDan Shick <dan.shick@nydig.com>
Tue, 14 Mar 2023 17:32:52 +0000 (13:32 -0400)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 15 Mar 2023 02:51:53 +0000 (03:51 +0100)
reuse existing test with pytest parameterization

psycopg_pool/psycopg_pool/base.py
psycopg_pool/psycopg_pool/null_pool.py
psycopg_pool/psycopg_pool/null_pool_async.py
psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py
tests/pool/test_pool_async.py

index fe29408c3034b2c6abe5bf1f818d8b3ddb4161cb..6d66f5206ec07bc30287e416b059f41dac71b031 100644 (file)
@@ -6,7 +6,7 @@ psycopg connection pool base class and functionalities.
 
 from time import monotonic
 from random import random
-from typing import Any, Callable, Dict, Generic, Optional, Tuple
+from typing import Any, Dict, Generic, Optional, Tuple
 
 from psycopg import errors as e
 from psycopg.abc import ConnectionType
@@ -50,7 +50,6 @@ class BasePool(Generic[ConnectionType]):
         max_lifetime: float,
         max_idle: float,
         reconnect_timeout: float,
-        reconnect_failed: Optional[Callable[["BasePool[ConnectionType]"], None]],
         num_workers: int,
     ):
         min_size, max_size = self._check_size(min_size, max_size)
@@ -64,8 +63,6 @@ class BasePool(Generic[ConnectionType]):
 
         self.conninfo = conninfo
         self.kwargs: Dict[str, Any] = kwargs or {}
-        self._reconnect_failed: Callable[["BasePool[ConnectionType]"], None]
-        self._reconnect_failed = reconnect_failed or (lambda pool: None)
         self.name = name
         self._min_size = min_size
         self._max_size = max_size
index 20f9811b77a748b5060fb8368e90c2dd708554e7..76d2414583de012deac9f73797f1bb5fe9434ac0 100644 (file)
@@ -11,7 +11,6 @@ from typing import Any, Callable, Dict, Optional, Tuple, Type
 from psycopg import Connection
 from psycopg.pq import TransactionStatus
 
-from .base import BasePool
 from .pool import ConnectionPool, AddConnection
 from .errors import PoolTimeout, TooManyRequests
 from ._compat import ConnectionTimeout
@@ -60,7 +59,7 @@ class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool):
         max_lifetime: float = 60 * 60.0,
         max_idle: float = 10 * 60.0,
         reconnect_timeout: float = 5 * 60.0,
-        reconnect_failed: Optional[Callable[[BasePool[Connection[Any]]], None]] = None,
+        reconnect_failed: Optional[Callable[["NullConnectionPool"], None]] = None,
         num_workers: int = 3,
     ):
         super().__init__(
@@ -78,7 +77,6 @@ class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool):
             max_lifetime=max_lifetime,
             max_idle=max_idle,
             reconnect_timeout=reconnect_timeout,
-            reconnect_failed=reconnect_failed,
             num_workers=num_workers,
         )
 
index 9f566c66360347dcf06863bb98c83f89b6b0c14b..cb1db2d7a4f30ee480fb2808d4b8176473ac155f 100644 (file)
@@ -6,12 +6,11 @@ psycopg asynchronous null connection pool
 
 import asyncio
 import logging
-from typing import Any, Awaitable, Callable, Dict, Optional, Type
+from typing import Any, Awaitable, Callable, Dict, Optional, Type, Union
 
 from psycopg import AsyncConnection
 from psycopg.pq import TransactionStatus
 
-from .base import BasePool
 from .errors import PoolTimeout, TooManyRequests
 from ._compat import ConnectionTimeout
 from .null_pool import _BaseNullConnectionPool
@@ -40,7 +39,10 @@ class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool):
         max_idle: float = 10 * 60.0,
         reconnect_timeout: float = 5 * 60.0,
         reconnect_failed: Optional[
-            Callable[[BasePool[AsyncConnection[None]]], None]
+            Union[
+                Callable[["AsyncNullConnectionPool"], None],
+                Callable[["AsyncNullConnectionPool"], Awaitable[None]],
+            ]
         ] = None,
         num_workers: int = 3,
     ):
@@ -59,7 +61,6 @@ class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool):
             max_lifetime=max_lifetime,
             max_idle=max_idle,
             reconnect_timeout=reconnect_timeout,
-            reconnect_failed=reconnect_failed,
             num_workers=num_workers,
         )
 
index dd50d73a6de4aea2c8bc0bef7b001a81694a9d5d..327a2b6d7a37399b6cb50e13827336ef67f852eb 100644 (file)
@@ -45,13 +45,16 @@ class ConnectionPool(BasePool[Connection[Any]]):
         max_lifetime: float = 60 * 60.0,
         max_idle: float = 10 * 60.0,
         reconnect_timeout: float = 5 * 60.0,
-        reconnect_failed: Optional[Callable[[BasePool[Connection[Any]]], None]] = None,
+        reconnect_failed: Optional[Callable[["ConnectionPool"], None]] = None,
         num_workers: int = 3,
     ):
         self.connection_class = connection_class
         self._configure = configure
         self._reset = reset
 
+        self._reconnect_failed: Callable[["ConnectionPool"], None]
+        self._reconnect_failed = reconnect_failed or (lambda pool: None)
+
         self._lock = threading.RLock()
         self._waiting = Deque["WaitingClient"]()
 
@@ -75,7 +78,6 @@ class ConnectionPool(BasePool[Connection[Any]]):
             max_lifetime=max_lifetime,
             max_idle=max_idle,
             reconnect_timeout=reconnect_timeout,
-            reconnect_failed=reconnect_failed,
             num_workers=num_workers,
         )
 
index 1cffcce68c61e874c3d7823fe61f76c1fee70807..3150da02ab9d7e2d9863d0c749999735477157e9 100644 (file)
@@ -10,7 +10,7 @@ from abc import ABC, abstractmethod
 from time import monotonic
 from types import TracebackType
 from typing import Any, AsyncIterator, Awaitable, Callable
-from typing import Dict, List, Optional, Sequence, Type
+from typing import Dict, List, Optional, Sequence, Type, Union
 from weakref import ref
 from contextlib import asynccontextmanager
 
@@ -45,7 +45,10 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
         max_idle: float = 10 * 60.0,
         reconnect_timeout: float = 5 * 60.0,
         reconnect_failed: Optional[
-            Callable[[BasePool[AsyncConnection[Any]]], None]
+            Union[
+                Callable[["AsyncConnectionPool"], None],
+                Callable[["AsyncConnectionPool"], Awaitable[None]],
+            ]
         ] = None,
         num_workers: int = 3,
     ):
@@ -53,6 +56,12 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
         self._configure = configure
         self._reset = reset
 
+        self._reconnect_failed: Union[
+            Callable[["AsyncConnectionPool"], None],
+            Callable[["AsyncConnectionPool"], Awaitable[None]],
+        ]
+        self._reconnect_failed = reconnect_failed or (lambda pool: None)
+
         # asyncio objects, created on open to attach them to the right loop.
         self._lock: asyncio.Lock
         self._sched: AsyncScheduler
@@ -78,7 +87,6 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
             max_lifetime=max_lifetime,
             max_idle=max_idle,
             reconnect_timeout=reconnect_timeout,
-            reconnect_failed=reconnect_failed,
             num_workers=num_workers,
         )
 
@@ -381,11 +389,14 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
             else:
                 await self._add_to_pool(conn)
 
-    def reconnect_failed(self) -> None:
+    async def reconnect_failed(self) -> None:
         """
         Called when reconnection failed for longer than `reconnect_timeout`.
         """
-        self._reconnect_failed(self)
+        if asyncio.iscoroutinefunction(self._reconnect_failed):
+            await self._reconnect_failed(self)
+        else:
+            self._reconnect_failed(self)
 
     def run_task(self, task: "MaintenanceTask") -> None:
         """Run a maintenance task in a worker."""
@@ -484,7 +495,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
                     # If we have given up with a growing attempt, allow a new one.
                     if growing and self._growing:
                         self._growing = False
-                self.reconnect_failed()
+                await self.reconnect_failed()
             else:
                 attempt.update_delay(now)
                 await self.schedule_task(
index 1f16ae2f3d73c57c347db3ca4842107c116a9ba0..490a05690b72c18cc17202d8c29ba22c38617473 100644 (file)
@@ -834,15 +834,25 @@ async def test_reconnect(proxy, caplog, monkeypatch):
 
 @pytest.mark.slow
 @pytest.mark.timing
-async def test_reconnect_failure(proxy):
+@pytest.mark.parametrize("async_cb", [True, False])
+async def test_reconnect_failure(proxy, async_cb):
     proxy.start()
 
     t1 = None
 
-    def failed(pool):
-        assert pool.name == "this-one"
-        nonlocal t1
-        t1 = time()
+    if async_cb:
+
+        async def failed(pool):
+            assert pool.name == "this-one"
+            nonlocal t1
+            t1 = time()
+
+    else:
+
+        def failed(pool):
+            assert pool.name == "this-one"
+            nonlocal t1
+            t1 = time()
 
     async with pool.AsyncConnectionPool(
         proxy.client_dsn,