]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add pool.max_waiting
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Mar 2021 14:26:18 +0000 (15:26 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Mar 2021 14:50:10 +0000 (15:50 +0100)
docs/api/pool.rst
psycopg3/psycopg3/pool/__init__.py
psycopg3/psycopg3/pool/async_pool.py
psycopg3/psycopg3/pool/base.py
psycopg3/psycopg3/pool/errors.py
psycopg3/psycopg3/pool/pool.py
tests/pool/test_pool.py
tests/pool/test_pool_async.py

index 1d6983d6b4c580aadf1e8bfcbf6e5c401b6bdc9f..6cefb955a3c7f8c12573f6838732d5d24fca7845 100644 (file)
@@ -86,6 +86,12 @@ The `!ConnectionPool` class
                    the *timeout* default. Default: 30 seconds.
    :type timeout: `!float`
 
+   :param max_waiting: Maximum number of requests that can be queued to the
+                       pool. Adding more requests will fail, raising
+                       `TooManyRequests`. Specifying 0 (the default) means to
+                       upper bound.
+   :type max_waiting: `!int`
+
    :param max_lifetime: The maximum lifetime of a connection in the pool, in
                         seconds. Connections used for longer get closed and
                         replaced by a new one. The amount is reduced by a
@@ -178,6 +184,10 @@ The `!ConnectionPool` class
 
    Subclass of `~psycopg3.OperationalError`
 
+.. autoclass:: TooManyRequests()
+
+   Subclass of `~psycopg3.OperationalError`
+
 
 The `!AsyncConnectionPool` class
 --------------------------------
index 91f3496934b228bae3bcf1efcf9f32ac7038bb5c..4eeddd885b4b72443fd7c59cb3e074628e63a39b 100644 (file)
@@ -6,11 +6,12 @@ psycopg3 connection pool package
 
 from .pool import ConnectionPool
 from .async_pool import AsyncConnectionPool
-from .errors import PoolClosed, PoolTimeout
+from .errors import PoolClosed, PoolTimeout, TooManyRequests
 
 __all__ = [
     "AsyncConnectionPool",
     "ConnectionPool",
     "PoolClosed",
     "PoolTimeout",
+    "TooManyRequests",
 ]
index e268fcd81a5c3825a8a254b9c912cfd6ea9bdd1d..f901ebc824eeadbb2816ccf2573d5bf6a6f1aff1 100644 (file)
@@ -22,7 +22,7 @@ from ..utils.compat import asynccontextmanager, create_task
 
 from .base import ConnectionAttempt, BasePool
 from .sched import AsyncScheduler
-from .errors import PoolClosed, PoolTimeout
+from .errors import PoolClosed, PoolTimeout, TooManyRequests
 
 logger = logging.getLogger(__name__)
 
@@ -118,7 +118,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
     async def getconn(
         self, timeout: Optional[float] = None
     ) -> AsyncConnection:
-        logger.info("connection requested to %r", self.name)
+        logger.info("connection requested from %r", self.name)
         self._stats[self._REQUESTS_NUM] += 1
         # Critical section: decide here if there's a connection ready
         # or if the client needs to wait.
@@ -133,6 +133,12 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
                 if len(self._pool) < self._nconns_min:
                     self._nconns_min = len(self._pool)
             else:
+                if self.max_waiting and len(self._waiting) >= self.max_waiting:
+                    raise TooManyRequests(
+                        f"the pool {self.name!r} has aleady"
+                        f" {len(self._waiting)} requests waiting"
+                    )
+
                 # No connection available: put the client in the waiting queue
                 t0 = monotonic()
                 pos = AsyncClient()
index 4fba7862e4781a7f7bd60d19a9d9b366a958ed7a..03d9d860c477b3bc79b65b9329917e351c95f325 100644 (file)
@@ -50,6 +50,7 @@ class BasePool(Generic[ConnectionType]):
         maxconn: Optional[int] = None,
         name: Optional[str] = None,
         timeout: float = 30.0,
+        max_waiting: int = 0,
         max_lifetime: float = 60 * 60.0,
         max_idle: float = 10 * 60.0,
         reconnect_timeout: float = 5 * 60.0,
@@ -77,6 +78,7 @@ class BasePool(Generic[ConnectionType]):
         self._minconn = minconn
         self._maxconn = maxconn
         self.timeout = timeout
+        self.max_waiting = max_waiting
         self.reconnect_timeout = reconnect_timeout
         self.max_lifetime = max_lifetime
         self.max_idle = max_idle
index 12f8fa64a8b81b7754b70d620a6370f624aa2729..23eef69d0f87321a28f338fdfdb5afc13d790e06 100644 (file)
@@ -17,3 +17,9 @@ class PoolTimeout(e.OperationalError):
     """The pool couldn't provide a connection in acceptable time."""
 
     __module__ = "psycopg3.pool"
+
+
+class TooManyRequests(e.OperationalError):
+    """Too many requests in the queue waiting for a connection from the pool."""
+
+    __module__ = "psycopg3.pool"
index d6e0860f051e4c591d9c84639e790be3c15752c0..ae58d36bfb9674f2516e06c7f8f519c6b9b80463 100644 (file)
@@ -21,7 +21,7 @@ from ..connection import Connection
 
 from .base import ConnectionAttempt, BasePool
 from .sched import Scheduler
-from .errors import PoolClosed, PoolTimeout
+from .errors import PoolClosed, PoolTimeout, TooManyRequests
 
 logger = logging.getLogger(__name__)
 
@@ -132,8 +132,8 @@ class ConnectionPool(BasePool[Connection]):
         """Context manager to obtain a connection from the pool.
 
         Returned the connection immediately if available, otherwise wait up to
-        *timeout* or `self.timeout` and throw `PoolTimeout` if a connection is
-        not available in time.
+        *timeout* or `self.timeout` seconds and throw `PoolTimeout` if a
+        connection is not available in time.
 
         Upon context exit, return the connection to the pool. Apply the normal
         :ref:`connection context behaviour <with-connection>` (commit/rollback
@@ -161,7 +161,7 @@ class ConnectionPool(BasePool[Connection]):
         failing to do so will deplete the pool. A depleted pool is a sad pool:
         you don't want a depleted pool.
         """
-        logger.info("connection requested to %r", self.name)
+        logger.info("connection requested from %r", self.name)
         self._stats[self._REQUESTS_NUM] += 1
         # Critical section: decide here if there's a connection ready
         # or if the client needs to wait.
@@ -176,6 +176,12 @@ class ConnectionPool(BasePool[Connection]):
                 if len(self._pool) < self._nconns_min:
                     self._nconns_min = len(self._pool)
             else:
+                if self.max_waiting and len(self._waiting) >= self.max_waiting:
+                    raise TooManyRequests(
+                        f"the pool {self.name!r} has aleady"
+                        f" {len(self._waiting)} requests waiting"
+                    )
+
                 # No connection available: put the client in the waiting queue
                 t0 = monotonic()
                 pos = WaitingClient()
index 8be90ca3710afa95ac0939b5c6149756829790da..64a728aa7f8541ab185e18ed42f3548f2412c7ae 100644 (file)
@@ -2,7 +2,7 @@ import sys
 import logging
 import weakref
 from time import sleep, time
-from threading import Thread
+from threading import Thread, Event
 from collections import Counter
 
 import pytest
@@ -291,6 +291,40 @@ def test_queue(dsn, retries):
             assert len(set(r[2] for r in results)) == 2, results
 
 
+@pytest.mark.slow
+def test_queue_size(dsn):
+    def worker(t, ev=None):
+        try:
+            with p.connection():
+                if ev:
+                    ev.set()
+                sleep(t)
+        except pool.TooManyRequests as e:
+            errors.append(e)
+        else:
+            success.append(True)
+
+    errors = []
+    success = []
+
+    with pool.ConnectionPool(dsn, minconn=1, max_waiting=3) as p:
+        p.wait()
+        ev = Event()
+        t = Thread(target=worker, args=(0.3, ev))
+        t.start()
+        ev.wait()
+
+        ts = [Thread(target=worker, args=(0.1,)) for i in range(4)]
+        [t.start() for t in ts]
+        [t.join() for t in ts]
+
+    assert len(success) == 4
+    assert len(errors) == 1
+    assert isinstance(errors[0], pool.TooManyRequests)
+    assert p.name in str(errors[0])
+    assert str(p.max_waiting) in str(errors[0])
+
+
 @pytest.mark.slow
 def test_queue_timeout(dsn):
     def worker(n):
index 59159496a17bc00e08c7fa9c645abd7000cd000e..3c92e7ba32a86dbe83c50c146707a64952ddf6fc 100644 (file)
@@ -309,6 +309,38 @@ async def test_queue(dsn, retries):
             assert len(set(r[2] for r in results)) == 2, results
 
 
+@pytest.mark.slow
+async def test_queue_size(dsn):
+    async def worker(t, ev=None):
+        try:
+            async with p.connection():
+                if ev:
+                    ev.set()
+                await asyncio.sleep(t)
+        except pool.TooManyRequests as e:
+            errors.append(e)
+        else:
+            success.append(True)
+
+    errors = []
+    success = []
+
+    async with pool.AsyncConnectionPool(dsn, minconn=1, max_waiting=3) as p:
+        await p.wait()
+        ev = asyncio.Event()
+        create_task(worker(0.3, ev))
+        await ev.wait()
+
+        ts = [create_task(worker(0.1)) for i in range(4)]
+        await asyncio.gather(*ts)
+
+    assert len(success) == 4
+    assert len(errors) == 1
+    assert isinstance(errors[0], pool.TooManyRequests)
+    assert p.name in str(errors[0])
+    assert str(p.max_waiting) in str(errors[0])
+
+
 @pytest.mark.slow
 async def test_queue_timeout(dsn):
     async def worker(n):