From: Daniele Varrazzo Date: Mon, 3 Jan 2022 19:20:29 +0000 (+0100) Subject: Add NullPool and AsyncNullPool X-Git-Tag: pool-3.1~21^2~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=44f97ccd78121ef685da37f1ead84fa0d374978e;p=thirdparty%2Fpsycopg.git Add NullPool and AsyncNullPool Close #148 --- diff --git a/psycopg_pool/psycopg_pool/__init__.py b/psycopg_pool/psycopg_pool/__init__.py index 49b035b31..e4d975fed 100644 --- a/psycopg_pool/psycopg_pool/__init__.py +++ b/psycopg_pool/psycopg_pool/__init__.py @@ -6,12 +6,16 @@ psycopg connection pool package from .pool import ConnectionPool from .pool_async import AsyncConnectionPool +from .null_pool import NullConnectionPool +from .null_pool_async import AsyncNullConnectionPool from .errors import PoolClosed, PoolTimeout, TooManyRequests from .version import __version__ as __version__ # noqa: F401 __all__ = [ "AsyncConnectionPool", + "AsyncNullConnectionPool", "ConnectionPool", + "NullConnectionPool", "PoolClosed", "PoolTimeout", "TooManyRequests", diff --git a/psycopg_pool/psycopg_pool/_compat.py b/psycopg_pool/psycopg_pool/_compat.py index f666e677c..c1b14f2fe 100644 --- a/psycopg_pool/psycopg_pool/_compat.py +++ b/psycopg_pool/psycopg_pool/_compat.py @@ -6,7 +6,9 @@ compatibility functions for different Python versions import sys import asyncio -from typing import Any, Awaitable, Generator, Optional, Union, TypeVar +from typing import Any, Awaitable, Generator, Optional, Union, Type, TypeVar + +import psycopg.errors as e T = TypeVar("T") FutureT = Union["asyncio.Future[T]", Generator[Any, None, T], Awaitable[T]] @@ -35,3 +37,14 @@ __all__ = [ "Task", "create_task", ] + +# Workaround for psycopg < 3.0.8. +# Timeout on NullPool connection mignt not work correctly. +try: + ConnectionTimeout: Type[e.OperationalError] = e.ConnectionTimeout +except AttributeError: + + class DummyConnectionTimeout(e.OperationalError): + pass + + ConnectionTimeout = DummyConnectionTimeout diff --git a/psycopg_pool/psycopg_pool/base.py b/psycopg_pool/psycopg_pool/base.py index 7c9d96223..1e3187b55 100644 --- a/psycopg_pool/psycopg_pool/base.py +++ b/psycopg_pool/psycopg_pool/base.py @@ -121,11 +121,11 @@ class BasePool(Generic[ConnectionType]): def _check_size( self, min_size: int, max_size: Optional[int] ) -> Tuple[int, int]: - if min_size < 0: - raise ValueError("min_size cannot be negative") - if max_size is None: max_size = min_size + + if min_size < 0: + raise ValueError("min_size cannot be negative") if max_size < min_size: raise ValueError("max_size must be greater or equal than min_size") if min_size == max_size == 0: diff --git a/psycopg_pool/psycopg_pool/null_pool.py b/psycopg_pool/psycopg_pool/null_pool.py new file mode 100644 index 000000000..58823cb2e --- /dev/null +++ b/psycopg_pool/psycopg_pool/null_pool.py @@ -0,0 +1,198 @@ +""" +Psycopg null connection pools +""" + +# Copyright (C) 2022 The Psycopg Team + +import logging +import threading +from time import monotonic +from typing import Any, Optional, Tuple + +from psycopg import Connection +from psycopg.pq import TransactionStatus + +from .pool import ConnectionPool, WaitingClient +from .pool import AddConnection, ReturnConnection +from .errors import PoolTimeout, TooManyRequests +from ._compat import ConnectionTimeout + +logger = logging.getLogger("psycopg.pool") + + +class _BaseNullConnectionPool: + def __init__( + self, conninfo: str = "", min_size: int = 0, *args: Any, **kwargs: Any + ): + super().__init__( # type: ignore[call-arg] + conninfo, *args, min_size=min_size, **kwargs + ) + + def _check_size( + self, min_size: int, max_size: Optional[int] + ) -> Tuple[int, int]: + if max_size is None: + max_size = min_size + + if min_size != 0: + raise ValueError("null pools must have min_size = 0") + if max_size < min_size: + raise ValueError("max_size must be greater or equal than min_size") + + return min_size, max_size + + def _start_initial_tasks(self) -> None: + # Null pools don't have background tasks to fill connections + # or to grow/shrink. + return + + +class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool): + def wait(self, timeout: float = 30.0) -> None: + """ + Create a connection for test. + + Calling this function will verify that the connectivity with the + database works as expected. However the connection will not be stored + in the pool. + + Raise `PoolTimeout` if not ready within *timeout* sec. + """ + self._check_open_getconn() + + with self._lock: + assert not self._pool_full_event + self._pool_full_event = threading.Event() + + logger.info("waiting for pool %r initialization", self.name) + self.run_task(AddConnection(self)) + if not self._pool_full_event.wait(timeout): + self.close() # stop all the threads + raise PoolTimeout( + f"pool initialization incomplete after {timeout} sec" + ) + + with self._lock: + assert self._pool_full_event + self._pool_full_event = None + + logger.info("pool %r is ready to use", self.name) + + def getconn(self, timeout: Optional[float] = None) -> Connection[Any]: + logger.info("connection requested from %r", self.name) + self._stats[self._REQUESTS_NUM] += 1 + + # Critical section: decide here if there's a connection ready + # or if the client needs to wait. + with self._lock: + self._check_open_getconn() + + pos: Optional[WaitingClient] = None + if self.max_size == 0 or self._nconns < self.max_size: + # Create a new connection for the client + try: + conn = self._connect(timeout=timeout) + except ConnectionTimeout as ex: + raise PoolTimeout(str(ex)) from None + self._nconns += 1 + else: + if self.max_waiting and len(self._waiting) >= self.max_waiting: + self._stats[self._REQUESTS_ERRORS] += 1 + raise TooManyRequests( + f"the pool {self.name!r} has aleady" + f" {len(self._waiting)} requests waiting" + ) + + # No connection available: put the client in the waiting queue + t0 = monotonic() + pos = WaitingClient() + self._waiting.append(pos) + self._stats[self._REQUESTS_QUEUED] += 1 + + # If we are in the waiting queue, wait to be assigned a connection + # (outside the critical section, so only the waiting client is locked) + if pos: + if timeout is None: + timeout = self.timeout + try: + conn = pos.wait(timeout=timeout) + except Exception: + self._stats[self._REQUESTS_ERRORS] += 1 + raise + finally: + t1 = monotonic() + self._stats[self._REQUESTS_WAIT_MS] += int(1000.0 * (t1 - t0)) + + # Tell the connection it belongs to a pool to avoid closing on __exit__ + conn._pool = self + logger.info("connection given by %r", self.name) + return conn + + def putconn(self, conn: Connection[Any]) -> None: + # Quick check to discard the wrong connection + self._check_pool_putconn(conn) + + logger.info("returning connection to %r", self.name) + + # Close the connection if no client is waiting for it, or if the pool + # is closed. For extra refcare remove the pool reference from it. + # Maintain the stats. + with self._lock: + if self._closed or not self._waiting: + conn._pool = None + if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN: + self._stats[self._RETURNS_BAD] += 1 + conn.close() + self._nconns -= 1 + return + + # Use a worker to perform eventual maintenance work in a separate thread + if self._reset: + self.run_task(ReturnConnection(self, conn)) + else: + self._return_connection(conn) + + def resize(self, min_size: int, max_size: Optional[int] = None) -> None: + min_size, max_size = self._check_size(min_size, max_size) + + logger.info( + "resizing %r to min_size=%s max_size=%s", + self.name, + min_size, + max_size, + ) + with self._lock: + self._min_size = min_size + self._max_size = max_size + + def check(self) -> None: + """No-op, as the pool doesn't have connections in its state.""" + pass + + def _add_to_pool(self, conn: Connection[Any]) -> None: + # Remove the pool reference from the connection before returning it + # to the state, to avoid to create a reference loop. + # Also disable the warning for open connection in conn.__del__ + conn._pool = None + + # Critical section: if there is a client waiting give it the connection + # otherwise put it back into the pool. + with self._lock: + while self._waiting: + # If there is a client waiting (which is still waiting and + # hasn't timed out), give it the connection and notify it. + pos = self._waiting.popleft() + if pos.set(conn): + break + else: + # No client waiting for a connection: close the connection + conn.close() + + # If we have been asked to wait for pool init, notify the + # waiter if the pool is ready. + if self._pool_full_event: + self._pool_full_event.set() + else: + # The connection created by wait shoudn't decrease the + # count of the number of connection used. + self._nconns -= 1 diff --git a/psycopg_pool/psycopg_pool/null_pool_async.py b/psycopg_pool/psycopg_pool/null_pool_async.py new file mode 100644 index 000000000..690122310 --- /dev/null +++ b/psycopg_pool/psycopg_pool/null_pool_async.py @@ -0,0 +1,168 @@ +""" +psycopg asynchronous null connection pool +""" + +# Copyright (C) 2022 The Psycopg Team + +import asyncio +import logging +from time import monotonic +from typing import Any, Optional + +from psycopg.pq import TransactionStatus +from psycopg.connection_async import AsyncConnection + +from .errors import PoolTimeout, TooManyRequests +from ._compat import ConnectionTimeout +from .null_pool import _BaseNullConnectionPool +from .pool_async import AsyncConnectionPool, AsyncClient +from .pool_async import AddConnection, ReturnConnection + +logger = logging.getLogger("psycopg.pool") + + +class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool): + async def wait(self, timeout: float = 30.0) -> None: + self._check_open_getconn() + + async with self._lock: + assert not self._pool_full_event + self._pool_full_event = asyncio.Event() + + logger.info("waiting for pool %r initialization", self.name) + self.run_task(AddConnection(self)) + try: + await asyncio.wait_for(self._pool_full_event.wait(), timeout) + except asyncio.TimeoutError: + await self.close() # stop all the tasks + raise PoolTimeout( + f"pool initialization incomplete after {timeout} sec" + ) from None + + async with self._lock: + assert self._pool_full_event + self._pool_full_event = None + + logger.info("pool %r is ready to use", self.name) + + async def getconn( + self, timeout: Optional[float] = None + ) -> AsyncConnection[Any]: + logger.info("connection requested from %r", self.name) + self._stats[self._REQUESTS_NUM] += 1 + + # Critical section: decide here if there's a connection ready + # or if the client needs to wait. + async with self._lock: + self._check_open_getconn() + + pos: Optional[AsyncClient] = None + if self.max_size == 0 or self._nconns < self.max_size: + # Create a new connection for the client + try: + conn = await self._connect(timeout=timeout) + except ConnectionTimeout as ex: + raise PoolTimeout(str(ex)) from None + self._nconns += 1 + else: + if self.max_waiting and len(self._waiting) >= self.max_waiting: + self._stats[self._REQUESTS_ERRORS] += 1 + raise TooManyRequests( + f"the pool {self.name!r} has aleady" + f" {len(self._waiting)} requests waiting" + ) + + # No connection available: put the client in the waiting queue + t0 = monotonic() + pos = AsyncClient() + self._waiting.append(pos) + self._stats[self._REQUESTS_QUEUED] += 1 + + # If we are in the waiting queue, wait to be assigned a connection + # (outside the critical section, so only the waiting client is locked) + if pos: + if timeout is None: + timeout = self.timeout + try: + conn = await pos.wait(timeout=timeout) + except Exception: + self._stats[self._REQUESTS_ERRORS] += 1 + raise + finally: + t1 = monotonic() + self._stats[self._REQUESTS_WAIT_MS] += int(1000.0 * (t1 - t0)) + + # Tell the connection it belongs to a pool to avoid closing on __exit__ + conn._pool = self + logger.info("connection given by %r", self.name) + return conn + + async def putconn(self, conn: AsyncConnection[Any]) -> None: + # Quick check to discard the wrong connection + self._check_pool_putconn(conn) + + logger.info("returning connection to %r", self.name) + + # Close the connection if no client is waiting for it, or if the pool + # is closed. For extra refcare remove the pool reference from it. + # Maintain the stats. + async with self._lock: + if self._closed or not self._waiting: + conn._pool = None + if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN: + self._stats[self._RETURNS_BAD] += 1 + await conn.close() + self._nconns -= 1 + return + + # Use a worker to perform eventual maintenance work in a separate task + if self._reset: + self.run_task(ReturnConnection(self, conn)) + else: + await self._return_connection(conn) + + async def resize( + self, min_size: int, max_size: Optional[int] = None + ) -> None: + min_size, max_size = self._check_size(min_size, max_size) + + logger.info( + "resizing %r to min_size=%s max_size=%s", + self.name, + min_size, + max_size, + ) + async with self._lock: + self._min_size = min_size + self._max_size = max_size + + async def check(self) -> None: + pass + + async def _add_to_pool(self, conn: AsyncConnection[Any]) -> None: + # Remove the pool reference from the connection before returning it + # to the state, to avoid to create a reference loop. + # Also disable the warning for open connection in conn.__del__ + conn._pool = None + + # Critical section: if there is a client waiting give it the connection + # otherwise put it back into the pool. + async with self._lock: + while self._waiting: + # If there is a client waiting (which is still waiting and + # hasn't timed out), give it the connection and notify it. + pos = self._waiting.popleft() + if await pos.set(conn): + break + else: + # No client waiting for a connection: close the connection + await conn.close() + + # If we have been asked to wait for pool init, notify the + # waiter if the pool is ready. + if self._pool_full_event: + self._pool_full_event.set() + else: + # The connection created by wait shoudn't decrease the + # count of the number of connection used. + self._nconns -= 1 diff --git a/psycopg_pool/psycopg_pool/pool.py b/psycopg_pool/psycopg_pool/pool.py index aa1dd20b3..06683ba90 100644 --- a/psycopg_pool/psycopg_pool/pool.py +++ b/psycopg_pool/psycopg_pool/pool.py @@ -454,13 +454,17 @@ class ConnectionPool(BasePool[Connection[Any]]): ex, ) - def _connect(self) -> Connection[Any]: + def _connect(self, timeout: Optional[float] = None) -> Connection[Any]: """Return a new connection configured for the pool.""" self._stats[self._CONNECTIONS_NUM] += 1 + kwargs = self.kwargs + if timeout: + kwargs = kwargs.copy() + kwargs["connect_timeout"] = max(round(timeout), 1) t0 = monotonic() try: conn: Connection[Any] - conn = self.connection_class.connect(self.conninfo, **self.kwargs) + conn = self.connection_class.connect(self.conninfo, **kwargs) except Exception: self._stats[self._CONNECTIONS_ERRORS] += 1 raise diff --git a/psycopg_pool/psycopg_pool/pool_async.py b/psycopg_pool/psycopg_pool/pool_async.py index 8a4391ceb..7fe772ebc 100644 --- a/psycopg_pool/psycopg_pool/pool_async.py +++ b/psycopg_pool/psycopg_pool/pool_async.py @@ -368,15 +368,18 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): ex, ) - async def _connect(self) -> AsyncConnection[Any]: - """Return a new connection configured for the pool.""" + async def _connect( + self, timeout: Optional[float] = None + ) -> AsyncConnection[Any]: self._stats[self._CONNECTIONS_NUM] += 1 + kwargs = self.kwargs + if timeout: + kwargs = kwargs.copy() + kwargs["connect_timeout"] = max(round(timeout), 1) t0 = monotonic() try: conn: AsyncConnection[Any] - conn = await self.connection_class.connect( - self.conninfo, **self.kwargs - ) + conn = await self.connection_class.connect(self.conninfo, **kwargs) except Exception: self._stats[self._CONNECTIONS_ERRORS] += 1 raise diff --git a/tests/pool/test_null_pool.py b/tests/pool/test_null_pool.py new file mode 100644 index 000000000..9747c5e19 --- /dev/null +++ b/tests/pool/test_null_pool.py @@ -0,0 +1,898 @@ +import logging +from time import sleep, time +from threading import Thread, Event +from typing import Any, List, Tuple + +import pytest +from packaging.version import parse as ver # noqa: F401 # used in skipif + +import psycopg +from psycopg.pq import TransactionStatus + +from .test_pool import delay_connection + +try: + from psycopg_pool import NullConnectionPool + from psycopg_pool import PoolClosed, PoolTimeout, TooManyRequests +except ImportError: + pass + + +def test_defaults(dsn): + with NullConnectionPool(dsn) as p: + assert p.min_size == p.max_size == 0 + assert p.timeout == 30 + assert p.max_idle == 10 * 60 + assert p.max_lifetime == 60 * 60 + assert p.num_workers == 3 + + +def test_min_size_max_size(dsn): + with NullConnectionPool(dsn, min_size=0, max_size=2) as p: + assert p.min_size == 0 + assert p.max_size == 2 + + +@pytest.mark.parametrize( + "min_size, max_size", [(1, None), (-1, None), (0, -2)] +) +def test_bad_size(dsn, min_size, max_size): + with pytest.raises(ValueError): + NullConnectionPool(min_size=min_size, max_size=max_size) + + +def test_connection_class(dsn): + class MyConn(psycopg.Connection[Any]): + pass + + with NullConnectionPool(dsn, connection_class=MyConn) as p: + with p.connection() as conn: + assert isinstance(conn, MyConn) + + +def test_kwargs(dsn): + with NullConnectionPool(dsn, kwargs={"autocommit": True}) as p: + with p.connection() as conn: + assert conn.autocommit + + +def test_its_no_pool_at_all(dsn): + with NullConnectionPool(dsn, max_size=2) as p: + with p.connection() as conn: + with conn.execute("select pg_backend_pid()") as cur: + (pid1,) = cur.fetchone() # type: ignore[misc] + + with p.connection() as conn2: + with conn2.execute("select pg_backend_pid()") as cur: + (pid2,) = cur.fetchone() # type: ignore[misc] + + with p.connection() as conn: + assert conn.pgconn.backend_pid not in (pid1, pid2) + + +def test_context(dsn): + with NullConnectionPool(dsn) as p: + assert not p.closed + assert p.closed + + +@pytest.mark.slow +@pytest.mark.timing +def test_wait_ready(dsn, monkeypatch): + delay_connection(monkeypatch, 0.2) + with pytest.raises(PoolTimeout): + with NullConnectionPool(dsn, num_workers=1) as p: + p.wait(0.1) + + with NullConnectionPool(dsn, num_workers=1) as p: + p.wait(0.4) + + +def test_wait_closed(dsn): + with NullConnectionPool(dsn) as p: + pass + + with pytest.raises(PoolClosed): + p.wait() + + +@pytest.mark.slow +def test_setup_no_timeout(dsn, proxy): + with pytest.raises(PoolTimeout): + with NullConnectionPool(proxy.client_dsn, num_workers=1) as p: + p.wait(0.2) + + with NullConnectionPool(proxy.client_dsn, num_workers=1) as p: + sleep(0.5) + assert not p._pool + proxy.start() + + with p.connection() as conn: + conn.execute("select 1") + + +def test_configure(dsn): + inits = 0 + + def configure(conn): + nonlocal inits + inits += 1 + with conn.transaction(): + conn.execute("set default_transaction_read_only to on") + + with NullConnectionPool(dsn, configure=configure) as p: + with p.connection() as conn: + assert inits == 1 + res = conn.execute("show default_transaction_read_only") + assert res.fetchone()[0] == "on" # type: ignore[index] + + with p.connection() as conn: + assert inits == 2 + res = conn.execute("show default_transaction_read_only") + assert res.fetchone()[0] == "on" # type: ignore[index] + conn.close() + + with p.connection() as conn: + assert inits == 3 + res = conn.execute("show default_transaction_read_only") + assert res.fetchone()[0] == "on" # type: ignore[index] + + +@pytest.mark.slow +def test_configure_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def configure(conn): + conn.execute("select 1") + + with NullConnectionPool(dsn, configure=configure) as p: + with pytest.raises(PoolTimeout): + p.wait(timeout=0.5) + + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.slow +def test_configure_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def configure(conn): + with conn.transaction(): + conn.execute("WAT") + + with NullConnectionPool(dsn, configure=configure) as p: + with pytest.raises(PoolTimeout): + p.wait(timeout=0.5) + + assert caplog.records + assert "WAT" in caplog.records[0].message + + +def test_reset(dsn): + resets = 0 + + def setup(conn): + with conn.transaction(): + conn.execute("set timezone to '+1:00'") + + def reset(conn): + nonlocal resets + resets += 1 + with conn.transaction(): + conn.execute("set timezone to utc") + + pids = [] + + def worker(): + with p.connection() as conn: + assert resets == 1 + with conn.execute("show timezone") as cur: + assert cur.fetchone() == ("UTC",) + pids.append(conn.pgconn.backend_pid) + + with NullConnectionPool(dsn, max_size=1, reset=reset) as p: + with p.connection() as conn: + + # Queue the worker so it will take the same connection a second time + # instead of making a new one. + t = Thread(target=worker) + t.start() + + assert resets == 0 + conn.execute("set timezone to '+2:00'") + pids.append(conn.pgconn.backend_pid) + + t.join() + p.wait() + + assert resets == 1 + assert pids[0] == pids[1] + + +def test_reset_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def reset(conn): + conn.execute("reset all") + + pids = [] + + def worker(): + with p.connection() as conn: + conn.execute("select 1") + pids.append(conn.pgconn.backend_pid) + + with NullConnectionPool(dsn, max_size=1, reset=reset) as p: + with p.connection() as conn: + + t = Thread(target=worker) + t.start() + + conn.execute("select 1") + pids.append(conn.pgconn.backend_pid) + + t.join() + + assert pids[0] != pids[1] + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +def test_reset_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + def reset(conn): + with conn.transaction(): + conn.execute("WAT") + + pids = [] + + def worker(): + with p.connection() as conn: + conn.execute("select 1") + pids.append(conn.pgconn.backend_pid) + + with NullConnectionPool(dsn, max_size=1, reset=reset) as p: + with p.connection() as conn: + + t = Thread(target=worker) + t.start() + + conn.execute("select 1") + pids.append(conn.pgconn.backend_pid) + + t.join() + + assert pids[0] != pids[1] + assert caplog.records + assert "WAT" in caplog.records[0].message + + +@pytest.mark.slow +@pytest.mark.skipif("ver(psycopg.__version__) < ver('3.0.8')") +def test_no_queue_timeout(deaf_port): + with NullConnectionPool( + kwargs={"host": "localhost", "port": deaf_port} + ) as p: + with pytest.raises(PoolTimeout): + with p.connection(timeout=1): + pass + + +@pytest.mark.slow +@pytest.mark.timing +def test_queue(dsn, retries): + def worker(n): + t0 = time() + with p.connection() as conn: + (pid,) = conn.execute( + "select pg_backend_pid() from pg_sleep(0.2)" + ).fetchone() # type: ignore[misc] + t1 = time() + results.append((n, t1 - t0, pid)) + + for retry in retries: + with retry: + results: List[Tuple[int, float, int]] = [] + with NullConnectionPool(dsn, max_size=2) as p: + p.wait() + ts = [Thread(target=worker, args=(i,)) for i in range(6)] + for t in ts: + t.start() + for t in ts: + t.join() + + times = [item[1] for item in results] + want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6] + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.2), times + + assert len(set(r[2] for r in results)) == 2, results + + +@pytest.mark.slow +def test_queue_size(dsn): + def worker(t, ev=None): + try: + with p.connection(): + if ev: + ev.set() + sleep(t) + except TooManyRequests as e: + errors.append(e) + else: + success.append(True) + + errors: List[Exception] = [] + success: List[bool] = [] + + with NullConnectionPool(dsn, max_size=1, max_waiting=3) as p: + p.wait() + ev = Event() + t = Thread(target=worker, args=(0.3, ev)) + t.start() + ev.wait() + + ts = [Thread(target=worker, args=(0.1,)) for i in range(4)] + for t in ts: + t.start() + for t in ts: + t.join() + + assert len(success) == 4 + assert len(errors) == 1 + assert isinstance(errors[0], TooManyRequests) + assert p.name in str(errors[0]) + assert str(p.max_waiting) in str(errors[0]) + assert p.get_stats()["requests_errors"] == 1 + + +@pytest.mark.slow +@pytest.mark.timing +def test_queue_timeout(dsn, retries): + def worker(n): + t0 = time() + try: + with p.connection() as conn: + (pid,) = conn.execute( # type: ignore[misc] + "select pg_backend_pid() from pg_sleep(0.2)" + ).fetchone() + except PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + for retry in retries: + with retry: + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + with NullConnectionPool(dsn, max_size=2, timeout=0.1) as p: + ts = [Thread(target=worker, args=(i,)) for i in range(4)] + for t in ts: + t.start() + for t in ts: + t.join() + + assert len(results) == 2 + assert len(errors) == 2 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +@pytest.mark.slow +@pytest.mark.timing +def test_dead_client(dsn): + def worker(i, timeout): + try: + with p.connection(timeout=timeout) as conn: + conn.execute("select pg_sleep(0.3)") + results.append(i) + except PoolTimeout: + if timeout > 0.2: + raise + + results: List[int] = [] + + with NullConnectionPool(dsn, max_size=2) as p: + ts = [ + Thread(target=worker, args=(i, timeout)) + for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4]) + ] + for t in ts: + t.start() + for t in ts: + t.join() + sleep(0.2) + assert set(results) == set([0, 1, 3, 4]) + + +@pytest.mark.slow +@pytest.mark.timing +def test_queue_timeout_override(dsn, retries): + def worker(n): + t0 = time() + timeout = 0.25 if n == 3 else None + try: + with p.connection(timeout=timeout) as conn: + (pid,) = conn.execute( # type: ignore[misc] + "select pg_backend_pid() from pg_sleep(0.2)" + ).fetchone() + except PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + for retry in retries: + with retry: + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + with NullConnectionPool(dsn, max_size=2, timeout=0.1) as p: + ts = [Thread(target=worker, args=(i,)) for i in range(4)] + for t in ts: + t.start() + for t in ts: + t.join() + + assert len(results) == 3 + assert len(errors) == 1 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +def test_broken_reconnect(dsn): + with NullConnectionPool(dsn, max_size=1) as p: + with p.connection() as conn: + with conn.execute("select pg_backend_pid()") as cur: + (pid1,) = cur.fetchone() # type: ignore[misc] + conn.close() + + with p.connection() as conn2: + with conn2.execute("select pg_backend_pid()") as cur: + (pid2,) = cur.fetchone() # type: ignore[misc] + + assert pid1 != pid2 + + +def test_intrans_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + def worker(): + with p.connection() as conn: + pids.append(conn.pgconn.backend_pid) + assert conn.pgconn.transaction_status == TransactionStatus.IDLE + assert not conn.execute( + "select 1 from pg_class where relname = 'test_intrans_rollback'" + ).fetchone() + + with NullConnectionPool(dsn, max_size=1) as p: + conn = p.getconn() + + # Queue the worker so it will take the connection a second time instead + # of making a new one. + t = Thread(target=worker) + t.start() + + pids.append(conn.pgconn.backend_pid) + conn.execute("create table test_intrans_rollback ()") + assert conn.pgconn.transaction_status == TransactionStatus.INTRANS + p.putconn(conn) + t.join() + + assert pids[0] == pids[1] + assert len(caplog.records) == 1 + assert "INTRANS" in caplog.records[0].message + + +def test_inerror_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + def worker(): + with p.connection() as conn: + pids.append(conn.pgconn.backend_pid) + assert conn.pgconn.transaction_status == TransactionStatus.IDLE + + with NullConnectionPool(dsn, max_size=1) as p: + conn = p.getconn() + + # Queue the worker so it will take the connection a second time instead + # of making a new one. + t = Thread(target=worker) + t.start() + + pids.append(conn.pgconn.backend_pid) + with pytest.raises(psycopg.ProgrammingError): + conn.execute("wat") + assert conn.pgconn.transaction_status == TransactionStatus.INERROR + p.putconn(conn) + t.join() + + assert pids[0] == pids[1] + assert len(caplog.records) == 1 + assert "INERROR" in caplog.records[0].message + + +def test_active_close(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + def worker(): + with p.connection() as conn: + pids.append(conn.pgconn.backend_pid) + assert conn.pgconn.transaction_status == TransactionStatus.IDLE + + with NullConnectionPool(dsn, max_size=1) as p: + conn = p.getconn() + + t = Thread(target=worker) + t.start() + + pids.append(conn.pgconn.backend_pid) + cur = conn.cursor() + with cur.copy("copy (select * from generate_series(1, 10)) to stdout"): + pass + assert conn.pgconn.transaction_status == TransactionStatus.ACTIVE + p.putconn(conn) + t.join() + + assert pids[0] != pids[1] + assert len(caplog.records) == 2 + assert "ACTIVE" in caplog.records[0].message + assert "BAD" in caplog.records[1].message + + +def test_fail_rollback_close(dsn, caplog, monkeypatch): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + def worker(p): + with p.connection() as conn: + pids.append(conn.pgconn.backend_pid) + assert conn.pgconn.transaction_status == TransactionStatus.IDLE + + with NullConnectionPool(dsn, max_size=1) as p: + conn = p.getconn() + + def bad_rollback(): + conn.pgconn.finish() + orig_rollback() + + # Make the rollback fail + orig_rollback = conn.rollback + monkeypatch.setattr(conn, "rollback", bad_rollback) + + t = Thread(target=worker, args=(p,)) + t.start() + + pids.append(conn.pgconn.backend_pid) + with pytest.raises(psycopg.ProgrammingError): + conn.execute("wat") + assert conn.pgconn.transaction_status == TransactionStatus.INERROR + p.putconn(conn) + t.join() + + assert pids[0] != pids[1] + assert len(caplog.records) == 3 + assert "INERROR" in caplog.records[0].message + assert "OperationalError" in caplog.records[1].message + assert "BAD" in caplog.records[2].message + + +def test_close_no_threads(dsn): + p = NullConnectionPool(dsn) + assert p._sched_runner and p._sched_runner.is_alive() + workers = p._workers[:] + assert workers + for t in workers: + assert t.is_alive() + + p.close() + assert p._sched_runner is None + assert not p._workers + for t in workers: + assert not t.is_alive() + + +def test_putconn_no_pool(dsn): + with NullConnectionPool(dsn) as p: + conn = psycopg.connect(dsn) + with pytest.raises(ValueError): + p.putconn(conn) + + conn.close() + + +def test_putconn_wrong_pool(dsn): + with NullConnectionPool(dsn) as p1: + with NullConnectionPool(dsn) as p2: + conn = p1.getconn() + with pytest.raises(ValueError): + p2.putconn(conn) + + +@pytest.mark.slow +def test_del_stop_threads(dsn): + p = NullConnectionPool(dsn) + assert p._sched_runner is not None + ts = [p._sched_runner] + p._workers + del p + sleep(0.1) + for t in ts: + assert not t.is_alive() + + +def test_closed_getconn(dsn): + p = NullConnectionPool(dsn) + assert not p.closed + with p.connection(): + pass + + p.close() + assert p.closed + + with pytest.raises(PoolClosed): + with p.connection(): + pass + + +def test_closed_putconn(dsn): + p = NullConnectionPool(dsn) + + with p.connection() as conn: + pass + assert conn.closed + + with p.connection() as conn: + p.close() + assert conn.closed + + +def test_closed_queue(dsn): + def w1(): + with p.connection() as conn: + e1.set() # Tell w0 that w1 got a connection + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + e2.wait() # Wait until w0 has tested w2 + success.append("w1") + + def w2(): + try: + with p.connection(): + pass # unexpected + except PoolClosed: + success.append("w2") + + e1 = Event() + e2 = Event() + + p = NullConnectionPool(dsn, max_size=1) + p.wait() + success: List[str] = [] + + t1 = Thread(target=w1) + t1.start() + # Wait until w1 has received a connection + e1.wait() + + t2 = Thread(target=w2) + t2.start() + # Wait until w2 is in the queue + while not p._waiting: + sleep(0) + + p.close(0) + + # Wait for the workers to finish + e2.set() + t1.join() + t2.join() + assert len(success) == 2 + + +def test_open_explicit(dsn): + p = NullConnectionPool(dsn, open=False) + assert p.closed + with pytest.raises(PoolClosed, match="is not open yet"): + p.getconn() + + with pytest.raises(PoolClosed): + with p.connection(): + pass + + p.open() + try: + assert not p.closed + + with p.connection() as conn: + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + + finally: + p.close() + + with pytest.raises(PoolClosed, match="is already closed"): + p.getconn() + + +def test_open_context(dsn): + p = NullConnectionPool(dsn, open=False) + assert p.closed + + with p: + assert not p.closed + + with p.connection() as conn: + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + + assert p.closed + + +def test_open_no_op(dsn): + p = NullConnectionPool(dsn) + try: + assert not p.closed + p.open() + assert not p.closed + + with p.connection() as conn: + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + + finally: + p.close() + + +def test_reopen(dsn): + p = NullConnectionPool(dsn) + with p.connection() as conn: + conn.execute("select 1") + p.close() + assert p._sched_runner is None + assert not p._workers + + with pytest.raises(psycopg.OperationalError, match="cannot be reused"): + p.open() + + +@pytest.mark.parametrize( + "min_size, max_size", [(1, None), (-1, None), (0, -2)] +) +def test_bad_resize(dsn, min_size, max_size): + with NullConnectionPool() as p: + with pytest.raises(ValueError): + p.resize(min_size=min_size, max_size=max_size) + + +@pytest.mark.slow +@pytest.mark.timing +def test_max_lifetime(dsn): + pids = [] + + def worker(p): + with p.connection() as conn: + pids.append(conn.pgconn.backend_pid) + sleep(0.1) + + ts = [] + with NullConnectionPool(dsn, max_size=1, max_lifetime=0.2) as p: + for i in range(5): + ts.append(Thread(target=worker, args=(p,))) + ts[-1].start() + + for t in ts: + t.join() + + assert pids[0] == pids[1] != pids[4], pids + + +def test_check(dsn): + with NullConnectionPool(dsn) as p: + # No-op + p.check() + + +@pytest.mark.slow +@pytest.mark.timing +def test_stats_measures(dsn): + def worker(n): + with p.connection() as conn: + conn.execute("select pg_sleep(0.2)") + + with NullConnectionPool(dsn, max_size=4) as p: + p.wait(2.0) + + stats = p.get_stats() + assert stats["pool_min"] == 0 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 0 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 0 + + ts = [Thread(target=worker, args=(i,)) for i in range(3)] + for t in ts: + t.start() + sleep(0.1) + stats = p.get_stats() + for t in ts: + t.join() + assert stats["pool_min"] == 0 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 3 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 0 + + p.wait(2.0) + ts = [Thread(target=worker, args=(i,)) for i in range(7)] + for t in ts: + t.start() + sleep(0.1) + stats = p.get_stats() + for t in ts: + t.join() + assert stats["pool_min"] == 0 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 4 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 3 + + +@pytest.mark.slow +@pytest.mark.timing +def test_stats_usage(dsn, retries): + def worker(n): + try: + with p.connection(timeout=0.3) as conn: + conn.execute("select pg_sleep(0.2)") + except PoolTimeout: + pass + + for retry in retries: + with retry: + with NullConnectionPool(dsn, max_size=3) as p: + p.wait(2.0) + + ts = [Thread(target=worker, args=(i,)) for i in range(7)] + for t in ts: + t.start() + for t in ts: + t.join() + stats = p.get_stats() + assert stats["requests_num"] == 7 + assert stats["requests_queued"] == 4 + assert 850 <= stats["requests_wait_ms"] <= 950 + assert stats["requests_errors"] == 1 + assert 1150 <= stats["usage_ms"] <= 1250 + assert stats.get("returns_bad", 0) == 0 + + with p.connection() as conn: + conn.close() + p.wait() + stats = p.pop_stats() + assert stats["requests_num"] == 8 + assert stats["returns_bad"] == 1 + with p.connection(): + pass + assert p.get_stats()["requests_num"] == 1 + + +@pytest.mark.slow +def test_stats_connect(dsn, proxy, monkeypatch): + proxy.start() + delay_connection(monkeypatch, 0.2) + with NullConnectionPool(proxy.client_dsn, max_size=3) as p: + p.wait() + stats = p.get_stats() + assert stats["connections_num"] == 1 + assert stats.get("connections_errors", 0) == 0 + assert stats.get("connections_lost", 0) == 0 + assert 200 <= stats["connections_ms"] < 300 diff --git a/tests/pool/test_null_pool_async.py b/tests/pool/test_null_pool_async.py new file mode 100644 index 000000000..9175307e7 --- /dev/null +++ b/tests/pool/test_null_pool_async.py @@ -0,0 +1,864 @@ +import sys +import asyncio +import logging +from time import time +from typing import Any, List, Tuple + +import pytest +from packaging.version import parse as ver # noqa: F401 # used in skipif + +import psycopg +from psycopg.pq import TransactionStatus +from psycopg._compat import create_task +from .test_pool_async import delay_connection + +pytestmark = [ + pytest.mark.asyncio, + pytest.mark.skipif( + sys.version_info < (3, 7), + reason="async pool not supported before Python 3.7", + ), +] + +try: + from psycopg_pool import AsyncNullConnectionPool # noqa: F401 + from psycopg_pool import PoolClosed, PoolTimeout, TooManyRequests +except ImportError: + pass + + +async def test_defaults(dsn): + async with AsyncNullConnectionPool(dsn) as p: + assert p.min_size == p.max_size == 0 + assert p.timeout == 30 + assert p.max_idle == 10 * 60 + assert p.max_lifetime == 60 * 60 + assert p.num_workers == 3 + + +async def test_min_size_max_size(dsn): + async with AsyncNullConnectionPool(dsn, min_size=0, max_size=2) as p: + assert p.min_size == 0 + assert p.max_size == 2 + + +@pytest.mark.parametrize( + "min_size, max_size", [(1, None), (-1, None), (0, -2)] +) +async def test_bad_size(dsn, min_size, max_size): + with pytest.raises(ValueError): + AsyncNullConnectionPool(min_size=min_size, max_size=max_size) + + +async def test_connection_class(dsn): + class MyConn(psycopg.AsyncConnection[Any]): + pass + + async with AsyncNullConnectionPool(dsn, connection_class=MyConn) as p: + async with p.connection() as conn: + assert isinstance(conn, MyConn) + + +async def test_kwargs(dsn): + async with AsyncNullConnectionPool(dsn, kwargs={"autocommit": True}) as p: + async with p.connection() as conn: + assert conn.autocommit + + +async def test_its_no_pool_at_all(dsn): + async with AsyncNullConnectionPool(dsn, max_size=2) as p: + async with p.connection() as conn: + cur = await conn.execute("select pg_backend_pid()") + (pid1,) = await cur.fetchone() # type: ignore[misc] + + async with p.connection() as conn2: + cur = await conn2.execute("select pg_backend_pid()") + (pid2,) = await cur.fetchone() # type: ignore[misc] + + async with p.connection() as conn: + assert conn.pgconn.backend_pid not in (pid1, pid2) + + +async def test_context(dsn): + async with AsyncNullConnectionPool(dsn) as p: + assert not p.closed + assert p.closed + + +@pytest.mark.slow +@pytest.mark.timing +async def test_wait_ready(dsn, monkeypatch): + delay_connection(monkeypatch, 0.2) + with pytest.raises(PoolTimeout): + async with AsyncNullConnectionPool(dsn, num_workers=1) as p: + await p.wait(0.1) + + async with AsyncNullConnectionPool(dsn, num_workers=1) as p: + await p.wait(0.4) + + +async def test_wait_closed(dsn): + async with AsyncNullConnectionPool(dsn) as p: + pass + + with pytest.raises(PoolClosed): + await p.wait() + + +@pytest.mark.slow +async def test_setup_no_timeout(dsn, proxy): + with pytest.raises(PoolTimeout): + async with AsyncNullConnectionPool( + proxy.client_dsn, num_workers=1 + ) as p: + await p.wait(0.2) + + async with AsyncNullConnectionPool(proxy.client_dsn, num_workers=1) as p: + await asyncio.sleep(0.5) + assert not p._pool + proxy.start() + + async with p.connection() as conn: + await conn.execute("select 1") + + +async def test_configure(dsn): + inits = 0 + + async def configure(conn): + nonlocal inits + inits += 1 + async with conn.transaction(): + await conn.execute("set default_transaction_read_only to on") + + async with AsyncNullConnectionPool(dsn, configure=configure) as p: + async with p.connection() as conn: + assert inits == 1 + res = await conn.execute("show default_transaction_read_only") + assert (await res.fetchone())[0] == "on" # type: ignore[index] + + async with p.connection() as conn: + assert inits == 2 + res = await conn.execute("show default_transaction_read_only") + assert (await res.fetchone())[0] == "on" # type: ignore[index] + await conn.close() + + async with p.connection() as conn: + assert inits == 3 + res = await conn.execute("show default_transaction_read_only") + assert (await res.fetchone())[0] == "on" # type: ignore[index] + + +@pytest.mark.slow +async def test_configure_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def configure(conn): + await conn.execute("select 1") + + async with AsyncNullConnectionPool(dsn, configure=configure) as p: + with pytest.raises(PoolTimeout): + await p.wait(timeout=0.5) + + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +@pytest.mark.slow +async def test_configure_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def configure(conn): + async with conn.transaction(): + await conn.execute("WAT") + + async with AsyncNullConnectionPool(dsn, configure=configure) as p: + with pytest.raises(PoolTimeout): + await p.wait(timeout=0.5) + + assert caplog.records + assert "WAT" in caplog.records[0].message + + +async def test_reset(dsn): + resets = 0 + + async def setup(conn): + async with conn.transaction(): + await conn.execute("set timezone to '+1:00'") + + async def reset(conn): + nonlocal resets + resets += 1 + async with conn.transaction(): + await conn.execute("set timezone to utc") + + pids = [] + + async def worker(): + async with p.connection() as conn: + assert resets == 1 + cur = await conn.execute("show timezone") + assert (await cur.fetchone()) == ("UTC",) + pids.append(conn.pgconn.backend_pid) + + async with AsyncNullConnectionPool(dsn, max_size=1, reset=reset) as p: + async with p.connection() as conn: + + # Queue the worker so it will take the same connection a second time + # instead of making a new one. + t = create_task(worker()) + + assert resets == 0 + await conn.execute("set timezone to '+2:00'") + pids.append(conn.pgconn.backend_pid) + + await asyncio.gather(t) + await p.wait() + + assert resets == 1 + assert pids[0] == pids[1] + + +async def test_reset_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def reset(conn): + await conn.execute("reset all") + + pids = [] + + async def worker(): + async with p.connection() as conn: + await conn.execute("select 1") + pids.append(conn.pgconn.backend_pid) + + async with AsyncNullConnectionPool(dsn, max_size=1, reset=reset) as p: + async with p.connection() as conn: + + t = create_task(worker()) + + await conn.execute("select 1") + pids.append(conn.pgconn.backend_pid) + + await asyncio.gather(t) + + assert pids[0] != pids[1] + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +async def test_reset_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + + async def reset(conn): + async with conn.transaction(): + await conn.execute("WAT") + + pids = [] + + async def worker(): + async with p.connection() as conn: + await conn.execute("select 1") + pids.append(conn.pgconn.backend_pid) + + async with AsyncNullConnectionPool(dsn, max_size=1, reset=reset) as p: + async with p.connection() as conn: + + t = create_task(worker()) + + await conn.execute("select 1") + pids.append(conn.pgconn.backend_pid) + + await asyncio.gather(t) + + assert pids[0] != pids[1] + assert caplog.records + assert "WAT" in caplog.records[0].message + + +@pytest.mark.slow +@pytest.mark.skipif("ver(psycopg.__version__) < ver('3.0.8')") +async def test_no_queue_timeout(deaf_port): + async with AsyncNullConnectionPool( + kwargs={"host": "localhost", "port": deaf_port} + ) as p: + with pytest.raises(PoolTimeout): + async with p.connection(timeout=1): + pass + + +@pytest.mark.slow +@pytest.mark.timing +async def test_queue(dsn, retries): + async def worker(n): + t0 = time() + async with p.connection() as conn: + cur = await conn.execute( + "select pg_backend_pid() from pg_sleep(0.2)" + ) + (pid,) = await cur.fetchone() # type: ignore[misc] + t1 = time() + results.append((n, t1 - t0, pid)) + + async for retry in retries: + with retry: + results: List[Tuple[int, float, int]] = [] + async with AsyncNullConnectionPool(dsn, max_size=2) as p: + await p.wait() + ts = [create_task(worker(i)) for i in range(6)] + await asyncio.gather(*ts) + + times = [item[1] for item in results] + want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6] + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.2), times + + assert len(set(r[2] for r in results)) == 2, results + + +@pytest.mark.slow +async def test_queue_size(dsn): + async def worker(t, ev=None): + try: + async with p.connection(): + if ev: + ev.set() + await asyncio.sleep(t) + except TooManyRequests as e: + errors.append(e) + else: + success.append(True) + + errors: List[Exception] = [] + success: List[bool] = [] + + async with AsyncNullConnectionPool(dsn, max_size=1, max_waiting=3) as p: + await p.wait() + ev = asyncio.Event() + create_task(worker(0.3, ev)) + await ev.wait() + + ts = [create_task(worker(0.1)) for i in range(4)] + await asyncio.gather(*ts) + + assert len(success) == 4 + assert len(errors) == 1 + assert isinstance(errors[0], TooManyRequests) + assert p.name in str(errors[0]) + assert str(p.max_waiting) in str(errors[0]) + assert p.get_stats()["requests_errors"] == 1 + + +@pytest.mark.slow +@pytest.mark.timing +async def test_queue_timeout(dsn, retries): + async def worker(n): + t0 = time() + try: + async with p.connection() as conn: + cur = await conn.execute( + "select pg_backend_pid() from pg_sleep(0.2)" + ) + (pid,) = await cur.fetchone() # type: ignore[misc] + except PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + async for retry in retries: + with retry: + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + async with AsyncNullConnectionPool( + dsn, max_size=2, timeout=0.1 + ) as p: + ts = [create_task(worker(i)) for i in range(4)] + await asyncio.gather(*ts) + + assert len(results) == 2 + assert len(errors) == 2 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +@pytest.mark.slow +@pytest.mark.timing +async def test_dead_client(dsn): + async def worker(i, timeout): + try: + async with p.connection(timeout=timeout) as conn: + await conn.execute("select pg_sleep(0.3)") + results.append(i) + except PoolTimeout: + if timeout > 0.2: + raise + + async with AsyncNullConnectionPool(dsn, max_size=2) as p: + results: List[int] = [] + ts = [ + create_task(worker(i, timeout)) + for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4]) + ] + await asyncio.gather(*ts) + + await asyncio.sleep(0.2) + assert set(results) == set([0, 1, 3, 4]) + + +@pytest.mark.slow +@pytest.mark.timing +async def test_queue_timeout_override(dsn, retries): + async def worker(n): + t0 = time() + timeout = 0.25 if n == 3 else None + try: + async with p.connection(timeout=timeout) as conn: + cur = await conn.execute( + "select pg_backend_pid() from pg_sleep(0.2)" + ) + (pid,) = await cur.fetchone() # type: ignore[misc] + except PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + async for retry in retries: + with retry: + results: List[Tuple[int, float, int]] = [] + errors: List[Tuple[int, float, Exception]] = [] + + async with AsyncNullConnectionPool( + dsn, max_size=2, timeout=0.1 + ) as p: + ts = [create_task(worker(i)) for i in range(4)] + await asyncio.gather(*ts) + + assert len(results) == 3 + assert len(errors) == 1 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +async def test_broken_reconnect(dsn): + async with AsyncNullConnectionPool(dsn, max_size=1) as p: + async with p.connection() as conn: + cur = await conn.execute("select pg_backend_pid()") + (pid1,) = await cur.fetchone() # type: ignore[misc] + await conn.close() + + async with p.connection() as conn2: + cur = await conn2.execute("select pg_backend_pid()") + (pid2,) = await cur.fetchone() # type: ignore[misc] + + assert pid1 != pid2 + + +async def test_intrans_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + async def worker(): + async with p.connection() as conn: + pids.append(conn.pgconn.backend_pid) + assert conn.pgconn.transaction_status == TransactionStatus.IDLE + cur = await conn.execute( + "select 1 from pg_class where relname = 'test_intrans_rollback'" + ) + assert not await cur.fetchone() + + async with AsyncNullConnectionPool(dsn, max_size=1) as p: + conn = await p.getconn() + + # Queue the worker so it will take the connection a second time instead + # of making a new one. + t = create_task(worker()) + + pids.append(conn.pgconn.backend_pid) + await conn.execute("create table test_intrans_rollback ()") + assert conn.pgconn.transaction_status == TransactionStatus.INTRANS + await p.putconn(conn) + await asyncio.gather(t) + + assert pids[0] == pids[1] + assert len(caplog.records) == 1 + assert "INTRANS" in caplog.records[0].message + + +async def test_inerror_rollback(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + async def worker(): + async with p.connection() as conn: + pids.append(conn.pgconn.backend_pid) + assert conn.pgconn.transaction_status == TransactionStatus.IDLE + + async with AsyncNullConnectionPool(dsn, max_size=1) as p: + conn = await p.getconn() + + t = create_task(worker()) + + pids.append(conn.pgconn.backend_pid) + with pytest.raises(psycopg.ProgrammingError): + await conn.execute("wat") + assert conn.pgconn.transaction_status == TransactionStatus.INERROR + await p.putconn(conn) + await asyncio.gather(t) + + assert pids[0] == pids[1] + assert len(caplog.records) == 1 + assert "INERROR" in caplog.records[0].message + + +async def test_active_close(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + async def worker(): + async with p.connection() as conn: + pids.append(conn.pgconn.backend_pid) + assert conn.pgconn.transaction_status == TransactionStatus.IDLE + + async with AsyncNullConnectionPool(dsn, max_size=1) as p: + conn = await p.getconn() + + t = create_task(worker()) + + pids.append(conn.pgconn.backend_pid) + cur = conn.cursor() + async with cur.copy( + "copy (select * from generate_series(1, 10)) to stdout" + ): + pass + assert conn.pgconn.transaction_status == TransactionStatus.ACTIVE + await p.putconn(conn) + await asyncio.gather(t) + + assert pids[0] != pids[1] + assert len(caplog.records) == 2 + assert "ACTIVE" in caplog.records[0].message + assert "BAD" in caplog.records[1].message + + +async def test_fail_rollback_close(dsn, caplog, monkeypatch): + caplog.set_level(logging.WARNING, logger="psycopg.pool") + pids = [] + + async def worker(): + async with p.connection() as conn: + pids.append(conn.pgconn.backend_pid) + assert conn.pgconn.transaction_status == TransactionStatus.IDLE + + async with AsyncNullConnectionPool(dsn, max_size=1) as p: + conn = await p.getconn() + t = create_task(worker()) + + async def bad_rollback(): + conn.pgconn.finish() + await orig_rollback() + + # Make the rollback fail + orig_rollback = conn.rollback + monkeypatch.setattr(conn, "rollback", bad_rollback) + + pids.append(conn.pgconn.backend_pid) + with pytest.raises(psycopg.ProgrammingError): + await conn.execute("wat") + assert conn.pgconn.transaction_status == TransactionStatus.INERROR + await p.putconn(conn) + await asyncio.gather(t) + + assert pids[0] != pids[1] + assert len(caplog.records) == 3 + assert "INERROR" in caplog.records[0].message + assert "OperationalError" in caplog.records[1].message + assert "BAD" in caplog.records[2].message + + +async def test_close_no_tasks(dsn): + p = AsyncNullConnectionPool(dsn) + assert p._sched_runner and not p._sched_runner.done() + assert p._workers + workers = p._workers[:] + for t in workers: + assert not t.done() + + await p.close() + assert p._sched_runner is None + assert not p._workers + for t in workers: + assert t.done() + + +async def test_putconn_no_pool(dsn): + async with AsyncNullConnectionPool(dsn) as p: + conn = await psycopg.AsyncConnection.connect(dsn) + with pytest.raises(ValueError): + await p.putconn(conn) + + await conn.close() + + +async def test_putconn_wrong_pool(dsn): + async with AsyncNullConnectionPool(dsn) as p1: + async with AsyncNullConnectionPool(dsn) as p2: + conn = await p1.getconn() + with pytest.raises(ValueError): + await p2.putconn(conn) + + +async def test_closed_getconn(dsn): + p = AsyncNullConnectionPool(dsn) + assert not p.closed + async with p.connection(): + pass + + await p.close() + assert p.closed + + with pytest.raises(PoolClosed): + async with p.connection(): + pass + + +async def test_closed_putconn(dsn): + p = AsyncNullConnectionPool(dsn) + + async with p.connection() as conn: + pass + assert conn.closed + + async with p.connection() as conn: + await p.close() + assert conn.closed + + +async def test_closed_queue(dsn): + async def w1(): + async with p.connection() as conn: + e1.set() # Tell w0 that w1 got a connection + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + await e2.wait() # Wait until w0 has tested w2 + success.append("w1") + + async def w2(): + try: + async with p.connection(): + pass # unexpected + except PoolClosed: + success.append("w2") + + e1 = asyncio.Event() + e2 = asyncio.Event() + + p = AsyncNullConnectionPool(dsn, max_size=1) + await p.wait() + success: List[str] = [] + + t1 = create_task(w1()) + # Wait until w1 has received a connection + await e1.wait() + + t2 = create_task(w2()) + # Wait until w2 is in the queue + while not p._waiting: + await asyncio.sleep(0) + + await p.close() + + # Wait for the workers to finish + e2.set() + await asyncio.gather(t1, t2) + assert len(success) == 2 + + +async def test_open_explicit(dsn): + p = AsyncNullConnectionPool(dsn, open=False) + assert p.closed + with pytest.raises(PoolClosed): + await p.getconn() + + with pytest.raises(PoolClosed, match="is not open yet"): + async with p.connection(): + pass + + await p.open() + try: + assert not p.closed + + async with p.connection() as conn: + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + + finally: + await p.close() + + with pytest.raises(PoolClosed, match="is already closed"): + await p.getconn() + + +async def test_open_context(dsn): + p = AsyncNullConnectionPool(dsn, open=False) + assert p.closed + + async with p: + assert not p.closed + + async with p.connection() as conn: + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + + assert p.closed + + +async def test_open_no_op(dsn): + p = AsyncNullConnectionPool(dsn) + try: + assert not p.closed + await p.open() + assert not p.closed + + async with p.connection() as conn: + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + + finally: + await p.close() + + +async def test_reopen(dsn): + p = AsyncNullConnectionPool(dsn) + async with p.connection() as conn: + await conn.execute("select 1") + await p.close() + assert p._sched_runner is None + + with pytest.raises(psycopg.OperationalError, match="cannot be reused"): + await p.open() + + +@pytest.mark.parametrize( + "min_size, max_size", [(1, None), (-1, None), (0, -2)] +) +async def test_bad_resize(dsn, min_size, max_size): + async with AsyncNullConnectionPool() as p: + with pytest.raises(ValueError): + await p.resize(min_size=min_size, max_size=max_size) + + +@pytest.mark.slow +@pytest.mark.timing +async def test_max_lifetime(dsn): + pids: List[int] = [] + + async def worker(): + async with p.connection() as conn: + pids.append(conn.pgconn.backend_pid) + await asyncio.sleep(0.1) + + async with AsyncNullConnectionPool(dsn, max_size=1, max_lifetime=0.2) as p: + ts = [create_task(worker()) for i in range(5)] + await asyncio.gather(*ts) + + assert pids[0] == pids[1] != pids[4], pids + + +async def test_check(dsn): + # no.op + async with AsyncNullConnectionPool(dsn) as p: + await p.check() + + +@pytest.mark.slow +@pytest.mark.timing +async def test_stats_measures(dsn): + async def worker(n): + async with p.connection() as conn: + await conn.execute("select pg_sleep(0.2)") + + async with AsyncNullConnectionPool(dsn, max_size=4) as p: + await p.wait(2.0) + + stats = p.get_stats() + assert stats["pool_min"] == 0 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 0 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 0 + + ts = [create_task(worker(i)) for i in range(3)] + await asyncio.sleep(0.1) + stats = p.get_stats() + await asyncio.gather(*ts) + assert stats["pool_min"] == 0 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 3 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 0 + + await p.wait(2.0) + ts = [create_task(worker(i)) for i in range(7)] + await asyncio.sleep(0.1) + stats = p.get_stats() + await asyncio.gather(*ts) + assert stats["pool_min"] == 0 + assert stats["pool_max"] == 4 + assert stats["pool_size"] == 4 + assert stats["pool_available"] == 0 + assert stats["requests_waiting"] == 3 + + +@pytest.mark.slow +@pytest.mark.timing +async def test_stats_usage(dsn, retries): + async def worker(n): + try: + async with p.connection(timeout=0.3) as conn: + await conn.execute("select pg_sleep(0.2)") + except PoolTimeout: + pass + + async for retry in retries: + with retry: + async with AsyncNullConnectionPool(dsn, max_size=3) as p: + await p.wait(2.0) + + ts = [create_task(worker(i)) for i in range(7)] + await asyncio.gather(*ts) + stats = p.get_stats() + assert stats["requests_num"] == 7 + assert stats["requests_queued"] == 4 + assert 850 <= stats["requests_wait_ms"] <= 950 + assert stats["requests_errors"] == 1 + assert 1150 <= stats["usage_ms"] <= 1250 + assert stats.get("returns_bad", 0) == 0 + + async with p.connection() as conn: + await conn.close() + await p.wait() + stats = p.pop_stats() + assert stats["requests_num"] == 8 + assert stats["returns_bad"] == 1 + async with p.connection(): + pass + assert p.get_stats()["requests_num"] == 1 + + +@pytest.mark.slow +async def test_stats_connect(dsn, proxy, monkeypatch): + proxy.start() + delay_connection(monkeypatch, 0.2) + async with AsyncNullConnectionPool(proxy.client_dsn, max_size=3) as p: + await p.wait() + stats = p.get_stats() + assert stats["connections_num"] == 1 + assert stats.get("connections_errors", 0) == 0 + assert stats.get("connections_lost", 0) == 0 + assert 200 <= stats["connections_ms"] < 300 diff --git a/tests/pool/test_pool.py b/tests/pool/test_pool.py index c24759a31..17cac2bcb 100644 --- a/tests/pool/test_pool.py +++ b/tests/pool/test_pool.py @@ -488,7 +488,7 @@ def test_intrans_rollback(dsn, caplog): with p.connection() as conn2: assert conn2.pgconn.backend_pid == pid assert conn2.pgconn.transaction_status == TransactionStatus.IDLE - assert not conn.execute( + assert not conn2.execute( "select 1 from pg_class where relname = 'test_intrans_rollback'" ).fetchone()