from .pool import ConnectionPool
from .pool_async import AsyncConnectionPool
+from .null_pool import NullConnectionPool
+from .null_pool_async import AsyncNullConnectionPool
from .errors import PoolClosed, PoolTimeout, TooManyRequests
from .version import __version__ as __version__ # noqa: F401
__all__ = [
"AsyncConnectionPool",
+ "AsyncNullConnectionPool",
"ConnectionPool",
+ "NullConnectionPool",
"PoolClosed",
"PoolTimeout",
"TooManyRequests",
import sys
import asyncio
-from typing import Any, Awaitable, Generator, Optional, Union, TypeVar
+from typing import Any, Awaitable, Generator, Optional, Union, Type, TypeVar
+
+import psycopg.errors as e
T = TypeVar("T")
FutureT = Union["asyncio.Future[T]", Generator[Any, None, T], Awaitable[T]]
"Task",
"create_task",
]
+
+# Workaround for psycopg < 3.0.8.
+# Timeout on NullPool connection mignt not work correctly.
+try:
+ ConnectionTimeout: Type[e.OperationalError] = e.ConnectionTimeout
+except AttributeError:
+
+ class DummyConnectionTimeout(e.OperationalError):
+ pass
+
+ ConnectionTimeout = DummyConnectionTimeout
def _check_size(
self, min_size: int, max_size: Optional[int]
) -> Tuple[int, int]:
- if min_size < 0:
- raise ValueError("min_size cannot be negative")
-
if max_size is None:
max_size = min_size
+
+ if min_size < 0:
+ raise ValueError("min_size cannot be negative")
if max_size < min_size:
raise ValueError("max_size must be greater or equal than min_size")
if min_size == max_size == 0:
--- /dev/null
+"""
+Psycopg null connection pools
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+import logging
+import threading
+from time import monotonic
+from typing import Any, Optional, Tuple
+
+from psycopg import Connection
+from psycopg.pq import TransactionStatus
+
+from .pool import ConnectionPool, WaitingClient
+from .pool import AddConnection, ReturnConnection
+from .errors import PoolTimeout, TooManyRequests
+from ._compat import ConnectionTimeout
+
+logger = logging.getLogger("psycopg.pool")
+
+
+class _BaseNullConnectionPool:
+ def __init__(
+ self, conninfo: str = "", min_size: int = 0, *args: Any, **kwargs: Any
+ ):
+ super().__init__( # type: ignore[call-arg]
+ conninfo, *args, min_size=min_size, **kwargs
+ )
+
+ def _check_size(
+ self, min_size: int, max_size: Optional[int]
+ ) -> Tuple[int, int]:
+ if max_size is None:
+ max_size = min_size
+
+ if min_size != 0:
+ raise ValueError("null pools must have min_size = 0")
+ if max_size < min_size:
+ raise ValueError("max_size must be greater or equal than min_size")
+
+ return min_size, max_size
+
+ def _start_initial_tasks(self) -> None:
+ # Null pools don't have background tasks to fill connections
+ # or to grow/shrink.
+ return
+
+
+class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool):
+ def wait(self, timeout: float = 30.0) -> None:
+ """
+ Create a connection for test.
+
+ Calling this function will verify that the connectivity with the
+ database works as expected. However the connection will not be stored
+ in the pool.
+
+ Raise `PoolTimeout` if not ready within *timeout* sec.
+ """
+ self._check_open_getconn()
+
+ with self._lock:
+ assert not self._pool_full_event
+ self._pool_full_event = threading.Event()
+
+ logger.info("waiting for pool %r initialization", self.name)
+ self.run_task(AddConnection(self))
+ if not self._pool_full_event.wait(timeout):
+ self.close() # stop all the threads
+ raise PoolTimeout(
+ f"pool initialization incomplete after {timeout} sec"
+ )
+
+ with self._lock:
+ assert self._pool_full_event
+ self._pool_full_event = None
+
+ logger.info("pool %r is ready to use", self.name)
+
+ def getconn(self, timeout: Optional[float] = None) -> Connection[Any]:
+ logger.info("connection requested from %r", self.name)
+ self._stats[self._REQUESTS_NUM] += 1
+
+ # Critical section: decide here if there's a connection ready
+ # or if the client needs to wait.
+ with self._lock:
+ self._check_open_getconn()
+
+ pos: Optional[WaitingClient] = None
+ if self.max_size == 0 or self._nconns < self.max_size:
+ # Create a new connection for the client
+ try:
+ conn = self._connect(timeout=timeout)
+ except ConnectionTimeout as ex:
+ raise PoolTimeout(str(ex)) from None
+ self._nconns += 1
+ else:
+ if self.max_waiting and len(self._waiting) >= self.max_waiting:
+ self._stats[self._REQUESTS_ERRORS] += 1
+ raise TooManyRequests(
+ f"the pool {self.name!r} has aleady"
+ f" {len(self._waiting)} requests waiting"
+ )
+
+ # No connection available: put the client in the waiting queue
+ t0 = monotonic()
+ pos = WaitingClient()
+ self._waiting.append(pos)
+ self._stats[self._REQUESTS_QUEUED] += 1
+
+ # If we are in the waiting queue, wait to be assigned a connection
+ # (outside the critical section, so only the waiting client is locked)
+ if pos:
+ if timeout is None:
+ timeout = self.timeout
+ try:
+ conn = pos.wait(timeout=timeout)
+ except Exception:
+ self._stats[self._REQUESTS_ERRORS] += 1
+ raise
+ finally:
+ t1 = monotonic()
+ self._stats[self._REQUESTS_WAIT_MS] += int(1000.0 * (t1 - t0))
+
+ # Tell the connection it belongs to a pool to avoid closing on __exit__
+ conn._pool = self
+ logger.info("connection given by %r", self.name)
+ return conn
+
+ def putconn(self, conn: Connection[Any]) -> None:
+ # Quick check to discard the wrong connection
+ self._check_pool_putconn(conn)
+
+ logger.info("returning connection to %r", self.name)
+
+ # Close the connection if no client is waiting for it, or if the pool
+ # is closed. For extra refcare remove the pool reference from it.
+ # Maintain the stats.
+ with self._lock:
+ if self._closed or not self._waiting:
+ conn._pool = None
+ if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
+ self._stats[self._RETURNS_BAD] += 1
+ conn.close()
+ self._nconns -= 1
+ return
+
+ # Use a worker to perform eventual maintenance work in a separate thread
+ if self._reset:
+ self.run_task(ReturnConnection(self, conn))
+ else:
+ self._return_connection(conn)
+
+ def resize(self, min_size: int, max_size: Optional[int] = None) -> None:
+ min_size, max_size = self._check_size(min_size, max_size)
+
+ logger.info(
+ "resizing %r to min_size=%s max_size=%s",
+ self.name,
+ min_size,
+ max_size,
+ )
+ with self._lock:
+ self._min_size = min_size
+ self._max_size = max_size
+
+ def check(self) -> None:
+ """No-op, as the pool doesn't have connections in its state."""
+ pass
+
+ def _add_to_pool(self, conn: Connection[Any]) -> None:
+ # Remove the pool reference from the connection before returning it
+ # to the state, to avoid to create a reference loop.
+ # Also disable the warning for open connection in conn.__del__
+ conn._pool = None
+
+ # Critical section: if there is a client waiting give it the connection
+ # otherwise put it back into the pool.
+ with self._lock:
+ while self._waiting:
+ # If there is a client waiting (which is still waiting and
+ # hasn't timed out), give it the connection and notify it.
+ pos = self._waiting.popleft()
+ if pos.set(conn):
+ break
+ else:
+ # No client waiting for a connection: close the connection
+ conn.close()
+
+ # If we have been asked to wait for pool init, notify the
+ # waiter if the pool is ready.
+ if self._pool_full_event:
+ self._pool_full_event.set()
+ else:
+ # The connection created by wait shoudn't decrease the
+ # count of the number of connection used.
+ self._nconns -= 1
--- /dev/null
+"""
+psycopg asynchronous null connection pool
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+import asyncio
+import logging
+from time import monotonic
+from typing import Any, Optional
+
+from psycopg.pq import TransactionStatus
+from psycopg.connection_async import AsyncConnection
+
+from .errors import PoolTimeout, TooManyRequests
+from ._compat import ConnectionTimeout
+from .null_pool import _BaseNullConnectionPool
+from .pool_async import AsyncConnectionPool, AsyncClient
+from .pool_async import AddConnection, ReturnConnection
+
+logger = logging.getLogger("psycopg.pool")
+
+
+class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool):
+ async def wait(self, timeout: float = 30.0) -> None:
+ self._check_open_getconn()
+
+ async with self._lock:
+ assert not self._pool_full_event
+ self._pool_full_event = asyncio.Event()
+
+ logger.info("waiting for pool %r initialization", self.name)
+ self.run_task(AddConnection(self))
+ try:
+ await asyncio.wait_for(self._pool_full_event.wait(), timeout)
+ except asyncio.TimeoutError:
+ await self.close() # stop all the tasks
+ raise PoolTimeout(
+ f"pool initialization incomplete after {timeout} sec"
+ ) from None
+
+ async with self._lock:
+ assert self._pool_full_event
+ self._pool_full_event = None
+
+ logger.info("pool %r is ready to use", self.name)
+
+ async def getconn(
+ self, timeout: Optional[float] = None
+ ) -> AsyncConnection[Any]:
+ logger.info("connection requested from %r", self.name)
+ self._stats[self._REQUESTS_NUM] += 1
+
+ # Critical section: decide here if there's a connection ready
+ # or if the client needs to wait.
+ async with self._lock:
+ self._check_open_getconn()
+
+ pos: Optional[AsyncClient] = None
+ if self.max_size == 0 or self._nconns < self.max_size:
+ # Create a new connection for the client
+ try:
+ conn = await self._connect(timeout=timeout)
+ except ConnectionTimeout as ex:
+ raise PoolTimeout(str(ex)) from None
+ self._nconns += 1
+ else:
+ if self.max_waiting and len(self._waiting) >= self.max_waiting:
+ self._stats[self._REQUESTS_ERRORS] += 1
+ raise TooManyRequests(
+ f"the pool {self.name!r} has aleady"
+ f" {len(self._waiting)} requests waiting"
+ )
+
+ # No connection available: put the client in the waiting queue
+ t0 = monotonic()
+ pos = AsyncClient()
+ self._waiting.append(pos)
+ self._stats[self._REQUESTS_QUEUED] += 1
+
+ # If we are in the waiting queue, wait to be assigned a connection
+ # (outside the critical section, so only the waiting client is locked)
+ if pos:
+ if timeout is None:
+ timeout = self.timeout
+ try:
+ conn = await pos.wait(timeout=timeout)
+ except Exception:
+ self._stats[self._REQUESTS_ERRORS] += 1
+ raise
+ finally:
+ t1 = monotonic()
+ self._stats[self._REQUESTS_WAIT_MS] += int(1000.0 * (t1 - t0))
+
+ # Tell the connection it belongs to a pool to avoid closing on __exit__
+ conn._pool = self
+ logger.info("connection given by %r", self.name)
+ return conn
+
+ async def putconn(self, conn: AsyncConnection[Any]) -> None:
+ # Quick check to discard the wrong connection
+ self._check_pool_putconn(conn)
+
+ logger.info("returning connection to %r", self.name)
+
+ # Close the connection if no client is waiting for it, or if the pool
+ # is closed. For extra refcare remove the pool reference from it.
+ # Maintain the stats.
+ async with self._lock:
+ if self._closed or not self._waiting:
+ conn._pool = None
+ if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
+ self._stats[self._RETURNS_BAD] += 1
+ await conn.close()
+ self._nconns -= 1
+ return
+
+ # Use a worker to perform eventual maintenance work in a separate task
+ if self._reset:
+ self.run_task(ReturnConnection(self, conn))
+ else:
+ await self._return_connection(conn)
+
+ async def resize(
+ self, min_size: int, max_size: Optional[int] = None
+ ) -> None:
+ min_size, max_size = self._check_size(min_size, max_size)
+
+ logger.info(
+ "resizing %r to min_size=%s max_size=%s",
+ self.name,
+ min_size,
+ max_size,
+ )
+ async with self._lock:
+ self._min_size = min_size
+ self._max_size = max_size
+
+ async def check(self) -> None:
+ pass
+
+ async def _add_to_pool(self, conn: AsyncConnection[Any]) -> None:
+ # Remove the pool reference from the connection before returning it
+ # to the state, to avoid to create a reference loop.
+ # Also disable the warning for open connection in conn.__del__
+ conn._pool = None
+
+ # Critical section: if there is a client waiting give it the connection
+ # otherwise put it back into the pool.
+ async with self._lock:
+ while self._waiting:
+ # If there is a client waiting (which is still waiting and
+ # hasn't timed out), give it the connection and notify it.
+ pos = self._waiting.popleft()
+ if await pos.set(conn):
+ break
+ else:
+ # No client waiting for a connection: close the connection
+ await conn.close()
+
+ # If we have been asked to wait for pool init, notify the
+ # waiter if the pool is ready.
+ if self._pool_full_event:
+ self._pool_full_event.set()
+ else:
+ # The connection created by wait shoudn't decrease the
+ # count of the number of connection used.
+ self._nconns -= 1
ex,
)
- def _connect(self) -> Connection[Any]:
+ def _connect(self, timeout: Optional[float] = None) -> Connection[Any]:
"""Return a new connection configured for the pool."""
self._stats[self._CONNECTIONS_NUM] += 1
+ kwargs = self.kwargs
+ if timeout:
+ kwargs = kwargs.copy()
+ kwargs["connect_timeout"] = max(round(timeout), 1)
t0 = monotonic()
try:
conn: Connection[Any]
- conn = self.connection_class.connect(self.conninfo, **self.kwargs)
+ conn = self.connection_class.connect(self.conninfo, **kwargs)
except Exception:
self._stats[self._CONNECTIONS_ERRORS] += 1
raise
ex,
)
- async def _connect(self) -> AsyncConnection[Any]:
- """Return a new connection configured for the pool."""
+ async def _connect(
+ self, timeout: Optional[float] = None
+ ) -> AsyncConnection[Any]:
self._stats[self._CONNECTIONS_NUM] += 1
+ kwargs = self.kwargs
+ if timeout:
+ kwargs = kwargs.copy()
+ kwargs["connect_timeout"] = max(round(timeout), 1)
t0 = monotonic()
try:
conn: AsyncConnection[Any]
- conn = await self.connection_class.connect(
- self.conninfo, **self.kwargs
- )
+ conn = await self.connection_class.connect(self.conninfo, **kwargs)
except Exception:
self._stats[self._CONNECTIONS_ERRORS] += 1
raise
--- /dev/null
+import logging
+from time import sleep, time
+from threading import Thread, Event
+from typing import Any, List, Tuple
+
+import pytest
+from packaging.version import parse as ver # noqa: F401 # used in skipif
+
+import psycopg
+from psycopg.pq import TransactionStatus
+
+from .test_pool import delay_connection
+
+try:
+ from psycopg_pool import NullConnectionPool
+ from psycopg_pool import PoolClosed, PoolTimeout, TooManyRequests
+except ImportError:
+ pass
+
+
+def test_defaults(dsn):
+ with NullConnectionPool(dsn) as p:
+ assert p.min_size == p.max_size == 0
+ assert p.timeout == 30
+ assert p.max_idle == 10 * 60
+ assert p.max_lifetime == 60 * 60
+ assert p.num_workers == 3
+
+
+def test_min_size_max_size(dsn):
+ with NullConnectionPool(dsn, min_size=0, max_size=2) as p:
+ assert p.min_size == 0
+ assert p.max_size == 2
+
+
+@pytest.mark.parametrize(
+ "min_size, max_size", [(1, None), (-1, None), (0, -2)]
+)
+def test_bad_size(dsn, min_size, max_size):
+ with pytest.raises(ValueError):
+ NullConnectionPool(min_size=min_size, max_size=max_size)
+
+
+def test_connection_class(dsn):
+ class MyConn(psycopg.Connection[Any]):
+ pass
+
+ with NullConnectionPool(dsn, connection_class=MyConn) as p:
+ with p.connection() as conn:
+ assert isinstance(conn, MyConn)
+
+
+def test_kwargs(dsn):
+ with NullConnectionPool(dsn, kwargs={"autocommit": True}) as p:
+ with p.connection() as conn:
+ assert conn.autocommit
+
+
+def test_its_no_pool_at_all(dsn):
+ with NullConnectionPool(dsn, max_size=2) as p:
+ with p.connection() as conn:
+ with conn.execute("select pg_backend_pid()") as cur:
+ (pid1,) = cur.fetchone() # type: ignore[misc]
+
+ with p.connection() as conn2:
+ with conn2.execute("select pg_backend_pid()") as cur:
+ (pid2,) = cur.fetchone() # type: ignore[misc]
+
+ with p.connection() as conn:
+ assert conn.pgconn.backend_pid not in (pid1, pid2)
+
+
+def test_context(dsn):
+ with NullConnectionPool(dsn) as p:
+ assert not p.closed
+ assert p.closed
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_wait_ready(dsn, monkeypatch):
+ delay_connection(monkeypatch, 0.2)
+ with pytest.raises(PoolTimeout):
+ with NullConnectionPool(dsn, num_workers=1) as p:
+ p.wait(0.1)
+
+ with NullConnectionPool(dsn, num_workers=1) as p:
+ p.wait(0.4)
+
+
+def test_wait_closed(dsn):
+ with NullConnectionPool(dsn) as p:
+ pass
+
+ with pytest.raises(PoolClosed):
+ p.wait()
+
+
+@pytest.mark.slow
+def test_setup_no_timeout(dsn, proxy):
+ with pytest.raises(PoolTimeout):
+ with NullConnectionPool(proxy.client_dsn, num_workers=1) as p:
+ p.wait(0.2)
+
+ with NullConnectionPool(proxy.client_dsn, num_workers=1) as p:
+ sleep(0.5)
+ assert not p._pool
+ proxy.start()
+
+ with p.connection() as conn:
+ conn.execute("select 1")
+
+
+def test_configure(dsn):
+ inits = 0
+
+ def configure(conn):
+ nonlocal inits
+ inits += 1
+ with conn.transaction():
+ conn.execute("set default_transaction_read_only to on")
+
+ with NullConnectionPool(dsn, configure=configure) as p:
+ with p.connection() as conn:
+ assert inits == 1
+ res = conn.execute("show default_transaction_read_only")
+ assert res.fetchone()[0] == "on" # type: ignore[index]
+
+ with p.connection() as conn:
+ assert inits == 2
+ res = conn.execute("show default_transaction_read_only")
+ assert res.fetchone()[0] == "on" # type: ignore[index]
+ conn.close()
+
+ with p.connection() as conn:
+ assert inits == 3
+ res = conn.execute("show default_transaction_read_only")
+ assert res.fetchone()[0] == "on" # type: ignore[index]
+
+
+@pytest.mark.slow
+def test_configure_badstate(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ def configure(conn):
+ conn.execute("select 1")
+
+ with NullConnectionPool(dsn, configure=configure) as p:
+ with pytest.raises(PoolTimeout):
+ p.wait(timeout=0.5)
+
+ assert caplog.records
+ assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.slow
+def test_configure_broken(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ def configure(conn):
+ with conn.transaction():
+ conn.execute("WAT")
+
+ with NullConnectionPool(dsn, configure=configure) as p:
+ with pytest.raises(PoolTimeout):
+ p.wait(timeout=0.5)
+
+ assert caplog.records
+ assert "WAT" in caplog.records[0].message
+
+
+def test_reset(dsn):
+ resets = 0
+
+ def setup(conn):
+ with conn.transaction():
+ conn.execute("set timezone to '+1:00'")
+
+ def reset(conn):
+ nonlocal resets
+ resets += 1
+ with conn.transaction():
+ conn.execute("set timezone to utc")
+
+ pids = []
+
+ def worker():
+ with p.connection() as conn:
+ assert resets == 1
+ with conn.execute("show timezone") as cur:
+ assert cur.fetchone() == ("UTC",)
+ pids.append(conn.pgconn.backend_pid)
+
+ with NullConnectionPool(dsn, max_size=1, reset=reset) as p:
+ with p.connection() as conn:
+
+ # Queue the worker so it will take the same connection a second time
+ # instead of making a new one.
+ t = Thread(target=worker)
+ t.start()
+
+ assert resets == 0
+ conn.execute("set timezone to '+2:00'")
+ pids.append(conn.pgconn.backend_pid)
+
+ t.join()
+ p.wait()
+
+ assert resets == 1
+ assert pids[0] == pids[1]
+
+
+def test_reset_badstate(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ def reset(conn):
+ conn.execute("reset all")
+
+ pids = []
+
+ def worker():
+ with p.connection() as conn:
+ conn.execute("select 1")
+ pids.append(conn.pgconn.backend_pid)
+
+ with NullConnectionPool(dsn, max_size=1, reset=reset) as p:
+ with p.connection() as conn:
+
+ t = Thread(target=worker)
+ t.start()
+
+ conn.execute("select 1")
+ pids.append(conn.pgconn.backend_pid)
+
+ t.join()
+
+ assert pids[0] != pids[1]
+ assert caplog.records
+ assert "INTRANS" in caplog.records[0].message
+
+
+def test_reset_broken(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ def reset(conn):
+ with conn.transaction():
+ conn.execute("WAT")
+
+ pids = []
+
+ def worker():
+ with p.connection() as conn:
+ conn.execute("select 1")
+ pids.append(conn.pgconn.backend_pid)
+
+ with NullConnectionPool(dsn, max_size=1, reset=reset) as p:
+ with p.connection() as conn:
+
+ t = Thread(target=worker)
+ t.start()
+
+ conn.execute("select 1")
+ pids.append(conn.pgconn.backend_pid)
+
+ t.join()
+
+ assert pids[0] != pids[1]
+ assert caplog.records
+ assert "WAT" in caplog.records[0].message
+
+
+@pytest.mark.slow
+@pytest.mark.skipif("ver(psycopg.__version__) < ver('3.0.8')")
+def test_no_queue_timeout(deaf_port):
+ with NullConnectionPool(
+ kwargs={"host": "localhost", "port": deaf_port}
+ ) as p:
+ with pytest.raises(PoolTimeout):
+ with p.connection(timeout=1):
+ pass
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_queue(dsn, retries):
+ def worker(n):
+ t0 = time()
+ with p.connection() as conn:
+ (pid,) = conn.execute(
+ "select pg_backend_pid() from pg_sleep(0.2)"
+ ).fetchone() # type: ignore[misc]
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ for retry in retries:
+ with retry:
+ results: List[Tuple[int, float, int]] = []
+ with NullConnectionPool(dsn, max_size=2) as p:
+ p.wait()
+ ts = [Thread(target=worker, args=(i,)) for i in range(6)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+
+ times = [item[1] for item in results]
+ want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6]
+ for got, want in zip(times, want_times):
+ assert got == pytest.approx(want, 0.2), times
+
+ assert len(set(r[2] for r in results)) == 2, results
+
+
+@pytest.mark.slow
+def test_queue_size(dsn):
+ def worker(t, ev=None):
+ try:
+ with p.connection():
+ if ev:
+ ev.set()
+ sleep(t)
+ except TooManyRequests as e:
+ errors.append(e)
+ else:
+ success.append(True)
+
+ errors: List[Exception] = []
+ success: List[bool] = []
+
+ with NullConnectionPool(dsn, max_size=1, max_waiting=3) as p:
+ p.wait()
+ ev = Event()
+ t = Thread(target=worker, args=(0.3, ev))
+ t.start()
+ ev.wait()
+
+ ts = [Thread(target=worker, args=(0.1,)) for i in range(4)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+
+ assert len(success) == 4
+ assert len(errors) == 1
+ assert isinstance(errors[0], TooManyRequests)
+ assert p.name in str(errors[0])
+ assert str(p.max_waiting) in str(errors[0])
+ assert p.get_stats()["requests_errors"] == 1
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_queue_timeout(dsn, retries):
+ def worker(n):
+ t0 = time()
+ try:
+ with p.connection() as conn:
+ (pid,) = conn.execute( # type: ignore[misc]
+ "select pg_backend_pid() from pg_sleep(0.2)"
+ ).fetchone()
+ except PoolTimeout as e:
+ t1 = time()
+ errors.append((n, t1 - t0, e))
+ else:
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ for retry in retries:
+ with retry:
+ results: List[Tuple[int, float, int]] = []
+ errors: List[Tuple[int, float, Exception]] = []
+
+ with NullConnectionPool(dsn, max_size=2, timeout=0.1) as p:
+ ts = [Thread(target=worker, args=(i,)) for i in range(4)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+
+ assert len(results) == 2
+ assert len(errors) == 2
+ for e in errors:
+ assert 0.1 < e[1] < 0.15
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_dead_client(dsn):
+ def worker(i, timeout):
+ try:
+ with p.connection(timeout=timeout) as conn:
+ conn.execute("select pg_sleep(0.3)")
+ results.append(i)
+ except PoolTimeout:
+ if timeout > 0.2:
+ raise
+
+ results: List[int] = []
+
+ with NullConnectionPool(dsn, max_size=2) as p:
+ ts = [
+ Thread(target=worker, args=(i, timeout))
+ for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4])
+ ]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+ sleep(0.2)
+ assert set(results) == set([0, 1, 3, 4])
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_queue_timeout_override(dsn, retries):
+ def worker(n):
+ t0 = time()
+ timeout = 0.25 if n == 3 else None
+ try:
+ with p.connection(timeout=timeout) as conn:
+ (pid,) = conn.execute( # type: ignore[misc]
+ "select pg_backend_pid() from pg_sleep(0.2)"
+ ).fetchone()
+ except PoolTimeout as e:
+ t1 = time()
+ errors.append((n, t1 - t0, e))
+ else:
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ for retry in retries:
+ with retry:
+ results: List[Tuple[int, float, int]] = []
+ errors: List[Tuple[int, float, Exception]] = []
+
+ with NullConnectionPool(dsn, max_size=2, timeout=0.1) as p:
+ ts = [Thread(target=worker, args=(i,)) for i in range(4)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+
+ assert len(results) == 3
+ assert len(errors) == 1
+ for e in errors:
+ assert 0.1 < e[1] < 0.15
+
+
+def test_broken_reconnect(dsn):
+ with NullConnectionPool(dsn, max_size=1) as p:
+ with p.connection() as conn:
+ with conn.execute("select pg_backend_pid()") as cur:
+ (pid1,) = cur.fetchone() # type: ignore[misc]
+ conn.close()
+
+ with p.connection() as conn2:
+ with conn2.execute("select pg_backend_pid()") as cur:
+ (pid2,) = cur.fetchone() # type: ignore[misc]
+
+ assert pid1 != pid2
+
+
+def test_intrans_rollback(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ pids = []
+
+ def worker():
+ with p.connection() as conn:
+ pids.append(conn.pgconn.backend_pid)
+ assert conn.pgconn.transaction_status == TransactionStatus.IDLE
+ assert not conn.execute(
+ "select 1 from pg_class where relname = 'test_intrans_rollback'"
+ ).fetchone()
+
+ with NullConnectionPool(dsn, max_size=1) as p:
+ conn = p.getconn()
+
+ # Queue the worker so it will take the connection a second time instead
+ # of making a new one.
+ t = Thread(target=worker)
+ t.start()
+
+ pids.append(conn.pgconn.backend_pid)
+ conn.execute("create table test_intrans_rollback ()")
+ assert conn.pgconn.transaction_status == TransactionStatus.INTRANS
+ p.putconn(conn)
+ t.join()
+
+ assert pids[0] == pids[1]
+ assert len(caplog.records) == 1
+ assert "INTRANS" in caplog.records[0].message
+
+
+def test_inerror_rollback(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ pids = []
+
+ def worker():
+ with p.connection() as conn:
+ pids.append(conn.pgconn.backend_pid)
+ assert conn.pgconn.transaction_status == TransactionStatus.IDLE
+
+ with NullConnectionPool(dsn, max_size=1) as p:
+ conn = p.getconn()
+
+ # Queue the worker so it will take the connection a second time instead
+ # of making a new one.
+ t = Thread(target=worker)
+ t.start()
+
+ pids.append(conn.pgconn.backend_pid)
+ with pytest.raises(psycopg.ProgrammingError):
+ conn.execute("wat")
+ assert conn.pgconn.transaction_status == TransactionStatus.INERROR
+ p.putconn(conn)
+ t.join()
+
+ assert pids[0] == pids[1]
+ assert len(caplog.records) == 1
+ assert "INERROR" in caplog.records[0].message
+
+
+def test_active_close(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ pids = []
+
+ def worker():
+ with p.connection() as conn:
+ pids.append(conn.pgconn.backend_pid)
+ assert conn.pgconn.transaction_status == TransactionStatus.IDLE
+
+ with NullConnectionPool(dsn, max_size=1) as p:
+ conn = p.getconn()
+
+ t = Thread(target=worker)
+ t.start()
+
+ pids.append(conn.pgconn.backend_pid)
+ cur = conn.cursor()
+ with cur.copy("copy (select * from generate_series(1, 10)) to stdout"):
+ pass
+ assert conn.pgconn.transaction_status == TransactionStatus.ACTIVE
+ p.putconn(conn)
+ t.join()
+
+ assert pids[0] != pids[1]
+ assert len(caplog.records) == 2
+ assert "ACTIVE" in caplog.records[0].message
+ assert "BAD" in caplog.records[1].message
+
+
+def test_fail_rollback_close(dsn, caplog, monkeypatch):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ pids = []
+
+ def worker(p):
+ with p.connection() as conn:
+ pids.append(conn.pgconn.backend_pid)
+ assert conn.pgconn.transaction_status == TransactionStatus.IDLE
+
+ with NullConnectionPool(dsn, max_size=1) as p:
+ conn = p.getconn()
+
+ def bad_rollback():
+ conn.pgconn.finish()
+ orig_rollback()
+
+ # Make the rollback fail
+ orig_rollback = conn.rollback
+ monkeypatch.setattr(conn, "rollback", bad_rollback)
+
+ t = Thread(target=worker, args=(p,))
+ t.start()
+
+ pids.append(conn.pgconn.backend_pid)
+ with pytest.raises(psycopg.ProgrammingError):
+ conn.execute("wat")
+ assert conn.pgconn.transaction_status == TransactionStatus.INERROR
+ p.putconn(conn)
+ t.join()
+
+ assert pids[0] != pids[1]
+ assert len(caplog.records) == 3
+ assert "INERROR" in caplog.records[0].message
+ assert "OperationalError" in caplog.records[1].message
+ assert "BAD" in caplog.records[2].message
+
+
+def test_close_no_threads(dsn):
+ p = NullConnectionPool(dsn)
+ assert p._sched_runner and p._sched_runner.is_alive()
+ workers = p._workers[:]
+ assert workers
+ for t in workers:
+ assert t.is_alive()
+
+ p.close()
+ assert p._sched_runner is None
+ assert not p._workers
+ for t in workers:
+ assert not t.is_alive()
+
+
+def test_putconn_no_pool(dsn):
+ with NullConnectionPool(dsn) as p:
+ conn = psycopg.connect(dsn)
+ with pytest.raises(ValueError):
+ p.putconn(conn)
+
+ conn.close()
+
+
+def test_putconn_wrong_pool(dsn):
+ with NullConnectionPool(dsn) as p1:
+ with NullConnectionPool(dsn) as p2:
+ conn = p1.getconn()
+ with pytest.raises(ValueError):
+ p2.putconn(conn)
+
+
+@pytest.mark.slow
+def test_del_stop_threads(dsn):
+ p = NullConnectionPool(dsn)
+ assert p._sched_runner is not None
+ ts = [p._sched_runner] + p._workers
+ del p
+ sleep(0.1)
+ for t in ts:
+ assert not t.is_alive()
+
+
+def test_closed_getconn(dsn):
+ p = NullConnectionPool(dsn)
+ assert not p.closed
+ with p.connection():
+ pass
+
+ p.close()
+ assert p.closed
+
+ with pytest.raises(PoolClosed):
+ with p.connection():
+ pass
+
+
+def test_closed_putconn(dsn):
+ p = NullConnectionPool(dsn)
+
+ with p.connection() as conn:
+ pass
+ assert conn.closed
+
+ with p.connection() as conn:
+ p.close()
+ assert conn.closed
+
+
+def test_closed_queue(dsn):
+ def w1():
+ with p.connection() as conn:
+ e1.set() # Tell w0 that w1 got a connection
+ cur = conn.execute("select 1")
+ assert cur.fetchone() == (1,)
+ e2.wait() # Wait until w0 has tested w2
+ success.append("w1")
+
+ def w2():
+ try:
+ with p.connection():
+ pass # unexpected
+ except PoolClosed:
+ success.append("w2")
+
+ e1 = Event()
+ e2 = Event()
+
+ p = NullConnectionPool(dsn, max_size=1)
+ p.wait()
+ success: List[str] = []
+
+ t1 = Thread(target=w1)
+ t1.start()
+ # Wait until w1 has received a connection
+ e1.wait()
+
+ t2 = Thread(target=w2)
+ t2.start()
+ # Wait until w2 is in the queue
+ while not p._waiting:
+ sleep(0)
+
+ p.close(0)
+
+ # Wait for the workers to finish
+ e2.set()
+ t1.join()
+ t2.join()
+ assert len(success) == 2
+
+
+def test_open_explicit(dsn):
+ p = NullConnectionPool(dsn, open=False)
+ assert p.closed
+ with pytest.raises(PoolClosed, match="is not open yet"):
+ p.getconn()
+
+ with pytest.raises(PoolClosed):
+ with p.connection():
+ pass
+
+ p.open()
+ try:
+ assert not p.closed
+
+ with p.connection() as conn:
+ cur = conn.execute("select 1")
+ assert cur.fetchone() == (1,)
+
+ finally:
+ p.close()
+
+ with pytest.raises(PoolClosed, match="is already closed"):
+ p.getconn()
+
+
+def test_open_context(dsn):
+ p = NullConnectionPool(dsn, open=False)
+ assert p.closed
+
+ with p:
+ assert not p.closed
+
+ with p.connection() as conn:
+ cur = conn.execute("select 1")
+ assert cur.fetchone() == (1,)
+
+ assert p.closed
+
+
+def test_open_no_op(dsn):
+ p = NullConnectionPool(dsn)
+ try:
+ assert not p.closed
+ p.open()
+ assert not p.closed
+
+ with p.connection() as conn:
+ cur = conn.execute("select 1")
+ assert cur.fetchone() == (1,)
+
+ finally:
+ p.close()
+
+
+def test_reopen(dsn):
+ p = NullConnectionPool(dsn)
+ with p.connection() as conn:
+ conn.execute("select 1")
+ p.close()
+ assert p._sched_runner is None
+ assert not p._workers
+
+ with pytest.raises(psycopg.OperationalError, match="cannot be reused"):
+ p.open()
+
+
+@pytest.mark.parametrize(
+ "min_size, max_size", [(1, None), (-1, None), (0, -2)]
+)
+def test_bad_resize(dsn, min_size, max_size):
+ with NullConnectionPool() as p:
+ with pytest.raises(ValueError):
+ p.resize(min_size=min_size, max_size=max_size)
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_max_lifetime(dsn):
+ pids = []
+
+ def worker(p):
+ with p.connection() as conn:
+ pids.append(conn.pgconn.backend_pid)
+ sleep(0.1)
+
+ ts = []
+ with NullConnectionPool(dsn, max_size=1, max_lifetime=0.2) as p:
+ for i in range(5):
+ ts.append(Thread(target=worker, args=(p,)))
+ ts[-1].start()
+
+ for t in ts:
+ t.join()
+
+ assert pids[0] == pids[1] != pids[4], pids
+
+
+def test_check(dsn):
+ with NullConnectionPool(dsn) as p:
+ # No-op
+ p.check()
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_stats_measures(dsn):
+ def worker(n):
+ with p.connection() as conn:
+ conn.execute("select pg_sleep(0.2)")
+
+ with NullConnectionPool(dsn, max_size=4) as p:
+ p.wait(2.0)
+
+ stats = p.get_stats()
+ assert stats["pool_min"] == 0
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 0
+ assert stats["pool_available"] == 0
+ assert stats["requests_waiting"] == 0
+
+ ts = [Thread(target=worker, args=(i,)) for i in range(3)]
+ for t in ts:
+ t.start()
+ sleep(0.1)
+ stats = p.get_stats()
+ for t in ts:
+ t.join()
+ assert stats["pool_min"] == 0
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 3
+ assert stats["pool_available"] == 0
+ assert stats["requests_waiting"] == 0
+
+ p.wait(2.0)
+ ts = [Thread(target=worker, args=(i,)) for i in range(7)]
+ for t in ts:
+ t.start()
+ sleep(0.1)
+ stats = p.get_stats()
+ for t in ts:
+ t.join()
+ assert stats["pool_min"] == 0
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 4
+ assert stats["pool_available"] == 0
+ assert stats["requests_waiting"] == 3
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_stats_usage(dsn, retries):
+ def worker(n):
+ try:
+ with p.connection(timeout=0.3) as conn:
+ conn.execute("select pg_sleep(0.2)")
+ except PoolTimeout:
+ pass
+
+ for retry in retries:
+ with retry:
+ with NullConnectionPool(dsn, max_size=3) as p:
+ p.wait(2.0)
+
+ ts = [Thread(target=worker, args=(i,)) for i in range(7)]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+ stats = p.get_stats()
+ assert stats["requests_num"] == 7
+ assert stats["requests_queued"] == 4
+ assert 850 <= stats["requests_wait_ms"] <= 950
+ assert stats["requests_errors"] == 1
+ assert 1150 <= stats["usage_ms"] <= 1250
+ assert stats.get("returns_bad", 0) == 0
+
+ with p.connection() as conn:
+ conn.close()
+ p.wait()
+ stats = p.pop_stats()
+ assert stats["requests_num"] == 8
+ assert stats["returns_bad"] == 1
+ with p.connection():
+ pass
+ assert p.get_stats()["requests_num"] == 1
+
+
+@pytest.mark.slow
+def test_stats_connect(dsn, proxy, monkeypatch):
+ proxy.start()
+ delay_connection(monkeypatch, 0.2)
+ with NullConnectionPool(proxy.client_dsn, max_size=3) as p:
+ p.wait()
+ stats = p.get_stats()
+ assert stats["connections_num"] == 1
+ assert stats.get("connections_errors", 0) == 0
+ assert stats.get("connections_lost", 0) == 0
+ assert 200 <= stats["connections_ms"] < 300
--- /dev/null
+import sys
+import asyncio
+import logging
+from time import time
+from typing import Any, List, Tuple
+
+import pytest
+from packaging.version import parse as ver # noqa: F401 # used in skipif
+
+import psycopg
+from psycopg.pq import TransactionStatus
+from psycopg._compat import create_task
+from .test_pool_async import delay_connection
+
+pytestmark = [
+ pytest.mark.asyncio,
+ pytest.mark.skipif(
+ sys.version_info < (3, 7),
+ reason="async pool not supported before Python 3.7",
+ ),
+]
+
+try:
+ from psycopg_pool import AsyncNullConnectionPool # noqa: F401
+ from psycopg_pool import PoolClosed, PoolTimeout, TooManyRequests
+except ImportError:
+ pass
+
+
+async def test_defaults(dsn):
+ async with AsyncNullConnectionPool(dsn) as p:
+ assert p.min_size == p.max_size == 0
+ assert p.timeout == 30
+ assert p.max_idle == 10 * 60
+ assert p.max_lifetime == 60 * 60
+ assert p.num_workers == 3
+
+
+async def test_min_size_max_size(dsn):
+ async with AsyncNullConnectionPool(dsn, min_size=0, max_size=2) as p:
+ assert p.min_size == 0
+ assert p.max_size == 2
+
+
+@pytest.mark.parametrize(
+ "min_size, max_size", [(1, None), (-1, None), (0, -2)]
+)
+async def test_bad_size(dsn, min_size, max_size):
+ with pytest.raises(ValueError):
+ AsyncNullConnectionPool(min_size=min_size, max_size=max_size)
+
+
+async def test_connection_class(dsn):
+ class MyConn(psycopg.AsyncConnection[Any]):
+ pass
+
+ async with AsyncNullConnectionPool(dsn, connection_class=MyConn) as p:
+ async with p.connection() as conn:
+ assert isinstance(conn, MyConn)
+
+
+async def test_kwargs(dsn):
+ async with AsyncNullConnectionPool(dsn, kwargs={"autocommit": True}) as p:
+ async with p.connection() as conn:
+ assert conn.autocommit
+
+
+async def test_its_no_pool_at_all(dsn):
+ async with AsyncNullConnectionPool(dsn, max_size=2) as p:
+ async with p.connection() as conn:
+ cur = await conn.execute("select pg_backend_pid()")
+ (pid1,) = await cur.fetchone() # type: ignore[misc]
+
+ async with p.connection() as conn2:
+ cur = await conn2.execute("select pg_backend_pid()")
+ (pid2,) = await cur.fetchone() # type: ignore[misc]
+
+ async with p.connection() as conn:
+ assert conn.pgconn.backend_pid not in (pid1, pid2)
+
+
+async def test_context(dsn):
+ async with AsyncNullConnectionPool(dsn) as p:
+ assert not p.closed
+ assert p.closed
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_wait_ready(dsn, monkeypatch):
+ delay_connection(monkeypatch, 0.2)
+ with pytest.raises(PoolTimeout):
+ async with AsyncNullConnectionPool(dsn, num_workers=1) as p:
+ await p.wait(0.1)
+
+ async with AsyncNullConnectionPool(dsn, num_workers=1) as p:
+ await p.wait(0.4)
+
+
+async def test_wait_closed(dsn):
+ async with AsyncNullConnectionPool(dsn) as p:
+ pass
+
+ with pytest.raises(PoolClosed):
+ await p.wait()
+
+
+@pytest.mark.slow
+async def test_setup_no_timeout(dsn, proxy):
+ with pytest.raises(PoolTimeout):
+ async with AsyncNullConnectionPool(
+ proxy.client_dsn, num_workers=1
+ ) as p:
+ await p.wait(0.2)
+
+ async with AsyncNullConnectionPool(proxy.client_dsn, num_workers=1) as p:
+ await asyncio.sleep(0.5)
+ assert not p._pool
+ proxy.start()
+
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+
+
+async def test_configure(dsn):
+ inits = 0
+
+ async def configure(conn):
+ nonlocal inits
+ inits += 1
+ async with conn.transaction():
+ await conn.execute("set default_transaction_read_only to on")
+
+ async with AsyncNullConnectionPool(dsn, configure=configure) as p:
+ async with p.connection() as conn:
+ assert inits == 1
+ res = await conn.execute("show default_transaction_read_only")
+ assert (await res.fetchone())[0] == "on" # type: ignore[index]
+
+ async with p.connection() as conn:
+ assert inits == 2
+ res = await conn.execute("show default_transaction_read_only")
+ assert (await res.fetchone())[0] == "on" # type: ignore[index]
+ await conn.close()
+
+ async with p.connection() as conn:
+ assert inits == 3
+ res = await conn.execute("show default_transaction_read_only")
+ assert (await res.fetchone())[0] == "on" # type: ignore[index]
+
+
+@pytest.mark.slow
+async def test_configure_badstate(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async def configure(conn):
+ await conn.execute("select 1")
+
+ async with AsyncNullConnectionPool(dsn, configure=configure) as p:
+ with pytest.raises(PoolTimeout):
+ await p.wait(timeout=0.5)
+
+ assert caplog.records
+ assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.slow
+async def test_configure_broken(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async def configure(conn):
+ async with conn.transaction():
+ await conn.execute("WAT")
+
+ async with AsyncNullConnectionPool(dsn, configure=configure) as p:
+ with pytest.raises(PoolTimeout):
+ await p.wait(timeout=0.5)
+
+ assert caplog.records
+ assert "WAT" in caplog.records[0].message
+
+
+async def test_reset(dsn):
+ resets = 0
+
+ async def setup(conn):
+ async with conn.transaction():
+ await conn.execute("set timezone to '+1:00'")
+
+ async def reset(conn):
+ nonlocal resets
+ resets += 1
+ async with conn.transaction():
+ await conn.execute("set timezone to utc")
+
+ pids = []
+
+ async def worker():
+ async with p.connection() as conn:
+ assert resets == 1
+ cur = await conn.execute("show timezone")
+ assert (await cur.fetchone()) == ("UTC",)
+ pids.append(conn.pgconn.backend_pid)
+
+ async with AsyncNullConnectionPool(dsn, max_size=1, reset=reset) as p:
+ async with p.connection() as conn:
+
+ # Queue the worker so it will take the same connection a second time
+ # instead of making a new one.
+ t = create_task(worker())
+
+ assert resets == 0
+ await conn.execute("set timezone to '+2:00'")
+ pids.append(conn.pgconn.backend_pid)
+
+ await asyncio.gather(t)
+ await p.wait()
+
+ assert resets == 1
+ assert pids[0] == pids[1]
+
+
+async def test_reset_badstate(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async def reset(conn):
+ await conn.execute("reset all")
+
+ pids = []
+
+ async def worker():
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+ pids.append(conn.pgconn.backend_pid)
+
+ async with AsyncNullConnectionPool(dsn, max_size=1, reset=reset) as p:
+ async with p.connection() as conn:
+
+ t = create_task(worker())
+
+ await conn.execute("select 1")
+ pids.append(conn.pgconn.backend_pid)
+
+ await asyncio.gather(t)
+
+ assert pids[0] != pids[1]
+ assert caplog.records
+ assert "INTRANS" in caplog.records[0].message
+
+
+async def test_reset_broken(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+ async def reset(conn):
+ async with conn.transaction():
+ await conn.execute("WAT")
+
+ pids = []
+
+ async def worker():
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+ pids.append(conn.pgconn.backend_pid)
+
+ async with AsyncNullConnectionPool(dsn, max_size=1, reset=reset) as p:
+ async with p.connection() as conn:
+
+ t = create_task(worker())
+
+ await conn.execute("select 1")
+ pids.append(conn.pgconn.backend_pid)
+
+ await asyncio.gather(t)
+
+ assert pids[0] != pids[1]
+ assert caplog.records
+ assert "WAT" in caplog.records[0].message
+
+
+@pytest.mark.slow
+@pytest.mark.skipif("ver(psycopg.__version__) < ver('3.0.8')")
+async def test_no_queue_timeout(deaf_port):
+ async with AsyncNullConnectionPool(
+ kwargs={"host": "localhost", "port": deaf_port}
+ ) as p:
+ with pytest.raises(PoolTimeout):
+ async with p.connection(timeout=1):
+ pass
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_queue(dsn, retries):
+ async def worker(n):
+ t0 = time()
+ async with p.connection() as conn:
+ cur = await conn.execute(
+ "select pg_backend_pid() from pg_sleep(0.2)"
+ )
+ (pid,) = await cur.fetchone() # type: ignore[misc]
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ async for retry in retries:
+ with retry:
+ results: List[Tuple[int, float, int]] = []
+ async with AsyncNullConnectionPool(dsn, max_size=2) as p:
+ await p.wait()
+ ts = [create_task(worker(i)) for i in range(6)]
+ await asyncio.gather(*ts)
+
+ times = [item[1] for item in results]
+ want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6]
+ for got, want in zip(times, want_times):
+ assert got == pytest.approx(want, 0.2), times
+
+ assert len(set(r[2] for r in results)) == 2, results
+
+
+@pytest.mark.slow
+async def test_queue_size(dsn):
+ async def worker(t, ev=None):
+ try:
+ async with p.connection():
+ if ev:
+ ev.set()
+ await asyncio.sleep(t)
+ except TooManyRequests as e:
+ errors.append(e)
+ else:
+ success.append(True)
+
+ errors: List[Exception] = []
+ success: List[bool] = []
+
+ async with AsyncNullConnectionPool(dsn, max_size=1, max_waiting=3) as p:
+ await p.wait()
+ ev = asyncio.Event()
+ create_task(worker(0.3, ev))
+ await ev.wait()
+
+ ts = [create_task(worker(0.1)) for i in range(4)]
+ await asyncio.gather(*ts)
+
+ assert len(success) == 4
+ assert len(errors) == 1
+ assert isinstance(errors[0], TooManyRequests)
+ assert p.name in str(errors[0])
+ assert str(p.max_waiting) in str(errors[0])
+ assert p.get_stats()["requests_errors"] == 1
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_queue_timeout(dsn, retries):
+ async def worker(n):
+ t0 = time()
+ try:
+ async with p.connection() as conn:
+ cur = await conn.execute(
+ "select pg_backend_pid() from pg_sleep(0.2)"
+ )
+ (pid,) = await cur.fetchone() # type: ignore[misc]
+ except PoolTimeout as e:
+ t1 = time()
+ errors.append((n, t1 - t0, e))
+ else:
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ async for retry in retries:
+ with retry:
+ results: List[Tuple[int, float, int]] = []
+ errors: List[Tuple[int, float, Exception]] = []
+
+ async with AsyncNullConnectionPool(
+ dsn, max_size=2, timeout=0.1
+ ) as p:
+ ts = [create_task(worker(i)) for i in range(4)]
+ await asyncio.gather(*ts)
+
+ assert len(results) == 2
+ assert len(errors) == 2
+ for e in errors:
+ assert 0.1 < e[1] < 0.15
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_dead_client(dsn):
+ async def worker(i, timeout):
+ try:
+ async with p.connection(timeout=timeout) as conn:
+ await conn.execute("select pg_sleep(0.3)")
+ results.append(i)
+ except PoolTimeout:
+ if timeout > 0.2:
+ raise
+
+ async with AsyncNullConnectionPool(dsn, max_size=2) as p:
+ results: List[int] = []
+ ts = [
+ create_task(worker(i, timeout))
+ for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4])
+ ]
+ await asyncio.gather(*ts)
+
+ await asyncio.sleep(0.2)
+ assert set(results) == set([0, 1, 3, 4])
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_queue_timeout_override(dsn, retries):
+ async def worker(n):
+ t0 = time()
+ timeout = 0.25 if n == 3 else None
+ try:
+ async with p.connection(timeout=timeout) as conn:
+ cur = await conn.execute(
+ "select pg_backend_pid() from pg_sleep(0.2)"
+ )
+ (pid,) = await cur.fetchone() # type: ignore[misc]
+ except PoolTimeout as e:
+ t1 = time()
+ errors.append((n, t1 - t0, e))
+ else:
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ async for retry in retries:
+ with retry:
+ results: List[Tuple[int, float, int]] = []
+ errors: List[Tuple[int, float, Exception]] = []
+
+ async with AsyncNullConnectionPool(
+ dsn, max_size=2, timeout=0.1
+ ) as p:
+ ts = [create_task(worker(i)) for i in range(4)]
+ await asyncio.gather(*ts)
+
+ assert len(results) == 3
+ assert len(errors) == 1
+ for e in errors:
+ assert 0.1 < e[1] < 0.15
+
+
+async def test_broken_reconnect(dsn):
+ async with AsyncNullConnectionPool(dsn, max_size=1) as p:
+ async with p.connection() as conn:
+ cur = await conn.execute("select pg_backend_pid()")
+ (pid1,) = await cur.fetchone() # type: ignore[misc]
+ await conn.close()
+
+ async with p.connection() as conn2:
+ cur = await conn2.execute("select pg_backend_pid()")
+ (pid2,) = await cur.fetchone() # type: ignore[misc]
+
+ assert pid1 != pid2
+
+
+async def test_intrans_rollback(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ pids = []
+
+ async def worker():
+ async with p.connection() as conn:
+ pids.append(conn.pgconn.backend_pid)
+ assert conn.pgconn.transaction_status == TransactionStatus.IDLE
+ cur = await conn.execute(
+ "select 1 from pg_class where relname = 'test_intrans_rollback'"
+ )
+ assert not await cur.fetchone()
+
+ async with AsyncNullConnectionPool(dsn, max_size=1) as p:
+ conn = await p.getconn()
+
+ # Queue the worker so it will take the connection a second time instead
+ # of making a new one.
+ t = create_task(worker())
+
+ pids.append(conn.pgconn.backend_pid)
+ await conn.execute("create table test_intrans_rollback ()")
+ assert conn.pgconn.transaction_status == TransactionStatus.INTRANS
+ await p.putconn(conn)
+ await asyncio.gather(t)
+
+ assert pids[0] == pids[1]
+ assert len(caplog.records) == 1
+ assert "INTRANS" in caplog.records[0].message
+
+
+async def test_inerror_rollback(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ pids = []
+
+ async def worker():
+ async with p.connection() as conn:
+ pids.append(conn.pgconn.backend_pid)
+ assert conn.pgconn.transaction_status == TransactionStatus.IDLE
+
+ async with AsyncNullConnectionPool(dsn, max_size=1) as p:
+ conn = await p.getconn()
+
+ t = create_task(worker())
+
+ pids.append(conn.pgconn.backend_pid)
+ with pytest.raises(psycopg.ProgrammingError):
+ await conn.execute("wat")
+ assert conn.pgconn.transaction_status == TransactionStatus.INERROR
+ await p.putconn(conn)
+ await asyncio.gather(t)
+
+ assert pids[0] == pids[1]
+ assert len(caplog.records) == 1
+ assert "INERROR" in caplog.records[0].message
+
+
+async def test_active_close(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ pids = []
+
+ async def worker():
+ async with p.connection() as conn:
+ pids.append(conn.pgconn.backend_pid)
+ assert conn.pgconn.transaction_status == TransactionStatus.IDLE
+
+ async with AsyncNullConnectionPool(dsn, max_size=1) as p:
+ conn = await p.getconn()
+
+ t = create_task(worker())
+
+ pids.append(conn.pgconn.backend_pid)
+ cur = conn.cursor()
+ async with cur.copy(
+ "copy (select * from generate_series(1, 10)) to stdout"
+ ):
+ pass
+ assert conn.pgconn.transaction_status == TransactionStatus.ACTIVE
+ await p.putconn(conn)
+ await asyncio.gather(t)
+
+ assert pids[0] != pids[1]
+ assert len(caplog.records) == 2
+ assert "ACTIVE" in caplog.records[0].message
+ assert "BAD" in caplog.records[1].message
+
+
+async def test_fail_rollback_close(dsn, caplog, monkeypatch):
+ caplog.set_level(logging.WARNING, logger="psycopg.pool")
+ pids = []
+
+ async def worker():
+ async with p.connection() as conn:
+ pids.append(conn.pgconn.backend_pid)
+ assert conn.pgconn.transaction_status == TransactionStatus.IDLE
+
+ async with AsyncNullConnectionPool(dsn, max_size=1) as p:
+ conn = await p.getconn()
+ t = create_task(worker())
+
+ async def bad_rollback():
+ conn.pgconn.finish()
+ await orig_rollback()
+
+ # Make the rollback fail
+ orig_rollback = conn.rollback
+ monkeypatch.setattr(conn, "rollback", bad_rollback)
+
+ pids.append(conn.pgconn.backend_pid)
+ with pytest.raises(psycopg.ProgrammingError):
+ await conn.execute("wat")
+ assert conn.pgconn.transaction_status == TransactionStatus.INERROR
+ await p.putconn(conn)
+ await asyncio.gather(t)
+
+ assert pids[0] != pids[1]
+ assert len(caplog.records) == 3
+ assert "INERROR" in caplog.records[0].message
+ assert "OperationalError" in caplog.records[1].message
+ assert "BAD" in caplog.records[2].message
+
+
+async def test_close_no_tasks(dsn):
+ p = AsyncNullConnectionPool(dsn)
+ assert p._sched_runner and not p._sched_runner.done()
+ assert p._workers
+ workers = p._workers[:]
+ for t in workers:
+ assert not t.done()
+
+ await p.close()
+ assert p._sched_runner is None
+ assert not p._workers
+ for t in workers:
+ assert t.done()
+
+
+async def test_putconn_no_pool(dsn):
+ async with AsyncNullConnectionPool(dsn) as p:
+ conn = await psycopg.AsyncConnection.connect(dsn)
+ with pytest.raises(ValueError):
+ await p.putconn(conn)
+
+ await conn.close()
+
+
+async def test_putconn_wrong_pool(dsn):
+ async with AsyncNullConnectionPool(dsn) as p1:
+ async with AsyncNullConnectionPool(dsn) as p2:
+ conn = await p1.getconn()
+ with pytest.raises(ValueError):
+ await p2.putconn(conn)
+
+
+async def test_closed_getconn(dsn):
+ p = AsyncNullConnectionPool(dsn)
+ assert not p.closed
+ async with p.connection():
+ pass
+
+ await p.close()
+ assert p.closed
+
+ with pytest.raises(PoolClosed):
+ async with p.connection():
+ pass
+
+
+async def test_closed_putconn(dsn):
+ p = AsyncNullConnectionPool(dsn)
+
+ async with p.connection() as conn:
+ pass
+ assert conn.closed
+
+ async with p.connection() as conn:
+ await p.close()
+ assert conn.closed
+
+
+async def test_closed_queue(dsn):
+ async def w1():
+ async with p.connection() as conn:
+ e1.set() # Tell w0 that w1 got a connection
+ cur = await conn.execute("select 1")
+ assert await cur.fetchone() == (1,)
+ await e2.wait() # Wait until w0 has tested w2
+ success.append("w1")
+
+ async def w2():
+ try:
+ async with p.connection():
+ pass # unexpected
+ except PoolClosed:
+ success.append("w2")
+
+ e1 = asyncio.Event()
+ e2 = asyncio.Event()
+
+ p = AsyncNullConnectionPool(dsn, max_size=1)
+ await p.wait()
+ success: List[str] = []
+
+ t1 = create_task(w1())
+ # Wait until w1 has received a connection
+ await e1.wait()
+
+ t2 = create_task(w2())
+ # Wait until w2 is in the queue
+ while not p._waiting:
+ await asyncio.sleep(0)
+
+ await p.close()
+
+ # Wait for the workers to finish
+ e2.set()
+ await asyncio.gather(t1, t2)
+ assert len(success) == 2
+
+
+async def test_open_explicit(dsn):
+ p = AsyncNullConnectionPool(dsn, open=False)
+ assert p.closed
+ with pytest.raises(PoolClosed):
+ await p.getconn()
+
+ with pytest.raises(PoolClosed, match="is not open yet"):
+ async with p.connection():
+ pass
+
+ await p.open()
+ try:
+ assert not p.closed
+
+ async with p.connection() as conn:
+ cur = await conn.execute("select 1")
+ assert await cur.fetchone() == (1,)
+
+ finally:
+ await p.close()
+
+ with pytest.raises(PoolClosed, match="is already closed"):
+ await p.getconn()
+
+
+async def test_open_context(dsn):
+ p = AsyncNullConnectionPool(dsn, open=False)
+ assert p.closed
+
+ async with p:
+ assert not p.closed
+
+ async with p.connection() as conn:
+ cur = await conn.execute("select 1")
+ assert await cur.fetchone() == (1,)
+
+ assert p.closed
+
+
+async def test_open_no_op(dsn):
+ p = AsyncNullConnectionPool(dsn)
+ try:
+ assert not p.closed
+ await p.open()
+ assert not p.closed
+
+ async with p.connection() as conn:
+ cur = await conn.execute("select 1")
+ assert await cur.fetchone() == (1,)
+
+ finally:
+ await p.close()
+
+
+async def test_reopen(dsn):
+ p = AsyncNullConnectionPool(dsn)
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+ await p.close()
+ assert p._sched_runner is None
+
+ with pytest.raises(psycopg.OperationalError, match="cannot be reused"):
+ await p.open()
+
+
+@pytest.mark.parametrize(
+ "min_size, max_size", [(1, None), (-1, None), (0, -2)]
+)
+async def test_bad_resize(dsn, min_size, max_size):
+ async with AsyncNullConnectionPool() as p:
+ with pytest.raises(ValueError):
+ await p.resize(min_size=min_size, max_size=max_size)
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_max_lifetime(dsn):
+ pids: List[int] = []
+
+ async def worker():
+ async with p.connection() as conn:
+ pids.append(conn.pgconn.backend_pid)
+ await asyncio.sleep(0.1)
+
+ async with AsyncNullConnectionPool(dsn, max_size=1, max_lifetime=0.2) as p:
+ ts = [create_task(worker()) for i in range(5)]
+ await asyncio.gather(*ts)
+
+ assert pids[0] == pids[1] != pids[4], pids
+
+
+async def test_check(dsn):
+ # no.op
+ async with AsyncNullConnectionPool(dsn) as p:
+ await p.check()
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_stats_measures(dsn):
+ async def worker(n):
+ async with p.connection() as conn:
+ await conn.execute("select pg_sleep(0.2)")
+
+ async with AsyncNullConnectionPool(dsn, max_size=4) as p:
+ await p.wait(2.0)
+
+ stats = p.get_stats()
+ assert stats["pool_min"] == 0
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 0
+ assert stats["pool_available"] == 0
+ assert stats["requests_waiting"] == 0
+
+ ts = [create_task(worker(i)) for i in range(3)]
+ await asyncio.sleep(0.1)
+ stats = p.get_stats()
+ await asyncio.gather(*ts)
+ assert stats["pool_min"] == 0
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 3
+ assert stats["pool_available"] == 0
+ assert stats["requests_waiting"] == 0
+
+ await p.wait(2.0)
+ ts = [create_task(worker(i)) for i in range(7)]
+ await asyncio.sleep(0.1)
+ stats = p.get_stats()
+ await asyncio.gather(*ts)
+ assert stats["pool_min"] == 0
+ assert stats["pool_max"] == 4
+ assert stats["pool_size"] == 4
+ assert stats["pool_available"] == 0
+ assert stats["requests_waiting"] == 3
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_stats_usage(dsn, retries):
+ async def worker(n):
+ try:
+ async with p.connection(timeout=0.3) as conn:
+ await conn.execute("select pg_sleep(0.2)")
+ except PoolTimeout:
+ pass
+
+ async for retry in retries:
+ with retry:
+ async with AsyncNullConnectionPool(dsn, max_size=3) as p:
+ await p.wait(2.0)
+
+ ts = [create_task(worker(i)) for i in range(7)]
+ await asyncio.gather(*ts)
+ stats = p.get_stats()
+ assert stats["requests_num"] == 7
+ assert stats["requests_queued"] == 4
+ assert 850 <= stats["requests_wait_ms"] <= 950
+ assert stats["requests_errors"] == 1
+ assert 1150 <= stats["usage_ms"] <= 1250
+ assert stats.get("returns_bad", 0) == 0
+
+ async with p.connection() as conn:
+ await conn.close()
+ await p.wait()
+ stats = p.pop_stats()
+ assert stats["requests_num"] == 8
+ assert stats["returns_bad"] == 1
+ async with p.connection():
+ pass
+ assert p.get_stats()["requests_num"] == 1
+
+
+@pytest.mark.slow
+async def test_stats_connect(dsn, proxy, monkeypatch):
+ proxy.start()
+ delay_connection(monkeypatch, 0.2)
+ async with AsyncNullConnectionPool(proxy.client_dsn, max_size=3) as p:
+ await p.wait()
+ stats = p.get_stats()
+ assert stats["connections_num"] == 1
+ assert stats.get("connections_errors", 0) == 0
+ assert stats.get("connections_lost", 0) == 0
+ assert 200 <= stats["connections_ms"] < 300
with p.connection() as conn2:
assert conn2.pgconn.backend_pid == pid
assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
- assert not conn.execute(
+ assert not conn2.execute(
"select 1 from pg_class where relname = 'test_intrans_rollback'"
).fetchone()