From: Daniele Varrazzo Date: Sat, 27 Feb 2021 01:09:44 +0000 (+0100) Subject: Add async connection pool X-Git-Tag: 3.0.dev0~87^2~42 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=76a85c5595aefb8bc2cddc5a4e165b8a42c161a1;p=thirdparty%2Fpsycopg.git Add async connection pool --- diff --git a/psycopg3/psycopg3/connection.py b/psycopg3/psycopg3/connection.py index 8ed35f7f9..3997061e3 100644 --- a/psycopg3/psycopg3/connection.py +++ b/psycopg3/psycopg3/connection.py @@ -46,7 +46,7 @@ execute: Callable[["PGconn"], PQGen[List["PGresult"]]] if TYPE_CHECKING: from .pq.proto import PGconn, PGresult - from .pool import ConnectionPool + from .pool.base import BasePool if pq.__impl__ == "c": from psycopg3_c import _psycopg3 @@ -127,7 +127,7 @@ class BaseConnection(AdaptContext): # Attribute is only set if the connection is from a pool so we can tell # apart a connection in the pool too (when _pool = None) - self._pool: Optional["ConnectionPool"] + self._pool: Optional["BasePool[Any]"] def __del__(self) -> None: # If fails on connection we might not have this attribute yet @@ -644,7 +644,9 @@ class AsyncConnection(BaseConnection): else: await self.commit() - await self.close() + # Close the connection only if it doesn't belong to a pool. + if not getattr(self, "_pool", None): + await self.close() async def close(self) -> None: self.pgconn.finish() diff --git a/psycopg3/psycopg3/pool/__init__.py b/psycopg3/psycopg3/pool/__init__.py index 327dcfc86..91f349693 100644 --- a/psycopg3/psycopg3/pool/__init__.py +++ b/psycopg3/psycopg3/pool/__init__.py @@ -5,6 +5,12 @@ psycopg3 connection pool package # Copyright (C) 2021 The Psycopg Team from .pool import ConnectionPool +from .async_pool import AsyncConnectionPool from .errors import PoolClosed, PoolTimeout -__all__ = ["ConnectionPool", "PoolClosed", "PoolTimeout"] +__all__ = [ + "AsyncConnectionPool", + "ConnectionPool", + "PoolClosed", + "PoolTimeout", +] diff --git a/psycopg3/psycopg3/pool/async_pool.py b/psycopg3/psycopg3/pool/async_pool.py new file mode 100644 index 000000000..528ffa4ad --- /dev/null +++ b/psycopg3/psycopg3/pool/async_pool.py @@ -0,0 +1,444 @@ +""" +psycopg3 synchronous connection pool +""" + +# Copyright (C) 2021 The Psycopg Team + +import asyncio +import logging +from time import monotonic +from typing import Any, Awaitable, Callable, Deque, AsyncIterator, Optional +from contextlib import asynccontextmanager +from collections import deque + +from ..pq import TransactionStatus +from ..connection import AsyncConnection + +from . import tasks +from .base import ConnectionAttempt, BasePool +from .errors import PoolClosed, PoolTimeout + +logger = logging.getLogger(__name__) + + +class AsyncConnectionPool(BasePool[AsyncConnection]): + def __init__( + self, + conninfo: str = "", + configure: Optional[ + Callable[[AsyncConnection], Awaitable[None]] + ] = None, + **kwargs: Any, + ): + self._configure = configure + + self._lock = asyncio.Lock() + self._waiting: Deque["AsyncClient"] = deque() + + # to notify that the pool is full + self._pool_full_event: Optional[asyncio.Event] = None + + self.loop = asyncio.get_event_loop() + + super().__init__(conninfo, **kwargs) + + async def wait_ready(self, timeout: float = 30.0) -> None: + """ + Wait for the pool to be full after init. + + Raise `PoolTimeout` if not ready within *timeout* sec. + """ + async with self._lock: + assert not self._pool_full_event + if len(self._pool) >= self._nconns: + return + self._pool_full_event = asyncio.Event() + + try: + await asyncio.wait_for(self._pool_full_event.wait(), timeout) + except asyncio.TimeoutError: + await self.close() # stop all the threads + raise PoolTimeout( + f"pool initialization incomplete after {timeout} sec" + ) + + async with self._lock: + self._pool_full_event = None + + @asynccontextmanager + async def connection( + self, timeout: Optional[float] = None + ) -> AsyncIterator[AsyncConnection]: + """Context manager to obtain a connection from the pool. + + Returned the connection immediately if available, otherwise wait up to + *timeout* or `self.timeout` and throw `PoolTimeout` if a connection is + not available in time. + + Upon context exit, return the connection to the pool. Apply the normal + connection context behaviour (commit/rollback the transaction in case + of success/error). If the connection is no more in working state + replace it with a new one. + """ + conn = await self.getconn(timeout=timeout) + try: + async with conn: + yield conn + finally: + await self.putconn(conn) + + async def getconn( + self, timeout: Optional[float] = None + ) -> AsyncConnection: + """Obtain a contection from the pool. + + You should preferrably use `connection()`. Use this function only if + it is not possible to use the connection as context manager. + + After using this function you *must* call a corresponding `putconn()`: + failing to do so will deplete the pool. A depleted pool is a sad pool: + you don't want a depleted pool. + """ + logger.info("connection requested to %r", self.name) + # Critical section: decide here if there's a connection ready + # or if the client needs to wait. + async with self._lock: + if self._closed: + raise PoolClosed(f"the pool {self.name!r} is closed") + + pos: Optional[AsyncClient] = None + if self._pool: + # Take a connection ready out of the pool + conn = self._pool.popleft() + if len(self._pool) < self._nconns_min: + self._nconns_min = len(self._pool) + else: + # No connection available: put the client in the waiting queue + pos = AsyncClient() + self._waiting.append(pos) + + # If there is space for the pool to grow, let's do it + if self._nconns < self.maxconn: + self._nconns += 1 + logger.info( + "growing pool %r to %s", self.name, self._nconns + ) + self.run_task(tasks.AddConnection(self)) + + # If we are in the waiting queue, wait to be assigned a connection + # (outside the critical section, so only the waiting client is locked) + if pos: + if timeout is None: + timeout = self.timeout + conn = await pos.wait(timeout=timeout) + + # Tell the connection it belongs to a pool to avoid closing on __exit__ + # Note that this property shouldn't be set while the connection is in + # the pool, to avoid to create a reference loop. + conn._pool = self + logger.info("connection given by %r", self.name) + return conn + + async def putconn(self, conn: AsyncConnection) -> None: + """Return a connection to the loving hands of its pool. + + Use this function only paired with a `getconn()`. You don't need to use + it if you use the much more comfortable `connection()` context manager. + """ + # Quick check to discard the wrong connection + pool = getattr(conn, "_pool", None) + if pool is not self: + if pool: + msg = f"it comes from pool {pool.name!r}" + else: + msg = "it doesn't come from any pool" + raise ValueError( + f"can't return connection to pool {self.name!r}, {msg}: {conn}" + ) + + logger.info("returning connection to %r", self.name) + + # If the pool is closed just close the connection instead of returning + # it to the pool. For extra refcare remove the pool reference from it. + if self._closed: + conn._pool = None + await conn.close() + return + + # Use a worker to perform eventual maintenance work in a separate thread + self.run_task(tasks.ReturnConnection(self, conn)) + + async def close(self, timeout: float = 1.0) -> None: + """Close the pool and make it unavailable to new clients. + + All the waiting and future client will fail to acquire a connection + with a `PoolClosed` exception. Currently used connections will not be + closed until returned to the pool. + + Wait *timeout* for threads to terminate their job, if positive. + """ + if self._closed: + return + + async with self._lock: + self._closed = True + logger.debug("pool %r closed", self.name) + + # Take waiting client and pool connections out of the state + waiting = list(self._waiting) + self._waiting.clear() + pool = list(self._pool) + self._pool.clear() + + # Now that the flag _closed is set, getconn will fail immediately, + # putconn will just close the returned connection. + + # Stop the scheduler + self._sched.enter(0, None) + + # Stop the worker threads + for i in range(len(self._workers)): + self.run_task(tasks.StopWorker(self)) + + # Signal to eventual clients in the queue that business is closed. + for pos in waiting: + await pos.fail(PoolClosed(f"the pool {self.name!r} is closed")) + + # Close the connections still in the pool + for conn in pool: + await conn.close() + + # Wait for the worker threads to terminate + if timeout > 0: + loop = asyncio.get_running_loop() + for t in [self._sched_runner] + self._workers: + if not t.is_alive(): + continue + await loop.run_in_executor(None, lambda: t.join(timeout)) + if t.is_alive(): + logger.warning( + "couldn't stop thread %s in pool %r within %s seconds", + t, + self.name, + timeout, + ) + + async def configure(self, conn: AsyncConnection) -> None: + """Configure a connection after creation.""" + if self._configure: + await self._configure(conn) + + def reconnect_failed(self) -> None: + """ + Called when reconnection failed for longer than `reconnect_timeout`. + """ + self._reconnect_failed(self) + + async def _connect(self) -> AsyncConnection: + """Return a new connection configured for the pool.""" + conn = await AsyncConnection.connect(self.conninfo, **self.kwargs) + await self.configure(conn) + conn._pool = self + return conn + + async def _add_connection( + self, attempt: Optional[ConnectionAttempt] + ) -> None: + """Try to connect and add the connection to the pool. + + If failed, reschedule a new attempt in the future for a few times, then + give up, decrease the pool connections number and call + `self.reconnect_failed()`. + + """ + now = monotonic() + if not attempt: + attempt = ConnectionAttempt( + reconnect_timeout=self.reconnect_timeout + ) + + try: + conn = await self._connect() + except Exception as e: + logger.warning(f"error connecting in {self.name!r}: {e}") + if attempt.time_to_give_up(now): + logger.warning( + "reconnection attempt in pool %r failed after %s sec", + self.name, + self.reconnect_timeout, + ) + async with self._lock: + self._nconns -= 1 + self.reconnect_failed() + else: + attempt.update_delay(now) + self.schedule_task( + tasks.AddConnection(self, attempt), attempt.delay + ) + else: + await self._add_to_pool(conn) + + async def _return_connection(self, conn: AsyncConnection) -> None: + """ + Return a connection to the pool after usage. + """ + await self._reset_connection(conn) + if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN: + # Connection no more in working state: create a new one. + logger.warning("discarding closed connection: %s", conn) + self.run_task(tasks.AddConnection(self)) + else: + await self._add_to_pool(conn) + + async def _add_to_pool(self, conn: AsyncConnection) -> None: + """ + Add a connection to the pool. + + The connection can be a fresh one or one already used in the pool. + + If a client is already waiting for a connection pass it on, otherwise + put it back into the pool + """ + # Remove the pool reference from the connection before returning it + # to the state, to avoid to create a reference loop. + # Also disable the warning for open connection in conn.__del__ + conn._pool = None + + pos: Optional[AsyncClient] = None + + # Critical section: if there is a client waiting give it the connection + # otherwise put it back into the pool. + async with self._lock: + while self._waiting: + # If there is a client waiting (which is still waiting and + # hasn't timed out), give it the connection and notify it. + pos = self._waiting.popleft() + if await pos.set(conn): + break + else: + # No client waiting for a connection: put it back into the pool + self._pool.append(conn) + + # If we have been asked to wait for pool init, notify the + # waiter if the pool is full. + if self._pool_full_event and len(self._pool) >= self._nconns: + self._pool_full_event.set() + + async def _reset_connection(self, conn: AsyncConnection) -> None: + """ + Bring a connection to IDLE state or close it. + """ + status = conn.pgconn.transaction_status + if status == TransactionStatus.IDLE: + return + + if status in (TransactionStatus.INTRANS, TransactionStatus.INERROR): + # Connection returned with an active transaction + logger.warning("rolling back returned connection: %s", conn) + try: + await conn.rollback() + except Exception as e: + logger.warning( + "rollback failed: %s: %s. Discarding connection %s", + e.__class__.__name__, + e, + conn, + ) + await conn.close() + + elif status == TransactionStatus.ACTIVE: + # Connection returned during an operation. Bad... just close it. + logger.warning("closing returned connection: %s", conn) + await conn.close() + + async def _shrink_pool(self) -> None: + to_close: Optional[AsyncConnection] = None + + async with self._lock: + # Reset the min number of connections used + nconns_min = self._nconns_min + self._nconns_min = len(self._pool) + + # If the pool can shrink and connections were unused, drop one + if self._nconns > self.minconn and nconns_min > 0: + to_close = self._pool.popleft() + self._nconns -= 1 + + if to_close: + logger.info( + "shrinking pool %r to %s because %s unused connections" + " in the last %s sec", + self.name, + self._nconns, + nconns_min, + self.max_idle, + ) + await to_close.close() + + +class AsyncClient: + """A position in a queue for a client waiting for a connection.""" + + __slots__ = ("conn", "error", "_cond") + + def __init__(self) -> None: + self.conn: Optional[AsyncConnection] = None + self.error: Optional[Exception] = None + + # The AsyncClient behaves in a way similar to an Event, but we need + # to notify reliably the flagger that the waiter has "accepted" the + # message and it hasn't timed out yet, otherwise the pool may give a + # connection to a client that has already timed out getconn(), which + # will be lost. + self._cond = asyncio.Condition() + + async def wait(self, timeout: float) -> AsyncConnection: + """Wait for a connection to be set and return it. + + Raise an exception if the wait times out or if fail() is called. + """ + async with self._cond: + if not (self.conn or self.error): + try: + await asyncio.wait_for(self._cond.wait(), timeout) + except asyncio.TimeoutError: + self.error = PoolTimeout( + f"couldn't get a connection after {timeout} sec" + ) + + if self.conn: + return self.conn + else: + assert self.error + raise self.error + + async def set(self, conn: AsyncConnection) -> bool: + """Signal the client waiting that a connection is ready. + + Return True if the client has "accepted" the connection, False + otherwise (typically because wait() has timed out). + """ + async with self._cond: + if self.conn or self.error: + return False + + self.conn = conn + self._cond.notify_all() + return True + + async def fail(self, error: Exception) -> bool: + """Signal the client that, alas, they won't have a connection today. + + Return True if the client has "accepted" the error, False otherwise + (typically because wait() has timed out). + """ + async with self._cond: + if self.conn or self.error: + return False + + self.error = error + self._cond.notify_all() + return True + + +tasks.AsyncConnectionPool = AsyncConnectionPool # type: ignore diff --git a/psycopg3/psycopg3/pool/base.py b/psycopg3/psycopg3/pool/base.py index 4d279cb87..0f74ab597 100644 --- a/psycopg3/psycopg3/pool/base.py +++ b/psycopg3/psycopg3/pool/base.py @@ -31,7 +31,6 @@ class BasePool(Generic[ConnectionType]): conninfo: str = "", *, kwargs: Optional[Dict[str, Any]] = None, - configure: Optional[Callable[[ConnectionType], None]] = None, minconn: int = 4, maxconn: Optional[int] = None, name: Optional[str] = None, @@ -52,15 +51,13 @@ class BasePool(Generic[ConnectionType]): ) if not name: num = BasePool._num_pool = BasePool._num_pool + 1 - name = f"pool-{num - 1}" + name = f"pool-{num}" if num_workers < 1: raise ValueError("num_workers must be at least 1") self.conninfo = conninfo self.kwargs: Dict[str, Any] = kwargs or {} - self._configure: Callable[[ConnectionType], None] - self._configure = configure or (lambda conn: None) self._reconnect_failed: Callable[["BasePool[ConnectionType]"], None] self._reconnect_failed = reconnect_failed or (lambda pool: None) self.name = name @@ -86,12 +83,15 @@ class BasePool(Generic[ConnectionType]): self._workers: List[threading.Thread] = [] for i in range(num_workers): t = threading.Thread( - target=self.worker, args=(self._tasks,), daemon=True + target=self.worker, + args=(self._tasks,), + name=f"{self.name}-worker-{i}", + daemon=True, ) self._workers.append(t) self._sched_runner = threading.Thread( - target=self._sched.run, daemon=True + target=self._sched.run, name=f"{self.name}-scheduler", daemon=True ) # _close should be the last property to be set in the state @@ -145,7 +145,7 @@ class BasePool(Generic[ConnectionType]): self._tasks.put_nowait(task) def schedule_task( - self, task: tasks.MaintenanceTask[ConnectionType], delay: float + self, task: tasks.MaintenanceTask[Any], delay: float ) -> None: """Run a maintenance task in a worker thread in the future.""" self._sched.enter(delay, task.tick) @@ -162,7 +162,7 @@ class BasePool(Generic[ConnectionType]): # Don't make all the workers time out at the same moment timeout = WORKER_TIMEOUT * (0.9 + 0.1 * random.random()) while True: - # Use a timeout to make the wait unterruptable + # Use a timeout to make the wait interruptable try: task = q.get(timeout=timeout) except Empty: @@ -177,6 +177,10 @@ class BasePool(Generic[ConnectionType]): ) if isinstance(task, tasks.StopWorker): + logger.debug( + "terminating working thread %s", + threading.current_thread().name, + ) return diff --git a/psycopg3/psycopg3/pool/pool.py b/psycopg3/psycopg3/pool/pool.py index dfdd7c8df..b6d0e9257 100644 --- a/psycopg3/psycopg3/pool/pool.py +++ b/psycopg3/psycopg3/pool/pool.py @@ -4,10 +4,10 @@ psycopg3 synchronous connection pool # Copyright (C) 2021 The Psycopg Team -import time import logging import threading -from typing import Any, Deque, Iterator, Optional +from time import monotonic +from typing import Any, Callable, Deque, Iterator, Optional from contextlib import contextmanager from collections import deque @@ -22,7 +22,15 @@ logger = logging.getLogger(__name__) class ConnectionPool(BasePool[Connection]): - def __init__(self, conninfo: str = "", **kwargs: Any): + def __init__( + self, + conninfo: str = "", + configure: Optional[Callable[[Connection], None]] = None, + **kwargs: Any, + ): + self._configure: Callable[[Connection], None] + self._configure = configure or (lambda conn: None) + self._lock = threading.RLock() self._waiting: Deque["WaitingClient"] = deque() @@ -167,6 +175,7 @@ class ConnectionPool(BasePool[Connection]): with self._lock: self._closed = True + logger.debug("pool %r closed", self.name) # Take waiting client and pool connections out of the state waiting = list(self._waiting) @@ -231,7 +240,7 @@ class ConnectionPool(BasePool[Connection]): `self.reconnect_failed()`. """ - now = time.monotonic() + now = monotonic() if not attempt: attempt = ConnectionAttempt( reconnect_timeout=self.reconnect_timeout @@ -295,7 +304,6 @@ class ConnectionPool(BasePool[Connection]): pos = self._waiting.popleft() if pos.set(conn): break - else: # No client waiting for a connection: put it back into the pool self._pool.append(conn) @@ -371,7 +379,7 @@ class WaitingClient: # message and it hasn't timed out yet, otherwise the pool may give a # connection to a client that has already timed out getconn(), which # will be lost. - self._cond = threading.Condition(threading.Lock()) + self._cond = threading.Condition() def wait(self, timeout: float) -> Connection: """Wait for a connection to be set and return it. @@ -418,3 +426,6 @@ class WaitingClient: self.error = error self._cond.notify_all() return True + + +tasks.ConnectionPool = ConnectionPool # type: ignore diff --git a/psycopg3/psycopg3/pool/tasks.py b/psycopg3/psycopg3/pool/tasks.py index e9fac5ddc..6028c3a88 100644 --- a/psycopg3/psycopg3/pool/tasks.py +++ b/psycopg3/psycopg3/pool/tasks.py @@ -4,16 +4,24 @@ Maintenance tasks for the connection pools. # Copyright (C) 2021 The Psycopg Team +import asyncio import logging +import threading from abc import ABC, abstractmethod -from typing import Any, Generic, Optional, TYPE_CHECKING +from typing import Any, cast, Generic, Optional, Type, TYPE_CHECKING from weakref import ref from ..proto import ConnectionType +from .. import Connection, AsyncConnection if TYPE_CHECKING: + from .pool import ConnectionPool + from .async_pool import AsyncConnectionPool from .base import BasePool, ConnectionAttempt - from ..connection import Connection +else: + # Injected at pool.py and async_pool.py import + ConnectionPool: "Type[BasePool[Connection]]" + AsyncConnectionPool: "Type[BasePool[AsyncConnection]]" logger = logging.getLogger(__name__) @@ -21,9 +29,16 @@ logger = logging.getLogger(__name__) class MaintenanceTask(ABC, Generic[ConnectionType]): """A task to run asynchronously to maintain the pool state.""" + TIMEOUT = 10.0 + def __init__(self, pool: "BasePool[Any]"): + if isinstance(pool, AsyncConnectionPool): + self.event = threading.Event() + self.pool = ref(pool) - logger.debug("task created: %s", self) + logger.debug( + "task created in %s: %s", threading.current_thread().name, self + ) def __repr__(self) -> str: pool = self.pool() @@ -41,8 +56,20 @@ class MaintenanceTask(ABC, Generic[ConnectionType]): # Pool is no more working. Quietly discard the operation. return - logger.debug("task running: %s", self) - self._run(pool) + logger.debug( + "task running in %s: %s", threading.current_thread().name, self + ) + if isinstance(pool, ConnectionPool): + self._run(pool) + elif isinstance(pool, AsyncConnectionPool): + self.event.clear() + asyncio.run_coroutine_threadsafe(self._run_async(pool), pool.loop) + if not self.event.wait(self.TIMEOUT): + logger.warning( + "event %s didn't terminate after %s sec", self.TIMEOUT + ) + else: + logger.error("%s run got %s instead of a pool", self, pool) def tick(self) -> None: """Run the scheduled task @@ -58,16 +85,23 @@ class MaintenanceTask(ABC, Generic[ConnectionType]): pool.run_task(self) @abstractmethod - def _run(self, pool: "BasePool[Any]") -> None: + def _run(self, pool: "ConnectionPool") -> None: ... + @abstractmethod + async def _run_async(self, pool: "AsyncConnectionPool") -> None: + self.event.set() + class StopWorker(MaintenanceTask[ConnectionType]): """Signal the maintenance thread to terminate.""" - def _run(self, pool: "BasePool[Any]") -> None: + def _run(self, pool: "ConnectionPool") -> None: pass + async def _run_async(self, pool: "AsyncConnectionPool") -> None: + await super()._run_async(pool) + class AddConnection(MaintenanceTask[ConnectionType]): def __init__( @@ -78,29 +112,30 @@ class AddConnection(MaintenanceTask[ConnectionType]): super().__init__(pool) self.attempt = attempt - def _run(self, pool: "BasePool[Any]") -> None: - from . import ConnectionPool + def _run(self, pool: "ConnectionPool") -> None: + pool._add_connection(self.attempt) - if isinstance(pool, ConnectionPool): - pool._add_connection(self.attempt) - else: - assert False + async def _run_async(self, pool: "AsyncConnectionPool") -> None: + logger.debug("run async 1") + await pool._add_connection(self.attempt) + logger.debug("run async 2") + await super()._run_async(pool) + logger.debug("run async 3") class ReturnConnection(MaintenanceTask[ConnectionType]): """Clean up and return a connection to the pool.""" - def __init__(self, pool: "BasePool[Any]", conn: "Connection"): + def __init__(self, pool: "BasePool[Any]", conn: "ConnectionType"): super().__init__(pool) self.conn = conn - def _run(self, pool: "BasePool[Any]") -> None: - from . import ConnectionPool + def _run(self, pool: "ConnectionPool") -> None: + pool._return_connection(cast(Connection, self.conn)) - if isinstance(pool, ConnectionPool): - pool._return_connection(self.conn) - else: - assert False + async def _run_async(self, pool: "AsyncConnectionPool") -> None: + await pool._return_connection(cast(AsyncConnection, self.conn)) + await super()._run_async(pool) class ShrinkPool(MaintenanceTask[ConnectionType]): @@ -110,14 +145,13 @@ class ShrinkPool(MaintenanceTask[ConnectionType]): in the pool. """ - def _run(self, pool: "BasePool[Any]") -> None: + def _run(self, pool: "ConnectionPool") -> None: # Reschedule the task now so that in case of any error we don't lose # the periodic run. pool.schedule_task(self, pool.max_idle) + pool._shrink_pool() - from . import ConnectionPool - - if isinstance(pool, ConnectionPool): - pool._shrink_pool() - else: - assert False + async def _run_async(self, pool: "AsyncConnectionPool") -> None: + pool.schedule_task(self, pool.max_idle) + await pool._shrink_pool() + await super()._run_async(pool) diff --git a/tests/pool/test_pool.py b/tests/pool/test_pool.py index 87309b8ee..bb48f77c1 100644 --- a/tests/pool/test_pool.py +++ b/tests/pool/test_pool.py @@ -524,10 +524,10 @@ def test_shrink(dsn, monkeypatch): def test_reconnect(proxy, caplog, monkeypatch): caplog.set_level(logging.WARNING, logger="psycopg3.pool") - assert pool.pool.ConnectionAttempt.INITIAL_DELAY == 1.0 - assert pool.pool.ConnectionAttempt.DELAY_JITTER == 0.1 - monkeypatch.setattr(pool.pool.ConnectionAttempt, "INITIAL_DELAY", 0.1) - monkeypatch.setattr(pool.pool.ConnectionAttempt, "DELAY_JITTER", 0.0) + assert pool.base.ConnectionAttempt.INITIAL_DELAY == 1.0 + assert pool.base.ConnectionAttempt.DELAY_JITTER == 0.1 + monkeypatch.setattr(pool.base.ConnectionAttempt, "INITIAL_DELAY", 0.1) + monkeypatch.setattr(pool.base.ConnectionAttempt, "DELAY_JITTER", 0.0) proxy.start() p = pool.ConnectionPool(proxy.client_dsn, minconn=1) diff --git a/tests/pool/test_pool_async.py b/tests/pool/test_pool_async.py new file mode 100644 index 000000000..f407360bc --- /dev/null +++ b/tests/pool/test_pool_async.py @@ -0,0 +1,656 @@ +import sys +import asyncio +import logging +import weakref +from time import time +from collections import Counter + +import pytest + +import psycopg3 +from psycopg3 import pool +from psycopg3.pq import TransactionStatus + +create_task = ( + asyncio.create_task + if sys.version_info >= (3, 7) + else asyncio.ensure_future +) + +pytestmark = pytest.mark.asyncio + + +async def test_defaults(dsn): + p = pool.AsyncConnectionPool(dsn) + assert p.minconn == p.maxconn == 4 + assert p.timeout == 30 + assert p.max_idle == 600 + assert p.num_workers == 3 + await p.close() + + +async def test_minconn_maxconn(dsn): + p = pool.AsyncConnectionPool(dsn, minconn=2) + assert p.minconn == p.maxconn == 2 + await p.close() + + p = pool.AsyncConnectionPool(dsn, minconn=2, maxconn=4) + assert p.minconn == 2 + assert p.maxconn == 4 + await p.close() + + with pytest.raises(ValueError): + pool.AsyncConnectionPool(dsn, minconn=4, maxconn=2) + + +async def test_kwargs(dsn): + p = pool.AsyncConnectionPool(dsn, kwargs={"autocommit": True}, minconn=1) + async with p.connection() as conn: + assert conn.autocommit + + await p.close() + + +async def test_its_really_a_pool(dsn): + p = pool.AsyncConnectionPool(dsn, minconn=2) + async with p.connection() as conn: + cur = await conn.execute("select pg_backend_pid()") + (pid1,) = await cur.fetchone() + + async with p.connection() as conn2: + cur = await conn2.execute("select pg_backend_pid()") + (pid2,) = await cur.fetchone() + + async with p.connection() as conn: + assert conn.pgconn.backend_pid in (pid1, pid2) + + await p.close() + + +async def test_connection_not_lost(dsn): + p = pool.AsyncConnectionPool(dsn, minconn=1) + with pytest.raises(ZeroDivisionError): + async with p.connection() as conn: + pid = conn.pgconn.backend_pid + 1 / 0 + + async with p.connection() as conn2: + assert conn2.pgconn.backend_pid == pid + + await p.close() + + +@pytest.mark.slow +async def test_concurrent_filling(dsn, monkeypatch): + delay_connection(monkeypatch, 0.1) + t0 = time() + times = [] + + add_orig = pool.AsyncConnectionPool._add_to_pool + + async def add_time(self, conn): + times.append(time() - t0) + await add_orig(self, conn) + + monkeypatch.setattr(pool.AsyncConnectionPool, "_add_to_pool", add_time) + + p = pool.AsyncConnectionPool(dsn, minconn=5, num_workers=2) + await p.wait_ready(5.0) + want_times = [0.1, 0.1, 0.2, 0.2, 0.3] + assert len(times) == len(want_times) + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.2), times + await p.close() + + +@pytest.mark.slow +async def test_wait_ready(dsn, monkeypatch): + delay_connection(monkeypatch, 0.1) + with pytest.raises(pool.PoolTimeout): + p = pool.AsyncConnectionPool(dsn, minconn=4, num_workers=1) + await p.wait_ready(0.3) + + p = pool.AsyncConnectionPool(dsn, minconn=4, num_workers=1) + await p.wait_ready(0.5) + await p.close() + p = pool.AsyncConnectionPool(dsn, minconn=4, num_workers=2) + await p.wait_ready(0.3) + await p.wait_ready(0.0001) # idempotent + await p.close() + + +@pytest.mark.slow +async def test_setup_no_timeout(dsn, proxy): + with pytest.raises(pool.PoolTimeout): + p = pool.AsyncConnectionPool( + proxy.client_dsn, minconn=1, num_workers=1 + ) + await p.wait_ready(0.2) + + p = pool.AsyncConnectionPool(proxy.client_dsn, minconn=1, num_workers=1) + await asyncio.sleep(0.5) + assert not p._pool + proxy.start() + + async with p.connection() as conn: + await conn.execute("select 1") + + await p.close() + + +@pytest.mark.slow +async def test_queue(dsn): + p = pool.AsyncConnectionPool(dsn, minconn=2) + results = [] + + async def worker(n): + t0 = time() + async with p.connection() as conn: + cur = await conn.execute( + "select pg_backend_pid() from pg_sleep(0.2)" + ) + (pid,) = await cur.fetchone() + t1 = time() + results.append((n, t1 - t0, pid)) + + ts = [create_task(worker(i)) for i in range(6)] + await asyncio.gather(*ts) + await p.close() + + times = [item[1] for item in results] + want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6] + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.2), times + + assert len(set(r[2] for r in results)) == 2, results + + +@pytest.mark.slow +async def test_queue_timeout(dsn): + p = pool.AsyncConnectionPool(dsn, minconn=2, timeout=0.1) + results = [] + errors = [] + + async def worker(n): + t0 = time() + try: + async with p.connection() as conn: + cur = await conn.execute( + "select pg_backend_pid() from pg_sleep(0.2)" + ) + (pid,) = await cur.fetchone() + except pool.PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + ts = [create_task(worker(i)) for i in range(4)] + await asyncio.gather(*ts) + + assert len(results) == 2 + assert len(errors) == 2 + for e in errors: + assert 0.1 < e[1] < 0.15 + + await p.close() + + +@pytest.mark.slow +async def test_dead_client(dsn): + p = pool.AsyncConnectionPool(dsn, minconn=2) + + results = [] + + async def worker(i, timeout): + try: + async with p.connection(timeout=timeout) as conn: + await conn.execute("select pg_sleep(0.3)") + results.append(i) + except pool.PoolTimeout: + if timeout > 0.2: + raise + + ts = [ + create_task(worker(i, timeout)) + for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4]) + ] + await asyncio.gather(*ts) + + await asyncio.sleep(0.2) + assert set(results) == set([0, 1, 3, 4]) + assert len(p._pool) == 2 # no connection was lost + await p.close() + + +@pytest.mark.slow +async def test_queue_timeout_override(dsn): + p = pool.AsyncConnectionPool(dsn, minconn=2, timeout=0.1) + results = [] + errors = [] + + async def worker(n): + t0 = time() + timeout = 0.25 if n == 3 else None + try: + async with p.connection(timeout=timeout) as conn: + cur = await conn.execute( + "select pg_backend_pid() from pg_sleep(0.2)" + ) + (pid,) = await cur.fetchone() + except pool.PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + ts = [create_task(worker(i)) for i in range(4)] + await asyncio.gather(*ts) + await p.close() + + assert len(results) == 3 + assert len(errors) == 1 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +async def test_broken_reconnect(dsn): + p = pool.AsyncConnectionPool(dsn, minconn=1) + with pytest.raises(psycopg3.OperationalError): + async with p.connection() as conn: + cur = await conn.execute("select pg_backend_pid()") + (pid1,) = await cur.fetchone() + await conn.close() + + async with p.connection() as conn2: + cur = await conn2.execute("select pg_backend_pid()") + (pid2,) = await cur.fetchone() + + await p.close() + assert pid1 != pid2 + + +async def test_intrans_rollback(dsn, caplog): + p = pool.AsyncConnectionPool(dsn, minconn=1) + conn = await p.getconn() + pid = conn.pgconn.backend_pid + await conn.execute("create table test_intrans_rollback ()") + assert conn.pgconn.transaction_status == TransactionStatus.INTRANS + await p.putconn(conn) + + async with p.connection() as conn2: + assert conn2.pgconn.backend_pid == pid + assert conn2.pgconn.transaction_status == TransactionStatus.IDLE + cur = await conn.execute( + "select 1 from pg_class where relname = 'test_intrans_rollback'" + ) + assert not await cur.fetchone() + + await p.close() + recs = [ + r + for r in caplog.records + if r.name.startswith("psycopg3") and r.levelno >= logging.WARNING + ] + assert len(recs) == 1 + assert "INTRANS" in recs[0].message + + +async def test_inerror_rollback(dsn, caplog): + p = pool.AsyncConnectionPool(dsn, minconn=1) + conn = await p.getconn() + pid = conn.pgconn.backend_pid + with pytest.raises(psycopg3.ProgrammingError): + await conn.execute("wat") + assert conn.pgconn.transaction_status == TransactionStatus.INERROR + await p.putconn(conn) + + async with p.connection() as conn2: + assert conn2.pgconn.backend_pid == pid + assert conn2.pgconn.transaction_status == TransactionStatus.IDLE + + recs = [ + r + for r in caplog.records + if r.name.startswith("psycopg3") and r.levelno >= logging.WARNING + ] + assert len(recs) == 1 + assert "INERROR" in recs[0].message + + await p.close() + + +async def test_active_close(dsn, caplog): + p = pool.AsyncConnectionPool(dsn, minconn=1) + conn = await p.getconn() + pid = conn.pgconn.backend_pid + cur = conn.cursor() + async with cur.copy( + "copy (select * from generate_series(1, 10)) to stdout" + ): + pass + assert conn.pgconn.transaction_status == TransactionStatus.ACTIVE + await p.putconn(conn) + + async with p.connection() as conn2: + assert conn2.pgconn.backend_pid != pid + assert conn2.pgconn.transaction_status == TransactionStatus.IDLE + + await p.close() + recs = [ + r + for r in caplog.records + if r.name.startswith("psycopg3") and r.levelno >= logging.WARNING + ] + assert len(recs) == 2 + assert "ACTIVE" in recs[0].message + assert "BAD" in recs[1].message + + +async def test_fail_rollback_close(dsn, caplog, monkeypatch): + p = pool.AsyncConnectionPool(dsn, minconn=1) + conn = await p.getconn() + + # Make the rollback fail + orig_rollback = conn.rollback + + async def bad_rollback(): + conn.pgconn.finish() + await orig_rollback() + + monkeypatch.setattr(conn, "rollback", bad_rollback) + + pid = conn.pgconn.backend_pid + with pytest.raises(psycopg3.ProgrammingError): + await conn.execute("wat") + assert conn.pgconn.transaction_status == TransactionStatus.INERROR + await p.putconn(conn) + + async with p.connection() as conn2: + assert conn2.pgconn.backend_pid != pid + assert conn2.pgconn.transaction_status == TransactionStatus.IDLE + + await p.close() + + recs = [ + r + for r in caplog.records + if r.name.startswith("psycopg3") and r.levelno >= logging.WARNING + ] + assert len(recs) == 3 + assert "INERROR" in recs[0].message + assert "OperationalError" in recs[1].message + assert "BAD" in recs[2].message + + +async def test_close_no_threads(dsn): + p = pool.AsyncConnectionPool(dsn) + assert p._sched_runner.is_alive() + for t in p._workers: + assert t.is_alive() + + await p.close() + assert not p._sched_runner.is_alive() + for t in p._workers: + assert not t.is_alive() + + +async def test_putconn_no_pool(dsn): + p = pool.AsyncConnectionPool(dsn, minconn=1) + conn = psycopg3.connect(dsn) + with pytest.raises(ValueError): + await p.putconn(conn) + await p.close() + + +async def test_putconn_wrong_pool(dsn): + p1 = pool.AsyncConnectionPool(dsn, minconn=1) + p2 = pool.AsyncConnectionPool(dsn, minconn=1) + conn = await p1.getconn() + with pytest.raises(ValueError): + await p2.putconn(conn) + await p1.close() + await p2.close() + + +async def test_del_no_warning(dsn, recwarn): + p = pool.AsyncConnectionPool(dsn, minconn=2) + async with p.connection() as conn: + await conn.execute("select 1") + + await p.wait_ready() + ref = weakref.ref(p) + del p + assert not ref() + assert not recwarn + + +@pytest.mark.slow +async def test_del_stop_threads(dsn): + p = pool.AsyncConnectionPool(dsn) + ts = [p._sched_runner] + p._workers + del p + await asyncio.sleep(0.2) + for t in ts: + assert not t.is_alive() + + +async def test_closed_getconn(dsn): + p = pool.AsyncConnectionPool(dsn, minconn=1) + assert not p.closed + async with p.connection(): + pass + + await p.close() + assert p.closed + + with pytest.raises(pool.PoolClosed): + async with p.connection(): + pass + + +async def test_closed_putconn(dsn): + p = pool.AsyncConnectionPool(dsn, minconn=1) + + async with p.connection() as conn: + pass + assert not conn.closed + + async with p.connection() as conn: + await p.close() + assert conn.closed + + +@pytest.mark.slow +async def test_closed_queue(dsn): + p = pool.AsyncConnectionPool(dsn, minconn=1) + success = [] + + async def w1(): + async with p.connection() as conn: + res = await conn.execute("select 1 from pg_sleep(0.2)") + assert await res.fetchone() == (1,) + success.append("w1") + + async def w2(): + with pytest.raises(pool.PoolClosed): + async with p.connection(): + pass + success.append("w2") + + t1 = create_task(w1()) + await asyncio.sleep(0.1) + t2 = create_task(w2()) + await p.close() + await asyncio.gather(t1, t2) + assert len(success) == 2 + + +@pytest.mark.slow +async def test_grow(dsn, monkeypatch): + p = pool.AsyncConnectionPool(dsn, minconn=2, maxconn=4, num_workers=3) + await p.wait_ready(5.0) + delay_connection(monkeypatch, 0.1) + ts = [] + results = [] + + async def worker(n): + t0 = time() + async with p.connection() as conn: + await conn.execute("select 1 from pg_sleep(0.2)") + t1 = time() + results.append((n, t1 - t0)) + + ts = [create_task(worker(i)) for i in range(6)] + await asyncio.gather(*ts) + await p.close() + + want_times = [0.2, 0.2, 0.3, 0.3, 0.4, 0.4] + times = [item[1] for item in results] + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.2), times + + +@pytest.mark.slow +async def test_shrink(dsn, monkeypatch): + + from psycopg3.pool.tasks import ShrinkPool + + orig_run = ShrinkPool._run_async + results = [] + + async def run_async_hacked(self, pool): + n0 = pool._nconns + await orig_run(self, pool) + n1 = pool._nconns + results.append((n0, n1)) + + monkeypatch.setattr(ShrinkPool, "_run_async", run_async_hacked) + + p = pool.AsyncConnectionPool(dsn, minconn=2, maxconn=4, max_idle=0.2) + await p.wait_ready(5.0) + assert p.max_idle == 0.2 + + async def worker(n): + async with p.connection() as conn: + await conn.execute("select pg_sleep(0.1)") + + ts = [create_task(worker(i)) for i in range(4)] + await asyncio.gather(*ts) + + await asyncio.sleep(1) + await p.close() + assert results == [(4, 4), (4, 3), (3, 2), (2, 2), (2, 2)] + + +@pytest.mark.slow +async def test_reconnect(proxy, caplog, monkeypatch): + assert pool.base.ConnectionAttempt.INITIAL_DELAY == 1.0 + assert pool.base.ConnectionAttempt.DELAY_JITTER == 0.1 + monkeypatch.setattr(pool.base.ConnectionAttempt, "INITIAL_DELAY", 0.1) + monkeypatch.setattr(pool.base.ConnectionAttempt, "DELAY_JITTER", 0.0) + + proxy.start() + p = pool.AsyncConnectionPool(proxy.client_dsn, minconn=1) + await p.wait_ready(2.0) + proxy.stop() + + with pytest.raises(psycopg3.OperationalError): + async with p.connection() as conn: + await conn.execute("select 1") + + await asyncio.sleep(1.0) + proxy.start() + await p.wait_ready() + + async with p.connection() as conn: + await conn.execute("select 1") + + await p.close() + + recs = [ + r + for r in caplog.records + if r.name.startswith("psycopg3") and r.levelno >= logging.WARNING + ] + assert "BAD" in recs[0].message + times = [rec.created for rec in recs] + assert times[1] - times[0] < 0.05 + deltas = [times[i + 1] - times[i] for i in range(1, len(times) - 1)] + assert len(deltas) == 3 + want = 0.1 + for delta in deltas: + assert delta == pytest.approx(want, 0.05), deltas + want *= 2 + + +@pytest.mark.slow +async def test_reconnect_failure(proxy): + proxy.start() + + t1 = None + + def failed(pool): + assert pool.name == "this-one" + nonlocal t1 + t1 = time() + + p = pool.AsyncConnectionPool( + proxy.client_dsn, + name="this-one", + minconn=1, + reconnect_timeout=1.0, + reconnect_failed=failed, + ) + await p.wait_ready(2.0) + proxy.stop() + + with pytest.raises(psycopg3.OperationalError): + async with p.connection() as conn: + await conn.execute("select 1") + + t0 = time() + await asyncio.sleep(1.5) + assert t1 + assert t1 - t0 == pytest.approx(1.0, 0.1) + assert p._nconns == 0 + + proxy.start() + t0 = time() + async with p.connection() as conn: + await conn.execute("select 1") + t1 = time() + assert t1 - t0 < 0.2 + await p.close() + + +@pytest.mark.slow +async def test_uniform_use(dsn): + p = pool.AsyncConnectionPool(dsn, minconn=4) + counts = Counter() + for i in range(8): + async with p.connection() as conn: + await asyncio.sleep(0.1) + counts[id(conn)] += 1 + + await p.close() + assert len(counts) == 4 + assert set(counts.values()) == set([2]) + + +def delay_connection(monkeypatch, sec): + """ + Return a _connect_gen function delayed by the amount of seconds + """ + connect_orig = psycopg3.AsyncConnection.connect + + async def connect_delay(*args, **kwargs): + t0 = time() + rv = await connect_orig(*args, **kwargs) + t1 = time() + await asyncio.sleep(sec - (t1 - t0)) + return rv + + monkeypatch.setattr(psycopg3.AsyncConnection, "connect", connect_delay)