]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Fix async pool cancellation handoff race 1277/head
authorIlia Ablamonov <ilia@flamefork.ru>
Tue, 10 Mar 2026 15:46:40 +0000 (16:46 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 19 Mar 2026 11:03:51 +0000 (12:03 +0100)
Fixes #1275

docs/news_pool.rst
psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py
tests/pool/test_pool_async.py

index 5b5fcea79680f7830a06f8cb9056c4468e3f6296..fc1e798b86dc471715c7fa4639ab435f696e1f58 100644 (file)
@@ -7,6 +7,16 @@
 ``psycopg_pool`` release notes
 ==============================
 
+Future releases
+---------------
+
+psycopg_pool 3.3.1 (unreleased)
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+- Fix residual race condition catching `~asyncio.CancelledError` on connection
+  (:ticket:`#1275`).
+
+
 Current release
 ---------------
 
index 245907f67a05450978b99fe3981ea4e6024e6cd0..156f496cc37327241ff85c79b170af8cc041a0b0 100644 (file)
@@ -261,6 +261,8 @@ class ConnectionPool(Generic[CT], BasePool):
             try:
                 conn = pos.wait(timeout=timeout)
             except CLIENT_EXCEPTIONS:
+                if pos.conn:
+                    self.run_task(ReturnConnection(self, pos.conn, from_getconn=True))
                 self._stats[self._REQUESTS_ERRORS] += 1
                 raise
             finally:
@@ -894,11 +896,11 @@ class WaitingClient(Generic[CT]):
                 except CLIENT_EXCEPTIONS as ex:
                     self.error = ex
 
-        if self.conn:
-            return self.conn
-        else:
-            assert self.error
+        if self.error:
             raise self.error
+        else:
+            assert self.conn
+            return self.conn
 
     def set(self, conn: CT) -> bool:
         """Signal the client waiting that a connection is ready.
index 254e1367d17eb95c38a7f1c81f5ff97c5818b857..5ae43844cfbdfa1ef22e786dc702adc25eeac195 100644 (file)
@@ -298,6 +298,8 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
             try:
                 conn = await pos.wait(timeout=timeout)
             except CLIENT_EXCEPTIONS:
+                if pos.conn:
+                    self.run_task(ReturnConnection(self, pos.conn, from_getconn=True))
                 self._stats[self._REQUESTS_ERRORS] += 1
                 raise
             finally:
@@ -957,11 +959,11 @@ class WaitingClient(Generic[ACT]):
                 except CLIENT_EXCEPTIONS as ex:
                     self.error = ex
 
-        if self.conn:
-            return self.conn
-        else:
-            assert self.error
+        if self.error:
             raise self.error
+        else:
+            assert self.conn
+            return self.conn
 
     async def set(self, conn: ACT) -> bool:
         """Signal the client waiting that a connection is ready.
index 82d89650b50c0ffbbe671cce51c8fb783c363afb..bd9915ccbd7d645f39b586d628341c6d1ca33eb7 100644 (file)
@@ -996,6 +996,69 @@ async def test_cancellation_in_queue(dsn):
             assert await cur.fetchone() == (1,)
 
 
+@skip_sync
+@pytest.mark.crdb_skip("backend pid")
+async def test_cancelled_waiter_assigned_conn_is_reclaimed(dsn, monkeypatch):
+    from asyncio import CancelledError
+
+    from psycopg_pool.pool_async import WaitingClient
+
+    from .test_pool_common_async import ensure_waiting
+
+    assigned = AEvent()
+    release = AEvent()
+
+    async def set_blocked(self, conn):
+        async with self._cond:
+            if self.conn or self.error:
+                return False
+
+            self.conn = conn
+            assigned.set()
+            await release.wait()
+            self._cond.notify_all()
+            return True
+
+    monkeypatch.setattr(WaitingClient, "set", set_blocked)
+
+    async with pool.AsyncConnectionPool(dsn, min_size=1, max_size=1, timeout=1) as p:
+        await p.wait()
+
+        held_conn = await p.getconn()
+        held_pid = held_conn.info.backend_pid
+        waiter = spawn(p.getconn)
+        await ensure_waiting(p)
+
+        putter = spawn(p.putconn, args=(held_conn,))
+        await assigned.wait()
+
+        waiter.cancel()
+        release.set()
+
+        try:
+            unexpected_conn = await waiter
+        except CancelledError:
+            pass
+        else:
+            await p.putconn(unexpected_conn)
+            pytest.fail("cancelled waiter returned a connection instead of raising")
+
+        await gather(putter)
+
+        stats = p.get_stats()
+        assert stats["pool_available"] == 1
+        assert stats.get("requests_waiting", 0) == 0
+        assert stats["requests_errors"] == 1
+
+        reclaimed_conn = await p.getconn()
+        try:
+            assert reclaimed_conn.info.backend_pid == held_pid
+            cur = await reclaimed_conn.execute("select 1")
+            assert await cur.fetchone() == (1,)
+        finally:
+            await p.putconn(reclaimed_conn)
+
+
 @pytest.mark.slow
 @pytest.mark.timing
 async def test_check_backoff(dsn, caplog, monkeypatch):