if TYPE_CHECKING:
from .pq.proto import PGconn, PGresult
- from .pool import ConnectionPool
+ from .pool.base import BasePool
if pq.__impl__ == "c":
from psycopg3_c import _psycopg3
# Attribute is only set if the connection is from a pool so we can tell
# apart a connection in the pool too (when _pool = None)
- self._pool: Optional["ConnectionPool"]
+ self._pool: Optional["BasePool[Any]"]
def __del__(self) -> None:
# If fails on connection we might not have this attribute yet
else:
await self.commit()
- await self.close()
+ # Close the connection only if it doesn't belong to a pool.
+ if not getattr(self, "_pool", None):
+ await self.close()
async def close(self) -> None:
self.pgconn.finish()
# Copyright (C) 2021 The Psycopg Team
from .pool import ConnectionPool
+from .async_pool import AsyncConnectionPool
from .errors import PoolClosed, PoolTimeout
-__all__ = ["ConnectionPool", "PoolClosed", "PoolTimeout"]
+__all__ = [
+ "AsyncConnectionPool",
+ "ConnectionPool",
+ "PoolClosed",
+ "PoolTimeout",
+]
--- /dev/null
+"""
+psycopg3 synchronous connection pool
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import asyncio
+import logging
+from time import monotonic
+from typing import Any, Awaitable, Callable, Deque, AsyncIterator, Optional
+from contextlib import asynccontextmanager
+from collections import deque
+
+from ..pq import TransactionStatus
+from ..connection import AsyncConnection
+
+from . import tasks
+from .base import ConnectionAttempt, BasePool
+from .errors import PoolClosed, PoolTimeout
+
+logger = logging.getLogger(__name__)
+
+
+class AsyncConnectionPool(BasePool[AsyncConnection]):
+ def __init__(
+ self,
+ conninfo: str = "",
+ configure: Optional[
+ Callable[[AsyncConnection], Awaitable[None]]
+ ] = None,
+ **kwargs: Any,
+ ):
+ self._configure = configure
+
+ self._lock = asyncio.Lock()
+ self._waiting: Deque["AsyncClient"] = deque()
+
+ # to notify that the pool is full
+ self._pool_full_event: Optional[asyncio.Event] = None
+
+ self.loop = asyncio.get_event_loop()
+
+ super().__init__(conninfo, **kwargs)
+
+ async def wait_ready(self, timeout: float = 30.0) -> None:
+ """
+ Wait for the pool to be full after init.
+
+ Raise `PoolTimeout` if not ready within *timeout* sec.
+ """
+ async with self._lock:
+ assert not self._pool_full_event
+ if len(self._pool) >= self._nconns:
+ return
+ self._pool_full_event = asyncio.Event()
+
+ try:
+ await asyncio.wait_for(self._pool_full_event.wait(), timeout)
+ except asyncio.TimeoutError:
+ await self.close() # stop all the threads
+ raise PoolTimeout(
+ f"pool initialization incomplete after {timeout} sec"
+ )
+
+ async with self._lock:
+ self._pool_full_event = None
+
+ @asynccontextmanager
+ async def connection(
+ self, timeout: Optional[float] = None
+ ) -> AsyncIterator[AsyncConnection]:
+ """Context manager to obtain a connection from the pool.
+
+ Returned the connection immediately if available, otherwise wait up to
+ *timeout* or `self.timeout` and throw `PoolTimeout` if a connection is
+ not available in time.
+
+ Upon context exit, return the connection to the pool. Apply the normal
+ connection context behaviour (commit/rollback the transaction in case
+ of success/error). If the connection is no more in working state
+ replace it with a new one.
+ """
+ conn = await self.getconn(timeout=timeout)
+ try:
+ async with conn:
+ yield conn
+ finally:
+ await self.putconn(conn)
+
+ async def getconn(
+ self, timeout: Optional[float] = None
+ ) -> AsyncConnection:
+ """Obtain a contection from the pool.
+
+ You should preferrably use `connection()`. Use this function only if
+ it is not possible to use the connection as context manager.
+
+ After using this function you *must* call a corresponding `putconn()`:
+ failing to do so will deplete the pool. A depleted pool is a sad pool:
+ you don't want a depleted pool.
+ """
+ logger.info("connection requested to %r", self.name)
+ # Critical section: decide here if there's a connection ready
+ # or if the client needs to wait.
+ async with self._lock:
+ if self._closed:
+ raise PoolClosed(f"the pool {self.name!r} is closed")
+
+ pos: Optional[AsyncClient] = None
+ if self._pool:
+ # Take a connection ready out of the pool
+ conn = self._pool.popleft()
+ if len(self._pool) < self._nconns_min:
+ self._nconns_min = len(self._pool)
+ else:
+ # No connection available: put the client in the waiting queue
+ pos = AsyncClient()
+ self._waiting.append(pos)
+
+ # If there is space for the pool to grow, let's do it
+ if self._nconns < self.maxconn:
+ self._nconns += 1
+ logger.info(
+ "growing pool %r to %s", self.name, self._nconns
+ )
+ self.run_task(tasks.AddConnection(self))
+
+ # If we are in the waiting queue, wait to be assigned a connection
+ # (outside the critical section, so only the waiting client is locked)
+ if pos:
+ if timeout is None:
+ timeout = self.timeout
+ conn = await pos.wait(timeout=timeout)
+
+ # Tell the connection it belongs to a pool to avoid closing on __exit__
+ # Note that this property shouldn't be set while the connection is in
+ # the pool, to avoid to create a reference loop.
+ conn._pool = self
+ logger.info("connection given by %r", self.name)
+ return conn
+
+ async def putconn(self, conn: AsyncConnection) -> None:
+ """Return a connection to the loving hands of its pool.
+
+ Use this function only paired with a `getconn()`. You don't need to use
+ it if you use the much more comfortable `connection()` context manager.
+ """
+ # Quick check to discard the wrong connection
+ pool = getattr(conn, "_pool", None)
+ if pool is not self:
+ if pool:
+ msg = f"it comes from pool {pool.name!r}"
+ else:
+ msg = "it doesn't come from any pool"
+ raise ValueError(
+ f"can't return connection to pool {self.name!r}, {msg}: {conn}"
+ )
+
+ logger.info("returning connection to %r", self.name)
+
+ # If the pool is closed just close the connection instead of returning
+ # it to the pool. For extra refcare remove the pool reference from it.
+ if self._closed:
+ conn._pool = None
+ await conn.close()
+ return
+
+ # Use a worker to perform eventual maintenance work in a separate thread
+ self.run_task(tasks.ReturnConnection(self, conn))
+
+ async def close(self, timeout: float = 1.0) -> None:
+ """Close the pool and make it unavailable to new clients.
+
+ All the waiting and future client will fail to acquire a connection
+ with a `PoolClosed` exception. Currently used connections will not be
+ closed until returned to the pool.
+
+ Wait *timeout* for threads to terminate their job, if positive.
+ """
+ if self._closed:
+ return
+
+ async with self._lock:
+ self._closed = True
+ logger.debug("pool %r closed", self.name)
+
+ # Take waiting client and pool connections out of the state
+ waiting = list(self._waiting)
+ self._waiting.clear()
+ pool = list(self._pool)
+ self._pool.clear()
+
+ # Now that the flag _closed is set, getconn will fail immediately,
+ # putconn will just close the returned connection.
+
+ # Stop the scheduler
+ self._sched.enter(0, None)
+
+ # Stop the worker threads
+ for i in range(len(self._workers)):
+ self.run_task(tasks.StopWorker(self))
+
+ # Signal to eventual clients in the queue that business is closed.
+ for pos in waiting:
+ await pos.fail(PoolClosed(f"the pool {self.name!r} is closed"))
+
+ # Close the connections still in the pool
+ for conn in pool:
+ await conn.close()
+
+ # Wait for the worker threads to terminate
+ if timeout > 0:
+ loop = asyncio.get_running_loop()
+ for t in [self._sched_runner] + self._workers:
+ if not t.is_alive():
+ continue
+ await loop.run_in_executor(None, lambda: t.join(timeout))
+ if t.is_alive():
+ logger.warning(
+ "couldn't stop thread %s in pool %r within %s seconds",
+ t,
+ self.name,
+ timeout,
+ )
+
+ async def configure(self, conn: AsyncConnection) -> None:
+ """Configure a connection after creation."""
+ if self._configure:
+ await self._configure(conn)
+
+ def reconnect_failed(self) -> None:
+ """
+ Called when reconnection failed for longer than `reconnect_timeout`.
+ """
+ self._reconnect_failed(self)
+
+ async def _connect(self) -> AsyncConnection:
+ """Return a new connection configured for the pool."""
+ conn = await AsyncConnection.connect(self.conninfo, **self.kwargs)
+ await self.configure(conn)
+ conn._pool = self
+ return conn
+
+ async def _add_connection(
+ self, attempt: Optional[ConnectionAttempt]
+ ) -> None:
+ """Try to connect and add the connection to the pool.
+
+ If failed, reschedule a new attempt in the future for a few times, then
+ give up, decrease the pool connections number and call
+ `self.reconnect_failed()`.
+
+ """
+ now = monotonic()
+ if not attempt:
+ attempt = ConnectionAttempt(
+ reconnect_timeout=self.reconnect_timeout
+ )
+
+ try:
+ conn = await self._connect()
+ except Exception as e:
+ logger.warning(f"error connecting in {self.name!r}: {e}")
+ if attempt.time_to_give_up(now):
+ logger.warning(
+ "reconnection attempt in pool %r failed after %s sec",
+ self.name,
+ self.reconnect_timeout,
+ )
+ async with self._lock:
+ self._nconns -= 1
+ self.reconnect_failed()
+ else:
+ attempt.update_delay(now)
+ self.schedule_task(
+ tasks.AddConnection(self, attempt), attempt.delay
+ )
+ else:
+ await self._add_to_pool(conn)
+
+ async def _return_connection(self, conn: AsyncConnection) -> None:
+ """
+ Return a connection to the pool after usage.
+ """
+ await self._reset_connection(conn)
+ if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
+ # Connection no more in working state: create a new one.
+ logger.warning("discarding closed connection: %s", conn)
+ self.run_task(tasks.AddConnection(self))
+ else:
+ await self._add_to_pool(conn)
+
+ async def _add_to_pool(self, conn: AsyncConnection) -> None:
+ """
+ Add a connection to the pool.
+
+ The connection can be a fresh one or one already used in the pool.
+
+ If a client is already waiting for a connection pass it on, otherwise
+ put it back into the pool
+ """
+ # Remove the pool reference from the connection before returning it
+ # to the state, to avoid to create a reference loop.
+ # Also disable the warning for open connection in conn.__del__
+ conn._pool = None
+
+ pos: Optional[AsyncClient] = None
+
+ # Critical section: if there is a client waiting give it the connection
+ # otherwise put it back into the pool.
+ async with self._lock:
+ while self._waiting:
+ # If there is a client waiting (which is still waiting and
+ # hasn't timed out), give it the connection and notify it.
+ pos = self._waiting.popleft()
+ if await pos.set(conn):
+ break
+ else:
+ # No client waiting for a connection: put it back into the pool
+ self._pool.append(conn)
+
+ # If we have been asked to wait for pool init, notify the
+ # waiter if the pool is full.
+ if self._pool_full_event and len(self._pool) >= self._nconns:
+ self._pool_full_event.set()
+
+ async def _reset_connection(self, conn: AsyncConnection) -> None:
+ """
+ Bring a connection to IDLE state or close it.
+ """
+ status = conn.pgconn.transaction_status
+ if status == TransactionStatus.IDLE:
+ return
+
+ if status in (TransactionStatus.INTRANS, TransactionStatus.INERROR):
+ # Connection returned with an active transaction
+ logger.warning("rolling back returned connection: %s", conn)
+ try:
+ await conn.rollback()
+ except Exception as e:
+ logger.warning(
+ "rollback failed: %s: %s. Discarding connection %s",
+ e.__class__.__name__,
+ e,
+ conn,
+ )
+ await conn.close()
+
+ elif status == TransactionStatus.ACTIVE:
+ # Connection returned during an operation. Bad... just close it.
+ logger.warning("closing returned connection: %s", conn)
+ await conn.close()
+
+ async def _shrink_pool(self) -> None:
+ to_close: Optional[AsyncConnection] = None
+
+ async with self._lock:
+ # Reset the min number of connections used
+ nconns_min = self._nconns_min
+ self._nconns_min = len(self._pool)
+
+ # If the pool can shrink and connections were unused, drop one
+ if self._nconns > self.minconn and nconns_min > 0:
+ to_close = self._pool.popleft()
+ self._nconns -= 1
+
+ if to_close:
+ logger.info(
+ "shrinking pool %r to %s because %s unused connections"
+ " in the last %s sec",
+ self.name,
+ self._nconns,
+ nconns_min,
+ self.max_idle,
+ )
+ await to_close.close()
+
+
+class AsyncClient:
+ """A position in a queue for a client waiting for a connection."""
+
+ __slots__ = ("conn", "error", "_cond")
+
+ def __init__(self) -> None:
+ self.conn: Optional[AsyncConnection] = None
+ self.error: Optional[Exception] = None
+
+ # The AsyncClient behaves in a way similar to an Event, but we need
+ # to notify reliably the flagger that the waiter has "accepted" the
+ # message and it hasn't timed out yet, otherwise the pool may give a
+ # connection to a client that has already timed out getconn(), which
+ # will be lost.
+ self._cond = asyncio.Condition()
+
+ async def wait(self, timeout: float) -> AsyncConnection:
+ """Wait for a connection to be set and return it.
+
+ Raise an exception if the wait times out or if fail() is called.
+ """
+ async with self._cond:
+ if not (self.conn or self.error):
+ try:
+ await asyncio.wait_for(self._cond.wait(), timeout)
+ except asyncio.TimeoutError:
+ self.error = PoolTimeout(
+ f"couldn't get a connection after {timeout} sec"
+ )
+
+ if self.conn:
+ return self.conn
+ else:
+ assert self.error
+ raise self.error
+
+ async def set(self, conn: AsyncConnection) -> bool:
+ """Signal the client waiting that a connection is ready.
+
+ Return True if the client has "accepted" the connection, False
+ otherwise (typically because wait() has timed out).
+ """
+ async with self._cond:
+ if self.conn or self.error:
+ return False
+
+ self.conn = conn
+ self._cond.notify_all()
+ return True
+
+ async def fail(self, error: Exception) -> bool:
+ """Signal the client that, alas, they won't have a connection today.
+
+ Return True if the client has "accepted" the error, False otherwise
+ (typically because wait() has timed out).
+ """
+ async with self._cond:
+ if self.conn or self.error:
+ return False
+
+ self.error = error
+ self._cond.notify_all()
+ return True
+
+
+tasks.AsyncConnectionPool = AsyncConnectionPool # type: ignore
conninfo: str = "",
*,
kwargs: Optional[Dict[str, Any]] = None,
- configure: Optional[Callable[[ConnectionType], None]] = None,
minconn: int = 4,
maxconn: Optional[int] = None,
name: Optional[str] = None,
)
if not name:
num = BasePool._num_pool = BasePool._num_pool + 1
- name = f"pool-{num - 1}"
+ name = f"pool-{num}"
if num_workers < 1:
raise ValueError("num_workers must be at least 1")
self.conninfo = conninfo
self.kwargs: Dict[str, Any] = kwargs or {}
- self._configure: Callable[[ConnectionType], None]
- self._configure = configure or (lambda conn: None)
self._reconnect_failed: Callable[["BasePool[ConnectionType]"], None]
self._reconnect_failed = reconnect_failed or (lambda pool: None)
self.name = name
self._workers: List[threading.Thread] = []
for i in range(num_workers):
t = threading.Thread(
- target=self.worker, args=(self._tasks,), daemon=True
+ target=self.worker,
+ args=(self._tasks,),
+ name=f"{self.name}-worker-{i}",
+ daemon=True,
)
self._workers.append(t)
self._sched_runner = threading.Thread(
- target=self._sched.run, daemon=True
+ target=self._sched.run, name=f"{self.name}-scheduler", daemon=True
)
# _close should be the last property to be set in the state
self._tasks.put_nowait(task)
def schedule_task(
- self, task: tasks.MaintenanceTask[ConnectionType], delay: float
+ self, task: tasks.MaintenanceTask[Any], delay: float
) -> None:
"""Run a maintenance task in a worker thread in the future."""
self._sched.enter(delay, task.tick)
# Don't make all the workers time out at the same moment
timeout = WORKER_TIMEOUT * (0.9 + 0.1 * random.random())
while True:
- # Use a timeout to make the wait unterruptable
+ # Use a timeout to make the wait interruptable
try:
task = q.get(timeout=timeout)
except Empty:
)
if isinstance(task, tasks.StopWorker):
+ logger.debug(
+ "terminating working thread %s",
+ threading.current_thread().name,
+ )
return
# Copyright (C) 2021 The Psycopg Team
-import time
import logging
import threading
-from typing import Any, Deque, Iterator, Optional
+from time import monotonic
+from typing import Any, Callable, Deque, Iterator, Optional
from contextlib import contextmanager
from collections import deque
class ConnectionPool(BasePool[Connection]):
- def __init__(self, conninfo: str = "", **kwargs: Any):
+ def __init__(
+ self,
+ conninfo: str = "",
+ configure: Optional[Callable[[Connection], None]] = None,
+ **kwargs: Any,
+ ):
+ self._configure: Callable[[Connection], None]
+ self._configure = configure or (lambda conn: None)
+
self._lock = threading.RLock()
self._waiting: Deque["WaitingClient"] = deque()
with self._lock:
self._closed = True
+ logger.debug("pool %r closed", self.name)
# Take waiting client and pool connections out of the state
waiting = list(self._waiting)
`self.reconnect_failed()`.
"""
- now = time.monotonic()
+ now = monotonic()
if not attempt:
attempt = ConnectionAttempt(
reconnect_timeout=self.reconnect_timeout
pos = self._waiting.popleft()
if pos.set(conn):
break
-
else:
# No client waiting for a connection: put it back into the pool
self._pool.append(conn)
# message and it hasn't timed out yet, otherwise the pool may give a
# connection to a client that has already timed out getconn(), which
# will be lost.
- self._cond = threading.Condition(threading.Lock())
+ self._cond = threading.Condition()
def wait(self, timeout: float) -> Connection:
"""Wait for a connection to be set and return it.
self.error = error
self._cond.notify_all()
return True
+
+
+tasks.ConnectionPool = ConnectionPool # type: ignore
# Copyright (C) 2021 The Psycopg Team
+import asyncio
import logging
+import threading
from abc import ABC, abstractmethod
-from typing import Any, Generic, Optional, TYPE_CHECKING
+from typing import Any, cast, Generic, Optional, Type, TYPE_CHECKING
from weakref import ref
from ..proto import ConnectionType
+from .. import Connection, AsyncConnection
if TYPE_CHECKING:
+ from .pool import ConnectionPool
+ from .async_pool import AsyncConnectionPool
from .base import BasePool, ConnectionAttempt
- from ..connection import Connection
+else:
+ # Injected at pool.py and async_pool.py import
+ ConnectionPool: "Type[BasePool[Connection]]"
+ AsyncConnectionPool: "Type[BasePool[AsyncConnection]]"
logger = logging.getLogger(__name__)
class MaintenanceTask(ABC, Generic[ConnectionType]):
"""A task to run asynchronously to maintain the pool state."""
+ TIMEOUT = 10.0
+
def __init__(self, pool: "BasePool[Any]"):
+ if isinstance(pool, AsyncConnectionPool):
+ self.event = threading.Event()
+
self.pool = ref(pool)
- logger.debug("task created: %s", self)
+ logger.debug(
+ "task created in %s: %s", threading.current_thread().name, self
+ )
def __repr__(self) -> str:
pool = self.pool()
# Pool is no more working. Quietly discard the operation.
return
- logger.debug("task running: %s", self)
- self._run(pool)
+ logger.debug(
+ "task running in %s: %s", threading.current_thread().name, self
+ )
+ if isinstance(pool, ConnectionPool):
+ self._run(pool)
+ elif isinstance(pool, AsyncConnectionPool):
+ self.event.clear()
+ asyncio.run_coroutine_threadsafe(self._run_async(pool), pool.loop)
+ if not self.event.wait(self.TIMEOUT):
+ logger.warning(
+ "event %s didn't terminate after %s sec", self.TIMEOUT
+ )
+ else:
+ logger.error("%s run got %s instead of a pool", self, pool)
def tick(self) -> None:
"""Run the scheduled task
pool.run_task(self)
@abstractmethod
- def _run(self, pool: "BasePool[Any]") -> None:
+ def _run(self, pool: "ConnectionPool") -> None:
...
+ @abstractmethod
+ async def _run_async(self, pool: "AsyncConnectionPool") -> None:
+ self.event.set()
+
class StopWorker(MaintenanceTask[ConnectionType]):
"""Signal the maintenance thread to terminate."""
- def _run(self, pool: "BasePool[Any]") -> None:
+ def _run(self, pool: "ConnectionPool") -> None:
pass
+ async def _run_async(self, pool: "AsyncConnectionPool") -> None:
+ await super()._run_async(pool)
+
class AddConnection(MaintenanceTask[ConnectionType]):
def __init__(
super().__init__(pool)
self.attempt = attempt
- def _run(self, pool: "BasePool[Any]") -> None:
- from . import ConnectionPool
+ def _run(self, pool: "ConnectionPool") -> None:
+ pool._add_connection(self.attempt)
- if isinstance(pool, ConnectionPool):
- pool._add_connection(self.attempt)
- else:
- assert False
+ async def _run_async(self, pool: "AsyncConnectionPool") -> None:
+ logger.debug("run async 1")
+ await pool._add_connection(self.attempt)
+ logger.debug("run async 2")
+ await super()._run_async(pool)
+ logger.debug("run async 3")
class ReturnConnection(MaintenanceTask[ConnectionType]):
"""Clean up and return a connection to the pool."""
- def __init__(self, pool: "BasePool[Any]", conn: "Connection"):
+ def __init__(self, pool: "BasePool[Any]", conn: "ConnectionType"):
super().__init__(pool)
self.conn = conn
- def _run(self, pool: "BasePool[Any]") -> None:
- from . import ConnectionPool
+ def _run(self, pool: "ConnectionPool") -> None:
+ pool._return_connection(cast(Connection, self.conn))
- if isinstance(pool, ConnectionPool):
- pool._return_connection(self.conn)
- else:
- assert False
+ async def _run_async(self, pool: "AsyncConnectionPool") -> None:
+ await pool._return_connection(cast(AsyncConnection, self.conn))
+ await super()._run_async(pool)
class ShrinkPool(MaintenanceTask[ConnectionType]):
in the pool.
"""
- def _run(self, pool: "BasePool[Any]") -> None:
+ def _run(self, pool: "ConnectionPool") -> None:
# Reschedule the task now so that in case of any error we don't lose
# the periodic run.
pool.schedule_task(self, pool.max_idle)
+ pool._shrink_pool()
- from . import ConnectionPool
-
- if isinstance(pool, ConnectionPool):
- pool._shrink_pool()
- else:
- assert False
+ async def _run_async(self, pool: "AsyncConnectionPool") -> None:
+ pool.schedule_task(self, pool.max_idle)
+ await pool._shrink_pool()
+ await super()._run_async(pool)
def test_reconnect(proxy, caplog, monkeypatch):
caplog.set_level(logging.WARNING, logger="psycopg3.pool")
- assert pool.pool.ConnectionAttempt.INITIAL_DELAY == 1.0
- assert pool.pool.ConnectionAttempt.DELAY_JITTER == 0.1
- monkeypatch.setattr(pool.pool.ConnectionAttempt, "INITIAL_DELAY", 0.1)
- monkeypatch.setattr(pool.pool.ConnectionAttempt, "DELAY_JITTER", 0.0)
+ assert pool.base.ConnectionAttempt.INITIAL_DELAY == 1.0
+ assert pool.base.ConnectionAttempt.DELAY_JITTER == 0.1
+ monkeypatch.setattr(pool.base.ConnectionAttempt, "INITIAL_DELAY", 0.1)
+ monkeypatch.setattr(pool.base.ConnectionAttempt, "DELAY_JITTER", 0.0)
proxy.start()
p = pool.ConnectionPool(proxy.client_dsn, minconn=1)
--- /dev/null
+import sys
+import asyncio
+import logging
+import weakref
+from time import time
+from collections import Counter
+
+import pytest
+
+import psycopg3
+from psycopg3 import pool
+from psycopg3.pq import TransactionStatus
+
+create_task = (
+ asyncio.create_task
+ if sys.version_info >= (3, 7)
+ else asyncio.ensure_future
+)
+
+pytestmark = pytest.mark.asyncio
+
+
+async def test_defaults(dsn):
+ p = pool.AsyncConnectionPool(dsn)
+ assert p.minconn == p.maxconn == 4
+ assert p.timeout == 30
+ assert p.max_idle == 600
+ assert p.num_workers == 3
+ await p.close()
+
+
+async def test_minconn_maxconn(dsn):
+ p = pool.AsyncConnectionPool(dsn, minconn=2)
+ assert p.minconn == p.maxconn == 2
+ await p.close()
+
+ p = pool.AsyncConnectionPool(dsn, minconn=2, maxconn=4)
+ assert p.minconn == 2
+ assert p.maxconn == 4
+ await p.close()
+
+ with pytest.raises(ValueError):
+ pool.AsyncConnectionPool(dsn, minconn=4, maxconn=2)
+
+
+async def test_kwargs(dsn):
+ p = pool.AsyncConnectionPool(dsn, kwargs={"autocommit": True}, minconn=1)
+ async with p.connection() as conn:
+ assert conn.autocommit
+
+ await p.close()
+
+
+async def test_its_really_a_pool(dsn):
+ p = pool.AsyncConnectionPool(dsn, minconn=2)
+ async with p.connection() as conn:
+ cur = await conn.execute("select pg_backend_pid()")
+ (pid1,) = await cur.fetchone()
+
+ async with p.connection() as conn2:
+ cur = await conn2.execute("select pg_backend_pid()")
+ (pid2,) = await cur.fetchone()
+
+ async with p.connection() as conn:
+ assert conn.pgconn.backend_pid in (pid1, pid2)
+
+ await p.close()
+
+
+async def test_connection_not_lost(dsn):
+ p = pool.AsyncConnectionPool(dsn, minconn=1)
+ with pytest.raises(ZeroDivisionError):
+ async with p.connection() as conn:
+ pid = conn.pgconn.backend_pid
+ 1 / 0
+
+ async with p.connection() as conn2:
+ assert conn2.pgconn.backend_pid == pid
+
+ await p.close()
+
+
+@pytest.mark.slow
+async def test_concurrent_filling(dsn, monkeypatch):
+ delay_connection(monkeypatch, 0.1)
+ t0 = time()
+ times = []
+
+ add_orig = pool.AsyncConnectionPool._add_to_pool
+
+ async def add_time(self, conn):
+ times.append(time() - t0)
+ await add_orig(self, conn)
+
+ monkeypatch.setattr(pool.AsyncConnectionPool, "_add_to_pool", add_time)
+
+ p = pool.AsyncConnectionPool(dsn, minconn=5, num_workers=2)
+ await p.wait_ready(5.0)
+ want_times = [0.1, 0.1, 0.2, 0.2, 0.3]
+ assert len(times) == len(want_times)
+ for got, want in zip(times, want_times):
+ assert got == pytest.approx(want, 0.2), times
+ await p.close()
+
+
+@pytest.mark.slow
+async def test_wait_ready(dsn, monkeypatch):
+ delay_connection(monkeypatch, 0.1)
+ with pytest.raises(pool.PoolTimeout):
+ p = pool.AsyncConnectionPool(dsn, minconn=4, num_workers=1)
+ await p.wait_ready(0.3)
+
+ p = pool.AsyncConnectionPool(dsn, minconn=4, num_workers=1)
+ await p.wait_ready(0.5)
+ await p.close()
+ p = pool.AsyncConnectionPool(dsn, minconn=4, num_workers=2)
+ await p.wait_ready(0.3)
+ await p.wait_ready(0.0001) # idempotent
+ await p.close()
+
+
+@pytest.mark.slow
+async def test_setup_no_timeout(dsn, proxy):
+ with pytest.raises(pool.PoolTimeout):
+ p = pool.AsyncConnectionPool(
+ proxy.client_dsn, minconn=1, num_workers=1
+ )
+ await p.wait_ready(0.2)
+
+ p = pool.AsyncConnectionPool(proxy.client_dsn, minconn=1, num_workers=1)
+ await asyncio.sleep(0.5)
+ assert not p._pool
+ proxy.start()
+
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+
+ await p.close()
+
+
+@pytest.mark.slow
+async def test_queue(dsn):
+ p = pool.AsyncConnectionPool(dsn, minconn=2)
+ results = []
+
+ async def worker(n):
+ t0 = time()
+ async with p.connection() as conn:
+ cur = await conn.execute(
+ "select pg_backend_pid() from pg_sleep(0.2)"
+ )
+ (pid,) = await cur.fetchone()
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ ts = [create_task(worker(i)) for i in range(6)]
+ await asyncio.gather(*ts)
+ await p.close()
+
+ times = [item[1] for item in results]
+ want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6]
+ for got, want in zip(times, want_times):
+ assert got == pytest.approx(want, 0.2), times
+
+ assert len(set(r[2] for r in results)) == 2, results
+
+
+@pytest.mark.slow
+async def test_queue_timeout(dsn):
+ p = pool.AsyncConnectionPool(dsn, minconn=2, timeout=0.1)
+ results = []
+ errors = []
+
+ async def worker(n):
+ t0 = time()
+ try:
+ async with p.connection() as conn:
+ cur = await conn.execute(
+ "select pg_backend_pid() from pg_sleep(0.2)"
+ )
+ (pid,) = await cur.fetchone()
+ except pool.PoolTimeout as e:
+ t1 = time()
+ errors.append((n, t1 - t0, e))
+ else:
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ ts = [create_task(worker(i)) for i in range(4)]
+ await asyncio.gather(*ts)
+
+ assert len(results) == 2
+ assert len(errors) == 2
+ for e in errors:
+ assert 0.1 < e[1] < 0.15
+
+ await p.close()
+
+
+@pytest.mark.slow
+async def test_dead_client(dsn):
+ p = pool.AsyncConnectionPool(dsn, minconn=2)
+
+ results = []
+
+ async def worker(i, timeout):
+ try:
+ async with p.connection(timeout=timeout) as conn:
+ await conn.execute("select pg_sleep(0.3)")
+ results.append(i)
+ except pool.PoolTimeout:
+ if timeout > 0.2:
+ raise
+
+ ts = [
+ create_task(worker(i, timeout))
+ for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4])
+ ]
+ await asyncio.gather(*ts)
+
+ await asyncio.sleep(0.2)
+ assert set(results) == set([0, 1, 3, 4])
+ assert len(p._pool) == 2 # no connection was lost
+ await p.close()
+
+
+@pytest.mark.slow
+async def test_queue_timeout_override(dsn):
+ p = pool.AsyncConnectionPool(dsn, minconn=2, timeout=0.1)
+ results = []
+ errors = []
+
+ async def worker(n):
+ t0 = time()
+ timeout = 0.25 if n == 3 else None
+ try:
+ async with p.connection(timeout=timeout) as conn:
+ cur = await conn.execute(
+ "select pg_backend_pid() from pg_sleep(0.2)"
+ )
+ (pid,) = await cur.fetchone()
+ except pool.PoolTimeout as e:
+ t1 = time()
+ errors.append((n, t1 - t0, e))
+ else:
+ t1 = time()
+ results.append((n, t1 - t0, pid))
+
+ ts = [create_task(worker(i)) for i in range(4)]
+ await asyncio.gather(*ts)
+ await p.close()
+
+ assert len(results) == 3
+ assert len(errors) == 1
+ for e in errors:
+ assert 0.1 < e[1] < 0.15
+
+
+async def test_broken_reconnect(dsn):
+ p = pool.AsyncConnectionPool(dsn, minconn=1)
+ with pytest.raises(psycopg3.OperationalError):
+ async with p.connection() as conn:
+ cur = await conn.execute("select pg_backend_pid()")
+ (pid1,) = await cur.fetchone()
+ await conn.close()
+
+ async with p.connection() as conn2:
+ cur = await conn2.execute("select pg_backend_pid()")
+ (pid2,) = await cur.fetchone()
+
+ await p.close()
+ assert pid1 != pid2
+
+
+async def test_intrans_rollback(dsn, caplog):
+ p = pool.AsyncConnectionPool(dsn, minconn=1)
+ conn = await p.getconn()
+ pid = conn.pgconn.backend_pid
+ await conn.execute("create table test_intrans_rollback ()")
+ assert conn.pgconn.transaction_status == TransactionStatus.INTRANS
+ await p.putconn(conn)
+
+ async with p.connection() as conn2:
+ assert conn2.pgconn.backend_pid == pid
+ assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
+ cur = await conn.execute(
+ "select 1 from pg_class where relname = 'test_intrans_rollback'"
+ )
+ assert not await cur.fetchone()
+
+ await p.close()
+ recs = [
+ r
+ for r in caplog.records
+ if r.name.startswith("psycopg3") and r.levelno >= logging.WARNING
+ ]
+ assert len(recs) == 1
+ assert "INTRANS" in recs[0].message
+
+
+async def test_inerror_rollback(dsn, caplog):
+ p = pool.AsyncConnectionPool(dsn, minconn=1)
+ conn = await p.getconn()
+ pid = conn.pgconn.backend_pid
+ with pytest.raises(psycopg3.ProgrammingError):
+ await conn.execute("wat")
+ assert conn.pgconn.transaction_status == TransactionStatus.INERROR
+ await p.putconn(conn)
+
+ async with p.connection() as conn2:
+ assert conn2.pgconn.backend_pid == pid
+ assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
+
+ recs = [
+ r
+ for r in caplog.records
+ if r.name.startswith("psycopg3") and r.levelno >= logging.WARNING
+ ]
+ assert len(recs) == 1
+ assert "INERROR" in recs[0].message
+
+ await p.close()
+
+
+async def test_active_close(dsn, caplog):
+ p = pool.AsyncConnectionPool(dsn, minconn=1)
+ conn = await p.getconn()
+ pid = conn.pgconn.backend_pid
+ cur = conn.cursor()
+ async with cur.copy(
+ "copy (select * from generate_series(1, 10)) to stdout"
+ ):
+ pass
+ assert conn.pgconn.transaction_status == TransactionStatus.ACTIVE
+ await p.putconn(conn)
+
+ async with p.connection() as conn2:
+ assert conn2.pgconn.backend_pid != pid
+ assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
+
+ await p.close()
+ recs = [
+ r
+ for r in caplog.records
+ if r.name.startswith("psycopg3") and r.levelno >= logging.WARNING
+ ]
+ assert len(recs) == 2
+ assert "ACTIVE" in recs[0].message
+ assert "BAD" in recs[1].message
+
+
+async def test_fail_rollback_close(dsn, caplog, monkeypatch):
+ p = pool.AsyncConnectionPool(dsn, minconn=1)
+ conn = await p.getconn()
+
+ # Make the rollback fail
+ orig_rollback = conn.rollback
+
+ async def bad_rollback():
+ conn.pgconn.finish()
+ await orig_rollback()
+
+ monkeypatch.setattr(conn, "rollback", bad_rollback)
+
+ pid = conn.pgconn.backend_pid
+ with pytest.raises(psycopg3.ProgrammingError):
+ await conn.execute("wat")
+ assert conn.pgconn.transaction_status == TransactionStatus.INERROR
+ await p.putconn(conn)
+
+ async with p.connection() as conn2:
+ assert conn2.pgconn.backend_pid != pid
+ assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
+
+ await p.close()
+
+ recs = [
+ r
+ for r in caplog.records
+ if r.name.startswith("psycopg3") and r.levelno >= logging.WARNING
+ ]
+ assert len(recs) == 3
+ assert "INERROR" in recs[0].message
+ assert "OperationalError" in recs[1].message
+ assert "BAD" in recs[2].message
+
+
+async def test_close_no_threads(dsn):
+ p = pool.AsyncConnectionPool(dsn)
+ assert p._sched_runner.is_alive()
+ for t in p._workers:
+ assert t.is_alive()
+
+ await p.close()
+ assert not p._sched_runner.is_alive()
+ for t in p._workers:
+ assert not t.is_alive()
+
+
+async def test_putconn_no_pool(dsn):
+ p = pool.AsyncConnectionPool(dsn, minconn=1)
+ conn = psycopg3.connect(dsn)
+ with pytest.raises(ValueError):
+ await p.putconn(conn)
+ await p.close()
+
+
+async def test_putconn_wrong_pool(dsn):
+ p1 = pool.AsyncConnectionPool(dsn, minconn=1)
+ p2 = pool.AsyncConnectionPool(dsn, minconn=1)
+ conn = await p1.getconn()
+ with pytest.raises(ValueError):
+ await p2.putconn(conn)
+ await p1.close()
+ await p2.close()
+
+
+async def test_del_no_warning(dsn, recwarn):
+ p = pool.AsyncConnectionPool(dsn, minconn=2)
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+
+ await p.wait_ready()
+ ref = weakref.ref(p)
+ del p
+ assert not ref()
+ assert not recwarn
+
+
+@pytest.mark.slow
+async def test_del_stop_threads(dsn):
+ p = pool.AsyncConnectionPool(dsn)
+ ts = [p._sched_runner] + p._workers
+ del p
+ await asyncio.sleep(0.2)
+ for t in ts:
+ assert not t.is_alive()
+
+
+async def test_closed_getconn(dsn):
+ p = pool.AsyncConnectionPool(dsn, minconn=1)
+ assert not p.closed
+ async with p.connection():
+ pass
+
+ await p.close()
+ assert p.closed
+
+ with pytest.raises(pool.PoolClosed):
+ async with p.connection():
+ pass
+
+
+async def test_closed_putconn(dsn):
+ p = pool.AsyncConnectionPool(dsn, minconn=1)
+
+ async with p.connection() as conn:
+ pass
+ assert not conn.closed
+
+ async with p.connection() as conn:
+ await p.close()
+ assert conn.closed
+
+
+@pytest.mark.slow
+async def test_closed_queue(dsn):
+ p = pool.AsyncConnectionPool(dsn, minconn=1)
+ success = []
+
+ async def w1():
+ async with p.connection() as conn:
+ res = await conn.execute("select 1 from pg_sleep(0.2)")
+ assert await res.fetchone() == (1,)
+ success.append("w1")
+
+ async def w2():
+ with pytest.raises(pool.PoolClosed):
+ async with p.connection():
+ pass
+ success.append("w2")
+
+ t1 = create_task(w1())
+ await asyncio.sleep(0.1)
+ t2 = create_task(w2())
+ await p.close()
+ await asyncio.gather(t1, t2)
+ assert len(success) == 2
+
+
+@pytest.mark.slow
+async def test_grow(dsn, monkeypatch):
+ p = pool.AsyncConnectionPool(dsn, minconn=2, maxconn=4, num_workers=3)
+ await p.wait_ready(5.0)
+ delay_connection(monkeypatch, 0.1)
+ ts = []
+ results = []
+
+ async def worker(n):
+ t0 = time()
+ async with p.connection() as conn:
+ await conn.execute("select 1 from pg_sleep(0.2)")
+ t1 = time()
+ results.append((n, t1 - t0))
+
+ ts = [create_task(worker(i)) for i in range(6)]
+ await asyncio.gather(*ts)
+ await p.close()
+
+ want_times = [0.2, 0.2, 0.3, 0.3, 0.4, 0.4]
+ times = [item[1] for item in results]
+ for got, want in zip(times, want_times):
+ assert got == pytest.approx(want, 0.2), times
+
+
+@pytest.mark.slow
+async def test_shrink(dsn, monkeypatch):
+
+ from psycopg3.pool.tasks import ShrinkPool
+
+ orig_run = ShrinkPool._run_async
+ results = []
+
+ async def run_async_hacked(self, pool):
+ n0 = pool._nconns
+ await orig_run(self, pool)
+ n1 = pool._nconns
+ results.append((n0, n1))
+
+ monkeypatch.setattr(ShrinkPool, "_run_async", run_async_hacked)
+
+ p = pool.AsyncConnectionPool(dsn, minconn=2, maxconn=4, max_idle=0.2)
+ await p.wait_ready(5.0)
+ assert p.max_idle == 0.2
+
+ async def worker(n):
+ async with p.connection() as conn:
+ await conn.execute("select pg_sleep(0.1)")
+
+ ts = [create_task(worker(i)) for i in range(4)]
+ await asyncio.gather(*ts)
+
+ await asyncio.sleep(1)
+ await p.close()
+ assert results == [(4, 4), (4, 3), (3, 2), (2, 2), (2, 2)]
+
+
+@pytest.mark.slow
+async def test_reconnect(proxy, caplog, monkeypatch):
+ assert pool.base.ConnectionAttempt.INITIAL_DELAY == 1.0
+ assert pool.base.ConnectionAttempt.DELAY_JITTER == 0.1
+ monkeypatch.setattr(pool.base.ConnectionAttempt, "INITIAL_DELAY", 0.1)
+ monkeypatch.setattr(pool.base.ConnectionAttempt, "DELAY_JITTER", 0.0)
+
+ proxy.start()
+ p = pool.AsyncConnectionPool(proxy.client_dsn, minconn=1)
+ await p.wait_ready(2.0)
+ proxy.stop()
+
+ with pytest.raises(psycopg3.OperationalError):
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+
+ await asyncio.sleep(1.0)
+ proxy.start()
+ await p.wait_ready()
+
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+
+ await p.close()
+
+ recs = [
+ r
+ for r in caplog.records
+ if r.name.startswith("psycopg3") and r.levelno >= logging.WARNING
+ ]
+ assert "BAD" in recs[0].message
+ times = [rec.created for rec in recs]
+ assert times[1] - times[0] < 0.05
+ deltas = [times[i + 1] - times[i] for i in range(1, len(times) - 1)]
+ assert len(deltas) == 3
+ want = 0.1
+ for delta in deltas:
+ assert delta == pytest.approx(want, 0.05), deltas
+ want *= 2
+
+
+@pytest.mark.slow
+async def test_reconnect_failure(proxy):
+ proxy.start()
+
+ t1 = None
+
+ def failed(pool):
+ assert pool.name == "this-one"
+ nonlocal t1
+ t1 = time()
+
+ p = pool.AsyncConnectionPool(
+ proxy.client_dsn,
+ name="this-one",
+ minconn=1,
+ reconnect_timeout=1.0,
+ reconnect_failed=failed,
+ )
+ await p.wait_ready(2.0)
+ proxy.stop()
+
+ with pytest.raises(psycopg3.OperationalError):
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+
+ t0 = time()
+ await asyncio.sleep(1.5)
+ assert t1
+ assert t1 - t0 == pytest.approx(1.0, 0.1)
+ assert p._nconns == 0
+
+ proxy.start()
+ t0 = time()
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+ t1 = time()
+ assert t1 - t0 < 0.2
+ await p.close()
+
+
+@pytest.mark.slow
+async def test_uniform_use(dsn):
+ p = pool.AsyncConnectionPool(dsn, minconn=4)
+ counts = Counter()
+ for i in range(8):
+ async with p.connection() as conn:
+ await asyncio.sleep(0.1)
+ counts[id(conn)] += 1
+
+ await p.close()
+ assert len(counts) == 4
+ assert set(counts.values()) == set([2])
+
+
+def delay_connection(monkeypatch, sec):
+ """
+ Return a _connect_gen function delayed by the amount of seconds
+ """
+ connect_orig = psycopg3.AsyncConnection.connect
+
+ async def connect_delay(*args, **kwargs):
+ t0 = time()
+ rv = await connect_orig(*args, **kwargs)
+ t1 = time()
+ await asyncio.sleep(sec - (t1 - t0))
+ return rv
+
+ monkeypatch.setattr(psycopg3.AsyncConnection, "connect", connect_delay)