From: Ilia Ablamonov Date: Tue, 10 Mar 2026 15:46:40 +0000 (+0100) Subject: Fix async pool cancellation handoff race X-Git-Tag: pool-3.3.1~12^2 X-Git-Url: http://git.ipfire.org/gitweb/index.cgi?a=commitdiff_plain;h=refs%2Fpull%2F1277%2Fhead;p=thirdparty%2Fpsycopg.git Fix async pool cancellation handoff race Fixes #1275 --- diff --git a/docs/news_pool.rst b/docs/news_pool.rst index 5b5fcea79..fc1e798b8 100644 --- a/docs/news_pool.rst +++ b/docs/news_pool.rst @@ -7,6 +7,16 @@ ``psycopg_pool`` release notes ============================== +Future releases +--------------- + +psycopg_pool 3.3.1 (unreleased) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +- Fix residual race condition catching `~asyncio.CancelledError` on connection + (:ticket:`#1275`). + + Current release --------------- diff --git a/psycopg_pool/psycopg_pool/pool.py b/psycopg_pool/psycopg_pool/pool.py index 245907f67..156f496cc 100644 --- a/psycopg_pool/psycopg_pool/pool.py +++ b/psycopg_pool/psycopg_pool/pool.py @@ -261,6 +261,8 @@ class ConnectionPool(Generic[CT], BasePool): try: conn = pos.wait(timeout=timeout) except CLIENT_EXCEPTIONS: + if pos.conn: + self.run_task(ReturnConnection(self, pos.conn, from_getconn=True)) self._stats[self._REQUESTS_ERRORS] += 1 raise finally: @@ -894,11 +896,11 @@ class WaitingClient(Generic[CT]): except CLIENT_EXCEPTIONS as ex: self.error = ex - if self.conn: - return self.conn - else: - assert self.error + if self.error: raise self.error + else: + assert self.conn + return self.conn def set(self, conn: CT) -> bool: """Signal the client waiting that a connection is ready. diff --git a/psycopg_pool/psycopg_pool/pool_async.py b/psycopg_pool/psycopg_pool/pool_async.py index 254e1367d..5ae43844c 100644 --- a/psycopg_pool/psycopg_pool/pool_async.py +++ b/psycopg_pool/psycopg_pool/pool_async.py @@ -298,6 +298,8 @@ class AsyncConnectionPool(Generic[ACT], BasePool): try: conn = await pos.wait(timeout=timeout) except CLIENT_EXCEPTIONS: + if pos.conn: + self.run_task(ReturnConnection(self, pos.conn, from_getconn=True)) self._stats[self._REQUESTS_ERRORS] += 1 raise finally: @@ -957,11 +959,11 @@ class WaitingClient(Generic[ACT]): except CLIENT_EXCEPTIONS as ex: self.error = ex - if self.conn: - return self.conn - else: - assert self.error + if self.error: raise self.error + else: + assert self.conn + return self.conn async def set(self, conn: ACT) -> bool: """Signal the client waiting that a connection is ready. diff --git a/tests/pool/test_pool_async.py b/tests/pool/test_pool_async.py index 82d89650b..bd9915ccb 100644 --- a/tests/pool/test_pool_async.py +++ b/tests/pool/test_pool_async.py @@ -996,6 +996,69 @@ async def test_cancellation_in_queue(dsn): assert await cur.fetchone() == (1,) +@skip_sync +@pytest.mark.crdb_skip("backend pid") +async def test_cancelled_waiter_assigned_conn_is_reclaimed(dsn, monkeypatch): + from asyncio import CancelledError + + from psycopg_pool.pool_async import WaitingClient + + from .test_pool_common_async import ensure_waiting + + assigned = AEvent() + release = AEvent() + + async def set_blocked(self, conn): + async with self._cond: + if self.conn or self.error: + return False + + self.conn = conn + assigned.set() + await release.wait() + self._cond.notify_all() + return True + + monkeypatch.setattr(WaitingClient, "set", set_blocked) + + async with pool.AsyncConnectionPool(dsn, min_size=1, max_size=1, timeout=1) as p: + await p.wait() + + held_conn = await p.getconn() + held_pid = held_conn.info.backend_pid + waiter = spawn(p.getconn) + await ensure_waiting(p) + + putter = spawn(p.putconn, args=(held_conn,)) + await assigned.wait() + + waiter.cancel() + release.set() + + try: + unexpected_conn = await waiter + except CancelledError: + pass + else: + await p.putconn(unexpected_conn) + pytest.fail("cancelled waiter returned a connection instead of raising") + + await gather(putter) + + stats = p.get_stats() + assert stats["pool_available"] == 1 + assert stats.get("requests_waiting", 0) == 0 + assert stats["requests_errors"] == 1 + + reclaimed_conn = await p.getconn() + try: + assert reclaimed_conn.info.backend_pid == held_pid + cur = await reclaimed_conn.execute("select 1") + assert await cur.fetchone() == (1,) + finally: + await p.putconn(reclaimed_conn) + + @pytest.mark.slow @pytest.mark.timing async def test_check_backoff(dsn, caplog, monkeypatch):