]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix(pool): trap CancelledError more consistently in the pool codebase
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 21 Nov 2025 10:36:39 +0000 (11:36 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 21 Nov 2025 16:03:01 +0000 (17:03 +0100)
Include also places that were left out such as the rollback and the task
scheduling.

Note that we are relaxing the exception handler we had set up to fix the
problem with cancelled clients on wait (#509): we only had to trap
CancelledError additionally but we started managing the whole
BaseException. I don't think that trapping KeyboardException or
SystemExit without re-raising is a good idea (I think that, for
robustness, we should, but then things become very verbose and not
necessarily correct).

psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py
psycopg_pool/psycopg_pool/sched.py
psycopg_pool/psycopg_pool/sched_async.py
tests/pool/test_pool_common.py
tests/pool/test_pool_common_async.py

index 821a2734bdc1b5dae88ca438c7b7a15d97b03dc4..a5f062fdaf384bde451e2d3cf6f8a47c58f9357e 100644 (file)
@@ -15,7 +15,6 @@ from abc import ABC, abstractmethod
 from time import monotonic
 from types import TracebackType
 from typing import Any, Generic, cast
-from asyncio import CancelledError
 from weakref import ref
 from contextlib import contextmanager
 from collections import deque
@@ -33,6 +32,8 @@ from ._compat import PSYCOPG_VERSION, PoolConnection, Self
 from ._acompat import Condition, Event, Lock, Queue, Worker, current_thread_name
 from ._acompat import gather, sleep, spawn
 
+CLIENT_EXCEPTIONS = Exception
+
 logger = logging.getLogger("psycopg.pool")
 
 
@@ -221,7 +222,7 @@ class ConnectionPool(Generic[CT], BasePool):
             conn = self._getconn_unchecked(deadline - monotonic())
             try:
                 self._check_connection(conn)
-            except (Exception, CancelledError):
+            except CLIENT_EXCEPTIONS:
                 self._putconn(conn, from_getconn=True)
             else:
                 logger.info("connection given by %r", self.name)
@@ -259,7 +260,7 @@ class ConnectionPool(Generic[CT], BasePool):
         if not conn:
             try:
                 conn = pos.wait(timeout=timeout)
-            except BaseException:
+            except CLIENT_EXCEPTIONS:
                 self._stats[self._REQUESTS_ERRORS] += 1
                 raise
             finally:
@@ -295,7 +296,7 @@ class ConnectionPool(Generic[CT], BasePool):
             return
         try:
             self._check(conn)
-        except BaseException as e:
+        except CLIENT_EXCEPTIONS as e:
             logger.info("connection failed check: %s", e)
             raise
 
@@ -527,7 +528,7 @@ class ConnectionPool(Generic[CT], BasePool):
             # Check for broken connections
             try:
                 self.check_connection(conn)
-            except (Exception, CancelledError):
+            except CLIENT_EXCEPTIONS:
                 self._stats[self._CONNECTIONS_LOST] += 1
                 logger.warning("discarding broken connection: %s", conn)
                 self.run_task(AddConnection(self))
@@ -588,7 +589,7 @@ class ConnectionPool(Generic[CT], BasePool):
             # Run the task. Make sure don't die in the attempt.
             try:
                 task.run()
-            except (Exception, CancelledError) as ex:
+            except CLIENT_EXCEPTIONS as ex:
                 logger.warning(
                     "task run %s failed: %s: %s", task, ex.__class__.__name__, ex
                 )
@@ -604,7 +605,7 @@ class ConnectionPool(Generic[CT], BasePool):
         t0 = monotonic()
         try:
             conn = self.connection_class.connect(conninfo, **kwargs)
-        except (Exception, CancelledError):
+        except CLIENT_EXCEPTIONS:
             self._stats[self._CONNECTIONS_ERRORS] += 1
             raise
         else:
@@ -658,7 +659,7 @@ class ConnectionPool(Generic[CT], BasePool):
 
         try:
             conn = self._connect()
-        except (Exception, CancelledError) as ex:
+        except CLIENT_EXCEPTIONS as ex:
             logger.warning("error connecting in %r: %s", self.name, ex)
             if attempt.time_to_give_up(now):
                 logger.warning(
@@ -783,7 +784,7 @@ class ConnectionPool(Generic[CT], BasePool):
             logger.warning("rolling back returned connection: %s", conn)
             try:
                 conn.rollback()
-            except Exception as ex:
+            except CLIENT_EXCEPTIONS as ex:
                 logger.warning(
                     "rollback failed: %s: %s. Discarding connection %s",
                     ex.__class__.__name__,
@@ -804,7 +805,7 @@ class ConnectionPool(Generic[CT], BasePool):
                     raise e.ProgrammingError(
                         f"connection left in status {sname} by reset function {self._reset}: discarded"
                     )
-            except (Exception, CancelledError) as ex:
+            except CLIENT_EXCEPTIONS as ex:
                 logger.warning("error resetting connection: %s", ex)
                 self._close_connection(conn)
 
@@ -870,7 +871,7 @@ class WaitingClient(Generic[CT]):
                         self.error = PoolTimeout(
                             f"couldn't get a connection after {timeout:.2f} sec"
                         )
-                except BaseException as ex:
+                except CLIENT_EXCEPTIONS as ex:
                     self.error = ex
 
         if self.conn:
index 1c26ff81bb2940afc44c059d8f8df019238b11f3..6ea1b3c8850533635d21bfb52be4136250d55d74 100644 (file)
@@ -12,7 +12,6 @@ from abc import ABC, abstractmethod
 from time import monotonic
 from types import TracebackType
 from typing import Any, Generic, cast
-from asyncio import CancelledError
 from weakref import ref
 from contextlib import asynccontextmanager
 from collections import deque
@@ -34,6 +33,13 @@ from .sched_async import AsyncScheduler
 if True:  # ASYNC
     import asyncio
 
+    # The exceptions that we need to capture in order to keep the pool
+    # consistent and avoid losing connections on errors in callers code.
+    CLIENT_EXCEPTIONS = (Exception, asyncio.CancelledError)
+else:
+    CLIENT_EXCEPTIONS = Exception
+
+
 logger = logging.getLogger("psycopg.pool")
 
 
@@ -253,7 +259,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
             conn = await self._getconn_unchecked(deadline - monotonic())
             try:
                 await self._check_connection(conn)
-            except (Exception, CancelledError):
+            except CLIENT_EXCEPTIONS:
                 await self._putconn(conn, from_getconn=True)
             else:
                 logger.info("connection given by %r", self.name)
@@ -291,7 +297,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         if not conn:
             try:
                 conn = await pos.wait(timeout=timeout)
-            except BaseException:
+            except CLIENT_EXCEPTIONS:
                 self._stats[self._REQUESTS_ERRORS] += 1
                 raise
             finally:
@@ -328,7 +334,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
             return
         try:
             await self._check(conn)
-        except BaseException as e:
+        except CLIENT_EXCEPTIONS as e:
             logger.info("connection failed check: %s", e)
             raise
 
@@ -567,7 +573,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
             # Check for broken connections
             try:
                 await self.check_connection(conn)
-            except (Exception, CancelledError):
+            except CLIENT_EXCEPTIONS:
                 self._stats[self._CONNECTIONS_LOST] += 1
                 logger.warning("discarding broken connection: %s", conn)
                 self.run_task(AddConnection(self))
@@ -638,7 +644,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
             # Run the task. Make sure don't die in the attempt.
             try:
                 await task.run()
-            except (Exception, CancelledError) as ex:
+            except CLIENT_EXCEPTIONS as ex:
                 logger.warning(
                     "task run %s failed: %s: %s", task, ex.__class__.__name__, ex
                 )
@@ -654,7 +660,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         t0 = monotonic()
         try:
             conn = await self.connection_class.connect(conninfo, **kwargs)
-        except (Exception, CancelledError):
+        except CLIENT_EXCEPTIONS:
             self._stats[self._CONNECTIONS_ERRORS] += 1
             raise
         else:
@@ -709,7 +715,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
 
         try:
             conn = await self._connect()
-        except (Exception, CancelledError) as ex:
+        except CLIENT_EXCEPTIONS as ex:
             logger.warning("error connecting in %r: %s", self.name, ex)
             if attempt.time_to_give_up(now):
                 logger.warning(
@@ -838,7 +844,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
             logger.warning("rolling back returned connection: %s", conn)
             try:
                 await conn.rollback()
-            except Exception as ex:
+            except CLIENT_EXCEPTIONS as ex:
                 logger.warning(
                     "rollback failed: %s: %s. Discarding connection %s",
                     ex.__class__.__name__,
@@ -861,7 +867,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
                         f"connection left in status {sname} by reset function"
                         f" {self._reset}: discarded"
                     )
-            except (Exception, CancelledError) as ex:
+            except CLIENT_EXCEPTIONS as ex:
                 logger.warning("error resetting connection: %s", ex)
                 await self._close_connection(conn)
 
@@ -928,7 +934,7 @@ class WaitingClient(Generic[ACT]):
                         self.error = PoolTimeout(
                             f"couldn't get a connection after {timeout:.2f} sec"
                         )
-                except BaseException as ex:
+                except CLIENT_EXCEPTIONS as ex:
                     self.error = ex
 
         if self.conn:
index 6f2ece88b6841b359cfd504f7a8173118eab92ef..175843756415d8a61564af0c5e1be17fcc1476e6 100644 (file)
@@ -28,6 +28,8 @@ from ._acompat import Event, Lock
 
 logger = logging.getLogger(__name__)
 
+CLIENT_EXCEPTIONS = Exception
+
 
 class Scheduler:
 
@@ -82,7 +84,7 @@ class Scheduler:
                     break
                 try:
                     task.action()
-                except Exception as e:
+                except CLIENT_EXCEPTIONS as e:
                     logger.warning(
                         "scheduled task run %s failed: %s: %s",
                         task.action,
index 11298e50b5dfd2d147c8e384c95253a462e59744..6046d25d9be8773732d0ccb6e85bb85f678daeee 100644 (file)
@@ -25,6 +25,15 @@ from ._acompat import AEvent, ALock
 
 logger = logging.getLogger(__name__)
 
+if True:  # ASYNC
+    from asyncio import CancelledError
+
+    # The exceptions that we need to capture in order to keep the pool
+    # consistent and avoid losing connections on errors in callers code.
+    CLIENT_EXCEPTIONS = (Exception, CancelledError)
+else:
+    CLIENT_EXCEPTIONS = Exception
+
 
 class AsyncScheduler:
     def __init__(self) -> None:
@@ -78,7 +87,7 @@ class AsyncScheduler:
                     break
                 try:
                     await task.action()
-                except Exception as e:
+                except CLIENT_EXCEPTIONS as e:
                     logger.warning(
                         "scheduled task run %s failed: %s: %s",
                         task.action,
index 19bb2fe7a9d2ae3ffac4ce6fb190f5c8e8da0b89..259e7d3ac6dc76ad7e23d0696b2aa5afe1db5f96 100644 (file)
@@ -719,6 +719,33 @@ def test_cancel_on_check(pool_cls, dsn):
             conn.execute("select 1")
 
 
+@skip_sync
+def test_cancel_on_rollback(pool_cls, dsn, monkeypatch):
+    do_cancel = False
+
+    with pool_cls(dsn, min_size=min_size(pool_cls, 1), timeout=1.0) as p:
+        with p.connection() as conn:
+
+            def rollback(self):
+                if do_cancel:
+                    raise CancelledError()
+                else:
+                    type(self).rollback(self)
+
+            monkeypatch.setattr(type(conn), "rollback", rollback)
+            conn.execute("select 1")
+
+        do_cancel = True
+        with pytest.raises((psycopg.errors.SyntaxError, CancelledError)):
+            with p.connection() as conn:
+                conn.execute("selexx 2")
+
+        do_cancel = False
+        with p.connection() as conn:
+            cur = conn.execute("select 3")
+            assert cur.fetchone() == (3,)
+
+
 def min_size(pool_cls, num=1):
     """Return the minimum min_size supported by the pool class."""
     if pool_cls is pool.ConnectionPool:
index eac43973e64f8fde9be804ef869df3a791a45bb9..3dc01a238010e30e1bad1c69ffb232224062ebd8 100644 (file)
@@ -732,6 +732,33 @@ async def test_cancel_on_check(pool_cls, dsn):
             await conn.execute("select 1")
 
 
+@skip_sync
+async def test_cancel_on_rollback(pool_cls, dsn, monkeypatch):
+    do_cancel = False
+
+    async with pool_cls(dsn, min_size=min_size(pool_cls, 1), timeout=1.0) as p:
+        async with p.connection() as conn:
+
+            async def rollback(self):
+                if do_cancel:
+                    raise CancelledError()
+                else:
+                    await type(self).rollback(self)
+
+            monkeypatch.setattr(type(conn), "rollback", rollback)
+            await conn.execute("select 1")
+
+        do_cancel = True
+        with pytest.raises((psycopg.errors.SyntaxError, CancelledError)):
+            async with p.connection() as conn:
+                await conn.execute("selexx 2")
+
+        do_cancel = False
+        async with p.connection() as conn:
+            cur = await conn.execute("select 3")
+            assert (await cur.fetchone()) == (3,)
+
+
 def min_size(pool_cls, num=1):
     """Return the minimum min_size supported by the pool class."""
     if pool_cls is pool.AsyncConnectionPool: