]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix(pool): fix handling of errors in queued async tasks
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 16 Mar 2023 21:07:05 +0000 (22:07 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 16 Mar 2023 21:21:41 +0000 (22:21 +0100)
Failing to do so, cancelled tasks still in the queue end up consuming
a connection without a chance of returning it, depleting the pool.

Close #509

docs/news_pool.rst
psycopg_pool/psycopg_pool/pool_async.py
tests/pool/test_null_pool_async.py
tests/pool/test_pool_async.py

index 3335b10843c9370a7b81fee082fcfb882f98139c..d63d96d8b1d50764b59a3027fffc71969e70ec87 100644 (file)
@@ -7,6 +7,16 @@
 ``psycopg_pool`` release notes
 ==============================
 
+Future releases
+---------------
+
+psycopg_pool 3.1.7 (unreleased)
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+- Fix handling of tasks cancelled while waiting in async pool queue
+  (:ticket:`#503`).
+
+
 Current release
 ---------------
 
index 1cffcce68c61e874c3d7823fe61f76c1fee70807..74a30dd2835b97a2ca43c5453fcad2a88ddef8e1 100644 (file)
@@ -640,7 +640,7 @@ class AsyncClient:
 
     def __init__(self) -> None:
         self.conn: Optional[AsyncConnection[Any]] = None
-        self.error: Optional[Exception] = None
+        self.error: Optional[BaseException] = None
 
         # The AsyncClient behaves in a way similar to an Event, but we need
         # to notify reliably the flagger that the waiter has "accepted" the
@@ -662,6 +662,8 @@ class AsyncClient:
                     self.error = PoolTimeout(
                         f"couldn't get a connection after {timeout} sec"
                     )
+                except BaseException as ex:
+                    self.error = ex
 
         if self.conn:
             return self.conn
index fea47fbf45af566cfd1c27793ccfb8bf9129f6d7..f3f16726bcea29ea513cf3e933b32acad2b24d6f 100644 (file)
@@ -839,3 +839,61 @@ async def test_stats_connect(dsn, proxy, monkeypatch):
         assert stats.get("connections_errors", 0) == 0
         assert stats.get("connections_lost", 0) == 0
         assert 200 <= stats["connections_ms"] < 300
+
+
+async def test_cancellation_in_queue(dsn):
+    # https://github.com/psycopg/psycopg/issues/509
+
+    nconns = 3
+
+    async with AsyncNullConnectionPool(
+        dsn, min_size=0, max_size=nconns, timeout=1
+    ) as p:
+        await p.wait()
+
+        got_conns = []
+        ev = asyncio.Event()
+
+        async def worker(i):
+            try:
+                logging.info("worker %s started", i)
+                nonlocal got_conns
+
+                async with p.connection() as conn:
+                    logging.info("worker %s got conn", i)
+                    cur = await conn.execute("select 1")
+                    assert (await cur.fetchone()) == (1,)
+
+                    got_conns.append(conn)
+                    if len(got_conns) >= nconns:
+                        ev.set()
+
+                    while True:
+                        await asyncio.sleep(10)
+
+            except BaseException as ex:
+                logging.info("worker %s stopped: %r", i, ex)
+                raise
+
+        # Start tasks taking up all the connections and getting in the queue
+        tasks = [asyncio.ensure_future(worker(i)) for i in range(nconns * 3)]
+
+        # wait until the pool has served all the connections and clients are queued.
+        await ev.wait()
+        for i in range(10):
+            if p.get_stats().get("requests_queued", 0):
+                break
+            else:
+                await asyncio.sleep(0.1)
+        else:
+            pytest.fail("no client got in the queue")
+
+        [task.cancel() for task in reversed(tasks)]
+        await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), 1.0)
+
+        stats = p.get_stats()
+        assert stats.get("requests_waiting", 0) == 0
+
+        async with p.connection() as conn:
+            cur = await conn.execute("select 1")
+            assert await cur.fetchone() == (1,)
index 1f16ae2f3d73c57c347db3ca4842107c116a9ba0..668643ee7bc70daa812092c050392e321c11c140 100644 (file)
@@ -1176,6 +1176,63 @@ async def test_debug_deadlock(dsn):
         logger.setLevel(old_level)
 
 
+async def test_cancellation_in_queue(dsn):
+    # https://github.com/psycopg/psycopg/issues/509
+
+    nconns = 3
+
+    async with pool.AsyncConnectionPool(dsn, min_size=nconns, timeout=1) as p:
+        await p.wait()
+
+        got_conns = []
+        ev = asyncio.Event()
+
+        async def worker(i):
+            try:
+                logging.info("worker %s started", i)
+                nonlocal got_conns
+
+                async with p.connection() as conn:
+                    logging.info("worker %s got conn", i)
+                    cur = await conn.execute("select 1")
+                    assert (await cur.fetchone()) == (1,)
+
+                    got_conns.append(conn)
+                    if len(got_conns) >= nconns:
+                        ev.set()
+
+                    while True:
+                        await asyncio.sleep(10)
+
+            except BaseException as ex:
+                logging.info("worker %s stopped: %r", i, ex)
+                raise
+
+        # Start tasks taking up all the connections and getting in the queue
+        tasks = [asyncio.ensure_future(worker(i)) for i in range(nconns * 3)]
+
+        # wait until the pool has served all the connections and clients are queued.
+        await ev.wait()
+        for i in range(10):
+            if p.get_stats().get("requests_queued", 0):
+                break
+            else:
+                await asyncio.sleep(0.1)
+        else:
+            pytest.fail("no client got in the queue")
+
+        [task.cancel() for task in reversed(tasks)]
+        await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), 1.0)
+
+        stats = p.get_stats()
+        assert stats["pool_available"] == 3
+        assert stats.get("requests_waiting", 0) == 0
+
+        async with p.connection() as conn:
+            cur = await conn.execute("select 1")
+            assert await cur.fetchone() == (1,)
+
+
 def delay_connection(monkeypatch, sec):
     """
     Return a _connect_gen function delayed by the amount of seconds