]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
wrap before/after_cursor_execute event hooks in error handling
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 26 Jun 2026 14:09:38 +0000 (10:09 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 26 Jun 2026 14:58:00 +0000 (10:58 -0400)
Expanded try/except error handling in _exec_single_context(),
_exec_insertmany_context(), and _cursor_execute() to encompass the
before_cursor_execute and after_cursor_execute event hooks. This
ensures that exceptions raised within these hooks, including
BaseException subclasses such as asyncio.CancelledError, are
properly handled via _handle_dbapi_exception(), providing correct
connection invalidation and pool notification.

Also added a guard in _handle_dbapi_exception to avoid
double-wrapping exceptions that are already StatementError
instances, which could occur when _cursor_execute's error handling
propagates up through _execute_context.

As part of this change, DBAPI errors raised from within these event
hooks will now be wrapped as SQLAlchemy exceptions.

Fixes: #13381
Change-Id: I6df29406f1eed9c318d02b00b999408c1e83535d

doc/build/changelog/unreleased_21/13381.rst [new file with mode: 0644]
lib/sqlalchemy/engine/base.py
test/engine/test_execute.py
test/ext/asyncio/test_engine.py

diff --git a/doc/build/changelog/unreleased_21/13381.rst b/doc/build/changelog/unreleased_21/13381.rst
new file mode 100644 (file)
index 0000000..d1f5fc1
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: bug, engine
+    :tickets: 13381
+
+    Expanded try/except error handling to encompass the
+    :meth:`_events.ConnectionEvents.before_cursor_execute` and
+    :meth:`_events.ConnectionEvents.after_cursor_execute` event hooks, so that
+    exceptions raised within these hooks, including ``BaseException``
+    subclasses such as ``asyncio.CancelledError``, are properly handled via the
+    error handling path used for DBAPI errors. This ensures proper connection
+    invalidation and pool notification when exit-type exceptions are raised in
+    event hooks. As part of this change, DBAPI errors raised from within these
+    event hooks will now be wrapped as SQLAlchemy exceptions.
index ba69b471d2b8d863e1c6f4d21abba837c4a1c63b..847411fa91b0dae27f6d3f5b9c6349ece435d32f 100644 (file)
@@ -1860,40 +1860,40 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         else:
             effective_parameters = parameters
 
-        if self._has_events or self.engine._has_events:
-            for fn in self.dispatch.before_cursor_execute:
-                str_statement, effective_parameters = fn(
-                    self,
-                    cursor,
-                    str_statement,
-                    effective_parameters,
-                    context,
-                    context.executemany,
-                )
-
-        if self._echo:
-            self._log_info(str_statement)
+        evt_handled: bool = False
+        try:
+            if self._has_events or self.engine._has_events:
+                for fn in self.dispatch.before_cursor_execute:
+                    str_statement, effective_parameters = fn(
+                        self,
+                        cursor,
+                        str_statement,
+                        effective_parameters,
+                        context,
+                        context.executemany,
+                    )
 
-            stats = context._get_cache_stats()
+            if self._echo:
+                self._log_info(str_statement)
 
-            if not self.engine.hide_parameters:
-                self._log_info(
-                    "[%s] %r",
-                    stats,
-                    sql_util._repr_params(
-                        effective_parameters,
-                        batches=10,
-                        ismulti=context.executemany,
-                    ),
-                )
-            else:
-                self._log_info(
-                    "[%s] [SQL parameters hidden due to hide_parameters=True]",
-                    stats,
-                )
+                stats = context._get_cache_stats()
 
-        evt_handled: bool = False
-        try:
+                if not self.engine.hide_parameters:
+                    self._log_info(
+                        "[%s] %r",
+                        stats,
+                        sql_util._repr_params(
+                            effective_parameters,
+                            batches=10,
+                            ismulti=context.executemany,
+                        ),
+                    )
+                else:
+                    self._log_info(
+                        "[%s] [SQL parameters hidden due to "
+                        "hide_parameters=True]",
+                        stats,
+                    )
             if context.execute_style is ExecuteStyle.EXECUTEMANY:
                 effective_parameters = cast(
                     "_CoreMultiExecuteParams", effective_parameters
@@ -2034,54 +2034,54 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
             sub_stmt = imv_batch.replaced_statement
             sub_params = imv_batch.replaced_parameters
 
-            if engine_events:
-                for fn in self.dispatch.before_cursor_execute:
-                    sub_stmt, sub_params = fn(
-                        self,
-                        cursor,
-                        sub_stmt,
-                        sub_params,
-                        context,
-                        True,
-                    )
+            try:
+                if engine_events:
+                    for fn in self.dispatch.before_cursor_execute:
+                        sub_stmt, sub_params = fn(
+                            self,
+                            cursor,
+                            sub_stmt,
+                            sub_params,
+                            context,
+                            True,
+                        )
 
-            if self._echo:
-                self._log_info(sql_util._long_statement(sub_stmt))
-
-                imv_stats = f""" {imv_batch.batchnum}/{
-                            imv_batch.total_batches
-                } ({
-                    'ordered'
-                    if imv_batch.rows_sorted else 'unordered'
-                }{
-                    '; batch not supported'
-                    if imv_batch.is_downgraded
-                    else ''
-                })"""
-
-                if imv_batch.batchnum == 1:
-                    stats += imv_stats
-                else:
-                    stats = f"insertmanyvalues{imv_stats}"
+                if self._echo:
+                    self._log_info(sql_util._long_statement(sub_stmt))
+
+                    imv_stats = f""" {imv_batch.batchnum}/{
+                                imv_batch.total_batches
+                    } ({
+                        'ordered'
+                        if imv_batch.rows_sorted else 'unordered'
+                    }{
+                        '; batch not supported'
+                        if imv_batch.is_downgraded
+                        else ''
+                    })"""
+
+                    if imv_batch.batchnum == 1:
+                        stats += imv_stats
+                    else:
+                        stats = f"insertmanyvalues{imv_stats}"
 
-                if not self.engine.hide_parameters:
-                    self._log_info(
-                        "[%s] %r",
-                        stats,
-                        sql_util._repr_params(
-                            sub_params,
-                            batches=10,
-                            ismulti=False,
-                        ),
-                    )
-                else:
-                    self._log_info(
-                        "[%s] [SQL parameters hidden due to "
-                        "hide_parameters=True]",
-                        stats,
-                    )
+                    if not self.engine.hide_parameters:
+                        self._log_info(
+                            "[%s] %r",
+                            stats,
+                            sql_util._repr_params(
+                                sub_params,
+                                batches=10,
+                                ismulti=False,
+                            ),
+                        )
+                    else:
+                        self._log_info(
+                            "[%s] [SQL parameters hidden due to "
+                            "hide_parameters=True]",
+                            stats,
+                        )
 
-            try:
                 for fn in do_execute_dispatch:
                     if fn(
                         cursor,
@@ -2097,6 +2097,16 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
                         sub_params,
                         context,
                     )
+
+                if engine_events:
+                    self.dispatch.after_cursor_execute(
+                        self,
+                        cursor,
+                        sub_stmt,
+                        sub_params,
+                        context,
+                        context.executemany,
+                    )
             except BaseException as e:
                 self._handle_dbapi_exception(
                     e,
@@ -2107,16 +2117,6 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
                     is_sub_exec=True,
                 )
 
-            if engine_events:
-                self.dispatch.after_cursor_execute(
-                    self,
-                    cursor,
-                    sub_stmt,
-                    sub_params,
-                    context,
-                    context.executemany,
-                )
-
             if preserve_rowcount:
                 rowcount += imv_batch.current_batch_size
 
@@ -2152,16 +2152,17 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         terminates at _execute_context().
 
         """
-        if self._has_events or self.engine._has_events:
-            for fn in self.dispatch.before_cursor_execute:
-                statement, parameters = fn(
-                    self, cursor, statement, parameters, context, False
-                )
-
-        if self._echo:
-            self._log_info(statement)
-            self._log_info("[raw sql] %r", parameters)
         try:
+            if self._has_events or self.engine._has_events:
+                for fn in self.dispatch.before_cursor_execute:
+                    statement, parameters = fn(
+                        self, cursor, statement, parameters, context, False
+                    )
+
+            if self._echo:
+                self._log_info(statement)
+                self._log_info("[raw sql] %r", parameters)
+
             for fn in (
                 ()
                 if not self.dialect._has_events
@@ -2171,16 +2172,16 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
                     break
             else:
                 self.dialect.do_execute(cursor, statement, parameters, context)
+
+            if self._has_events or self.engine._has_events:
+                self.dispatch.after_cursor_execute(
+                    self, cursor, statement, parameters, context, False
+                )
         except BaseException as e:
             self._handle_dbapi_exception(
                 e, statement, parameters, cursor, context
             )
 
-        if self._has_events or self.engine._has_events:
-            self.dispatch.after_cursor_execute(
-                self, cursor, statement, parameters, context, False
-            )
-
     def _safe_close_cursor(self, cursor: DBAPICursor) -> None:
         """Close the given cursor, catching exceptions
         and turning into log warnings.
@@ -2243,7 +2244,8 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
             # non-DBAPI error - if we already got a context,
             # or there's no string statement, don't wrap it
             should_wrap = isinstance(e, self.dialect.loaded_dbapi.Error) or (
-                statement is not None
+                not isinstance(e, exc.StatementError)
+                and statement is not None
                 and context is None
                 and not is_exit_exception
             )
index d3e7b5b79618cb2029a867e15bda497be763de96..240fb0080cb6471c5f1bb0a41fada8031df724e4 100644 (file)
@@ -34,6 +34,7 @@ from sqlalchemy.connectors.asyncio import AsyncAdapt_dbapi_module
 from sqlalchemy.engine import BindTyping
 from sqlalchemy.engine.base import Connection
 from sqlalchemy.engine.base import Engine
+from sqlalchemy.engine.interfaces import ExecuteStyle
 from sqlalchemy.pool import AsyncAdaptedQueuePool
 from sqlalchemy.pool import NullPool
 from sqlalchemy.pool import QueuePool
@@ -3363,6 +3364,223 @@ class HandleErrorTest(fixtures.TestBase):
         conn.close()
 
 
+class CursorEventErrorTest(fixtures.RemovesEvents, fixtures.TestBase):
+    """tests for #13381"""
+
+    __sparse_driver_backend__ = True
+
+    @testing.fixture
+    def imv_table(self, metadata):
+        t = Table(
+            "t_imv",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("data", String(50)),
+        )
+        t.create(testing.db)
+        return t
+
+    @testing.fixture
+    def cursor_execute_table(self, metadata):
+        t = Table(
+            "t_ce",
+            metadata,
+            Column(
+                "x",
+                Integer,
+                normalize_sequence(config, Sequence("t_ce_id_seq")),
+                primary_key=True,
+            ),
+            implicit_returning=False,
+        )
+        metadata.create_all(testing.db)
+        return t
+
+    @testing.combinations(
+        "before_cursor_execute",
+        "after_cursor_execute",
+        argnames="event_name",
+    )
+    def test_dbapi_error_exec_single(self, event_name, connection):
+        def handler(
+            conn,
+            cursor,
+            statement,
+            parameters,
+            context,
+            executemany,
+        ):
+            raise connection.dialect.dbapi.OperationalError("error in event")
+
+        self.event_listen(connection, event_name, handler)
+
+        with expect_raises_message(
+            tsa.exc.OperationalError,
+            "error in event",
+        ):
+            connection.exec_driver_sql("select 1")
+
+    @testing.combinations(
+        "before_cursor_execute",
+        "after_cursor_execute",
+        argnames="event_name",
+    )
+    def test_base_exception_invalidates_exec_single(self, event_name):
+        with testing.db.connect() as conn:
+
+            def handler(
+                conn,
+                cursor,
+                statement,
+                parameters,
+                context,
+                executemany,
+            ):
+                raise BaseException("exit-like error")
+
+            self.event_listen(conn, event_name, handler)
+
+            with expect_raises_message(BaseException, "exit-like error"):
+                conn.exec_driver_sql("select 1")
+
+            is_true(conn.invalidated)
+
+    @testing.requires.insertmanyvalues
+    @testing.combinations(
+        "before_cursor_execute",
+        "after_cursor_execute",
+        argnames="event_name",
+    )
+    def test_dbapi_error_insertmanyvalues(
+        self, event_name, imv_table, connection
+    ):
+        def handler(
+            conn,
+            cursor,
+            statement,
+            parameters,
+            context,
+            executemany,
+        ):
+            if context.execute_style is ExecuteStyle.INSERTMANYVALUES:
+                raise connection.dialect.dbapi.OperationalError(
+                    "error in event"
+                )
+
+        self.event_listen(connection, event_name, handler)
+
+        with expect_raises_message(
+            tsa.exc.OperationalError,
+            "error in event",
+        ):
+            connection.execute(
+                imv_table.insert().returning(
+                    imv_table.c.id,
+                    sort_by_parameter_order=True,
+                ),
+                [{"data": f"d{i}"} for i in range(10)],
+            )
+
+    @testing.requires.insertmanyvalues
+    @testing.combinations(
+        "before_cursor_execute",
+        "after_cursor_execute",
+        argnames="event_name",
+    )
+    def test_base_exception_invalidates_insertmanyvalues(
+        self, event_name, imv_table
+    ):
+        with testing.db.connect() as conn:
+
+            def handler(
+                conn,
+                cursor,
+                statement,
+                parameters,
+                context,
+                executemany,
+            ):
+                if context.execute_style is ExecuteStyle.INSERTMANYVALUES:
+                    raise BaseException("exit-like error")
+
+            self.event_listen(conn, event_name, handler)
+
+            with expect_raises_message(BaseException, "exit-like error"):
+                conn.execute(
+                    imv_table.insert().returning(
+                        imv_table.c.id,
+                        sort_by_parameter_order=True,
+                    ),
+                    [{"data": f"d{i}"} for i in range(10)],
+                )
+
+            is_true(conn.invalidated)
+
+        # release lingering cursor refs so sqlite file lock clears
+        gc_collect()
+
+    @testing.requires.sequences
+    @testing.combinations(
+        "before_cursor_execute",
+        "after_cursor_execute",
+        argnames="event_name",
+    )
+    def test_dbapi_error_cursor_execute(
+        self, event_name, cursor_execute_table, connection
+    ):
+        def handler(
+            conn,
+            cursor,
+            statement,
+            parameters,
+            context,
+            executemany,
+        ):
+            if "t_ce_id_seq" in str(statement):
+                raise connection.dialect.dbapi.OperationalError(
+                    "error in event"
+                )
+
+        self.event_listen(connection, event_name, handler)
+
+        with expect_raises_message(
+            tsa.exc.OperationalError,
+            "error in event",
+        ):
+            connection.execute(cursor_execute_table.insert())
+
+    @testing.requires.sequences
+    @testing.combinations(
+        "before_cursor_execute",
+        "after_cursor_execute",
+        argnames="event_name",
+    )
+    def test_base_exception_invalidates_cursor_execute(
+        self, event_name, cursor_execute_table
+    ):
+        with testing.db.connect() as conn:
+
+            def handler(
+                conn,
+                cursor,
+                statement,
+                parameters,
+                context,
+                executemany,
+            ):
+                if "t_ce_id_seq" in str(statement):
+                    raise BaseException("exit-like error")
+
+            self.event_listen(conn, event_name, handler)
+
+            with expect_raises_message(BaseException, "exit-like error"):
+                conn.execute(cursor_execute_table.insert())
+
+            is_true(conn.invalidated)
+
+        gc_collect()
+
+
 class OnConnectTest(fixtures.TestBase):
     __requires__ = ("sqlite",)
 
index 4082349ef785a4c6bb2a02486ea554d3cba4e4d1..91f1302b17b85669e41e02fcb1ae60d8e83b848f 100644 (file)
@@ -1082,6 +1082,42 @@ class AsyncEventTest(EngineFixture):
                 )
 
 
+class AsyncCursorEventCancelledErrorTest(fixtures.TestBase):
+    """tests for #13381"""
+
+    __requires__ = ("async_dialect",)
+    __backend__ = True
+
+    @testing.fixture
+    def async_engine(self):
+        return engines.testing_engine(asyncio=True)
+
+    @combinations(
+        "before_cursor_execute",
+        "after_cursor_execute",
+        argnames="event_name",
+    )
+    @async_test
+    async def test_cancelled_error_invalidates(self, async_engine, event_name):
+        @event.listens_for(async_engine.sync_engine, event_name)
+        def handler(
+            conn,
+            cursor,
+            statement,
+            parameters,
+            context,
+            executemany,
+        ):
+            raise asyncio.CancelledError()
+
+        conn = await async_engine.connect()
+        with expect_raises(asyncio.CancelledError):
+            await conn.execute(select(1))
+
+        is_true(conn.invalidated)
+        await conn.close()
+
+
 class AsyncInspection(EngineFixture):
     __backend__ = True