From: Daniele Varrazzo Date: Fri, 21 Nov 2025 10:36:39 +0000 (+0100) Subject: fix(pool): trap CancelledError more consistently in the pool codebase X-Git-Tag: 3.3.0~17^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=cee9e574b7c85dbca7d3e16e460d13f8dda9f10a;p=thirdparty%2Fpsycopg.git fix(pool): trap CancelledError more consistently in the pool codebase Include also places that were left out such as the rollback and the task scheduling. Note that we are relaxing the exception handler we had set up to fix the problem with cancelled clients on wait (#509): we only had to trap CancelledError additionally but we started managing the whole BaseException. I don't think that trapping KeyboardException or SystemExit without re-raising is a good idea (I think that, for robustness, we should, but then things become very verbose and not necessarily correct). --- diff --git a/psycopg_pool/psycopg_pool/pool.py b/psycopg_pool/psycopg_pool/pool.py index 821a2734b..a5f062fda 100644 --- a/psycopg_pool/psycopg_pool/pool.py +++ b/psycopg_pool/psycopg_pool/pool.py @@ -15,7 +15,6 @@ from abc import ABC, abstractmethod from time import monotonic from types import TracebackType from typing import Any, Generic, cast -from asyncio import CancelledError from weakref import ref from contextlib import contextmanager from collections import deque @@ -33,6 +32,8 @@ from ._compat import PSYCOPG_VERSION, PoolConnection, Self from ._acompat import Condition, Event, Lock, Queue, Worker, current_thread_name from ._acompat import gather, sleep, spawn +CLIENT_EXCEPTIONS = Exception + logger = logging.getLogger("psycopg.pool") @@ -221,7 +222,7 @@ class ConnectionPool(Generic[CT], BasePool): conn = self._getconn_unchecked(deadline - monotonic()) try: self._check_connection(conn) - except (Exception, CancelledError): + except CLIENT_EXCEPTIONS: self._putconn(conn, from_getconn=True) else: logger.info("connection given by %r", self.name) @@ -259,7 +260,7 @@ class ConnectionPool(Generic[CT], BasePool): if not conn: try: conn = pos.wait(timeout=timeout) - except BaseException: + except CLIENT_EXCEPTIONS: self._stats[self._REQUESTS_ERRORS] += 1 raise finally: @@ -295,7 +296,7 @@ class ConnectionPool(Generic[CT], BasePool): return try: self._check(conn) - except BaseException as e: + except CLIENT_EXCEPTIONS as e: logger.info("connection failed check: %s", e) raise @@ -527,7 +528,7 @@ class ConnectionPool(Generic[CT], BasePool): # Check for broken connections try: self.check_connection(conn) - except (Exception, CancelledError): + except CLIENT_EXCEPTIONS: self._stats[self._CONNECTIONS_LOST] += 1 logger.warning("discarding broken connection: %s", conn) self.run_task(AddConnection(self)) @@ -588,7 +589,7 @@ class ConnectionPool(Generic[CT], BasePool): # Run the task. Make sure don't die in the attempt. try: task.run() - except (Exception, CancelledError) as ex: + except CLIENT_EXCEPTIONS as ex: logger.warning( "task run %s failed: %s: %s", task, ex.__class__.__name__, ex ) @@ -604,7 +605,7 @@ class ConnectionPool(Generic[CT], BasePool): t0 = monotonic() try: conn = self.connection_class.connect(conninfo, **kwargs) - except (Exception, CancelledError): + except CLIENT_EXCEPTIONS: self._stats[self._CONNECTIONS_ERRORS] += 1 raise else: @@ -658,7 +659,7 @@ class ConnectionPool(Generic[CT], BasePool): try: conn = self._connect() - except (Exception, CancelledError) as ex: + except CLIENT_EXCEPTIONS as ex: logger.warning("error connecting in %r: %s", self.name, ex) if attempt.time_to_give_up(now): logger.warning( @@ -783,7 +784,7 @@ class ConnectionPool(Generic[CT], BasePool): logger.warning("rolling back returned connection: %s", conn) try: conn.rollback() - except Exception as ex: + except CLIENT_EXCEPTIONS as ex: logger.warning( "rollback failed: %s: %s. Discarding connection %s", ex.__class__.__name__, @@ -804,7 +805,7 @@ class ConnectionPool(Generic[CT], BasePool): raise e.ProgrammingError( f"connection left in status {sname} by reset function {self._reset}: discarded" ) - except (Exception, CancelledError) as ex: + except CLIENT_EXCEPTIONS as ex: logger.warning("error resetting connection: %s", ex) self._close_connection(conn) @@ -870,7 +871,7 @@ class WaitingClient(Generic[CT]): self.error = PoolTimeout( f"couldn't get a connection after {timeout:.2f} sec" ) - except BaseException as ex: + except CLIENT_EXCEPTIONS as ex: self.error = ex if self.conn: diff --git a/psycopg_pool/psycopg_pool/pool_async.py b/psycopg_pool/psycopg_pool/pool_async.py index 1c26ff81b..6ea1b3c88 100644 --- a/psycopg_pool/psycopg_pool/pool_async.py +++ b/psycopg_pool/psycopg_pool/pool_async.py @@ -12,7 +12,6 @@ from abc import ABC, abstractmethod from time import monotonic from types import TracebackType from typing import Any, Generic, cast -from asyncio import CancelledError from weakref import ref from contextlib import asynccontextmanager from collections import deque @@ -34,6 +33,13 @@ from .sched_async import AsyncScheduler if True: # ASYNC import asyncio + # The exceptions that we need to capture in order to keep the pool + # consistent and avoid losing connections on errors in callers code. + CLIENT_EXCEPTIONS = (Exception, asyncio.CancelledError) +else: + CLIENT_EXCEPTIONS = Exception + + logger = logging.getLogger("psycopg.pool") @@ -253,7 +259,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): conn = await self._getconn_unchecked(deadline - monotonic()) try: await self._check_connection(conn) - except (Exception, CancelledError): + except CLIENT_EXCEPTIONS: await self._putconn(conn, from_getconn=True) else: logger.info("connection given by %r", self.name) @@ -291,7 +297,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): if not conn: try: conn = await pos.wait(timeout=timeout) - except BaseException: + except CLIENT_EXCEPTIONS: self._stats[self._REQUESTS_ERRORS] += 1 raise finally: @@ -328,7 +334,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): return try: await self._check(conn) - except BaseException as e: + except CLIENT_EXCEPTIONS as e: logger.info("connection failed check: %s", e) raise @@ -567,7 +573,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): # Check for broken connections try: await self.check_connection(conn) - except (Exception, CancelledError): + except CLIENT_EXCEPTIONS: self._stats[self._CONNECTIONS_LOST] += 1 logger.warning("discarding broken connection: %s", conn) self.run_task(AddConnection(self)) @@ -638,7 +644,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): # Run the task. Make sure don't die in the attempt. try: await task.run() - except (Exception, CancelledError) as ex: + except CLIENT_EXCEPTIONS as ex: logger.warning( "task run %s failed: %s: %s", task, ex.__class__.__name__, ex ) @@ -654,7 +660,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): t0 = monotonic() try: conn = await self.connection_class.connect(conninfo, **kwargs) - except (Exception, CancelledError): + except CLIENT_EXCEPTIONS: self._stats[self._CONNECTIONS_ERRORS] += 1 raise else: @@ -709,7 +715,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): try: conn = await self._connect() - except (Exception, CancelledError) as ex: + except CLIENT_EXCEPTIONS as ex: logger.warning("error connecting in %r: %s", self.name, ex) if attempt.time_to_give_up(now): logger.warning( @@ -838,7 +844,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): logger.warning("rolling back returned connection: %s", conn) try: await conn.rollback() - except Exception as ex: + except CLIENT_EXCEPTIONS as ex: logger.warning( "rollback failed: %s: %s. Discarding connection %s", ex.__class__.__name__, @@ -861,7 +867,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): f"connection left in status {sname} by reset function" f" {self._reset}: discarded" ) - except (Exception, CancelledError) as ex: + except CLIENT_EXCEPTIONS as ex: logger.warning("error resetting connection: %s", ex) await self._close_connection(conn) @@ -928,7 +934,7 @@ class WaitingClient(Generic[ACT]): self.error = PoolTimeout( f"couldn't get a connection after {timeout:.2f} sec" ) - except BaseException as ex: + except CLIENT_EXCEPTIONS as ex: self.error = ex if self.conn: diff --git a/psycopg_pool/psycopg_pool/sched.py b/psycopg_pool/psycopg_pool/sched.py index 6f2ece88b..175843756 100644 --- a/psycopg_pool/psycopg_pool/sched.py +++ b/psycopg_pool/psycopg_pool/sched.py @@ -28,6 +28,8 @@ from ._acompat import Event, Lock logger = logging.getLogger(__name__) +CLIENT_EXCEPTIONS = Exception + class Scheduler: @@ -82,7 +84,7 @@ class Scheduler: break try: task.action() - except Exception as e: + except CLIENT_EXCEPTIONS as e: logger.warning( "scheduled task run %s failed: %s: %s", task.action, diff --git a/psycopg_pool/psycopg_pool/sched_async.py b/psycopg_pool/psycopg_pool/sched_async.py index 11298e50b..6046d25d9 100644 --- a/psycopg_pool/psycopg_pool/sched_async.py +++ b/psycopg_pool/psycopg_pool/sched_async.py @@ -25,6 +25,15 @@ from ._acompat import AEvent, ALock logger = logging.getLogger(__name__) +if True: # ASYNC + from asyncio import CancelledError + + # The exceptions that we need to capture in order to keep the pool + # consistent and avoid losing connections on errors in callers code. + CLIENT_EXCEPTIONS = (Exception, CancelledError) +else: + CLIENT_EXCEPTIONS = Exception + class AsyncScheduler: def __init__(self) -> None: @@ -78,7 +87,7 @@ class AsyncScheduler: break try: await task.action() - except Exception as e: + except CLIENT_EXCEPTIONS as e: logger.warning( "scheduled task run %s failed: %s: %s", task.action, diff --git a/tests/pool/test_pool_common.py b/tests/pool/test_pool_common.py index 19bb2fe7a..259e7d3ac 100644 --- a/tests/pool/test_pool_common.py +++ b/tests/pool/test_pool_common.py @@ -719,6 +719,33 @@ def test_cancel_on_check(pool_cls, dsn): conn.execute("select 1") +@skip_sync +def test_cancel_on_rollback(pool_cls, dsn, monkeypatch): + do_cancel = False + + with pool_cls(dsn, min_size=min_size(pool_cls, 1), timeout=1.0) as p: + with p.connection() as conn: + + def rollback(self): + if do_cancel: + raise CancelledError() + else: + type(self).rollback(self) + + monkeypatch.setattr(type(conn), "rollback", rollback) + conn.execute("select 1") + + do_cancel = True + with pytest.raises((psycopg.errors.SyntaxError, CancelledError)): + with p.connection() as conn: + conn.execute("selexx 2") + + do_cancel = False + with p.connection() as conn: + cur = conn.execute("select 3") + assert cur.fetchone() == (3,) + + def min_size(pool_cls, num=1): """Return the minimum min_size supported by the pool class.""" if pool_cls is pool.ConnectionPool: diff --git a/tests/pool/test_pool_common_async.py b/tests/pool/test_pool_common_async.py index eac43973e..3dc01a238 100644 --- a/tests/pool/test_pool_common_async.py +++ b/tests/pool/test_pool_common_async.py @@ -732,6 +732,33 @@ async def test_cancel_on_check(pool_cls, dsn): await conn.execute("select 1") +@skip_sync +async def test_cancel_on_rollback(pool_cls, dsn, monkeypatch): + do_cancel = False + + async with pool_cls(dsn, min_size=min_size(pool_cls, 1), timeout=1.0) as p: + async with p.connection() as conn: + + async def rollback(self): + if do_cancel: + raise CancelledError() + else: + await type(self).rollback(self) + + monkeypatch.setattr(type(conn), "rollback", rollback) + await conn.execute("select 1") + + do_cancel = True + with pytest.raises((psycopg.errors.SyntaxError, CancelledError)): + async with p.connection() as conn: + await conn.execute("selexx 2") + + do_cancel = False + async with p.connection() as conn: + cur = await conn.execute("select 3") + assert (await cur.fetchone()) == (3,) + + def min_size(pool_cls, num=1): """Return the minimum min_size supported by the pool class.""" if pool_cls is pool.AsyncConnectionPool: