]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix(pool): manage CancelledError in some exception handling path
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 19 Nov 2025 23:51:43 +0000 (00:51 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 21 Nov 2025 16:01:54 +0000 (17:01 +0100)
If a CancelledError was raised during check the connection would have
been lost. The exception would have bubbled up but likely users are
using some framework swallowing it because nobody reporting the "lost
connections" issue actually reported the CancelledError.

Close #1123
Close #1208

docs/news_pool.rst
psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py
tests/pool/test_pool_common.py
tests/pool/test_pool_common_async.py

index dd65a2d98878dadffeca19fa1778b96c505f83e3..18ac1cdadd61a644b2cddb821b8d59db7dd6a1ac 100644 (file)
@@ -19,6 +19,13 @@ psycopg_pool 3.3.0 (unreleased)
   parameters# update (:ticket:`#851`).
 
 
+psycopg_pool 3.2.8 (unreleased)
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+- Don't lose connections if a `~asyncio.CancelledError` is raised in a check
+  (:tickets:`#1123, #1208`)
+
+
 Current release
 ---------------
 
index e4a58c8d9c87516b3449dc7b5232012bd9b31ce3..821a2734bdc1b5dae88ca438c7b7a15d97b03dc4 100644 (file)
@@ -15,6 +15,7 @@ from abc import ABC, abstractmethod
 from time import monotonic
 from types import TracebackType
 from typing import Any, Generic, cast
+from asyncio import CancelledError
 from weakref import ref
 from contextlib import contextmanager
 from collections import deque
@@ -220,7 +221,7 @@ class ConnectionPool(Generic[CT], BasePool):
             conn = self._getconn_unchecked(deadline - monotonic())
             try:
                 self._check_connection(conn)
-            except Exception:
+            except (Exception, CancelledError):
                 self._putconn(conn, from_getconn=True)
             else:
                 logger.info("connection given by %r", self.name)
@@ -258,7 +259,7 @@ class ConnectionPool(Generic[CT], BasePool):
         if not conn:
             try:
                 conn = pos.wait(timeout=timeout)
-            except Exception:
+            except BaseException:
                 self._stats[self._REQUESTS_ERRORS] += 1
                 raise
             finally:
@@ -294,7 +295,7 @@ class ConnectionPool(Generic[CT], BasePool):
             return
         try:
             self._check(conn)
-        except Exception as e:
+        except BaseException as e:
             logger.info("connection failed check: %s", e)
             raise
 
@@ -526,7 +527,7 @@ class ConnectionPool(Generic[CT], BasePool):
             # Check for broken connections
             try:
                 self.check_connection(conn)
-            except Exception:
+            except (Exception, CancelledError):
                 self._stats[self._CONNECTIONS_LOST] += 1
                 logger.warning("discarding broken connection: %s", conn)
                 self.run_task(AddConnection(self))
@@ -587,7 +588,7 @@ class ConnectionPool(Generic[CT], BasePool):
             # Run the task. Make sure don't die in the attempt.
             try:
                 task.run()
-            except Exception as ex:
+            except (Exception, CancelledError) as ex:
                 logger.warning(
                     "task run %s failed: %s: %s", task, ex.__class__.__name__, ex
                 )
@@ -603,7 +604,7 @@ class ConnectionPool(Generic[CT], BasePool):
         t0 = monotonic()
         try:
             conn = self.connection_class.connect(conninfo, **kwargs)
-        except Exception:
+        except (Exception, CancelledError):
             self._stats[self._CONNECTIONS_ERRORS] += 1
             raise
         else:
@@ -657,7 +658,7 @@ class ConnectionPool(Generic[CT], BasePool):
 
         try:
             conn = self._connect()
-        except Exception as ex:
+        except (Exception, CancelledError) as ex:
             logger.warning("error connecting in %r: %s", self.name, ex)
             if attempt.time_to_give_up(now):
                 logger.warning(
@@ -803,7 +804,7 @@ class ConnectionPool(Generic[CT], BasePool):
                     raise e.ProgrammingError(
                         f"connection left in status {sname} by reset function {self._reset}: discarded"
                     )
-            except Exception as ex:
+            except (Exception, CancelledError) as ex:
                 logger.warning("error resetting connection: %s", ex)
                 self._close_connection(conn)
 
index b3242af7eca467592129411e8ca9bba278d61f99..1c26ff81bb2940afc44c059d8f8df019238b11f3 100644 (file)
@@ -12,6 +12,7 @@ from abc import ABC, abstractmethod
 from time import monotonic
 from types import TracebackType
 from typing import Any, Generic, cast
+from asyncio import CancelledError
 from weakref import ref
 from contextlib import asynccontextmanager
 from collections import deque
@@ -252,7 +253,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
             conn = await self._getconn_unchecked(deadline - monotonic())
             try:
                 await self._check_connection(conn)
-            except Exception:
+            except (Exception, CancelledError):
                 await self._putconn(conn, from_getconn=True)
             else:
                 logger.info("connection given by %r", self.name)
@@ -290,7 +291,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         if not conn:
             try:
                 conn = await pos.wait(timeout=timeout)
-            except Exception:
+            except BaseException:
                 self._stats[self._REQUESTS_ERRORS] += 1
                 raise
             finally:
@@ -327,7 +328,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
             return
         try:
             await self._check(conn)
-        except Exception as e:
+        except BaseException as e:
             logger.info("connection failed check: %s", e)
             raise
 
@@ -566,7 +567,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
             # Check for broken connections
             try:
                 await self.check_connection(conn)
-            except Exception:
+            except (Exception, CancelledError):
                 self._stats[self._CONNECTIONS_LOST] += 1
                 logger.warning("discarding broken connection: %s", conn)
                 self.run_task(AddConnection(self))
@@ -637,7 +638,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
             # Run the task. Make sure don't die in the attempt.
             try:
                 await task.run()
-            except Exception as ex:
+            except (Exception, CancelledError) as ex:
                 logger.warning(
                     "task run %s failed: %s: %s", task, ex.__class__.__name__, ex
                 )
@@ -653,7 +654,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         t0 = monotonic()
         try:
             conn = await self.connection_class.connect(conninfo, **kwargs)
-        except Exception:
+        except (Exception, CancelledError):
             self._stats[self._CONNECTIONS_ERRORS] += 1
             raise
         else:
@@ -708,7 +709,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
 
         try:
             conn = await self._connect()
-        except Exception as ex:
+        except (Exception, CancelledError) as ex:
             logger.warning("error connecting in %r: %s", self.name, ex)
             if attempt.time_to_give_up(now):
                 logger.warning(
@@ -860,7 +861,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
                         f"connection left in status {sname} by reset function"
                         f" {self._reset}: discarded"
                     )
-            except Exception as ex:
+            except (Exception, CancelledError) as ex:
                 logger.warning("error resetting connection: %s", ex)
                 await self._close_connection(conn)
 
index 6d1726a43e708cfc0bdff2c2a72d5b534c288cf2..19bb2fe7a9d2ae3ffac4ce6fb190f5c8e8da0b89 100644 (file)
@@ -6,6 +6,7 @@ from __future__ import annotations
 import logging
 from time import time
 from typing import Any
+from asyncio import CancelledError
 
 import pytest
 
@@ -695,6 +696,29 @@ def test_cancellation_in_queue(pool_cls, dsn):
             assert cur.fetchone() == (1,)
 
 
+@skip_sync
+def test_cancel_on_check(pool_cls, dsn):
+    do_cancel = True
+
+    def check(conn):
+        nonlocal do_cancel
+        if do_cancel:
+            do_cancel = False
+            raise CancelledError()
+
+        pool_cls.check_connection(conn)
+
+    with pool_cls(dsn, min_size=min_size(pool_cls, 1), check=check, timeout=1.0) as p:
+        try:
+            with p.connection() as conn:
+                conn.execute("select 1")
+        except CancelledError:
+            pass
+
+        with p.connection() as conn:
+            conn.execute("select 1")
+
+
 def min_size(pool_cls, num=1):
     """Return the minimum min_size supported by the pool class."""
     if pool_cls is pool.ConnectionPool:
index c61cb0ffbf355fb38c6eada4c9c394e86f729247..eac43973e64f8fde9be804ef869df3a791a45bb9 100644 (file)
@@ -3,6 +3,7 @@ from __future__ import annotations
 import logging
 from time import time
 from typing import Any
+from asyncio import CancelledError
 
 import pytest
 
@@ -706,6 +707,31 @@ async def test_cancellation_in_queue(pool_cls, dsn):
             assert await cur.fetchone() == (1,)
 
 
+@skip_sync
+async def test_cancel_on_check(pool_cls, dsn):
+    do_cancel = True
+
+    async def check(conn):
+        nonlocal do_cancel
+        if do_cancel:
+            do_cancel = False
+            raise CancelledError()
+
+        await pool_cls.check_connection(conn)
+
+    async with pool_cls(
+        dsn, min_size=min_size(pool_cls, 1), check=check, timeout=1.0
+    ) as p:
+        try:
+            async with p.connection() as conn:
+                await conn.execute("select 1")
+        except CancelledError:
+            pass
+
+        async with p.connection() as conn:
+            await conn.execute("select 1")
+
+
 def min_size(pool_cls, num=1):
     """Return the minimum min_size supported by the pool class."""
     if pool_cls is pool.AsyncConnectionPool: