From: Daniele Varrazzo Date: Wed, 19 Nov 2025 23:51:43 +0000 (+0100) Subject: fix(pool): manage CancelledError in some exception handling path X-Git-Tag: 3.3.0~17^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e1c6920f3d624764155ddeffa5c1710778a54af6;p=thirdparty%2Fpsycopg.git fix(pool): manage CancelledError in some exception handling path If a CancelledError was raised during check the connection would have been lost. The exception would have bubbled up but likely users are using some framework swallowing it because nobody reporting the "lost connections" issue actually reported the CancelledError. Close #1123 Close #1208 --- diff --git a/docs/news_pool.rst b/docs/news_pool.rst index dd65a2d98..18ac1cdad 100644 --- a/docs/news_pool.rst +++ b/docs/news_pool.rst @@ -19,6 +19,13 @@ psycopg_pool 3.3.0 (unreleased) parameters# update (:ticket:`#851`). +psycopg_pool 3.2.8 (unreleased) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +- Don't lose connections if a `~asyncio.CancelledError` is raised in a check + (:tickets:`#1123, #1208`) + + Current release --------------- diff --git a/psycopg_pool/psycopg_pool/pool.py b/psycopg_pool/psycopg_pool/pool.py index e4a58c8d9..821a2734b 100644 --- a/psycopg_pool/psycopg_pool/pool.py +++ b/psycopg_pool/psycopg_pool/pool.py @@ -15,6 +15,7 @@ from abc import ABC, abstractmethod from time import monotonic from types import TracebackType from typing import Any, Generic, cast +from asyncio import CancelledError from weakref import ref from contextlib import contextmanager from collections import deque @@ -220,7 +221,7 @@ class ConnectionPool(Generic[CT], BasePool): conn = self._getconn_unchecked(deadline - monotonic()) try: self._check_connection(conn) - except Exception: + except (Exception, CancelledError): self._putconn(conn, from_getconn=True) else: logger.info("connection given by %r", self.name) @@ -258,7 +259,7 @@ class ConnectionPool(Generic[CT], BasePool): if not conn: try: conn = pos.wait(timeout=timeout) - except Exception: + except BaseException: self._stats[self._REQUESTS_ERRORS] += 1 raise finally: @@ -294,7 +295,7 @@ class ConnectionPool(Generic[CT], BasePool): return try: self._check(conn) - except Exception as e: + except BaseException as e: logger.info("connection failed check: %s", e) raise @@ -526,7 +527,7 @@ class ConnectionPool(Generic[CT], BasePool): # Check for broken connections try: self.check_connection(conn) - except Exception: + except (Exception, CancelledError): self._stats[self._CONNECTIONS_LOST] += 1 logger.warning("discarding broken connection: %s", conn) self.run_task(AddConnection(self)) @@ -587,7 +588,7 @@ class ConnectionPool(Generic[CT], BasePool): # Run the task. Make sure don't die in the attempt. try: task.run() - except Exception as ex: + except (Exception, CancelledError) as ex: logger.warning( "task run %s failed: %s: %s", task, ex.__class__.__name__, ex ) @@ -603,7 +604,7 @@ class ConnectionPool(Generic[CT], BasePool): t0 = monotonic() try: conn = self.connection_class.connect(conninfo, **kwargs) - except Exception: + except (Exception, CancelledError): self._stats[self._CONNECTIONS_ERRORS] += 1 raise else: @@ -657,7 +658,7 @@ class ConnectionPool(Generic[CT], BasePool): try: conn = self._connect() - except Exception as ex: + except (Exception, CancelledError) as ex: logger.warning("error connecting in %r: %s", self.name, ex) if attempt.time_to_give_up(now): logger.warning( @@ -803,7 +804,7 @@ class ConnectionPool(Generic[CT], BasePool): raise e.ProgrammingError( f"connection left in status {sname} by reset function {self._reset}: discarded" ) - except Exception as ex: + except (Exception, CancelledError) as ex: logger.warning("error resetting connection: %s", ex) self._close_connection(conn) diff --git a/psycopg_pool/psycopg_pool/pool_async.py b/psycopg_pool/psycopg_pool/pool_async.py index b3242af7e..1c26ff81b 100644 --- a/psycopg_pool/psycopg_pool/pool_async.py +++ b/psycopg_pool/psycopg_pool/pool_async.py @@ -12,6 +12,7 @@ from abc import ABC, abstractmethod from time import monotonic from types import TracebackType from typing import Any, Generic, cast +from asyncio import CancelledError from weakref import ref from contextlib import asynccontextmanager from collections import deque @@ -252,7 +253,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): conn = await self._getconn_unchecked(deadline - monotonic()) try: await self._check_connection(conn) - except Exception: + except (Exception, CancelledError): await self._putconn(conn, from_getconn=True) else: logger.info("connection given by %r", self.name) @@ -290,7 +291,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): if not conn: try: conn = await pos.wait(timeout=timeout) - except Exception: + except BaseException: self._stats[self._REQUESTS_ERRORS] += 1 raise finally: @@ -327,7 +328,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): return try: await self._check(conn) - except Exception as e: + except BaseException as e: logger.info("connection failed check: %s", e) raise @@ -566,7 +567,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): # Check for broken connections try: await self.check_connection(conn) - except Exception: + except (Exception, CancelledError): self._stats[self._CONNECTIONS_LOST] += 1 logger.warning("discarding broken connection: %s", conn) self.run_task(AddConnection(self)) @@ -637,7 +638,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): # Run the task. Make sure don't die in the attempt. try: await task.run() - except Exception as ex: + except (Exception, CancelledError) as ex: logger.warning( "task run %s failed: %s: %s", task, ex.__class__.__name__, ex ) @@ -653,7 +654,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): t0 = monotonic() try: conn = await self.connection_class.connect(conninfo, **kwargs) - except Exception: + except (Exception, CancelledError): self._stats[self._CONNECTIONS_ERRORS] += 1 raise else: @@ -708,7 +709,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): try: conn = await self._connect() - except Exception as ex: + except (Exception, CancelledError) as ex: logger.warning("error connecting in %r: %s", self.name, ex) if attempt.time_to_give_up(now): logger.warning( @@ -860,7 +861,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): f"connection left in status {sname} by reset function" f" {self._reset}: discarded" ) - except Exception as ex: + except (Exception, CancelledError) as ex: logger.warning("error resetting connection: %s", ex) await self._close_connection(conn) diff --git a/tests/pool/test_pool_common.py b/tests/pool/test_pool_common.py index 6d1726a43..19bb2fe7a 100644 --- a/tests/pool/test_pool_common.py +++ b/tests/pool/test_pool_common.py @@ -6,6 +6,7 @@ from __future__ import annotations import logging from time import time from typing import Any +from asyncio import CancelledError import pytest @@ -695,6 +696,29 @@ def test_cancellation_in_queue(pool_cls, dsn): assert cur.fetchone() == (1,) +@skip_sync +def test_cancel_on_check(pool_cls, dsn): + do_cancel = True + + def check(conn): + nonlocal do_cancel + if do_cancel: + do_cancel = False + raise CancelledError() + + pool_cls.check_connection(conn) + + with pool_cls(dsn, min_size=min_size(pool_cls, 1), check=check, timeout=1.0) as p: + try: + with p.connection() as conn: + conn.execute("select 1") + except CancelledError: + pass + + with p.connection() as conn: + conn.execute("select 1") + + def min_size(pool_cls, num=1): """Return the minimum min_size supported by the pool class.""" if pool_cls is pool.ConnectionPool: diff --git a/tests/pool/test_pool_common_async.py b/tests/pool/test_pool_common_async.py index c61cb0ffb..eac43973e 100644 --- a/tests/pool/test_pool_common_async.py +++ b/tests/pool/test_pool_common_async.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging from time import time from typing import Any +from asyncio import CancelledError import pytest @@ -706,6 +707,31 @@ async def test_cancellation_in_queue(pool_cls, dsn): assert await cur.fetchone() == (1,) +@skip_sync +async def test_cancel_on_check(pool_cls, dsn): + do_cancel = True + + async def check(conn): + nonlocal do_cancel + if do_cancel: + do_cancel = False + raise CancelledError() + + await pool_cls.check_connection(conn) + + async with pool_cls( + dsn, min_size=min_size(pool_cls, 1), check=check, timeout=1.0 + ) as p: + try: + async with p.connection() as conn: + await conn.execute("select 1") + except CancelledError: + pass + + async with p.connection() as conn: + await conn.execute("select 1") + + def min_size(pool_cls, num=1): """Return the minimum min_size supported by the pool class.""" if pool_cls is pool.AsyncConnectionPool: