From: Daniele Varrazzo Date: Thu, 16 Mar 2023 21:07:05 +0000 (+0100) Subject: fix(pool): fix handling of errors in queued async tasks X-Git-Tag: pool-3.1.7~2^2~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d87562392b4be36bb35b2b975041fcb0dfac11da;p=thirdparty%2Fpsycopg.git fix(pool): fix handling of errors in queued async tasks Failing to do so, cancelled tasks still in the queue end up consuming a connection without a chance of returning it, depleting the pool. Close #509 --- diff --git a/docs/news_pool.rst b/docs/news_pool.rst index 3335b1084..d63d96d8b 100644 --- a/docs/news_pool.rst +++ b/docs/news_pool.rst @@ -7,6 +7,16 @@ ``psycopg_pool`` release notes ============================== +Future releases +--------------- + +psycopg_pool 3.1.7 (unreleased) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +- Fix handling of tasks cancelled while waiting in async pool queue + (:ticket:`#503`). + + Current release --------------- diff --git a/psycopg_pool/psycopg_pool/pool_async.py b/psycopg_pool/psycopg_pool/pool_async.py index 1cffcce68..74a30dd28 100644 --- a/psycopg_pool/psycopg_pool/pool_async.py +++ b/psycopg_pool/psycopg_pool/pool_async.py @@ -640,7 +640,7 @@ class AsyncClient: def __init__(self) -> None: self.conn: Optional[AsyncConnection[Any]] = None - self.error: Optional[Exception] = None + self.error: Optional[BaseException] = None # The AsyncClient behaves in a way similar to an Event, but we need # to notify reliably the flagger that the waiter has "accepted" the @@ -662,6 +662,8 @@ class AsyncClient: self.error = PoolTimeout( f"couldn't get a connection after {timeout} sec" ) + except BaseException as ex: + self.error = ex if self.conn: return self.conn diff --git a/tests/pool/test_null_pool_async.py b/tests/pool/test_null_pool_async.py index fea47fbf4..f3f16726b 100644 --- a/tests/pool/test_null_pool_async.py +++ b/tests/pool/test_null_pool_async.py @@ -839,3 +839,61 @@ async def test_stats_connect(dsn, proxy, monkeypatch): assert stats.get("connections_errors", 0) == 0 assert stats.get("connections_lost", 0) == 0 assert 200 <= stats["connections_ms"] < 300 + + +async def test_cancellation_in_queue(dsn): + # https://github.com/psycopg/psycopg/issues/509 + + nconns = 3 + + async with AsyncNullConnectionPool( + dsn, min_size=0, max_size=nconns, timeout=1 + ) as p: + await p.wait() + + got_conns = [] + ev = asyncio.Event() + + async def worker(i): + try: + logging.info("worker %s started", i) + nonlocal got_conns + + async with p.connection() as conn: + logging.info("worker %s got conn", i) + cur = await conn.execute("select 1") + assert (await cur.fetchone()) == (1,) + + got_conns.append(conn) + if len(got_conns) >= nconns: + ev.set() + + while True: + await asyncio.sleep(10) + + except BaseException as ex: + logging.info("worker %s stopped: %r", i, ex) + raise + + # Start tasks taking up all the connections and getting in the queue + tasks = [asyncio.ensure_future(worker(i)) for i in range(nconns * 3)] + + # wait until the pool has served all the connections and clients are queued. + await ev.wait() + for i in range(10): + if p.get_stats().get("requests_queued", 0): + break + else: + await asyncio.sleep(0.1) + else: + pytest.fail("no client got in the queue") + + [task.cancel() for task in reversed(tasks)] + await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), 1.0) + + stats = p.get_stats() + assert stats.get("requests_waiting", 0) == 0 + + async with p.connection() as conn: + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) diff --git a/tests/pool/test_pool_async.py b/tests/pool/test_pool_async.py index 1f16ae2f3..668643ee7 100644 --- a/tests/pool/test_pool_async.py +++ b/tests/pool/test_pool_async.py @@ -1176,6 +1176,63 @@ async def test_debug_deadlock(dsn): logger.setLevel(old_level) +async def test_cancellation_in_queue(dsn): + # https://github.com/psycopg/psycopg/issues/509 + + nconns = 3 + + async with pool.AsyncConnectionPool(dsn, min_size=nconns, timeout=1) as p: + await p.wait() + + got_conns = [] + ev = asyncio.Event() + + async def worker(i): + try: + logging.info("worker %s started", i) + nonlocal got_conns + + async with p.connection() as conn: + logging.info("worker %s got conn", i) + cur = await conn.execute("select 1") + assert (await cur.fetchone()) == (1,) + + got_conns.append(conn) + if len(got_conns) >= nconns: + ev.set() + + while True: + await asyncio.sleep(10) + + except BaseException as ex: + logging.info("worker %s stopped: %r", i, ex) + raise + + # Start tasks taking up all the connections and getting in the queue + tasks = [asyncio.ensure_future(worker(i)) for i in range(nconns * 3)] + + # wait until the pool has served all the connections and clients are queued. + await ev.wait() + for i in range(10): + if p.get_stats().get("requests_queued", 0): + break + else: + await asyncio.sleep(0.1) + else: + pytest.fail("no client got in the queue") + + [task.cancel() for task in reversed(tasks)] + await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), 1.0) + + stats = p.get_stats() + assert stats["pool_available"] == 3 + assert stats.get("requests_waiting", 0) == 0 + + async with p.connection() as conn: + cur = await conn.execute("select 1") + assert await cur.fetchone() == (1,) + + def delay_connection(monkeypatch, sec): """ Return a _connect_gen function delayed by the amount of seconds