]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add NullPool and AsyncNullPool
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 3 Jan 2022 19:20:29 +0000 (20:20 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 9 Jan 2022 17:47:22 +0000 (18:47 +0100)
Close #148

psycopg_pool/psycopg_pool/__init__.py
psycopg_pool/psycopg_pool/_compat.py
psycopg_pool/psycopg_pool/base.py
psycopg_pool/psycopg_pool/null_pool.py [new file with mode: 0644]
psycopg_pool/psycopg_pool/null_pool_async.py [new file with mode: 0644]
psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py
tests/pool/test_null_pool.py [new file with mode: 0644]
tests/pool/test_null_pool_async.py [new file with mode: 0644]
tests/pool/test_pool.py

index 49b035b31aa026798c4978ae2638cb9145e844cd..e4d975feddf29fa5483941b5820fcd55e1fa65c3 100644 (file)
@@ -6,12 +6,16 @@ psycopg connection pool package
 
 from .pool import ConnectionPool
 from .pool_async import AsyncConnectionPool
+from .null_pool import NullConnectionPool
+from .null_pool_async import AsyncNullConnectionPool
 from .errors import PoolClosed, PoolTimeout, TooManyRequests
 from .version import __version__ as __version__  # noqa: F401
 
 __all__ = [
     "AsyncConnectionPool",
+    "AsyncNullConnectionPool",
     "ConnectionPool",
+    "NullConnectionPool",
     "PoolClosed",
     "PoolTimeout",
     "TooManyRequests",
index f666e677cd8656d590f315672a83afacdf2dc363..c1b14f2fe414c94ccb3d9f1170a42c642d35da92 100644 (file)
@@ -6,7 +6,9 @@ compatibility functions for different Python versions
 
 import sys
 import asyncio
-from typing import Any, Awaitable, Generator, Optional, Union, TypeVar
+from typing import Any, Awaitable, Generator, Optional, Union, Type, TypeVar
+
+import psycopg.errors as e
 
 T = TypeVar("T")
 FutureT = Union["asyncio.Future[T]", Generator[Any, None, T], Awaitable[T]]
@@ -35,3 +37,14 @@ __all__ = [
     "Task",
     "create_task",
 ]
+
+# Workaround for psycopg < 3.0.8.
+# Timeout on NullPool connection mignt not work correctly.
+try:
+    ConnectionTimeout: Type[e.OperationalError] = e.ConnectionTimeout
+except AttributeError:
+
+    class DummyConnectionTimeout(e.OperationalError):
+        pass
+
+    ConnectionTimeout = DummyConnectionTimeout
index 7c9d96223d7a5de5fe915c23d585c88b651e6019..1e3187b5510d8d5f6c5be7c5a791e89ecb6f3cef 100644 (file)
@@ -121,11 +121,11 @@ class BasePool(Generic[ConnectionType]):
     def _check_size(
         self, min_size: int, max_size: Optional[int]
     ) -> Tuple[int, int]:
-        if min_size < 0:
-            raise ValueError("min_size cannot be negative")
-
         if max_size is None:
             max_size = min_size
+
+        if min_size < 0:
+            raise ValueError("min_size cannot be negative")
         if max_size < min_size:
             raise ValueError("max_size must be greater or equal than min_size")
         if min_size == max_size == 0:
diff --git a/psycopg_pool/psycopg_pool/null_pool.py b/psycopg_pool/psycopg_pool/null_pool.py
new file mode 100644 (file)
index 0000000..58823cb
--- /dev/null
@@ -0,0 +1,198 @@
+"""
+Psycopg null connection pools
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+import logging
+import threading
+from time import monotonic
+from typing import Any, Optional, Tuple
+
+from psycopg import Connection
+from psycopg.pq import TransactionStatus
+
+from .pool import ConnectionPool, WaitingClient
+from .pool import AddConnection, ReturnConnection
+from .errors import PoolTimeout, TooManyRequests
+from ._compat import ConnectionTimeout
+
+logger = logging.getLogger("psycopg.pool")
+
+
+class _BaseNullConnectionPool:
+    def __init__(
+        self, conninfo: str = "", min_size: int = 0, *args: Any, **kwargs: Any
+    ):
+        super().__init__(  # type: ignore[call-arg]
+            conninfo, *args, min_size=min_size, **kwargs
+        )
+
+    def _check_size(
+        self, min_size: int, max_size: Optional[int]
+    ) -> Tuple[int, int]:
+        if max_size is None:
+            max_size = min_size
+
+        if min_size != 0:
+            raise ValueError("null pools must have min_size = 0")
+        if max_size < min_size:
+            raise ValueError("max_size must be greater or equal than min_size")
+
+        return min_size, max_size
+
+    def _start_initial_tasks(self) -> None:
+        # Null pools don't have background tasks to fill connections
+        # or to grow/shrink.
+        return
+
+
+class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool):
+    def wait(self, timeout: float = 30.0) -> None:
+        """
+        Create a connection for test.
+
+        Calling this function will verify that the connectivity with the
+        database works as expected. However the connection will not be stored
+        in the pool.
+
+        Raise `PoolTimeout` if not ready within *timeout* sec.
+        """
+        self._check_open_getconn()
+
+        with self._lock:
+            assert not self._pool_full_event
+            self._pool_full_event = threading.Event()
+
+        logger.info("waiting for pool %r initialization", self.name)
+        self.run_task(AddConnection(self))
+        if not self._pool_full_event.wait(timeout):
+            self.close()  # stop all the threads
+            raise PoolTimeout(
+                f"pool initialization incomplete after {timeout} sec"
+            )
+
+        with self._lock:
+            assert self._pool_full_event
+            self._pool_full_event = None
+
+        logger.info("pool %r is ready to use", self.name)
+
+    def getconn(self, timeout: Optional[float] = None) -> Connection[Any]:
+        logger.info("connection requested from %r", self.name)
+        self._stats[self._REQUESTS_NUM] += 1
+
+        # Critical section: decide here if there's a connection ready
+        # or if the client needs to wait.
+        with self._lock:
+            self._check_open_getconn()
+
+            pos: Optional[WaitingClient] = None
+            if self.max_size == 0 or self._nconns < self.max_size:
+                # Create a new connection for the client
+                try:
+                    conn = self._connect(timeout=timeout)
+                except ConnectionTimeout as ex:
+                    raise PoolTimeout(str(ex)) from None
+                self._nconns += 1
+            else:
+                if self.max_waiting and len(self._waiting) >= self.max_waiting:
+                    self._stats[self._REQUESTS_ERRORS] += 1
+                    raise TooManyRequests(
+                        f"the pool {self.name!r} has aleady"
+                        f" {len(self._waiting)} requests waiting"
+                    )
+
+                # No connection available: put the client in the waiting queue
+                t0 = monotonic()
+                pos = WaitingClient()
+                self._waiting.append(pos)
+                self._stats[self._REQUESTS_QUEUED] += 1
+
+        # If we are in the waiting queue, wait to be assigned a connection
+        # (outside the critical section, so only the waiting client is locked)
+        if pos:
+            if timeout is None:
+                timeout = self.timeout
+            try:
+                conn = pos.wait(timeout=timeout)
+            except Exception:
+                self._stats[self._REQUESTS_ERRORS] += 1
+                raise
+            finally:
+                t1 = monotonic()
+                self._stats[self._REQUESTS_WAIT_MS] += int(1000.0 * (t1 - t0))
+
+        # Tell the connection it belongs to a pool to avoid closing on __exit__
+        conn._pool = self
+        logger.info("connection given by %r", self.name)
+        return conn
+
+    def putconn(self, conn: Connection[Any]) -> None:
+        # Quick check to discard the wrong connection
+        self._check_pool_putconn(conn)
+
+        logger.info("returning connection to %r", self.name)
+
+        # Close the connection if no client is waiting for it, or if the pool
+        # is closed. For extra refcare remove the pool reference from it.
+        # Maintain the stats.
+        with self._lock:
+            if self._closed or not self._waiting:
+                conn._pool = None
+                if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
+                    self._stats[self._RETURNS_BAD] += 1
+                conn.close()
+                self._nconns -= 1
+                return
+
+        # Use a worker to perform eventual maintenance work in a separate thread
+        if self._reset:
+            self.run_task(ReturnConnection(self, conn))
+        else:
+            self._return_connection(conn)
+
+    def resize(self, min_size: int, max_size: Optional[int] = None) -> None:
+        min_size, max_size = self._check_size(min_size, max_size)
+
+        logger.info(
+            "resizing %r to min_size=%s max_size=%s",
+            self.name,
+            min_size,
+            max_size,
+        )
+        with self._lock:
+            self._min_size = min_size
+            self._max_size = max_size
+
+    def check(self) -> None:
+        """No-op, as the pool doesn't have connections in its state."""
+        pass
+
+    def _add_to_pool(self, conn: Connection[Any]) -> None:
+        # Remove the pool reference from the connection before returning it
+        # to the state, to avoid to create a reference loop.
+        # Also disable the warning for open connection in conn.__del__
+        conn._pool = None
+
+        # Critical section: if there is a client waiting give it the connection
+        # otherwise put it back into the pool.
+        with self._lock:
+            while self._waiting:
+                # If there is a client waiting (which is still waiting and
+                # hasn't timed out), give it the connection and notify it.
+                pos = self._waiting.popleft()
+                if pos.set(conn):
+                    break
+            else:
+                # No client waiting for a connection: close the connection
+                conn.close()
+
+                # If we have been asked to wait for pool init, notify the
+                # waiter if the pool is ready.
+                if self._pool_full_event:
+                    self._pool_full_event.set()
+                else:
+                    # The connection created by wait shoudn't decrease the
+                    # count of the number of connection used.
+                    self._nconns -= 1
diff --git a/psycopg_pool/psycopg_pool/null_pool_async.py b/psycopg_pool/psycopg_pool/null_pool_async.py
new file mode 100644 (file)
index 0000000..6901223
--- /dev/null
@@ -0,0 +1,168 @@
+"""
+psycopg asynchronous null connection pool
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+import asyncio
+import logging
+from time import monotonic
+from typing import Any, Optional
+
+from psycopg.pq import TransactionStatus
+from psycopg.connection_async import AsyncConnection
+
+from .errors import PoolTimeout, TooManyRequests
+from ._compat import ConnectionTimeout
+from .null_pool import _BaseNullConnectionPool
+from .pool_async import AsyncConnectionPool, AsyncClient
+from .pool_async import AddConnection, ReturnConnection
+
+logger = logging.getLogger("psycopg.pool")
+
+
+class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool):
+    async def wait(self, timeout: float = 30.0) -> None:
+        self._check_open_getconn()
+
+        async with self._lock:
+            assert not self._pool_full_event
+            self._pool_full_event = asyncio.Event()
+
+        logger.info("waiting for pool %r initialization", self.name)
+        self.run_task(AddConnection(self))
+        try:
+            await asyncio.wait_for(self._pool_full_event.wait(), timeout)
+        except asyncio.TimeoutError:
+            await self.close()  # stop all the tasks
+            raise PoolTimeout(
+                f"pool initialization incomplete after {timeout} sec"
+            ) from None
+
+        async with self._lock:
+            assert self._pool_full_event
+            self._pool_full_event = None
+
+        logger.info("pool %r is ready to use", self.name)
+
+    async def getconn(
+        self, timeout: Optional[float] = None
+    ) -> AsyncConnection[Any]:
+        logger.info("connection requested from %r", self.name)
+        self._stats[self._REQUESTS_NUM] += 1
+
+        # Critical section: decide here if there's a connection ready
+        # or if the client needs to wait.
+        async with self._lock:
+            self._check_open_getconn()
+
+            pos: Optional[AsyncClient] = None
+            if self.max_size == 0 or self._nconns < self.max_size:
+                # Create a new connection for the client
+                try:
+                    conn = await self._connect(timeout=timeout)
+                except ConnectionTimeout as ex:
+                    raise PoolTimeout(str(ex)) from None
+                self._nconns += 1
+            else:
+                if self.max_waiting and len(self._waiting) >= self.max_waiting:
+                    self._stats[self._REQUESTS_ERRORS] += 1
+                    raise TooManyRequests(
+                        f"the pool {self.name!r} has aleady"
+                        f" {len(self._waiting)} requests waiting"
+                    )
+
+                # No connection available: put the client in the waiting queue
+                t0 = monotonic()
+                pos = AsyncClient()
+                self._waiting.append(pos)
+                self._stats[self._REQUESTS_QUEUED] += 1
+
+        # If we are in the waiting queue, wait to be assigned a connection
+        # (outside the critical section, so only the waiting client is locked)
+        if pos:
+            if timeout is None:
+                timeout = self.timeout
+            try:
+                conn = await pos.wait(timeout=timeout)
+            except Exception:
+                self._stats[self._REQUESTS_ERRORS] += 1
+                raise
+            finally:
+                t1 = monotonic()
+                self._stats[self._REQUESTS_WAIT_MS] += int(1000.0 * (t1 - t0))
+
+        # Tell the connection it belongs to a pool to avoid closing on __exit__
+        conn._pool = self
+        logger.info("connection given by %r", self.name)
+        return conn
+
+    async def putconn(self, conn: AsyncConnection[Any]) -> None:
+        # Quick check to discard the wrong connection
+        self._check_pool_putconn(conn)
+
+        logger.info("returning connection to %r", self.name)
+
+        # Close the connection if no client is waiting for it, or if the pool
+        # is closed. For extra refcare remove the pool reference from it.
+        # Maintain the stats.
+        async with self._lock:
+            if self._closed or not self._waiting:
+                conn._pool = None
+                if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
+                    self._stats[self._RETURNS_BAD] += 1
+                await conn.close()
+                self._nconns -= 1
+                return
+
+        # Use a worker to perform eventual maintenance work in a separate task
+        if self._reset:
+            self.run_task(ReturnConnection(self, conn))
+        else:
+            await self._return_connection(conn)
+
+    async def resize(
+        self, min_size: int, max_size: Optional[int] = None
+    ) -> None:
+        min_size, max_size = self._check_size(min_size, max_size)
+
+        logger.info(
+            "resizing %r to min_size=%s max_size=%s",
+            self.name,
+            min_size,
+            max_size,
+        )
+        async with self._lock:
+            self._min_size = min_size
+            self._max_size = max_size
+
+    async def check(self) -> None:
+        pass
+
+    async def _add_to_pool(self, conn: AsyncConnection[Any]) -> None:
+        # Remove the pool reference from the connection before returning it
+        # to the state, to avoid to create a reference loop.
+        # Also disable the warning for open connection in conn.__del__
+        conn._pool = None
+
+        # Critical section: if there is a client waiting give it the connection
+        # otherwise put it back into the pool.
+        async with self._lock:
+            while self._waiting:
+                # If there is a client waiting (which is still waiting and
+                # hasn't timed out), give it the connection and notify it.
+                pos = self._waiting.popleft()
+                if await pos.set(conn):
+                    break
+            else:
+                # No client waiting for a connection: close the connection
+                await conn.close()
+
+                # If we have been asked to wait for pool init, notify the
+                # waiter if the pool is ready.
+                if self._pool_full_event:
+                    self._pool_full_event.set()
+                else:
+                    # The connection created by wait shoudn't decrease the
+                    # count of the number of connection used.
+                    self._nconns -= 1
index aa1dd20b3def97e04b0eb51968101438421a2743..06683ba90e9ae146215387f805fc310d0940d52b 100644 (file)
@@ -454,13 +454,17 @@ class ConnectionPool(BasePool[Connection[Any]]):
                     ex,
                 )
 
-    def _connect(self) -> Connection[Any]:
+    def _connect(self, timeout: Optional[float] = None) -> Connection[Any]:
         """Return a new connection configured for the pool."""
         self._stats[self._CONNECTIONS_NUM] += 1
+        kwargs = self.kwargs
+        if timeout:
+            kwargs = kwargs.copy()
+            kwargs["connect_timeout"] = max(round(timeout), 1)
         t0 = monotonic()
         try:
             conn: Connection[Any]
-            conn = self.connection_class.connect(self.conninfo, **self.kwargs)
+            conn = self.connection_class.connect(self.conninfo, **kwargs)
         except Exception:
             self._stats[self._CONNECTIONS_ERRORS] += 1
             raise
index 8a4391ceb5403827579de16e2e10e945d98daee0..7fe772ebc531bdd02789dc672fa3a4d4b6c2bc06 100644 (file)
@@ -368,15 +368,18 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
                     ex,
                 )
 
-    async def _connect(self) -> AsyncConnection[Any]:
-        """Return a new connection configured for the pool."""
+    async def _connect(
+        self, timeout: Optional[float] = None
+    ) -> AsyncConnection[Any]:
         self._stats[self._CONNECTIONS_NUM] += 1
+        kwargs = self.kwargs
+        if timeout:
+            kwargs = kwargs.copy()
+            kwargs["connect_timeout"] = max(round(timeout), 1)
         t0 = monotonic()
         try:
             conn: AsyncConnection[Any]
-            conn = await self.connection_class.connect(
-                self.conninfo, **self.kwargs
-            )
+            conn = await self.connection_class.connect(self.conninfo, **kwargs)
         except Exception:
             self._stats[self._CONNECTIONS_ERRORS] += 1
             raise
diff --git a/tests/pool/test_null_pool.py b/tests/pool/test_null_pool.py
new file mode 100644 (file)
index 0000000..9747c5e
--- /dev/null
@@ -0,0 +1,898 @@
+import logging
+from time import sleep, time
+from threading import Thread, Event
+from typing import Any, List, Tuple
+
+import pytest
+from packaging.version import parse as ver  # noqa: F401  # used in skipif
+
+import psycopg
+from psycopg.pq import TransactionStatus
+
+from .test_pool import delay_connection
+
+try:
+    from psycopg_pool import NullConnectionPool
+    from psycopg_pool import PoolClosed, PoolTimeout, TooManyRequests
+except ImportError:
+    pass
+
+
+def test_defaults(dsn):
+    with NullConnectionPool(dsn) as p:
+        assert p.min_size == p.max_size == 0
+        assert p.timeout == 30
+        assert p.max_idle == 10 * 60
+        assert p.max_lifetime == 60 * 60
+        assert p.num_workers == 3
+
+
+def test_min_size_max_size(dsn):
+    with NullConnectionPool(dsn, min_size=0, max_size=2) as p:
+        assert p.min_size == 0
+        assert p.max_size == 2
+
+
+@pytest.mark.parametrize(
+    "min_size, max_size", [(1, None), (-1, None), (0, -2)]
+)
+def test_bad_size(dsn, min_size, max_size):
+    with pytest.raises(ValueError):
+        NullConnectionPool(min_size=min_size, max_size=max_size)
+
+
+def test_connection_class(dsn):
+    class MyConn(psycopg.Connection[Any]):
+        pass
+
+    with NullConnectionPool(dsn, connection_class=MyConn) as p:
+        with p.connection() as conn:
+            assert isinstance(conn, MyConn)
+
+
+def test_kwargs(dsn):
+    with NullConnectionPool(dsn, kwargs={"autocommit": True}) as p:
+        with p.connection() as conn:
+            assert conn.autocommit
+
+
+def test_its_no_pool_at_all(dsn):
+    with NullConnectionPool(dsn, max_size=2) as p:
+        with p.connection() as conn:
+            with conn.execute("select pg_backend_pid()") as cur:
+                (pid1,) = cur.fetchone()  # type: ignore[misc]
+
+            with p.connection() as conn2:
+                with conn2.execute("select pg_backend_pid()") as cur:
+                    (pid2,) = cur.fetchone()  # type: ignore[misc]
+
+        with p.connection() as conn:
+            assert conn.pgconn.backend_pid not in (pid1, pid2)
+
+
+def test_context(dsn):
+    with NullConnectionPool(dsn) as p:
+        assert not p.closed
+    assert p.closed
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_wait_ready(dsn, monkeypatch):
+    delay_connection(monkeypatch, 0.2)
+    with pytest.raises(PoolTimeout):
+        with NullConnectionPool(dsn, num_workers=1) as p:
+            p.wait(0.1)
+
+    with NullConnectionPool(dsn, num_workers=1) as p:
+        p.wait(0.4)
+
+
+def test_wait_closed(dsn):
+    with NullConnectionPool(dsn) as p:
+        pass
+
+    with pytest.raises(PoolClosed):
+        p.wait()
+
+
+@pytest.mark.slow
+def test_setup_no_timeout(dsn, proxy):
+    with pytest.raises(PoolTimeout):
+        with NullConnectionPool(proxy.client_dsn, num_workers=1) as p:
+            p.wait(0.2)
+
+    with NullConnectionPool(proxy.client_dsn, num_workers=1) as p:
+        sleep(0.5)
+        assert not p._pool
+        proxy.start()
+
+        with p.connection() as conn:
+            conn.execute("select 1")
+
+
+def test_configure(dsn):
+    inits = 0
+
+    def configure(conn):
+        nonlocal inits
+        inits += 1
+        with conn.transaction():
+            conn.execute("set default_transaction_read_only to on")
+
+    with NullConnectionPool(dsn, configure=configure) as p:
+        with p.connection() as conn:
+            assert inits == 1
+            res = conn.execute("show default_transaction_read_only")
+            assert res.fetchone()[0] == "on"  # type: ignore[index]
+
+        with p.connection() as conn:
+            assert inits == 2
+            res = conn.execute("show default_transaction_read_only")
+            assert res.fetchone()[0] == "on"  # type: ignore[index]
+            conn.close()
+
+        with p.connection() as conn:
+            assert inits == 3
+            res = conn.execute("show default_transaction_read_only")
+            assert res.fetchone()[0] == "on"  # type: ignore[index]
+
+
+@pytest.mark.slow
+def test_configure_badstate(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+    def configure(conn):
+        conn.execute("select 1")
+
+    with NullConnectionPool(dsn, configure=configure) as p:
+        with pytest.raises(PoolTimeout):
+            p.wait(timeout=0.5)
+
+    assert caplog.records
+    assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.slow
+def test_configure_broken(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+    def configure(conn):
+        with conn.transaction():
+            conn.execute("WAT")
+
+    with NullConnectionPool(dsn, configure=configure) as p:
+        with pytest.raises(PoolTimeout):
+            p.wait(timeout=0.5)
+
+    assert caplog.records
+    assert "WAT" in caplog.records[0].message
+
+
+def test_reset(dsn):
+    resets = 0
+
+    def setup(conn):
+        with conn.transaction():
+            conn.execute("set timezone to '+1:00'")
+
+    def reset(conn):
+        nonlocal resets
+        resets += 1
+        with conn.transaction():
+            conn.execute("set timezone to utc")
+
+    pids = []
+
+    def worker():
+        with p.connection() as conn:
+            assert resets == 1
+            with conn.execute("show timezone") as cur:
+                assert cur.fetchone() == ("UTC",)
+            pids.append(conn.pgconn.backend_pid)
+
+    with NullConnectionPool(dsn, max_size=1, reset=reset) as p:
+        with p.connection() as conn:
+
+            # Queue the worker so it will take the same connection a second time
+            # instead of making a new one.
+            t = Thread(target=worker)
+            t.start()
+
+            assert resets == 0
+            conn.execute("set timezone to '+2:00'")
+            pids.append(conn.pgconn.backend_pid)
+
+        t.join()
+        p.wait()
+
+    assert resets == 1
+    assert pids[0] == pids[1]
+
+
+def test_reset_badstate(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+    def reset(conn):
+        conn.execute("reset all")
+
+    pids = []
+
+    def worker():
+        with p.connection() as conn:
+            conn.execute("select 1")
+            pids.append(conn.pgconn.backend_pid)
+
+    with NullConnectionPool(dsn, max_size=1, reset=reset) as p:
+        with p.connection() as conn:
+
+            t = Thread(target=worker)
+            t.start()
+
+            conn.execute("select 1")
+            pids.append(conn.pgconn.backend_pid)
+
+        t.join()
+
+    assert pids[0] != pids[1]
+    assert caplog.records
+    assert "INTRANS" in caplog.records[0].message
+
+
+def test_reset_broken(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+    def reset(conn):
+        with conn.transaction():
+            conn.execute("WAT")
+
+    pids = []
+
+    def worker():
+        with p.connection() as conn:
+            conn.execute("select 1")
+            pids.append(conn.pgconn.backend_pid)
+
+    with NullConnectionPool(dsn, max_size=1, reset=reset) as p:
+        with p.connection() as conn:
+
+            t = Thread(target=worker)
+            t.start()
+
+            conn.execute("select 1")
+            pids.append(conn.pgconn.backend_pid)
+
+        t.join()
+
+    assert pids[0] != pids[1]
+    assert caplog.records
+    assert "WAT" in caplog.records[0].message
+
+
+@pytest.mark.slow
+@pytest.mark.skipif("ver(psycopg.__version__) < ver('3.0.8')")
+def test_no_queue_timeout(deaf_port):
+    with NullConnectionPool(
+        kwargs={"host": "localhost", "port": deaf_port}
+    ) as p:
+        with pytest.raises(PoolTimeout):
+            with p.connection(timeout=1):
+                pass
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_queue(dsn, retries):
+    def worker(n):
+        t0 = time()
+        with p.connection() as conn:
+            (pid,) = conn.execute(
+                "select pg_backend_pid() from pg_sleep(0.2)"
+            ).fetchone()  # type: ignore[misc]
+        t1 = time()
+        results.append((n, t1 - t0, pid))
+
+    for retry in retries:
+        with retry:
+            results: List[Tuple[int, float, int]] = []
+            with NullConnectionPool(dsn, max_size=2) as p:
+                p.wait()
+                ts = [Thread(target=worker, args=(i,)) for i in range(6)]
+                for t in ts:
+                    t.start()
+                for t in ts:
+                    t.join()
+
+            times = [item[1] for item in results]
+            want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6]
+            for got, want in zip(times, want_times):
+                assert got == pytest.approx(want, 0.2), times
+
+            assert len(set(r[2] for r in results)) == 2, results
+
+
+@pytest.mark.slow
+def test_queue_size(dsn):
+    def worker(t, ev=None):
+        try:
+            with p.connection():
+                if ev:
+                    ev.set()
+                sleep(t)
+        except TooManyRequests as e:
+            errors.append(e)
+        else:
+            success.append(True)
+
+    errors: List[Exception] = []
+    success: List[bool] = []
+
+    with NullConnectionPool(dsn, max_size=1, max_waiting=3) as p:
+        p.wait()
+        ev = Event()
+        t = Thread(target=worker, args=(0.3, ev))
+        t.start()
+        ev.wait()
+
+        ts = [Thread(target=worker, args=(0.1,)) for i in range(4)]
+        for t in ts:
+            t.start()
+        for t in ts:
+            t.join()
+
+    assert len(success) == 4
+    assert len(errors) == 1
+    assert isinstance(errors[0], TooManyRequests)
+    assert p.name in str(errors[0])
+    assert str(p.max_waiting) in str(errors[0])
+    assert p.get_stats()["requests_errors"] == 1
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_queue_timeout(dsn, retries):
+    def worker(n):
+        t0 = time()
+        try:
+            with p.connection() as conn:
+                (pid,) = conn.execute(  # type: ignore[misc]
+                    "select pg_backend_pid() from pg_sleep(0.2)"
+                ).fetchone()
+        except PoolTimeout as e:
+            t1 = time()
+            errors.append((n, t1 - t0, e))
+        else:
+            t1 = time()
+            results.append((n, t1 - t0, pid))
+
+    for retry in retries:
+        with retry:
+            results: List[Tuple[int, float, int]] = []
+            errors: List[Tuple[int, float, Exception]] = []
+
+            with NullConnectionPool(dsn, max_size=2, timeout=0.1) as p:
+                ts = [Thread(target=worker, args=(i,)) for i in range(4)]
+                for t in ts:
+                    t.start()
+                for t in ts:
+                    t.join()
+
+            assert len(results) == 2
+            assert len(errors) == 2
+            for e in errors:
+                assert 0.1 < e[1] < 0.15
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_dead_client(dsn):
+    def worker(i, timeout):
+        try:
+            with p.connection(timeout=timeout) as conn:
+                conn.execute("select pg_sleep(0.3)")
+                results.append(i)
+        except PoolTimeout:
+            if timeout > 0.2:
+                raise
+
+    results: List[int] = []
+
+    with NullConnectionPool(dsn, max_size=2) as p:
+        ts = [
+            Thread(target=worker, args=(i, timeout))
+            for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4])
+        ]
+        for t in ts:
+            t.start()
+        for t in ts:
+            t.join()
+        sleep(0.2)
+        assert set(results) == set([0, 1, 3, 4])
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_queue_timeout_override(dsn, retries):
+    def worker(n):
+        t0 = time()
+        timeout = 0.25 if n == 3 else None
+        try:
+            with p.connection(timeout=timeout) as conn:
+                (pid,) = conn.execute(  # type: ignore[misc]
+                    "select pg_backend_pid() from pg_sleep(0.2)"
+                ).fetchone()
+        except PoolTimeout as e:
+            t1 = time()
+            errors.append((n, t1 - t0, e))
+        else:
+            t1 = time()
+            results.append((n, t1 - t0, pid))
+
+    for retry in retries:
+        with retry:
+            results: List[Tuple[int, float, int]] = []
+            errors: List[Tuple[int, float, Exception]] = []
+
+            with NullConnectionPool(dsn, max_size=2, timeout=0.1) as p:
+                ts = [Thread(target=worker, args=(i,)) for i in range(4)]
+                for t in ts:
+                    t.start()
+                for t in ts:
+                    t.join()
+
+            assert len(results) == 3
+            assert len(errors) == 1
+            for e in errors:
+                assert 0.1 < e[1] < 0.15
+
+
+def test_broken_reconnect(dsn):
+    with NullConnectionPool(dsn, max_size=1) as p:
+        with p.connection() as conn:
+            with conn.execute("select pg_backend_pid()") as cur:
+                (pid1,) = cur.fetchone()  # type: ignore[misc]
+            conn.close()
+
+        with p.connection() as conn2:
+            with conn2.execute("select pg_backend_pid()") as cur:
+                (pid2,) = cur.fetchone()  # type: ignore[misc]
+
+    assert pid1 != pid2
+
+
+def test_intrans_rollback(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg.pool")
+    pids = []
+
+    def worker():
+        with p.connection() as conn:
+            pids.append(conn.pgconn.backend_pid)
+            assert conn.pgconn.transaction_status == TransactionStatus.IDLE
+            assert not conn.execute(
+                "select 1 from pg_class where relname = 'test_intrans_rollback'"
+            ).fetchone()
+
+    with NullConnectionPool(dsn, max_size=1) as p:
+        conn = p.getconn()
+
+        # Queue the worker so it will take the connection a second time instead
+        # of making a new one.
+        t = Thread(target=worker)
+        t.start()
+
+        pids.append(conn.pgconn.backend_pid)
+        conn.execute("create table test_intrans_rollback ()")
+        assert conn.pgconn.transaction_status == TransactionStatus.INTRANS
+        p.putconn(conn)
+        t.join()
+
+    assert pids[0] == pids[1]
+    assert len(caplog.records) == 1
+    assert "INTRANS" in caplog.records[0].message
+
+
+def test_inerror_rollback(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg.pool")
+    pids = []
+
+    def worker():
+        with p.connection() as conn:
+            pids.append(conn.pgconn.backend_pid)
+            assert conn.pgconn.transaction_status == TransactionStatus.IDLE
+
+    with NullConnectionPool(dsn, max_size=1) as p:
+        conn = p.getconn()
+
+        # Queue the worker so it will take the connection a second time instead
+        # of making a new one.
+        t = Thread(target=worker)
+        t.start()
+
+        pids.append(conn.pgconn.backend_pid)
+        with pytest.raises(psycopg.ProgrammingError):
+            conn.execute("wat")
+        assert conn.pgconn.transaction_status == TransactionStatus.INERROR
+        p.putconn(conn)
+        t.join()
+
+    assert pids[0] == pids[1]
+    assert len(caplog.records) == 1
+    assert "INERROR" in caplog.records[0].message
+
+
+def test_active_close(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg.pool")
+    pids = []
+
+    def worker():
+        with p.connection() as conn:
+            pids.append(conn.pgconn.backend_pid)
+            assert conn.pgconn.transaction_status == TransactionStatus.IDLE
+
+    with NullConnectionPool(dsn, max_size=1) as p:
+        conn = p.getconn()
+
+        t = Thread(target=worker)
+        t.start()
+
+        pids.append(conn.pgconn.backend_pid)
+        cur = conn.cursor()
+        with cur.copy("copy (select * from generate_series(1, 10)) to stdout"):
+            pass
+        assert conn.pgconn.transaction_status == TransactionStatus.ACTIVE
+        p.putconn(conn)
+        t.join()
+
+    assert pids[0] != pids[1]
+    assert len(caplog.records) == 2
+    assert "ACTIVE" in caplog.records[0].message
+    assert "BAD" in caplog.records[1].message
+
+
+def test_fail_rollback_close(dsn, caplog, monkeypatch):
+    caplog.set_level(logging.WARNING, logger="psycopg.pool")
+    pids = []
+
+    def worker(p):
+        with p.connection() as conn:
+            pids.append(conn.pgconn.backend_pid)
+            assert conn.pgconn.transaction_status == TransactionStatus.IDLE
+
+    with NullConnectionPool(dsn, max_size=1) as p:
+        conn = p.getconn()
+
+        def bad_rollback():
+            conn.pgconn.finish()
+            orig_rollback()
+
+        # Make the rollback fail
+        orig_rollback = conn.rollback
+        monkeypatch.setattr(conn, "rollback", bad_rollback)
+
+        t = Thread(target=worker, args=(p,))
+        t.start()
+
+        pids.append(conn.pgconn.backend_pid)
+        with pytest.raises(psycopg.ProgrammingError):
+            conn.execute("wat")
+        assert conn.pgconn.transaction_status == TransactionStatus.INERROR
+        p.putconn(conn)
+        t.join()
+
+    assert pids[0] != pids[1]
+    assert len(caplog.records) == 3
+    assert "INERROR" in caplog.records[0].message
+    assert "OperationalError" in caplog.records[1].message
+    assert "BAD" in caplog.records[2].message
+
+
+def test_close_no_threads(dsn):
+    p = NullConnectionPool(dsn)
+    assert p._sched_runner and p._sched_runner.is_alive()
+    workers = p._workers[:]
+    assert workers
+    for t in workers:
+        assert t.is_alive()
+
+    p.close()
+    assert p._sched_runner is None
+    assert not p._workers
+    for t in workers:
+        assert not t.is_alive()
+
+
+def test_putconn_no_pool(dsn):
+    with NullConnectionPool(dsn) as p:
+        conn = psycopg.connect(dsn)
+        with pytest.raises(ValueError):
+            p.putconn(conn)
+
+    conn.close()
+
+
+def test_putconn_wrong_pool(dsn):
+    with NullConnectionPool(dsn) as p1:
+        with NullConnectionPool(dsn) as p2:
+            conn = p1.getconn()
+            with pytest.raises(ValueError):
+                p2.putconn(conn)
+
+
+@pytest.mark.slow
+def test_del_stop_threads(dsn):
+    p = NullConnectionPool(dsn)
+    assert p._sched_runner is not None
+    ts = [p._sched_runner] + p._workers
+    del p
+    sleep(0.1)
+    for t in ts:
+        assert not t.is_alive()
+
+
+def test_closed_getconn(dsn):
+    p = NullConnectionPool(dsn)
+    assert not p.closed
+    with p.connection():
+        pass
+
+    p.close()
+    assert p.closed
+
+    with pytest.raises(PoolClosed):
+        with p.connection():
+            pass
+
+
+def test_closed_putconn(dsn):
+    p = NullConnectionPool(dsn)
+
+    with p.connection() as conn:
+        pass
+    assert conn.closed
+
+    with p.connection() as conn:
+        p.close()
+    assert conn.closed
+
+
+def test_closed_queue(dsn):
+    def w1():
+        with p.connection() as conn:
+            e1.set()  # Tell w0 that w1 got a connection
+            cur = conn.execute("select 1")
+            assert cur.fetchone() == (1,)
+            e2.wait()  # Wait until w0 has tested w2
+        success.append("w1")
+
+    def w2():
+        try:
+            with p.connection():
+                pass  # unexpected
+        except PoolClosed:
+            success.append("w2")
+
+    e1 = Event()
+    e2 = Event()
+
+    p = NullConnectionPool(dsn, max_size=1)
+    p.wait()
+    success: List[str] = []
+
+    t1 = Thread(target=w1)
+    t1.start()
+    # Wait until w1 has received a connection
+    e1.wait()
+
+    t2 = Thread(target=w2)
+    t2.start()
+    # Wait until w2 is in the queue
+    while not p._waiting:
+        sleep(0)
+
+    p.close(0)
+
+    # Wait for the workers to finish
+    e2.set()
+    t1.join()
+    t2.join()
+    assert len(success) == 2
+
+
+def test_open_explicit(dsn):
+    p = NullConnectionPool(dsn, open=False)
+    assert p.closed
+    with pytest.raises(PoolClosed, match="is not open yet"):
+        p.getconn()
+
+    with pytest.raises(PoolClosed):
+        with p.connection():
+            pass
+
+    p.open()
+    try:
+        assert not p.closed
+
+        with p.connection() as conn:
+            cur = conn.execute("select 1")
+            assert cur.fetchone() == (1,)
+
+    finally:
+        p.close()
+
+    with pytest.raises(PoolClosed, match="is already closed"):
+        p.getconn()
+
+
+def test_open_context(dsn):
+    p = NullConnectionPool(dsn, open=False)
+    assert p.closed
+
+    with p:
+        assert not p.closed
+
+        with p.connection() as conn:
+            cur = conn.execute("select 1")
+            assert cur.fetchone() == (1,)
+
+    assert p.closed
+
+
+def test_open_no_op(dsn):
+    p = NullConnectionPool(dsn)
+    try:
+        assert not p.closed
+        p.open()
+        assert not p.closed
+
+        with p.connection() as conn:
+            cur = conn.execute("select 1")
+            assert cur.fetchone() == (1,)
+
+    finally:
+        p.close()
+
+
+def test_reopen(dsn):
+    p = NullConnectionPool(dsn)
+    with p.connection() as conn:
+        conn.execute("select 1")
+    p.close()
+    assert p._sched_runner is None
+    assert not p._workers
+
+    with pytest.raises(psycopg.OperationalError, match="cannot be reused"):
+        p.open()
+
+
+@pytest.mark.parametrize(
+    "min_size, max_size", [(1, None), (-1, None), (0, -2)]
+)
+def test_bad_resize(dsn, min_size, max_size):
+    with NullConnectionPool() as p:
+        with pytest.raises(ValueError):
+            p.resize(min_size=min_size, max_size=max_size)
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_max_lifetime(dsn):
+    pids = []
+
+    def worker(p):
+        with p.connection() as conn:
+            pids.append(conn.pgconn.backend_pid)
+            sleep(0.1)
+
+    ts = []
+    with NullConnectionPool(dsn, max_size=1, max_lifetime=0.2) as p:
+        for i in range(5):
+            ts.append(Thread(target=worker, args=(p,)))
+            ts[-1].start()
+
+        for t in ts:
+            t.join()
+
+    assert pids[0] == pids[1] != pids[4], pids
+
+
+def test_check(dsn):
+    with NullConnectionPool(dsn) as p:
+        # No-op
+        p.check()
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_stats_measures(dsn):
+    def worker(n):
+        with p.connection() as conn:
+            conn.execute("select pg_sleep(0.2)")
+
+    with NullConnectionPool(dsn, max_size=4) as p:
+        p.wait(2.0)
+
+        stats = p.get_stats()
+        assert stats["pool_min"] == 0
+        assert stats["pool_max"] == 4
+        assert stats["pool_size"] == 0
+        assert stats["pool_available"] == 0
+        assert stats["requests_waiting"] == 0
+
+        ts = [Thread(target=worker, args=(i,)) for i in range(3)]
+        for t in ts:
+            t.start()
+        sleep(0.1)
+        stats = p.get_stats()
+        for t in ts:
+            t.join()
+        assert stats["pool_min"] == 0
+        assert stats["pool_max"] == 4
+        assert stats["pool_size"] == 3
+        assert stats["pool_available"] == 0
+        assert stats["requests_waiting"] == 0
+
+        p.wait(2.0)
+        ts = [Thread(target=worker, args=(i,)) for i in range(7)]
+        for t in ts:
+            t.start()
+        sleep(0.1)
+        stats = p.get_stats()
+        for t in ts:
+            t.join()
+        assert stats["pool_min"] == 0
+        assert stats["pool_max"] == 4
+        assert stats["pool_size"] == 4
+        assert stats["pool_available"] == 0
+        assert stats["requests_waiting"] == 3
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_stats_usage(dsn, retries):
+    def worker(n):
+        try:
+            with p.connection(timeout=0.3) as conn:
+                conn.execute("select pg_sleep(0.2)")
+        except PoolTimeout:
+            pass
+
+    for retry in retries:
+        with retry:
+            with NullConnectionPool(dsn, max_size=3) as p:
+                p.wait(2.0)
+
+                ts = [Thread(target=worker, args=(i,)) for i in range(7)]
+                for t in ts:
+                    t.start()
+                for t in ts:
+                    t.join()
+                stats = p.get_stats()
+                assert stats["requests_num"] == 7
+                assert stats["requests_queued"] == 4
+                assert 850 <= stats["requests_wait_ms"] <= 950
+                assert stats["requests_errors"] == 1
+                assert 1150 <= stats["usage_ms"] <= 1250
+                assert stats.get("returns_bad", 0) == 0
+
+                with p.connection() as conn:
+                    conn.close()
+                p.wait()
+                stats = p.pop_stats()
+                assert stats["requests_num"] == 8
+                assert stats["returns_bad"] == 1
+                with p.connection():
+                    pass
+                assert p.get_stats()["requests_num"] == 1
+
+
+@pytest.mark.slow
+def test_stats_connect(dsn, proxy, monkeypatch):
+    proxy.start()
+    delay_connection(monkeypatch, 0.2)
+    with NullConnectionPool(proxy.client_dsn, max_size=3) as p:
+        p.wait()
+        stats = p.get_stats()
+        assert stats["connections_num"] == 1
+        assert stats.get("connections_errors", 0) == 0
+        assert stats.get("connections_lost", 0) == 0
+        assert 200 <= stats["connections_ms"] < 300
diff --git a/tests/pool/test_null_pool_async.py b/tests/pool/test_null_pool_async.py
new file mode 100644 (file)
index 0000000..9175307
--- /dev/null
@@ -0,0 +1,864 @@
+import sys
+import asyncio
+import logging
+from time import time
+from typing import Any, List, Tuple
+
+import pytest
+from packaging.version import parse as ver  # noqa: F401  # used in skipif
+
+import psycopg
+from psycopg.pq import TransactionStatus
+from psycopg._compat import create_task
+from .test_pool_async import delay_connection
+
+pytestmark = [
+    pytest.mark.asyncio,
+    pytest.mark.skipif(
+        sys.version_info < (3, 7),
+        reason="async pool not supported before Python 3.7",
+    ),
+]
+
+try:
+    from psycopg_pool import AsyncNullConnectionPool  # noqa: F401
+    from psycopg_pool import PoolClosed, PoolTimeout, TooManyRequests
+except ImportError:
+    pass
+
+
+async def test_defaults(dsn):
+    async with AsyncNullConnectionPool(dsn) as p:
+        assert p.min_size == p.max_size == 0
+        assert p.timeout == 30
+        assert p.max_idle == 10 * 60
+        assert p.max_lifetime == 60 * 60
+        assert p.num_workers == 3
+
+
+async def test_min_size_max_size(dsn):
+    async with AsyncNullConnectionPool(dsn, min_size=0, max_size=2) as p:
+        assert p.min_size == 0
+        assert p.max_size == 2
+
+
+@pytest.mark.parametrize(
+    "min_size, max_size", [(1, None), (-1, None), (0, -2)]
+)
+async def test_bad_size(dsn, min_size, max_size):
+    with pytest.raises(ValueError):
+        AsyncNullConnectionPool(min_size=min_size, max_size=max_size)
+
+
+async def test_connection_class(dsn):
+    class MyConn(psycopg.AsyncConnection[Any]):
+        pass
+
+    async with AsyncNullConnectionPool(dsn, connection_class=MyConn) as p:
+        async with p.connection() as conn:
+            assert isinstance(conn, MyConn)
+
+
+async def test_kwargs(dsn):
+    async with AsyncNullConnectionPool(dsn, kwargs={"autocommit": True}) as p:
+        async with p.connection() as conn:
+            assert conn.autocommit
+
+
+async def test_its_no_pool_at_all(dsn):
+    async with AsyncNullConnectionPool(dsn, max_size=2) as p:
+        async with p.connection() as conn:
+            cur = await conn.execute("select pg_backend_pid()")
+            (pid1,) = await cur.fetchone()  # type: ignore[misc]
+
+            async with p.connection() as conn2:
+                cur = await conn2.execute("select pg_backend_pid()")
+                (pid2,) = await cur.fetchone()  # type: ignore[misc]
+
+        async with p.connection() as conn:
+            assert conn.pgconn.backend_pid not in (pid1, pid2)
+
+
+async def test_context(dsn):
+    async with AsyncNullConnectionPool(dsn) as p:
+        assert not p.closed
+    assert p.closed
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_wait_ready(dsn, monkeypatch):
+    delay_connection(monkeypatch, 0.2)
+    with pytest.raises(PoolTimeout):
+        async with AsyncNullConnectionPool(dsn, num_workers=1) as p:
+            await p.wait(0.1)
+
+    async with AsyncNullConnectionPool(dsn, num_workers=1) as p:
+        await p.wait(0.4)
+
+
+async def test_wait_closed(dsn):
+    async with AsyncNullConnectionPool(dsn) as p:
+        pass
+
+    with pytest.raises(PoolClosed):
+        await p.wait()
+
+
+@pytest.mark.slow
+async def test_setup_no_timeout(dsn, proxy):
+    with pytest.raises(PoolTimeout):
+        async with AsyncNullConnectionPool(
+            proxy.client_dsn, num_workers=1
+        ) as p:
+            await p.wait(0.2)
+
+    async with AsyncNullConnectionPool(proxy.client_dsn, num_workers=1) as p:
+        await asyncio.sleep(0.5)
+        assert not p._pool
+        proxy.start()
+
+        async with p.connection() as conn:
+            await conn.execute("select 1")
+
+
+async def test_configure(dsn):
+    inits = 0
+
+    async def configure(conn):
+        nonlocal inits
+        inits += 1
+        async with conn.transaction():
+            await conn.execute("set default_transaction_read_only to on")
+
+    async with AsyncNullConnectionPool(dsn, configure=configure) as p:
+        async with p.connection() as conn:
+            assert inits == 1
+            res = await conn.execute("show default_transaction_read_only")
+            assert (await res.fetchone())[0] == "on"  # type: ignore[index]
+
+        async with p.connection() as conn:
+            assert inits == 2
+            res = await conn.execute("show default_transaction_read_only")
+            assert (await res.fetchone())[0] == "on"  # type: ignore[index]
+            await conn.close()
+
+        async with p.connection() as conn:
+            assert inits == 3
+            res = await conn.execute("show default_transaction_read_only")
+            assert (await res.fetchone())[0] == "on"  # type: ignore[index]
+
+
+@pytest.mark.slow
+async def test_configure_badstate(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+    async def configure(conn):
+        await conn.execute("select 1")
+
+    async with AsyncNullConnectionPool(dsn, configure=configure) as p:
+        with pytest.raises(PoolTimeout):
+            await p.wait(timeout=0.5)
+
+    assert caplog.records
+    assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.slow
+async def test_configure_broken(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+    async def configure(conn):
+        async with conn.transaction():
+            await conn.execute("WAT")
+
+    async with AsyncNullConnectionPool(dsn, configure=configure) as p:
+        with pytest.raises(PoolTimeout):
+            await p.wait(timeout=0.5)
+
+    assert caplog.records
+    assert "WAT" in caplog.records[0].message
+
+
+async def test_reset(dsn):
+    resets = 0
+
+    async def setup(conn):
+        async with conn.transaction():
+            await conn.execute("set timezone to '+1:00'")
+
+    async def reset(conn):
+        nonlocal resets
+        resets += 1
+        async with conn.transaction():
+            await conn.execute("set timezone to utc")
+
+    pids = []
+
+    async def worker():
+        async with p.connection() as conn:
+            assert resets == 1
+            cur = await conn.execute("show timezone")
+            assert (await cur.fetchone()) == ("UTC",)
+            pids.append(conn.pgconn.backend_pid)
+
+    async with AsyncNullConnectionPool(dsn, max_size=1, reset=reset) as p:
+        async with p.connection() as conn:
+
+            # Queue the worker so it will take the same connection a second time
+            # instead of making a new one.
+            t = create_task(worker())
+
+            assert resets == 0
+            await conn.execute("set timezone to '+2:00'")
+            pids.append(conn.pgconn.backend_pid)
+
+        await asyncio.gather(t)
+        await p.wait()
+
+    assert resets == 1
+    assert pids[0] == pids[1]
+
+
+async def test_reset_badstate(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+    async def reset(conn):
+        await conn.execute("reset all")
+
+    pids = []
+
+    async def worker():
+        async with p.connection() as conn:
+            await conn.execute("select 1")
+            pids.append(conn.pgconn.backend_pid)
+
+    async with AsyncNullConnectionPool(dsn, max_size=1, reset=reset) as p:
+        async with p.connection() as conn:
+
+            t = create_task(worker())
+
+            await conn.execute("select 1")
+            pids.append(conn.pgconn.backend_pid)
+
+        await asyncio.gather(t)
+
+    assert pids[0] != pids[1]
+    assert caplog.records
+    assert "INTRANS" in caplog.records[0].message
+
+
+async def test_reset_broken(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+    async def reset(conn):
+        async with conn.transaction():
+            await conn.execute("WAT")
+
+    pids = []
+
+    async def worker():
+        async with p.connection() as conn:
+            await conn.execute("select 1")
+            pids.append(conn.pgconn.backend_pid)
+
+    async with AsyncNullConnectionPool(dsn, max_size=1, reset=reset) as p:
+        async with p.connection() as conn:
+
+            t = create_task(worker())
+
+            await conn.execute("select 1")
+            pids.append(conn.pgconn.backend_pid)
+
+        await asyncio.gather(t)
+
+    assert pids[0] != pids[1]
+    assert caplog.records
+    assert "WAT" in caplog.records[0].message
+
+
+@pytest.mark.slow
+@pytest.mark.skipif("ver(psycopg.__version__) < ver('3.0.8')")
+async def test_no_queue_timeout(deaf_port):
+    async with AsyncNullConnectionPool(
+        kwargs={"host": "localhost", "port": deaf_port}
+    ) as p:
+        with pytest.raises(PoolTimeout):
+            async with p.connection(timeout=1):
+                pass
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_queue(dsn, retries):
+    async def worker(n):
+        t0 = time()
+        async with p.connection() as conn:
+            cur = await conn.execute(
+                "select pg_backend_pid() from pg_sleep(0.2)"
+            )
+            (pid,) = await cur.fetchone()  # type: ignore[misc]
+        t1 = time()
+        results.append((n, t1 - t0, pid))
+
+    async for retry in retries:
+        with retry:
+            results: List[Tuple[int, float, int]] = []
+            async with AsyncNullConnectionPool(dsn, max_size=2) as p:
+                await p.wait()
+                ts = [create_task(worker(i)) for i in range(6)]
+                await asyncio.gather(*ts)
+
+            times = [item[1] for item in results]
+            want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6]
+            for got, want in zip(times, want_times):
+                assert got == pytest.approx(want, 0.2), times
+
+            assert len(set(r[2] for r in results)) == 2, results
+
+
+@pytest.mark.slow
+async def test_queue_size(dsn):
+    async def worker(t, ev=None):
+        try:
+            async with p.connection():
+                if ev:
+                    ev.set()
+                await asyncio.sleep(t)
+        except TooManyRequests as e:
+            errors.append(e)
+        else:
+            success.append(True)
+
+    errors: List[Exception] = []
+    success: List[bool] = []
+
+    async with AsyncNullConnectionPool(dsn, max_size=1, max_waiting=3) as p:
+        await p.wait()
+        ev = asyncio.Event()
+        create_task(worker(0.3, ev))
+        await ev.wait()
+
+        ts = [create_task(worker(0.1)) for i in range(4)]
+        await asyncio.gather(*ts)
+
+    assert len(success) == 4
+    assert len(errors) == 1
+    assert isinstance(errors[0], TooManyRequests)
+    assert p.name in str(errors[0])
+    assert str(p.max_waiting) in str(errors[0])
+    assert p.get_stats()["requests_errors"] == 1
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_queue_timeout(dsn, retries):
+    async def worker(n):
+        t0 = time()
+        try:
+            async with p.connection() as conn:
+                cur = await conn.execute(
+                    "select pg_backend_pid() from pg_sleep(0.2)"
+                )
+                (pid,) = await cur.fetchone()  # type: ignore[misc]
+        except PoolTimeout as e:
+            t1 = time()
+            errors.append((n, t1 - t0, e))
+        else:
+            t1 = time()
+            results.append((n, t1 - t0, pid))
+
+    async for retry in retries:
+        with retry:
+            results: List[Tuple[int, float, int]] = []
+            errors: List[Tuple[int, float, Exception]] = []
+
+            async with AsyncNullConnectionPool(
+                dsn, max_size=2, timeout=0.1
+            ) as p:
+                ts = [create_task(worker(i)) for i in range(4)]
+                await asyncio.gather(*ts)
+
+            assert len(results) == 2
+            assert len(errors) == 2
+            for e in errors:
+                assert 0.1 < e[1] < 0.15
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_dead_client(dsn):
+    async def worker(i, timeout):
+        try:
+            async with p.connection(timeout=timeout) as conn:
+                await conn.execute("select pg_sleep(0.3)")
+                results.append(i)
+        except PoolTimeout:
+            if timeout > 0.2:
+                raise
+
+    async with AsyncNullConnectionPool(dsn, max_size=2) as p:
+        results: List[int] = []
+        ts = [
+            create_task(worker(i, timeout))
+            for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4])
+        ]
+        await asyncio.gather(*ts)
+
+        await asyncio.sleep(0.2)
+        assert set(results) == set([0, 1, 3, 4])
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_queue_timeout_override(dsn, retries):
+    async def worker(n):
+        t0 = time()
+        timeout = 0.25 if n == 3 else None
+        try:
+            async with p.connection(timeout=timeout) as conn:
+                cur = await conn.execute(
+                    "select pg_backend_pid() from pg_sleep(0.2)"
+                )
+                (pid,) = await cur.fetchone()  # type: ignore[misc]
+        except PoolTimeout as e:
+            t1 = time()
+            errors.append((n, t1 - t0, e))
+        else:
+            t1 = time()
+            results.append((n, t1 - t0, pid))
+
+    async for retry in retries:
+        with retry:
+            results: List[Tuple[int, float, int]] = []
+            errors: List[Tuple[int, float, Exception]] = []
+
+            async with AsyncNullConnectionPool(
+                dsn, max_size=2, timeout=0.1
+            ) as p:
+                ts = [create_task(worker(i)) for i in range(4)]
+                await asyncio.gather(*ts)
+
+            assert len(results) == 3
+            assert len(errors) == 1
+            for e in errors:
+                assert 0.1 < e[1] < 0.15
+
+
+async def test_broken_reconnect(dsn):
+    async with AsyncNullConnectionPool(dsn, max_size=1) as p:
+        async with p.connection() as conn:
+            cur = await conn.execute("select pg_backend_pid()")
+            (pid1,) = await cur.fetchone()  # type: ignore[misc]
+            await conn.close()
+
+        async with p.connection() as conn2:
+            cur = await conn2.execute("select pg_backend_pid()")
+            (pid2,) = await cur.fetchone()  # type: ignore[misc]
+
+    assert pid1 != pid2
+
+
+async def test_intrans_rollback(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg.pool")
+    pids = []
+
+    async def worker():
+        async with p.connection() as conn:
+            pids.append(conn.pgconn.backend_pid)
+            assert conn.pgconn.transaction_status == TransactionStatus.IDLE
+            cur = await conn.execute(
+                "select 1 from pg_class where relname = 'test_intrans_rollback'"
+            )
+            assert not await cur.fetchone()
+
+    async with AsyncNullConnectionPool(dsn, max_size=1) as p:
+        conn = await p.getconn()
+
+        # Queue the worker so it will take the connection a second time instead
+        # of making a new one.
+        t = create_task(worker())
+
+        pids.append(conn.pgconn.backend_pid)
+        await conn.execute("create table test_intrans_rollback ()")
+        assert conn.pgconn.transaction_status == TransactionStatus.INTRANS
+        await p.putconn(conn)
+        await asyncio.gather(t)
+
+    assert pids[0] == pids[1]
+    assert len(caplog.records) == 1
+    assert "INTRANS" in caplog.records[0].message
+
+
+async def test_inerror_rollback(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg.pool")
+    pids = []
+
+    async def worker():
+        async with p.connection() as conn:
+            pids.append(conn.pgconn.backend_pid)
+            assert conn.pgconn.transaction_status == TransactionStatus.IDLE
+
+    async with AsyncNullConnectionPool(dsn, max_size=1) as p:
+        conn = await p.getconn()
+
+        t = create_task(worker())
+
+        pids.append(conn.pgconn.backend_pid)
+        with pytest.raises(psycopg.ProgrammingError):
+            await conn.execute("wat")
+        assert conn.pgconn.transaction_status == TransactionStatus.INERROR
+        await p.putconn(conn)
+        await asyncio.gather(t)
+
+    assert pids[0] == pids[1]
+    assert len(caplog.records) == 1
+    assert "INERROR" in caplog.records[0].message
+
+
+async def test_active_close(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg.pool")
+    pids = []
+
+    async def worker():
+        async with p.connection() as conn:
+            pids.append(conn.pgconn.backend_pid)
+            assert conn.pgconn.transaction_status == TransactionStatus.IDLE
+
+    async with AsyncNullConnectionPool(dsn, max_size=1) as p:
+        conn = await p.getconn()
+
+        t = create_task(worker())
+
+        pids.append(conn.pgconn.backend_pid)
+        cur = conn.cursor()
+        async with cur.copy(
+            "copy (select * from generate_series(1, 10)) to stdout"
+        ):
+            pass
+        assert conn.pgconn.transaction_status == TransactionStatus.ACTIVE
+        await p.putconn(conn)
+        await asyncio.gather(t)
+
+    assert pids[0] != pids[1]
+    assert len(caplog.records) == 2
+    assert "ACTIVE" in caplog.records[0].message
+    assert "BAD" in caplog.records[1].message
+
+
+async def test_fail_rollback_close(dsn, caplog, monkeypatch):
+    caplog.set_level(logging.WARNING, logger="psycopg.pool")
+    pids = []
+
+    async def worker():
+        async with p.connection() as conn:
+            pids.append(conn.pgconn.backend_pid)
+            assert conn.pgconn.transaction_status == TransactionStatus.IDLE
+
+    async with AsyncNullConnectionPool(dsn, max_size=1) as p:
+        conn = await p.getconn()
+        t = create_task(worker())
+
+        async def bad_rollback():
+            conn.pgconn.finish()
+            await orig_rollback()
+
+        # Make the rollback fail
+        orig_rollback = conn.rollback
+        monkeypatch.setattr(conn, "rollback", bad_rollback)
+
+        pids.append(conn.pgconn.backend_pid)
+        with pytest.raises(psycopg.ProgrammingError):
+            await conn.execute("wat")
+        assert conn.pgconn.transaction_status == TransactionStatus.INERROR
+        await p.putconn(conn)
+        await asyncio.gather(t)
+
+    assert pids[0] != pids[1]
+    assert len(caplog.records) == 3
+    assert "INERROR" in caplog.records[0].message
+    assert "OperationalError" in caplog.records[1].message
+    assert "BAD" in caplog.records[2].message
+
+
+async def test_close_no_tasks(dsn):
+    p = AsyncNullConnectionPool(dsn)
+    assert p._sched_runner and not p._sched_runner.done()
+    assert p._workers
+    workers = p._workers[:]
+    for t in workers:
+        assert not t.done()
+
+    await p.close()
+    assert p._sched_runner is None
+    assert not p._workers
+    for t in workers:
+        assert t.done()
+
+
+async def test_putconn_no_pool(dsn):
+    async with AsyncNullConnectionPool(dsn) as p:
+        conn = await psycopg.AsyncConnection.connect(dsn)
+        with pytest.raises(ValueError):
+            await p.putconn(conn)
+
+    await conn.close()
+
+
+async def test_putconn_wrong_pool(dsn):
+    async with AsyncNullConnectionPool(dsn) as p1:
+        async with AsyncNullConnectionPool(dsn) as p2:
+            conn = await p1.getconn()
+            with pytest.raises(ValueError):
+                await p2.putconn(conn)
+
+
+async def test_closed_getconn(dsn):
+    p = AsyncNullConnectionPool(dsn)
+    assert not p.closed
+    async with p.connection():
+        pass
+
+    await p.close()
+    assert p.closed
+
+    with pytest.raises(PoolClosed):
+        async with p.connection():
+            pass
+
+
+async def test_closed_putconn(dsn):
+    p = AsyncNullConnectionPool(dsn)
+
+    async with p.connection() as conn:
+        pass
+    assert conn.closed
+
+    async with p.connection() as conn:
+        await p.close()
+    assert conn.closed
+
+
+async def test_closed_queue(dsn):
+    async def w1():
+        async with p.connection() as conn:
+            e1.set()  # Tell w0 that w1 got a connection
+            cur = await conn.execute("select 1")
+            assert await cur.fetchone() == (1,)
+            await e2.wait()  # Wait until w0 has tested w2
+        success.append("w1")
+
+    async def w2():
+        try:
+            async with p.connection():
+                pass  # unexpected
+        except PoolClosed:
+            success.append("w2")
+
+    e1 = asyncio.Event()
+    e2 = asyncio.Event()
+
+    p = AsyncNullConnectionPool(dsn, max_size=1)
+    await p.wait()
+    success: List[str] = []
+
+    t1 = create_task(w1())
+    # Wait until w1 has received a connection
+    await e1.wait()
+
+    t2 = create_task(w2())
+    # Wait until w2 is in the queue
+    while not p._waiting:
+        await asyncio.sleep(0)
+
+    await p.close()
+
+    # Wait for the workers to finish
+    e2.set()
+    await asyncio.gather(t1, t2)
+    assert len(success) == 2
+
+
+async def test_open_explicit(dsn):
+    p = AsyncNullConnectionPool(dsn, open=False)
+    assert p.closed
+    with pytest.raises(PoolClosed):
+        await p.getconn()
+
+    with pytest.raises(PoolClosed, match="is not open yet"):
+        async with p.connection():
+            pass
+
+    await p.open()
+    try:
+        assert not p.closed
+
+        async with p.connection() as conn:
+            cur = await conn.execute("select 1")
+            assert await cur.fetchone() == (1,)
+
+    finally:
+        await p.close()
+
+    with pytest.raises(PoolClosed, match="is already closed"):
+        await p.getconn()
+
+
+async def test_open_context(dsn):
+    p = AsyncNullConnectionPool(dsn, open=False)
+    assert p.closed
+
+    async with p:
+        assert not p.closed
+
+        async with p.connection() as conn:
+            cur = await conn.execute("select 1")
+            assert await cur.fetchone() == (1,)
+
+    assert p.closed
+
+
+async def test_open_no_op(dsn):
+    p = AsyncNullConnectionPool(dsn)
+    try:
+        assert not p.closed
+        await p.open()
+        assert not p.closed
+
+        async with p.connection() as conn:
+            cur = await conn.execute("select 1")
+            assert await cur.fetchone() == (1,)
+
+    finally:
+        await p.close()
+
+
+async def test_reopen(dsn):
+    p = AsyncNullConnectionPool(dsn)
+    async with p.connection() as conn:
+        await conn.execute("select 1")
+    await p.close()
+    assert p._sched_runner is None
+
+    with pytest.raises(psycopg.OperationalError, match="cannot be reused"):
+        await p.open()
+
+
+@pytest.mark.parametrize(
+    "min_size, max_size", [(1, None), (-1, None), (0, -2)]
+)
+async def test_bad_resize(dsn, min_size, max_size):
+    async with AsyncNullConnectionPool() as p:
+        with pytest.raises(ValueError):
+            await p.resize(min_size=min_size, max_size=max_size)
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_max_lifetime(dsn):
+    pids: List[int] = []
+
+    async def worker():
+        async with p.connection() as conn:
+            pids.append(conn.pgconn.backend_pid)
+            await asyncio.sleep(0.1)
+
+    async with AsyncNullConnectionPool(dsn, max_size=1, max_lifetime=0.2) as p:
+        ts = [create_task(worker()) for i in range(5)]
+        await asyncio.gather(*ts)
+
+    assert pids[0] == pids[1] != pids[4], pids
+
+
+async def test_check(dsn):
+    # no.op
+    async with AsyncNullConnectionPool(dsn) as p:
+        await p.check()
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_stats_measures(dsn):
+    async def worker(n):
+        async with p.connection() as conn:
+            await conn.execute("select pg_sleep(0.2)")
+
+    async with AsyncNullConnectionPool(dsn, max_size=4) as p:
+        await p.wait(2.0)
+
+        stats = p.get_stats()
+        assert stats["pool_min"] == 0
+        assert stats["pool_max"] == 4
+        assert stats["pool_size"] == 0
+        assert stats["pool_available"] == 0
+        assert stats["requests_waiting"] == 0
+
+        ts = [create_task(worker(i)) for i in range(3)]
+        await asyncio.sleep(0.1)
+        stats = p.get_stats()
+        await asyncio.gather(*ts)
+        assert stats["pool_min"] == 0
+        assert stats["pool_max"] == 4
+        assert stats["pool_size"] == 3
+        assert stats["pool_available"] == 0
+        assert stats["requests_waiting"] == 0
+
+        await p.wait(2.0)
+        ts = [create_task(worker(i)) for i in range(7)]
+        await asyncio.sleep(0.1)
+        stats = p.get_stats()
+        await asyncio.gather(*ts)
+        assert stats["pool_min"] == 0
+        assert stats["pool_max"] == 4
+        assert stats["pool_size"] == 4
+        assert stats["pool_available"] == 0
+        assert stats["requests_waiting"] == 3
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_stats_usage(dsn, retries):
+    async def worker(n):
+        try:
+            async with p.connection(timeout=0.3) as conn:
+                await conn.execute("select pg_sleep(0.2)")
+        except PoolTimeout:
+            pass
+
+    async for retry in retries:
+        with retry:
+            async with AsyncNullConnectionPool(dsn, max_size=3) as p:
+                await p.wait(2.0)
+
+                ts = [create_task(worker(i)) for i in range(7)]
+                await asyncio.gather(*ts)
+                stats = p.get_stats()
+                assert stats["requests_num"] == 7
+                assert stats["requests_queued"] == 4
+                assert 850 <= stats["requests_wait_ms"] <= 950
+                assert stats["requests_errors"] == 1
+                assert 1150 <= stats["usage_ms"] <= 1250
+                assert stats.get("returns_bad", 0) == 0
+
+                async with p.connection() as conn:
+                    await conn.close()
+                await p.wait()
+                stats = p.pop_stats()
+                assert stats["requests_num"] == 8
+                assert stats["returns_bad"] == 1
+                async with p.connection():
+                    pass
+                assert p.get_stats()["requests_num"] == 1
+
+
+@pytest.mark.slow
+async def test_stats_connect(dsn, proxy, monkeypatch):
+    proxy.start()
+    delay_connection(monkeypatch, 0.2)
+    async with AsyncNullConnectionPool(proxy.client_dsn, max_size=3) as p:
+        await p.wait()
+        stats = p.get_stats()
+        assert stats["connections_num"] == 1
+        assert stats.get("connections_errors", 0) == 0
+        assert stats.get("connections_lost", 0) == 0
+        assert 200 <= stats["connections_ms"] < 300
index c24759a31d45e5fc451a10d96c84264de91e206e..17cac2bcb608c3582612254a561aff768e42c0e0 100644 (file)
@@ -488,7 +488,7 @@ def test_intrans_rollback(dsn, caplog):
         with p.connection() as conn2:
             assert conn2.pgconn.backend_pid == pid
             assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
-            assert not conn.execute(
+            assert not conn2.execute(
                 "select 1 from pg_class where relname = 'test_intrans_rollback'"
             ).fetchone()