]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
restore statement substitution to before_execute()
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 20 Aug 2021 15:47:26 +0000 (11:47 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 21 Aug 2021 02:17:55 +0000 (22:17 -0400)
Fixed issue where the ability of the
:meth:`_engine.ConnectionEvents.before_execute` method to alter the SQL
statement object passed, returning the new object to be invoked, was
inadvertently removed. This behavior has been restored.

The refactor in a1939719a652774a437f69f8d4788b3f08650089 removed this
feature for some reason and there were no tests in place to detect
it.  I don't see any indication this was planned.

Fixes: #6913
Change-Id: Ia77ca08aa91ab9403f19a8eb61e2a0e41aad138a

doc/build/changelog/unreleased_14/6913.rst [new file with mode: 0644]
lib/sqlalchemy/engine/base.py
test/engine/test_deprecations.py
test/engine/test_execute.py

diff --git a/doc/build/changelog/unreleased_14/6913.rst b/doc/build/changelog/unreleased_14/6913.rst
new file mode 100644 (file)
index 0000000..43ce34c
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, engine, regression
+    :tickets: 6913
+
+    Fixed issue where the ability of the
+    :meth:`_engine.ConnectionEvents.before_execute` method to alter the SQL
+    statement object passed, returning the new object to be invoked, was
+    inadvertently removed. This behavior has been restored.
+
index c26d9a0a73f8fb8bd9c4a085ac96351adb825513..a316f904f01c478284a3df1750cd895a4f436864 100644 (file)
@@ -1287,6 +1287,7 @@ class Connection(Connectable):
 
         if self._has_events or self.engine._has_events:
             (
+                default,
                 distilled_params,
                 event_multiparams,
                 event_params,
@@ -1335,6 +1336,7 @@ class Connection(Connectable):
 
         if self._has_events or self.engine._has_events:
             (
+                ddl,
                 distilled_params,
                 event_multiparams,
                 event_params,
@@ -1399,7 +1401,7 @@ class Connection(Connectable):
         else:
             distilled_params = []
 
-        return distilled_params, event_multiparams, event_params
+        return elem, distilled_params, event_multiparams, event_params
 
     def _execute_clauseelement(
         self, elem, multiparams, params, execution_options
@@ -1415,6 +1417,7 @@ class Connection(Connectable):
         has_events = self._has_events or self.engine._has_events
         if has_events:
             (
+                elem,
                 distilled_params,
                 event_multiparams,
                 event_params,
@@ -1492,6 +1495,7 @@ class Connection(Connectable):
 
         if self._has_events or self.engine._has_events:
             (
+                compiled,
                 distilled_params,
                 event_multiparams,
                 event_params,
@@ -1536,6 +1540,7 @@ class Connection(Connectable):
         if not future:
             if self._has_events or self.engine._has_events:
                 (
+                    statement,
                     distilled_params,
                     event_multiparams,
                     event_params,
index 795cc5a4cf1efb88a1384b452cc88f5d3087709d..39e2bf7625e110172866f15829c20a3acbb7d57c 100644 (file)
@@ -1687,6 +1687,17 @@ class EngineEventsTest(fixtures.TestBase):
                 )
             eq_(result.all(), [("15",)])
 
+    @testing.only_on("sqlite")
+    def test_modify_statement_string(self, connection):
+        @event.listens_for(connection, "before_execute", retval=True)
+        def _modify(
+            conn, clauseelement, multiparams, params, execution_options
+        ):
+            return clauseelement.replace("hi", "there"), multiparams, params
+
+        with _string_deprecation_expect():
+            eq_(connection.scalar("select 'hi'"), "there")
+
     def test_retval_flag(self):
         canary = []
 
index 19ba5f03c9e932472e393d7555c70395b0835898..dd4ee32f8c49ee84bcaf31cdb3abd4927c7569a2 100644 (file)
@@ -31,6 +31,7 @@ from sqlalchemy.pool import NullPool
 from sqlalchemy.pool import QueuePool
 from sqlalchemy.sql import column
 from sqlalchemy.sql import literal
+from sqlalchemy.sql.elements import literal_column
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import config
@@ -1771,6 +1772,56 @@ class EngineEventsTest(fixtures.TestBase):
             with e1.connect() as conn:
                 conn.execute(select(literal("1")))
 
+    @testing.only_on("sqlite")
+    def test_dont_modify_statement_driversql(self, connection):
+        m1 = mock.Mock()
+
+        @event.listens_for(connection, "before_execute", retval=True)
+        def _modify(
+            conn, clauseelement, multiparams, params, execution_options
+        ):
+            m1.run_event()
+            return clauseelement.replace("hi", "there"), multiparams, params
+
+        # the event does not take effect for the "driver SQL" option
+        eq_(connection.exec_driver_sql("select 'hi'").scalar(), "hi")
+
+        # event is not called at all
+        eq_(m1.mock_calls, [])
+
+    @testing.combinations((True,), (False,), argnames="future")
+    @testing.only_on("sqlite")
+    def test_modify_statement_internal_driversql(self, connection, future):
+        m1 = mock.Mock()
+
+        @event.listens_for(connection, "before_execute", retval=True)
+        def _modify(
+            conn, clauseelement, multiparams, params, execution_options
+        ):
+            m1.run_event()
+            return clauseelement.replace("hi", "there"), multiparams, params
+
+        eq_(
+            connection._exec_driver_sql(
+                "select 'hi'", [], {}, {}, future=future
+            ).scalar(),
+            "hi" if future else "there",
+        )
+
+        if future:
+            eq_(m1.mock_calls, [])
+        else:
+            eq_(m1.mock_calls, [call.run_event()])
+
+    def test_modify_statement_clauseelement(self, connection):
+        @event.listens_for(connection, "before_execute", retval=True)
+        def _modify(
+            conn, clauseelement, multiparams, params, execution_options
+        ):
+            return select(literal_column("'there'")), multiparams, params
+
+        eq_(connection.scalar(select(literal_column("'hi'"))), "there")
+
     def test_argument_format_execute(self, testing_engine):
         def before_execute(
             conn, clauseelement, multiparams, params, execution_options