From: Mike Bayer Date: Fri, 26 Jun 2026 14:09:38 +0000 (-0400) Subject: wrap before/after_cursor_execute event hooks in error handling X-Git-Tag: rel_2_1_0b3~5 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=295013ba17895cb925bef042b9ad70c80f1d19b0;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git wrap before/after_cursor_execute event hooks in error handling Expanded try/except error handling in _exec_single_context(), _exec_insertmany_context(), and _cursor_execute() to encompass the before_cursor_execute and after_cursor_execute event hooks. This ensures that exceptions raised within these hooks, including BaseException subclasses such as asyncio.CancelledError, are properly handled via _handle_dbapi_exception(), providing correct connection invalidation and pool notification. Also added a guard in _handle_dbapi_exception to avoid double-wrapping exceptions that are already StatementError instances, which could occur when _cursor_execute's error handling propagates up through _execute_context. As part of this change, DBAPI errors raised from within these event hooks will now be wrapped as SQLAlchemy exceptions. Fixes: #13381 Change-Id: I6df29406f1eed9c318d02b00b999408c1e83535d --- diff --git a/doc/build/changelog/unreleased_21/13381.rst b/doc/build/changelog/unreleased_21/13381.rst new file mode 100644 index 0000000000..d1f5fc16cd --- /dev/null +++ b/doc/build/changelog/unreleased_21/13381.rst @@ -0,0 +1,13 @@ +.. change:: + :tags: bug, engine + :tickets: 13381 + + Expanded try/except error handling to encompass the + :meth:`_events.ConnectionEvents.before_cursor_execute` and + :meth:`_events.ConnectionEvents.after_cursor_execute` event hooks, so that + exceptions raised within these hooks, including ``BaseException`` + subclasses such as ``asyncio.CancelledError``, are properly handled via the + error handling path used for DBAPI errors. This ensures proper connection + invalidation and pool notification when exit-type exceptions are raised in + event hooks. As part of this change, DBAPI errors raised from within these + event hooks will now be wrapped as SQLAlchemy exceptions. diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index ba69b471d2..847411fa91 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1860,40 +1860,40 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): else: effective_parameters = parameters - if self._has_events or self.engine._has_events: - for fn in self.dispatch.before_cursor_execute: - str_statement, effective_parameters = fn( - self, - cursor, - str_statement, - effective_parameters, - context, - context.executemany, - ) - - if self._echo: - self._log_info(str_statement) + evt_handled: bool = False + try: + if self._has_events or self.engine._has_events: + for fn in self.dispatch.before_cursor_execute: + str_statement, effective_parameters = fn( + self, + cursor, + str_statement, + effective_parameters, + context, + context.executemany, + ) - stats = context._get_cache_stats() + if self._echo: + self._log_info(str_statement) - if not self.engine.hide_parameters: - self._log_info( - "[%s] %r", - stats, - sql_util._repr_params( - effective_parameters, - batches=10, - ismulti=context.executemany, - ), - ) - else: - self._log_info( - "[%s] [SQL parameters hidden due to hide_parameters=True]", - stats, - ) + stats = context._get_cache_stats() - evt_handled: bool = False - try: + if not self.engine.hide_parameters: + self._log_info( + "[%s] %r", + stats, + sql_util._repr_params( + effective_parameters, + batches=10, + ismulti=context.executemany, + ), + ) + else: + self._log_info( + "[%s] [SQL parameters hidden due to " + "hide_parameters=True]", + stats, + ) if context.execute_style is ExecuteStyle.EXECUTEMANY: effective_parameters = cast( "_CoreMultiExecuteParams", effective_parameters @@ -2034,54 +2034,54 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): sub_stmt = imv_batch.replaced_statement sub_params = imv_batch.replaced_parameters - if engine_events: - for fn in self.dispatch.before_cursor_execute: - sub_stmt, sub_params = fn( - self, - cursor, - sub_stmt, - sub_params, - context, - True, - ) + try: + if engine_events: + for fn in self.dispatch.before_cursor_execute: + sub_stmt, sub_params = fn( + self, + cursor, + sub_stmt, + sub_params, + context, + True, + ) - if self._echo: - self._log_info(sql_util._long_statement(sub_stmt)) - - imv_stats = f""" {imv_batch.batchnum}/{ - imv_batch.total_batches - } ({ - 'ordered' - if imv_batch.rows_sorted else 'unordered' - }{ - '; batch not supported' - if imv_batch.is_downgraded - else '' - })""" - - if imv_batch.batchnum == 1: - stats += imv_stats - else: - stats = f"insertmanyvalues{imv_stats}" + if self._echo: + self._log_info(sql_util._long_statement(sub_stmt)) + + imv_stats = f""" {imv_batch.batchnum}/{ + imv_batch.total_batches + } ({ + 'ordered' + if imv_batch.rows_sorted else 'unordered' + }{ + '; batch not supported' + if imv_batch.is_downgraded + else '' + })""" + + if imv_batch.batchnum == 1: + stats += imv_stats + else: + stats = f"insertmanyvalues{imv_stats}" - if not self.engine.hide_parameters: - self._log_info( - "[%s] %r", - stats, - sql_util._repr_params( - sub_params, - batches=10, - ismulti=False, - ), - ) - else: - self._log_info( - "[%s] [SQL parameters hidden due to " - "hide_parameters=True]", - stats, - ) + if not self.engine.hide_parameters: + self._log_info( + "[%s] %r", + stats, + sql_util._repr_params( + sub_params, + batches=10, + ismulti=False, + ), + ) + else: + self._log_info( + "[%s] [SQL parameters hidden due to " + "hide_parameters=True]", + stats, + ) - try: for fn in do_execute_dispatch: if fn( cursor, @@ -2097,6 +2097,16 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): sub_params, context, ) + + if engine_events: + self.dispatch.after_cursor_execute( + self, + cursor, + sub_stmt, + sub_params, + context, + context.executemany, + ) except BaseException as e: self._handle_dbapi_exception( e, @@ -2107,16 +2117,6 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): is_sub_exec=True, ) - if engine_events: - self.dispatch.after_cursor_execute( - self, - cursor, - sub_stmt, - sub_params, - context, - context.executemany, - ) - if preserve_rowcount: rowcount += imv_batch.current_batch_size @@ -2152,16 +2152,17 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): terminates at _execute_context(). """ - if self._has_events or self.engine._has_events: - for fn in self.dispatch.before_cursor_execute: - statement, parameters = fn( - self, cursor, statement, parameters, context, False - ) - - if self._echo: - self._log_info(statement) - self._log_info("[raw sql] %r", parameters) try: + if self._has_events or self.engine._has_events: + for fn in self.dispatch.before_cursor_execute: + statement, parameters = fn( + self, cursor, statement, parameters, context, False + ) + + if self._echo: + self._log_info(statement) + self._log_info("[raw sql] %r", parameters) + for fn in ( () if not self.dialect._has_events @@ -2171,16 +2172,16 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): break else: self.dialect.do_execute(cursor, statement, parameters, context) + + if self._has_events or self.engine._has_events: + self.dispatch.after_cursor_execute( + self, cursor, statement, parameters, context, False + ) except BaseException as e: self._handle_dbapi_exception( e, statement, parameters, cursor, context ) - if self._has_events or self.engine._has_events: - self.dispatch.after_cursor_execute( - self, cursor, statement, parameters, context, False - ) - def _safe_close_cursor(self, cursor: DBAPICursor) -> None: """Close the given cursor, catching exceptions and turning into log warnings. @@ -2243,7 +2244,8 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): # non-DBAPI error - if we already got a context, # or there's no string statement, don't wrap it should_wrap = isinstance(e, self.dialect.loaded_dbapi.Error) or ( - statement is not None + not isinstance(e, exc.StatementError) + and statement is not None and context is None and not is_exit_exception ) diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index d3e7b5b796..240fb0080c 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -34,6 +34,7 @@ from sqlalchemy.connectors.asyncio import AsyncAdapt_dbapi_module from sqlalchemy.engine import BindTyping from sqlalchemy.engine.base import Connection from sqlalchemy.engine.base import Engine +from sqlalchemy.engine.interfaces import ExecuteStyle from sqlalchemy.pool import AsyncAdaptedQueuePool from sqlalchemy.pool import NullPool from sqlalchemy.pool import QueuePool @@ -3363,6 +3364,223 @@ class HandleErrorTest(fixtures.TestBase): conn.close() +class CursorEventErrorTest(fixtures.RemovesEvents, fixtures.TestBase): + """tests for #13381""" + + __sparse_driver_backend__ = True + + @testing.fixture + def imv_table(self, metadata): + t = Table( + "t_imv", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + t.create(testing.db) + return t + + @testing.fixture + def cursor_execute_table(self, metadata): + t = Table( + "t_ce", + metadata, + Column( + "x", + Integer, + normalize_sequence(config, Sequence("t_ce_id_seq")), + primary_key=True, + ), + implicit_returning=False, + ) + metadata.create_all(testing.db) + return t + + @testing.combinations( + "before_cursor_execute", + "after_cursor_execute", + argnames="event_name", + ) + def test_dbapi_error_exec_single(self, event_name, connection): + def handler( + conn, + cursor, + statement, + parameters, + context, + executemany, + ): + raise connection.dialect.dbapi.OperationalError("error in event") + + self.event_listen(connection, event_name, handler) + + with expect_raises_message( + tsa.exc.OperationalError, + "error in event", + ): + connection.exec_driver_sql("select 1") + + @testing.combinations( + "before_cursor_execute", + "after_cursor_execute", + argnames="event_name", + ) + def test_base_exception_invalidates_exec_single(self, event_name): + with testing.db.connect() as conn: + + def handler( + conn, + cursor, + statement, + parameters, + context, + executemany, + ): + raise BaseException("exit-like error") + + self.event_listen(conn, event_name, handler) + + with expect_raises_message(BaseException, "exit-like error"): + conn.exec_driver_sql("select 1") + + is_true(conn.invalidated) + + @testing.requires.insertmanyvalues + @testing.combinations( + "before_cursor_execute", + "after_cursor_execute", + argnames="event_name", + ) + def test_dbapi_error_insertmanyvalues( + self, event_name, imv_table, connection + ): + def handler( + conn, + cursor, + statement, + parameters, + context, + executemany, + ): + if context.execute_style is ExecuteStyle.INSERTMANYVALUES: + raise connection.dialect.dbapi.OperationalError( + "error in event" + ) + + self.event_listen(connection, event_name, handler) + + with expect_raises_message( + tsa.exc.OperationalError, + "error in event", + ): + connection.execute( + imv_table.insert().returning( + imv_table.c.id, + sort_by_parameter_order=True, + ), + [{"data": f"d{i}"} for i in range(10)], + ) + + @testing.requires.insertmanyvalues + @testing.combinations( + "before_cursor_execute", + "after_cursor_execute", + argnames="event_name", + ) + def test_base_exception_invalidates_insertmanyvalues( + self, event_name, imv_table + ): + with testing.db.connect() as conn: + + def handler( + conn, + cursor, + statement, + parameters, + context, + executemany, + ): + if context.execute_style is ExecuteStyle.INSERTMANYVALUES: + raise BaseException("exit-like error") + + self.event_listen(conn, event_name, handler) + + with expect_raises_message(BaseException, "exit-like error"): + conn.execute( + imv_table.insert().returning( + imv_table.c.id, + sort_by_parameter_order=True, + ), + [{"data": f"d{i}"} for i in range(10)], + ) + + is_true(conn.invalidated) + + # release lingering cursor refs so sqlite file lock clears + gc_collect() + + @testing.requires.sequences + @testing.combinations( + "before_cursor_execute", + "after_cursor_execute", + argnames="event_name", + ) + def test_dbapi_error_cursor_execute( + self, event_name, cursor_execute_table, connection + ): + def handler( + conn, + cursor, + statement, + parameters, + context, + executemany, + ): + if "t_ce_id_seq" in str(statement): + raise connection.dialect.dbapi.OperationalError( + "error in event" + ) + + self.event_listen(connection, event_name, handler) + + with expect_raises_message( + tsa.exc.OperationalError, + "error in event", + ): + connection.execute(cursor_execute_table.insert()) + + @testing.requires.sequences + @testing.combinations( + "before_cursor_execute", + "after_cursor_execute", + argnames="event_name", + ) + def test_base_exception_invalidates_cursor_execute( + self, event_name, cursor_execute_table + ): + with testing.db.connect() as conn: + + def handler( + conn, + cursor, + statement, + parameters, + context, + executemany, + ): + if "t_ce_id_seq" in str(statement): + raise BaseException("exit-like error") + + self.event_listen(conn, event_name, handler) + + with expect_raises_message(BaseException, "exit-like error"): + conn.execute(cursor_execute_table.insert()) + + is_true(conn.invalidated) + + gc_collect() + + class OnConnectTest(fixtures.TestBase): __requires__ = ("sqlite",) diff --git a/test/ext/asyncio/test_engine.py b/test/ext/asyncio/test_engine.py index 4082349ef7..91f1302b17 100644 --- a/test/ext/asyncio/test_engine.py +++ b/test/ext/asyncio/test_engine.py @@ -1082,6 +1082,42 @@ class AsyncEventTest(EngineFixture): ) +class AsyncCursorEventCancelledErrorTest(fixtures.TestBase): + """tests for #13381""" + + __requires__ = ("async_dialect",) + __backend__ = True + + @testing.fixture + def async_engine(self): + return engines.testing_engine(asyncio=True) + + @combinations( + "before_cursor_execute", + "after_cursor_execute", + argnames="event_name", + ) + @async_test + async def test_cancelled_error_invalidates(self, async_engine, event_name): + @event.listens_for(async_engine.sync_engine, event_name) + def handler( + conn, + cursor, + statement, + parameters, + context, + executemany, + ): + raise asyncio.CancelledError() + + conn = await async_engine.connect() + with expect_raises(asyncio.CancelledError): + await conn.execute(select(1)) + + is_true(conn.invalidated) + await conn.close() + + class AsyncInspection(EngineFixture): __backend__ = True