]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add async connection pool
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 27 Feb 2021 01:09:44 +0000 (02:09 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Mar 2021 04:07:25 +0000 (05:07 +0100)
psycopg3/psycopg3/connection.py
psycopg3/psycopg3/pool/__init__.py
psycopg3/psycopg3/pool/async_pool.py [new file with mode: 0644]
psycopg3/psycopg3/pool/base.py
psycopg3/psycopg3/pool/pool.py
psycopg3/psycopg3/pool/tasks.py
tests/pool/test_pool.py
tests/pool/test_pool_async.py [new file with mode: 0644]

index 8ed35f7f9b141155ffb341854ff76dbb0e2c46bd..3997061e354998be6d5c7eba5b173b7d1b62a7bb 100644 (file)
@@ -46,7 +46,7 @@ execute: Callable[["PGconn"], PQGen[List["PGresult"]]]
 
 if TYPE_CHECKING:
     from .pq.proto import PGconn, PGresult
-    from .pool import ConnectionPool
+    from .pool.base import BasePool
 
 if pq.__impl__ == "c":
     from psycopg3_c import _psycopg3
@@ -127,7 +127,7 @@ class BaseConnection(AdaptContext):
 
         # Attribute is only set if the connection is from a pool so we can tell
         # apart a connection in the pool too (when _pool = None)
-        self._pool: Optional["ConnectionPool"]
+        self._pool: Optional["BasePool[Any]"]
 
     def __del__(self) -> None:
         # If fails on connection we might not have this attribute yet
@@ -644,7 +644,9 @@ class AsyncConnection(BaseConnection):
         else:
             await self.commit()
 
-        await self.close()
+        # Close the connection only if it doesn't belong to a pool.
+        if not getattr(self, "_pool", None):
+            await self.close()
 
     async def close(self) -> None:
         self.pgconn.finish()
index 327dcfc86c507d4af072e8933f55aa9c5cfb8b44..91f3496934b228bae3bcf1efcf9f32ac7038bb5c 100644 (file)
@@ -5,6 +5,12 @@ psycopg3 connection pool package
 # Copyright (C) 2021 The Psycopg Team
 
 from .pool import ConnectionPool
+from .async_pool import AsyncConnectionPool
 from .errors import PoolClosed, PoolTimeout
 
-__all__ = ["ConnectionPool", "PoolClosed", "PoolTimeout"]
+__all__ = [
+    "AsyncConnectionPool",
+    "ConnectionPool",
+    "PoolClosed",
+    "PoolTimeout",
+]
diff --git a/psycopg3/psycopg3/pool/async_pool.py b/psycopg3/psycopg3/pool/async_pool.py
new file mode 100644 (file)
index 0000000..528ffa4
--- /dev/null
@@ -0,0 +1,444 @@
+"""
+psycopg3 synchronous connection pool
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import asyncio
+import logging
+from time import monotonic
+from typing import Any, Awaitable, Callable, Deque, AsyncIterator, Optional
+from contextlib import asynccontextmanager
+from collections import deque
+
+from ..pq import TransactionStatus
+from ..connection import AsyncConnection
+
+from . import tasks
+from .base import ConnectionAttempt, BasePool
+from .errors import PoolClosed, PoolTimeout
+
+logger = logging.getLogger(__name__)
+
+
+class AsyncConnectionPool(BasePool[AsyncConnection]):
+    def __init__(
+        self,
+        conninfo: str = "",
+        configure: Optional[
+            Callable[[AsyncConnection], Awaitable[None]]
+        ] = None,
+        **kwargs: Any,
+    ):
+        self._configure = configure
+
+        self._lock = asyncio.Lock()
+        self._waiting: Deque["AsyncClient"] = deque()
+
+        # to notify that the pool is full
+        self._pool_full_event: Optional[asyncio.Event] = None
+
+        self.loop = asyncio.get_event_loop()
+
+        super().__init__(conninfo, **kwargs)
+
+    async def wait_ready(self, timeout: float = 30.0) -> None:
+        """
+        Wait for the pool to be full after init.
+
+        Raise `PoolTimeout` if not ready within *timeout* sec.
+        """
+        async with self._lock:
+            assert not self._pool_full_event
+            if len(self._pool) >= self._nconns:
+                return
+            self._pool_full_event = asyncio.Event()
+
+        try:
+            await asyncio.wait_for(self._pool_full_event.wait(), timeout)
+        except asyncio.TimeoutError:
+            await self.close()  # stop all the threads
+            raise PoolTimeout(
+                f"pool initialization incomplete after {timeout} sec"
+            )
+
+        async with self._lock:
+            self._pool_full_event = None
+
+    @asynccontextmanager
+    async def connection(
+        self, timeout: Optional[float] = None
+    ) -> AsyncIterator[AsyncConnection]:
+        """Context manager to obtain a connection from the pool.
+
+        Returned the connection immediately if available, otherwise wait up to
+        *timeout* or `self.timeout` and throw `PoolTimeout` if a connection is
+        not available in time.
+
+        Upon context exit, return the connection to the pool. Apply the normal
+        connection context behaviour (commit/rollback the transaction in case
+        of success/error). If the connection is no more in working state
+        replace it with a new one.
+        """
+        conn = await self.getconn(timeout=timeout)
+        try:
+            async with conn:
+                yield conn
+        finally:
+            await self.putconn(conn)
+
+    async def getconn(
+        self, timeout: Optional[float] = None
+    ) -> AsyncConnection:
+        """Obtain a contection from the pool.
+
+        You should preferrably use `connection()`. Use this function only if
+        it is not possible to use the connection as context manager.
+
+        After using this function you *must* call a corresponding `putconn()`:
+        failing to do so will deplete the pool. A depleted pool is a sad pool:
+        you don't want a depleted pool.
+        """
+        logger.info("connection requested to %r", self.name)
+        # Critical section: decide here if there's a connection ready
+        # or if the client needs to wait.
+        async with self._lock:
+            if self._closed:
+                raise PoolClosed(f"the pool {self.name!r} is closed")
+
+            pos: Optional[AsyncClient] = None
+            if self._pool:
+                # Take a connection ready out of the pool
+                conn = self._pool.popleft()
+                if len(self._pool) < self._nconns_min:
+                    self._nconns_min = len(self._pool)
+            else:
+                # No connection available: put the client in the waiting queue
+                pos = AsyncClient()
+                self._waiting.append(pos)
+
+                # If there is space for the pool to grow, let's do it
+                if self._nconns < self.maxconn:
+                    self._nconns += 1
+                    logger.info(
+                        "growing pool %r to %s", self.name, self._nconns
+                    )
+                    self.run_task(tasks.AddConnection(self))
+
+        # If we are in the waiting queue, wait to be assigned a connection
+        # (outside the critical section, so only the waiting client is locked)
+        if pos:
+            if timeout is None:
+                timeout = self.timeout
+            conn = await pos.wait(timeout=timeout)
+
+        # Tell the connection it belongs to a pool to avoid closing on __exit__
+        # Note that this property shouldn't be set while the connection is in
+        # the pool, to avoid to create a reference loop.
+        conn._pool = self
+        logger.info("connection given by %r", self.name)
+        return conn
+
+    async def putconn(self, conn: AsyncConnection) -> None:
+        """Return a connection to the loving hands of its pool.
+
+        Use this function only paired with a `getconn()`. You don't need to use
+        it if you use the much more comfortable `connection()` context manager.
+        """
+        # Quick check to discard the wrong connection
+        pool = getattr(conn, "_pool", None)
+        if pool is not self:
+            if pool:
+                msg = f"it comes from pool {pool.name!r}"
+            else:
+                msg = "it doesn't come from any pool"
+            raise ValueError(
+                f"can't return connection to pool {self.name!r}, {msg}: {conn}"
+            )
+
+        logger.info("returning connection to %r", self.name)
+
+        # If the pool is closed just close the connection instead of returning
+        # it to the pool. For extra refcare remove the pool reference from it.
+        if self._closed:
+            conn._pool = None
+            await conn.close()
+            return
+
+        # Use a worker to perform eventual maintenance work in a separate thread
+        self.run_task(tasks.ReturnConnection(self, conn))
+
+    async def close(self, timeout: float = 1.0) -> None:
+        """Close the pool and make it unavailable to new clients.
+
+        All the waiting and future client will fail to acquire a connection
+        with a `PoolClosed` exception. Currently used connections will not be
+        closed until returned to the pool.
+
+        Wait *timeout* for threads to terminate their job, if positive.
+        """
+        if self._closed:
+            return
+
+        async with self._lock:
+            self._closed = True
+            logger.debug("pool %r closed", self.name)
+
+            # Take waiting client and pool connections out of the state
+            waiting = list(self._waiting)
+            self._waiting.clear()
+            pool = list(self._pool)
+            self._pool.clear()
+
+        # Now that the flag _closed is set, getconn will fail immediately,
+        # putconn will just close the returned connection.
+
+        # Stop the scheduler
+        self._sched.enter(0, None)
+
+        # Stop the worker threads
+        for i in range(len(self._workers)):
+            self.run_task(tasks.StopWorker(self))
+
+        # Signal to eventual clients in the queue that business is closed.
+        for pos in waiting:
+            await pos.fail(PoolClosed(f"the pool {self.name!r} is closed"))
+
+        # Close the connections still in the pool
+        for conn in pool:
+            await conn.close()
+
+        # Wait for the worker threads to terminate
+        if timeout > 0:
+            loop = asyncio.get_running_loop()
+            for t in [self._sched_runner] + self._workers:
+                if not t.is_alive():
+                    continue
+                await loop.run_in_executor(None, lambda: t.join(timeout))
+                if t.is_alive():
+                    logger.warning(
+                        "couldn't stop thread %s in pool %r within %s seconds",
+                        t,
+                        self.name,
+                        timeout,
+                    )
+
+    async def configure(self, conn: AsyncConnection) -> None:
+        """Configure a connection after creation."""
+        if self._configure:
+            await self._configure(conn)
+
+    def reconnect_failed(self) -> None:
+        """
+        Called when reconnection failed for longer than `reconnect_timeout`.
+        """
+        self._reconnect_failed(self)
+
+    async def _connect(self) -> AsyncConnection:
+        """Return a new connection configured for the pool."""
+        conn = await AsyncConnection.connect(self.conninfo, **self.kwargs)
+        await self.configure(conn)
+        conn._pool = self
+        return conn
+
+    async def _add_connection(
+        self, attempt: Optional[ConnectionAttempt]
+    ) -> None:
+        """Try to connect and add the connection to the pool.
+
+        If failed, reschedule a new attempt in the future for a few times, then
+        give up, decrease the pool connections number and call
+        `self.reconnect_failed()`.
+
+        """
+        now = monotonic()
+        if not attempt:
+            attempt = ConnectionAttempt(
+                reconnect_timeout=self.reconnect_timeout
+            )
+
+        try:
+            conn = await self._connect()
+        except Exception as e:
+            logger.warning(f"error connecting in {self.name!r}: {e}")
+            if attempt.time_to_give_up(now):
+                logger.warning(
+                    "reconnection attempt in pool %r failed after %s sec",
+                    self.name,
+                    self.reconnect_timeout,
+                )
+                async with self._lock:
+                    self._nconns -= 1
+                self.reconnect_failed()
+            else:
+                attempt.update_delay(now)
+                self.schedule_task(
+                    tasks.AddConnection(self, attempt), attempt.delay
+                )
+        else:
+            await self._add_to_pool(conn)
+
+    async def _return_connection(self, conn: AsyncConnection) -> None:
+        """
+        Return a connection to the pool after usage.
+        """
+        await self._reset_connection(conn)
+        if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
+            # Connection no more in working state: create a new one.
+            logger.warning("discarding closed connection: %s", conn)
+            self.run_task(tasks.AddConnection(self))
+        else:
+            await self._add_to_pool(conn)
+
+    async def _add_to_pool(self, conn: AsyncConnection) -> None:
+        """
+        Add a connection to the pool.
+
+        The connection can be a fresh one or one already used in the pool.
+
+        If a client is already waiting for a connection pass it on, otherwise
+        put it back into the pool
+        """
+        # Remove the pool reference from the connection before returning it
+        # to the state, to avoid to create a reference loop.
+        # Also disable the warning for open connection in conn.__del__
+        conn._pool = None
+
+        pos: Optional[AsyncClient] = None
+
+        # Critical section: if there is a client waiting give it the connection
+        # otherwise put it back into the pool.
+        async with self._lock:
+            while self._waiting:
+                # If there is a client waiting (which is still waiting and
+                # hasn't timed out), give it the connection and notify it.
+                pos = self._waiting.popleft()
+                if await pos.set(conn):
+                    break
+            else:
+                # No client waiting for a connection: put it back into the pool
+                self._pool.append(conn)
+
+                # If we have been asked to wait for pool init, notify the
+                # waiter if the pool is full.
+                if self._pool_full_event and len(self._pool) >= self._nconns:
+                    self._pool_full_event.set()
+
+    async def _reset_connection(self, conn: AsyncConnection) -> None:
+        """
+        Bring a connection to IDLE state or close it.
+        """
+        status = conn.pgconn.transaction_status
+        if status == TransactionStatus.IDLE:
+            return
+
+        if status in (TransactionStatus.INTRANS, TransactionStatus.INERROR):
+            # Connection returned with an active transaction
+            logger.warning("rolling back returned connection: %s", conn)
+            try:
+                await conn.rollback()
+            except Exception as e:
+                logger.warning(
+                    "rollback failed: %s: %s. Discarding connection %s",
+                    e.__class__.__name__,
+                    e,
+                    conn,
+                )
+                await conn.close()
+
+        elif status == TransactionStatus.ACTIVE:
+            # Connection returned during an operation. Bad... just close it.
+            logger.warning("closing returned connection: %s", conn)
+            await conn.close()
+
+    async def _shrink_pool(self) -> None:
+        to_close: Optional[AsyncConnection] = None
+
+        async with self._lock:
+            # Reset the min number of connections used
+            nconns_min = self._nconns_min
+            self._nconns_min = len(self._pool)
+
+            # If the pool can shrink and connections were unused, drop one
+            if self._nconns > self.minconn and nconns_min > 0:
+                to_close = self._pool.popleft()
+                self._nconns -= 1
+
+        if to_close:
+            logger.info(
+                "shrinking pool %r to %s because %s unused connections"
+                " in the last %s sec",
+                self.name,
+                self._nconns,
+                nconns_min,
+                self.max_idle,
+            )
+            await to_close.close()
+
+
+class AsyncClient:
+    """A position in a queue for a client waiting for a connection."""
+
+    __slots__ = ("conn", "error", "_cond")
+
+    def __init__(self) -> None:
+        self.conn: Optional[AsyncConnection] = None
+        self.error: Optional[Exception] = None
+
+        # The AsyncClient behaves in a way similar to an Event, but we need
+        # to notify reliably the flagger that the waiter has "accepted" the
+        # message and it hasn't timed out yet, otherwise the pool may give a
+        # connection to a client that has already timed out getconn(), which
+        # will be lost.
+        self._cond = asyncio.Condition()
+
+    async def wait(self, timeout: float) -> AsyncConnection:
+        """Wait for a connection to be set and return it.
+
+        Raise an exception if the wait times out or if fail() is called.
+        """
+        async with self._cond:
+            if not (self.conn or self.error):
+                try:
+                    await asyncio.wait_for(self._cond.wait(), timeout)
+                except asyncio.TimeoutError:
+                    self.error = PoolTimeout(
+                        f"couldn't get a connection after {timeout} sec"
+                    )
+
+        if self.conn:
+            return self.conn
+        else:
+            assert self.error
+            raise self.error
+
+    async def set(self, conn: AsyncConnection) -> bool:
+        """Signal the client waiting that a connection is ready.
+
+        Return True if the client has "accepted" the connection, False
+        otherwise (typically because wait() has timed out).
+        """
+        async with self._cond:
+            if self.conn or self.error:
+                return False
+
+            self.conn = conn
+            self._cond.notify_all()
+            return True
+
+    async def fail(self, error: Exception) -> bool:
+        """Signal the client that, alas, they won't have a connection today.
+
+        Return True if the client has "accepted" the error, False otherwise
+        (typically because wait() has timed out).
+        """
+        async with self._cond:
+            if self.conn or self.error:
+                return False
+
+            self.error = error
+            self._cond.notify_all()
+            return True
+
+
+tasks.AsyncConnectionPool = AsyncConnectionPool  # type: ignore
index 4d279cb8726a8ea976d0966deebb55089acef881..0f74ab5975039c859cc461a38e5e0b1cdeccc21f 100644 (file)
@@ -31,7 +31,6 @@ class BasePool(Generic[ConnectionType]):
         conninfo: str = "",
         *,
         kwargs: Optional[Dict[str, Any]] = None,
-        configure: Optional[Callable[[ConnectionType], None]] = None,
         minconn: int = 4,
         maxconn: Optional[int] = None,
         name: Optional[str] = None,
@@ -52,15 +51,13 @@ class BasePool(Generic[ConnectionType]):
             )
         if not name:
             num = BasePool._num_pool = BasePool._num_pool + 1
-            name = f"pool-{num - 1}"
+            name = f"pool-{num}"
 
         if num_workers < 1:
             raise ValueError("num_workers must be at least 1")
 
         self.conninfo = conninfo
         self.kwargs: Dict[str, Any] = kwargs or {}
-        self._configure: Callable[[ConnectionType], None]
-        self._configure = configure or (lambda conn: None)
         self._reconnect_failed: Callable[["BasePool[ConnectionType]"], None]
         self._reconnect_failed = reconnect_failed or (lambda pool: None)
         self.name = name
@@ -86,12 +83,15 @@ class BasePool(Generic[ConnectionType]):
         self._workers: List[threading.Thread] = []
         for i in range(num_workers):
             t = threading.Thread(
-                target=self.worker, args=(self._tasks,), daemon=True
+                target=self.worker,
+                args=(self._tasks,),
+                name=f"{self.name}-worker-{i}",
+                daemon=True,
             )
             self._workers.append(t)
 
         self._sched_runner = threading.Thread(
-            target=self._sched.run, daemon=True
+            target=self._sched.run, name=f"{self.name}-scheduler", daemon=True
         )
 
         # _close should be the last property to be set in the state
@@ -145,7 +145,7 @@ class BasePool(Generic[ConnectionType]):
         self._tasks.put_nowait(task)
 
     def schedule_task(
-        self, task: tasks.MaintenanceTask[ConnectionType], delay: float
+        self, task: tasks.MaintenanceTask[Any], delay: float
     ) -> None:
         """Run a maintenance task in a worker thread in the future."""
         self._sched.enter(delay, task.tick)
@@ -162,7 +162,7 @@ class BasePool(Generic[ConnectionType]):
         # Don't make all the workers time out at the same moment
         timeout = WORKER_TIMEOUT * (0.9 + 0.1 * random.random())
         while True:
-            # Use a timeout to make the wait unterruptable
+            # Use a timeout to make the wait interruptable
             try:
                 task = q.get(timeout=timeout)
             except Empty:
@@ -177,6 +177,10 @@ class BasePool(Generic[ConnectionType]):
                 )
 
             if isinstance(task, tasks.StopWorker):
+                logger.debug(
+                    "terminating working thread %s",
+                    threading.current_thread().name,
+                )
                 return
 
 
index dfdd7c8dfd6428ddcd906ec4d66eb16eee63db55..b6d0e9257d288bcbd0d3d8d537e4e700ffd931a4 100644 (file)
@@ -4,10 +4,10 @@ psycopg3 synchronous connection pool
 
 # Copyright (C) 2021 The Psycopg Team
 
-import time
 import logging
 import threading
-from typing import Any, Deque, Iterator, Optional
+from time import monotonic
+from typing import Any, Callable, Deque, Iterator, Optional
 from contextlib import contextmanager
 from collections import deque
 
@@ -22,7 +22,15 @@ logger = logging.getLogger(__name__)
 
 
 class ConnectionPool(BasePool[Connection]):
-    def __init__(self, conninfo: str = "", **kwargs: Any):
+    def __init__(
+        self,
+        conninfo: str = "",
+        configure: Optional[Callable[[Connection], None]] = None,
+        **kwargs: Any,
+    ):
+        self._configure: Callable[[Connection], None]
+        self._configure = configure or (lambda conn: None)
+
         self._lock = threading.RLock()
         self._waiting: Deque["WaitingClient"] = deque()
 
@@ -167,6 +175,7 @@ class ConnectionPool(BasePool[Connection]):
 
         with self._lock:
             self._closed = True
+            logger.debug("pool %r closed", self.name)
 
             # Take waiting client and pool connections out of the state
             waiting = list(self._waiting)
@@ -231,7 +240,7 @@ class ConnectionPool(BasePool[Connection]):
         `self.reconnect_failed()`.
 
         """
-        now = time.monotonic()
+        now = monotonic()
         if not attempt:
             attempt = ConnectionAttempt(
                 reconnect_timeout=self.reconnect_timeout
@@ -295,7 +304,6 @@ class ConnectionPool(BasePool[Connection]):
                 pos = self._waiting.popleft()
                 if pos.set(conn):
                     break
-
             else:
                 # No client waiting for a connection: put it back into the pool
                 self._pool.append(conn)
@@ -371,7 +379,7 @@ class WaitingClient:
         # message and it hasn't timed out yet, otherwise the pool may give a
         # connection to a client that has already timed out getconn(), which
         # will be lost.
-        self._cond = threading.Condition(threading.Lock())
+        self._cond = threading.Condition()
 
     def wait(self, timeout: float) -> Connection:
         """Wait for a connection to be set and return it.
@@ -418,3 +426,6 @@ class WaitingClient:
             self.error = error
             self._cond.notify_all()
             return True
+
+
+tasks.ConnectionPool = ConnectionPool  # type: ignore
index e9fac5ddc2f5f0c089037ae25d3ad9e466452641..6028c3a88b1ad268d7988459bd87147d4981f6df 100644 (file)
@@ -4,16 +4,24 @@ Maintenance tasks for the connection pools.
 
 # Copyright (C) 2021 The Psycopg Team
 
+import asyncio
 import logging
+import threading
 from abc import ABC, abstractmethod
-from typing import Any, Generic, Optional, TYPE_CHECKING
+from typing import Any, cast, Generic, Optional, Type, TYPE_CHECKING
 from weakref import ref
 
 from ..proto import ConnectionType
+from .. import Connection, AsyncConnection
 
 if TYPE_CHECKING:
+    from .pool import ConnectionPool
+    from .async_pool import AsyncConnectionPool
     from .base import BasePool, ConnectionAttempt
-    from ..connection import Connection
+else:
+    # Injected at pool.py and async_pool.py import
+    ConnectionPool: "Type[BasePool[Connection]]"
+    AsyncConnectionPool: "Type[BasePool[AsyncConnection]]"
 
 logger = logging.getLogger(__name__)
 
@@ -21,9 +29,16 @@ logger = logging.getLogger(__name__)
 class MaintenanceTask(ABC, Generic[ConnectionType]):
     """A task to run asynchronously to maintain the pool state."""
 
+    TIMEOUT = 10.0
+
     def __init__(self, pool: "BasePool[Any]"):
+        if isinstance(pool, AsyncConnectionPool):
+            self.event = threading.Event()
+
         self.pool = ref(pool)
-        logger.debug("task created: %s", self)
+        logger.debug(
+            "task created in %s: %s", threading.current_thread().name, self
+        )
 
     def __repr__(self) -> str:
         pool = self.pool()
@@ -41,8 +56,20 @@ class MaintenanceTask(ABC, Generic[ConnectionType]):
             # Pool is no more working. Quietly discard the operation.
             return
 
-        logger.debug("task running: %s", self)
-        self._run(pool)
+        logger.debug(
+            "task running in %s: %s", threading.current_thread().name, self
+        )
+        if isinstance(pool, ConnectionPool):
+            self._run(pool)
+        elif isinstance(pool, AsyncConnectionPool):
+            self.event.clear()
+            asyncio.run_coroutine_threadsafe(self._run_async(pool), pool.loop)
+            if not self.event.wait(self.TIMEOUT):
+                logger.warning(
+                    "event %s didn't terminate after %s sec", self.TIMEOUT
+                )
+        else:
+            logger.error("%s run got %s instead of a pool", self, pool)
 
     def tick(self) -> None:
         """Run the scheduled task
@@ -58,16 +85,23 @@ class MaintenanceTask(ABC, Generic[ConnectionType]):
         pool.run_task(self)
 
     @abstractmethod
-    def _run(self, pool: "BasePool[Any]") -> None:
+    def _run(self, pool: "ConnectionPool") -> None:
         ...
 
+    @abstractmethod
+    async def _run_async(self, pool: "AsyncConnectionPool") -> None:
+        self.event.set()
+
 
 class StopWorker(MaintenanceTask[ConnectionType]):
     """Signal the maintenance thread to terminate."""
 
-    def _run(self, pool: "BasePool[Any]") -> None:
+    def _run(self, pool: "ConnectionPool") -> None:
         pass
 
+    async def _run_async(self, pool: "AsyncConnectionPool") -> None:
+        await super()._run_async(pool)
+
 
 class AddConnection(MaintenanceTask[ConnectionType]):
     def __init__(
@@ -78,29 +112,30 @@ class AddConnection(MaintenanceTask[ConnectionType]):
         super().__init__(pool)
         self.attempt = attempt
 
-    def _run(self, pool: "BasePool[Any]") -> None:
-        from . import ConnectionPool
+    def _run(self, pool: "ConnectionPool") -> None:
+        pool._add_connection(self.attempt)
 
-        if isinstance(pool, ConnectionPool):
-            pool._add_connection(self.attempt)
-        else:
-            assert False
+    async def _run_async(self, pool: "AsyncConnectionPool") -> None:
+        logger.debug("run async 1")
+        await pool._add_connection(self.attempt)
+        logger.debug("run async 2")
+        await super()._run_async(pool)
+        logger.debug("run async 3")
 
 
 class ReturnConnection(MaintenanceTask[ConnectionType]):
     """Clean up and return a connection to the pool."""
 
-    def __init__(self, pool: "BasePool[Any]", conn: "Connection"):
+    def __init__(self, pool: "BasePool[Any]", conn: "ConnectionType"):
         super().__init__(pool)
         self.conn = conn
 
-    def _run(self, pool: "BasePool[Any]") -> None:
-        from . import ConnectionPool
+    def _run(self, pool: "ConnectionPool") -> None:
+        pool._return_connection(cast(Connection, self.conn))
 
-        if isinstance(pool, ConnectionPool):
-            pool._return_connection(self.conn)
-        else:
-            assert False
+    async def _run_async(self, pool: "AsyncConnectionPool") -> None:
+        await pool._return_connection(cast(AsyncConnection, self.conn))
+        await super()._run_async(pool)
 
 
 class ShrinkPool(MaintenanceTask[ConnectionType]):
@@ -110,14 +145,13 @@ class ShrinkPool(MaintenanceTask[ConnectionType]):
     in the pool.
     """
 
-    def _run(self, pool: "BasePool[Any]") -> None:
+    def _run(self, pool: "ConnectionPool") -> None:
         # Reschedule the task now so that in case of any error we don't lose
         # the periodic run.
         pool.schedule_task(self, pool.max_idle)
+        pool._shrink_pool()
 
-        from . import ConnectionPool
-
-        if isinstance(pool, ConnectionPool):
-            pool._shrink_pool()
-        else:
-            assert False
+    async def _run_async(self, pool: "AsyncConnectionPool") -> None:
+        pool.schedule_task(self, pool.max_idle)
+        await pool._shrink_pool()
+        await super()._run_async(pool)
index 87309b8eebc065d5a922ae63322b3f9f596b3f41..bb48f77c14380dfe7244126fdc39c41475d499c1 100644 (file)
@@ -524,10 +524,10 @@ def test_shrink(dsn, monkeypatch):
 def test_reconnect(proxy, caplog, monkeypatch):
     caplog.set_level(logging.WARNING, logger="psycopg3.pool")
 
-    assert pool.pool.ConnectionAttempt.INITIAL_DELAY == 1.0
-    assert pool.pool.ConnectionAttempt.DELAY_JITTER == 0.1
-    monkeypatch.setattr(pool.pool.ConnectionAttempt, "INITIAL_DELAY", 0.1)
-    monkeypatch.setattr(pool.pool.ConnectionAttempt, "DELAY_JITTER", 0.0)
+    assert pool.base.ConnectionAttempt.INITIAL_DELAY == 1.0
+    assert pool.base.ConnectionAttempt.DELAY_JITTER == 0.1
+    monkeypatch.setattr(pool.base.ConnectionAttempt, "INITIAL_DELAY", 0.1)
+    monkeypatch.setattr(pool.base.ConnectionAttempt, "DELAY_JITTER", 0.0)
 
     proxy.start()
     p = pool.ConnectionPool(proxy.client_dsn, minconn=1)
diff --git a/tests/pool/test_pool_async.py b/tests/pool/test_pool_async.py
new file mode 100644 (file)
index 0000000..f407360
--- /dev/null
@@ -0,0 +1,656 @@
+import sys
+import asyncio
+import logging
+import weakref
+from time import time
+from collections import Counter
+
+import pytest
+
+import psycopg3
+from psycopg3 import pool
+from psycopg3.pq import TransactionStatus
+
+create_task = (
+    asyncio.create_task
+    if sys.version_info >= (3, 7)
+    else asyncio.ensure_future
+)
+
+pytestmark = pytest.mark.asyncio
+
+
+async def test_defaults(dsn):
+    p = pool.AsyncConnectionPool(dsn)
+    assert p.minconn == p.maxconn == 4
+    assert p.timeout == 30
+    assert p.max_idle == 600
+    assert p.num_workers == 3
+    await p.close()
+
+
+async def test_minconn_maxconn(dsn):
+    p = pool.AsyncConnectionPool(dsn, minconn=2)
+    assert p.minconn == p.maxconn == 2
+    await p.close()
+
+    p = pool.AsyncConnectionPool(dsn, minconn=2, maxconn=4)
+    assert p.minconn == 2
+    assert p.maxconn == 4
+    await p.close()
+
+    with pytest.raises(ValueError):
+        pool.AsyncConnectionPool(dsn, minconn=4, maxconn=2)
+
+
+async def test_kwargs(dsn):
+    p = pool.AsyncConnectionPool(dsn, kwargs={"autocommit": True}, minconn=1)
+    async with p.connection() as conn:
+        assert conn.autocommit
+
+    await p.close()
+
+
+async def test_its_really_a_pool(dsn):
+    p = pool.AsyncConnectionPool(dsn, minconn=2)
+    async with p.connection() as conn:
+        cur = await conn.execute("select pg_backend_pid()")
+        (pid1,) = await cur.fetchone()
+
+        async with p.connection() as conn2:
+            cur = await conn2.execute("select pg_backend_pid()")
+            (pid2,) = await cur.fetchone()
+
+    async with p.connection() as conn:
+        assert conn.pgconn.backend_pid in (pid1, pid2)
+
+    await p.close()
+
+
+async def test_connection_not_lost(dsn):
+    p = pool.AsyncConnectionPool(dsn, minconn=1)
+    with pytest.raises(ZeroDivisionError):
+        async with p.connection() as conn:
+            pid = conn.pgconn.backend_pid
+            1 / 0
+
+    async with p.connection() as conn2:
+        assert conn2.pgconn.backend_pid == pid
+
+    await p.close()
+
+
+@pytest.mark.slow
+async def test_concurrent_filling(dsn, monkeypatch):
+    delay_connection(monkeypatch, 0.1)
+    t0 = time()
+    times = []
+
+    add_orig = pool.AsyncConnectionPool._add_to_pool
+
+    async def add_time(self, conn):
+        times.append(time() - t0)
+        await add_orig(self, conn)
+
+    monkeypatch.setattr(pool.AsyncConnectionPool, "_add_to_pool", add_time)
+
+    p = pool.AsyncConnectionPool(dsn, minconn=5, num_workers=2)
+    await p.wait_ready(5.0)
+    want_times = [0.1, 0.1, 0.2, 0.2, 0.3]
+    assert len(times) == len(want_times)
+    for got, want in zip(times, want_times):
+        assert got == pytest.approx(want, 0.2), times
+    await p.close()
+
+
+@pytest.mark.slow
+async def test_wait_ready(dsn, monkeypatch):
+    delay_connection(monkeypatch, 0.1)
+    with pytest.raises(pool.PoolTimeout):
+        p = pool.AsyncConnectionPool(dsn, minconn=4, num_workers=1)
+        await p.wait_ready(0.3)
+
+    p = pool.AsyncConnectionPool(dsn, minconn=4, num_workers=1)
+    await p.wait_ready(0.5)
+    await p.close()
+    p = pool.AsyncConnectionPool(dsn, minconn=4, num_workers=2)
+    await p.wait_ready(0.3)
+    await p.wait_ready(0.0001)  # idempotent
+    await p.close()
+
+
+@pytest.mark.slow
+async def test_setup_no_timeout(dsn, proxy):
+    with pytest.raises(pool.PoolTimeout):
+        p = pool.AsyncConnectionPool(
+            proxy.client_dsn, minconn=1, num_workers=1
+        )
+        await p.wait_ready(0.2)
+
+    p = pool.AsyncConnectionPool(proxy.client_dsn, minconn=1, num_workers=1)
+    await asyncio.sleep(0.5)
+    assert not p._pool
+    proxy.start()
+
+    async with p.connection() as conn:
+        await conn.execute("select 1")
+
+    await p.close()
+
+
+@pytest.mark.slow
+async def test_queue(dsn):
+    p = pool.AsyncConnectionPool(dsn, minconn=2)
+    results = []
+
+    async def worker(n):
+        t0 = time()
+        async with p.connection() as conn:
+            cur = await conn.execute(
+                "select pg_backend_pid() from pg_sleep(0.2)"
+            )
+            (pid,) = await cur.fetchone()
+        t1 = time()
+        results.append((n, t1 - t0, pid))
+
+    ts = [create_task(worker(i)) for i in range(6)]
+    await asyncio.gather(*ts)
+    await p.close()
+
+    times = [item[1] for item in results]
+    want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6]
+    for got, want in zip(times, want_times):
+        assert got == pytest.approx(want, 0.2), times
+
+    assert len(set(r[2] for r in results)) == 2, results
+
+
+@pytest.mark.slow
+async def test_queue_timeout(dsn):
+    p = pool.AsyncConnectionPool(dsn, minconn=2, timeout=0.1)
+    results = []
+    errors = []
+
+    async def worker(n):
+        t0 = time()
+        try:
+            async with p.connection() as conn:
+                cur = await conn.execute(
+                    "select pg_backend_pid() from pg_sleep(0.2)"
+                )
+                (pid,) = await cur.fetchone()
+        except pool.PoolTimeout as e:
+            t1 = time()
+            errors.append((n, t1 - t0, e))
+        else:
+            t1 = time()
+            results.append((n, t1 - t0, pid))
+
+    ts = [create_task(worker(i)) for i in range(4)]
+    await asyncio.gather(*ts)
+
+    assert len(results) == 2
+    assert len(errors) == 2
+    for e in errors:
+        assert 0.1 < e[1] < 0.15
+
+    await p.close()
+
+
+@pytest.mark.slow
+async def test_dead_client(dsn):
+    p = pool.AsyncConnectionPool(dsn, minconn=2)
+
+    results = []
+
+    async def worker(i, timeout):
+        try:
+            async with p.connection(timeout=timeout) as conn:
+                await conn.execute("select pg_sleep(0.3)")
+                results.append(i)
+        except pool.PoolTimeout:
+            if timeout > 0.2:
+                raise
+
+    ts = [
+        create_task(worker(i, timeout))
+        for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4])
+    ]
+    await asyncio.gather(*ts)
+
+    await asyncio.sleep(0.2)
+    assert set(results) == set([0, 1, 3, 4])
+    assert len(p._pool) == 2  # no connection was lost
+    await p.close()
+
+
+@pytest.mark.slow
+async def test_queue_timeout_override(dsn):
+    p = pool.AsyncConnectionPool(dsn, minconn=2, timeout=0.1)
+    results = []
+    errors = []
+
+    async def worker(n):
+        t0 = time()
+        timeout = 0.25 if n == 3 else None
+        try:
+            async with p.connection(timeout=timeout) as conn:
+                cur = await conn.execute(
+                    "select pg_backend_pid() from pg_sleep(0.2)"
+                )
+                (pid,) = await cur.fetchone()
+        except pool.PoolTimeout as e:
+            t1 = time()
+            errors.append((n, t1 - t0, e))
+        else:
+            t1 = time()
+            results.append((n, t1 - t0, pid))
+
+    ts = [create_task(worker(i)) for i in range(4)]
+    await asyncio.gather(*ts)
+    await p.close()
+
+    assert len(results) == 3
+    assert len(errors) == 1
+    for e in errors:
+        assert 0.1 < e[1] < 0.15
+
+
+async def test_broken_reconnect(dsn):
+    p = pool.AsyncConnectionPool(dsn, minconn=1)
+    with pytest.raises(psycopg3.OperationalError):
+        async with p.connection() as conn:
+            cur = await conn.execute("select pg_backend_pid()")
+            (pid1,) = await cur.fetchone()
+            await conn.close()
+
+    async with p.connection() as conn2:
+        cur = await conn2.execute("select pg_backend_pid()")
+        (pid2,) = await cur.fetchone()
+
+    await p.close()
+    assert pid1 != pid2
+
+
+async def test_intrans_rollback(dsn, caplog):
+    p = pool.AsyncConnectionPool(dsn, minconn=1)
+    conn = await p.getconn()
+    pid = conn.pgconn.backend_pid
+    await conn.execute("create table test_intrans_rollback ()")
+    assert conn.pgconn.transaction_status == TransactionStatus.INTRANS
+    await p.putconn(conn)
+
+    async with p.connection() as conn2:
+        assert conn2.pgconn.backend_pid == pid
+        assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
+        cur = await conn.execute(
+            "select 1 from pg_class where relname = 'test_intrans_rollback'"
+        )
+        assert not await cur.fetchone()
+
+    await p.close()
+    recs = [
+        r
+        for r in caplog.records
+        if r.name.startswith("psycopg3") and r.levelno >= logging.WARNING
+    ]
+    assert len(recs) == 1
+    assert "INTRANS" in recs[0].message
+
+
+async def test_inerror_rollback(dsn, caplog):
+    p = pool.AsyncConnectionPool(dsn, minconn=1)
+    conn = await p.getconn()
+    pid = conn.pgconn.backend_pid
+    with pytest.raises(psycopg3.ProgrammingError):
+        await conn.execute("wat")
+    assert conn.pgconn.transaction_status == TransactionStatus.INERROR
+    await p.putconn(conn)
+
+    async with p.connection() as conn2:
+        assert conn2.pgconn.backend_pid == pid
+        assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
+
+    recs = [
+        r
+        for r in caplog.records
+        if r.name.startswith("psycopg3") and r.levelno >= logging.WARNING
+    ]
+    assert len(recs) == 1
+    assert "INERROR" in recs[0].message
+
+    await p.close()
+
+
+async def test_active_close(dsn, caplog):
+    p = pool.AsyncConnectionPool(dsn, minconn=1)
+    conn = await p.getconn()
+    pid = conn.pgconn.backend_pid
+    cur = conn.cursor()
+    async with cur.copy(
+        "copy (select * from generate_series(1, 10)) to stdout"
+    ):
+        pass
+    assert conn.pgconn.transaction_status == TransactionStatus.ACTIVE
+    await p.putconn(conn)
+
+    async with p.connection() as conn2:
+        assert conn2.pgconn.backend_pid != pid
+        assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
+
+    await p.close()
+    recs = [
+        r
+        for r in caplog.records
+        if r.name.startswith("psycopg3") and r.levelno >= logging.WARNING
+    ]
+    assert len(recs) == 2
+    assert "ACTIVE" in recs[0].message
+    assert "BAD" in recs[1].message
+
+
+async def test_fail_rollback_close(dsn, caplog, monkeypatch):
+    p = pool.AsyncConnectionPool(dsn, minconn=1)
+    conn = await p.getconn()
+
+    # Make the rollback fail
+    orig_rollback = conn.rollback
+
+    async def bad_rollback():
+        conn.pgconn.finish()
+        await orig_rollback()
+
+    monkeypatch.setattr(conn, "rollback", bad_rollback)
+
+    pid = conn.pgconn.backend_pid
+    with pytest.raises(psycopg3.ProgrammingError):
+        await conn.execute("wat")
+    assert conn.pgconn.transaction_status == TransactionStatus.INERROR
+    await p.putconn(conn)
+
+    async with p.connection() as conn2:
+        assert conn2.pgconn.backend_pid != pid
+        assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
+
+    await p.close()
+
+    recs = [
+        r
+        for r in caplog.records
+        if r.name.startswith("psycopg3") and r.levelno >= logging.WARNING
+    ]
+    assert len(recs) == 3
+    assert "INERROR" in recs[0].message
+    assert "OperationalError" in recs[1].message
+    assert "BAD" in recs[2].message
+
+
+async def test_close_no_threads(dsn):
+    p = pool.AsyncConnectionPool(dsn)
+    assert p._sched_runner.is_alive()
+    for t in p._workers:
+        assert t.is_alive()
+
+    await p.close()
+    assert not p._sched_runner.is_alive()
+    for t in p._workers:
+        assert not t.is_alive()
+
+
+async def test_putconn_no_pool(dsn):
+    p = pool.AsyncConnectionPool(dsn, minconn=1)
+    conn = psycopg3.connect(dsn)
+    with pytest.raises(ValueError):
+        await p.putconn(conn)
+    await p.close()
+
+
+async def test_putconn_wrong_pool(dsn):
+    p1 = pool.AsyncConnectionPool(dsn, minconn=1)
+    p2 = pool.AsyncConnectionPool(dsn, minconn=1)
+    conn = await p1.getconn()
+    with pytest.raises(ValueError):
+        await p2.putconn(conn)
+    await p1.close()
+    await p2.close()
+
+
+async def test_del_no_warning(dsn, recwarn):
+    p = pool.AsyncConnectionPool(dsn, minconn=2)
+    async with p.connection() as conn:
+        await conn.execute("select 1")
+
+    await p.wait_ready()
+    ref = weakref.ref(p)
+    del p
+    assert not ref()
+    assert not recwarn
+
+
+@pytest.mark.slow
+async def test_del_stop_threads(dsn):
+    p = pool.AsyncConnectionPool(dsn)
+    ts = [p._sched_runner] + p._workers
+    del p
+    await asyncio.sleep(0.2)
+    for t in ts:
+        assert not t.is_alive()
+
+
+async def test_closed_getconn(dsn):
+    p = pool.AsyncConnectionPool(dsn, minconn=1)
+    assert not p.closed
+    async with p.connection():
+        pass
+
+    await p.close()
+    assert p.closed
+
+    with pytest.raises(pool.PoolClosed):
+        async with p.connection():
+            pass
+
+
+async def test_closed_putconn(dsn):
+    p = pool.AsyncConnectionPool(dsn, minconn=1)
+
+    async with p.connection() as conn:
+        pass
+    assert not conn.closed
+
+    async with p.connection() as conn:
+        await p.close()
+    assert conn.closed
+
+
+@pytest.mark.slow
+async def test_closed_queue(dsn):
+    p = pool.AsyncConnectionPool(dsn, minconn=1)
+    success = []
+
+    async def w1():
+        async with p.connection() as conn:
+            res = await conn.execute("select 1 from pg_sleep(0.2)")
+            assert await res.fetchone() == (1,)
+        success.append("w1")
+
+    async def w2():
+        with pytest.raises(pool.PoolClosed):
+            async with p.connection():
+                pass
+        success.append("w2")
+
+    t1 = create_task(w1())
+    await asyncio.sleep(0.1)
+    t2 = create_task(w2())
+    await p.close()
+    await asyncio.gather(t1, t2)
+    assert len(success) == 2
+
+
+@pytest.mark.slow
+async def test_grow(dsn, monkeypatch):
+    p = pool.AsyncConnectionPool(dsn, minconn=2, maxconn=4, num_workers=3)
+    await p.wait_ready(5.0)
+    delay_connection(monkeypatch, 0.1)
+    ts = []
+    results = []
+
+    async def worker(n):
+        t0 = time()
+        async with p.connection() as conn:
+            await conn.execute("select 1 from pg_sleep(0.2)")
+        t1 = time()
+        results.append((n, t1 - t0))
+
+    ts = [create_task(worker(i)) for i in range(6)]
+    await asyncio.gather(*ts)
+    await p.close()
+
+    want_times = [0.2, 0.2, 0.3, 0.3, 0.4, 0.4]
+    times = [item[1] for item in results]
+    for got, want in zip(times, want_times):
+        assert got == pytest.approx(want, 0.2), times
+
+
+@pytest.mark.slow
+async def test_shrink(dsn, monkeypatch):
+
+    from psycopg3.pool.tasks import ShrinkPool
+
+    orig_run = ShrinkPool._run_async
+    results = []
+
+    async def run_async_hacked(self, pool):
+        n0 = pool._nconns
+        await orig_run(self, pool)
+        n1 = pool._nconns
+        results.append((n0, n1))
+
+    monkeypatch.setattr(ShrinkPool, "_run_async", run_async_hacked)
+
+    p = pool.AsyncConnectionPool(dsn, minconn=2, maxconn=4, max_idle=0.2)
+    await p.wait_ready(5.0)
+    assert p.max_idle == 0.2
+
+    async def worker(n):
+        async with p.connection() as conn:
+            await conn.execute("select pg_sleep(0.1)")
+
+    ts = [create_task(worker(i)) for i in range(4)]
+    await asyncio.gather(*ts)
+
+    await asyncio.sleep(1)
+    await p.close()
+    assert results == [(4, 4), (4, 3), (3, 2), (2, 2), (2, 2)]
+
+
+@pytest.mark.slow
+async def test_reconnect(proxy, caplog, monkeypatch):
+    assert pool.base.ConnectionAttempt.INITIAL_DELAY == 1.0
+    assert pool.base.ConnectionAttempt.DELAY_JITTER == 0.1
+    monkeypatch.setattr(pool.base.ConnectionAttempt, "INITIAL_DELAY", 0.1)
+    monkeypatch.setattr(pool.base.ConnectionAttempt, "DELAY_JITTER", 0.0)
+
+    proxy.start()
+    p = pool.AsyncConnectionPool(proxy.client_dsn, minconn=1)
+    await p.wait_ready(2.0)
+    proxy.stop()
+
+    with pytest.raises(psycopg3.OperationalError):
+        async with p.connection() as conn:
+            await conn.execute("select 1")
+
+    await asyncio.sleep(1.0)
+    proxy.start()
+    await p.wait_ready()
+
+    async with p.connection() as conn:
+        await conn.execute("select 1")
+
+    await p.close()
+
+    recs = [
+        r
+        for r in caplog.records
+        if r.name.startswith("psycopg3") and r.levelno >= logging.WARNING
+    ]
+    assert "BAD" in recs[0].message
+    times = [rec.created for rec in recs]
+    assert times[1] - times[0] < 0.05
+    deltas = [times[i + 1] - times[i] for i in range(1, len(times) - 1)]
+    assert len(deltas) == 3
+    want = 0.1
+    for delta in deltas:
+        assert delta == pytest.approx(want, 0.05), deltas
+        want *= 2
+
+
+@pytest.mark.slow
+async def test_reconnect_failure(proxy):
+    proxy.start()
+
+    t1 = None
+
+    def failed(pool):
+        assert pool.name == "this-one"
+        nonlocal t1
+        t1 = time()
+
+    p = pool.AsyncConnectionPool(
+        proxy.client_dsn,
+        name="this-one",
+        minconn=1,
+        reconnect_timeout=1.0,
+        reconnect_failed=failed,
+    )
+    await p.wait_ready(2.0)
+    proxy.stop()
+
+    with pytest.raises(psycopg3.OperationalError):
+        async with p.connection() as conn:
+            await conn.execute("select 1")
+
+    t0 = time()
+    await asyncio.sleep(1.5)
+    assert t1
+    assert t1 - t0 == pytest.approx(1.0, 0.1)
+    assert p._nconns == 0
+
+    proxy.start()
+    t0 = time()
+    async with p.connection() as conn:
+        await conn.execute("select 1")
+    t1 = time()
+    assert t1 - t0 < 0.2
+    await p.close()
+
+
+@pytest.mark.slow
+async def test_uniform_use(dsn):
+    p = pool.AsyncConnectionPool(dsn, minconn=4)
+    counts = Counter()
+    for i in range(8):
+        async with p.connection() as conn:
+            await asyncio.sleep(0.1)
+            counts[id(conn)] += 1
+
+    await p.close()
+    assert len(counts) == 4
+    assert set(counts.values()) == set([2])
+
+
+def delay_connection(monkeypatch, sec):
+    """
+    Return a _connect_gen function delayed by the amount of seconds
+    """
+    connect_orig = psycopg3.AsyncConnection.connect
+
+    async def connect_delay(*args, **kwargs):
+        t0 = time()
+        rv = await connect_orig(*args, **kwargs)
+        t1 = time()
+        await asyncio.sleep(sec - (t1 - t0))
+        return rv
+
+    monkeypatch.setattr(psycopg3.AsyncConnection, "connect", connect_delay)