]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
ensure whereclause, returning copied as tuples
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 28 Dec 2022 17:04:07 +0000 (12:04 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 28 Dec 2022 19:24:29 +0000 (14:24 -0500)
Fixed issue in the internal SQL traversal for DML statements like
:class:`_dml.Update` and :class:`_dml.Delete` which would cause among other
potential issues, a specific issue using lambda statements with the ORM
update/delete feature.

Fixes: #9033
Change-Id: I76428049cb767ba302fbea89555114bf63ab8687
(cherry picked from commit e68173bf7d296b2948abed06f79c7cbd0ab66b0d)

doc/build/changelog/unreleased_14/9033.rst [new file with mode: 0644]
lib/sqlalchemy/sql/dml.py
test/orm/test_update_delete.py
test/sql/test_external_traversal.py

diff --git a/doc/build/changelog/unreleased_14/9033.rst b/doc/build/changelog/unreleased_14/9033.rst
new file mode 100644 (file)
index 0000000..d0b0d2f
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 9033
+
+    Fixed issue in the internal SQL traversal for DML statements like
+    :class:`_dml.Update` and :class:`_dml.Delete` which would cause among other
+    potential issues, a specific issue using lambda statements with the ORM
+    update/delete feature.
index 07a4d7b2d585b34cb83839a57086793ce0ff385e..ae48740000e584fd5701ec67e2a0796ad0f07dd9 100644 (file)
@@ -928,12 +928,12 @@ class Insert(ValuesBase):
             ("_multi_values", InternalTraversal.dp_dml_multi_values),
             ("select", InternalTraversal.dp_clauseelement),
             ("_post_values_clause", InternalTraversal.dp_clauseelement),
-            ("_returning", InternalTraversal.dp_clauseelement_list),
+            ("_returning", InternalTraversal.dp_clauseelement_tuple),
             ("_hints", InternalTraversal.dp_table_hint_list),
             ("_return_defaults", InternalTraversal.dp_boolean),
             (
                 "_return_defaults_columns",
-                InternalTraversal.dp_clauseelement_list,
+                InternalTraversal.dp_clauseelement_tuple,
             ),
         ]
         + HasPrefixes._has_prefixes_traverse_internals
@@ -1208,16 +1208,16 @@ class Update(DMLWhereBase, ValuesBase):
     _traverse_internals = (
         [
             ("table", InternalTraversal.dp_clauseelement),
-            ("_where_criteria", InternalTraversal.dp_clauseelement_list),
+            ("_where_criteria", InternalTraversal.dp_clauseelement_tuple),
             ("_inline", InternalTraversal.dp_boolean),
             ("_ordered_values", InternalTraversal.dp_dml_ordered_values),
             ("_values", InternalTraversal.dp_dml_values),
-            ("_returning", InternalTraversal.dp_clauseelement_list),
+            ("_returning", InternalTraversal.dp_clauseelement_tuple),
             ("_hints", InternalTraversal.dp_table_hint_list),
             ("_return_defaults", InternalTraversal.dp_boolean),
             (
                 "_return_defaults_columns",
-                InternalTraversal.dp_clauseelement_list,
+                InternalTraversal.dp_clauseelement_tuple,
             ),
         ]
         + HasPrefixes._has_prefixes_traverse_internals
@@ -1436,8 +1436,8 @@ class Delete(DMLWhereBase, UpdateBase):
     _traverse_internals = (
         [
             ("table", InternalTraversal.dp_clauseelement),
-            ("_where_criteria", InternalTraversal.dp_clauseelement_list),
-            ("_returning", InternalTraversal.dp_clauseelement_list),
+            ("_where_criteria", InternalTraversal.dp_clauseelement_tuple),
+            ("_returning", InternalTraversal.dp_clauseelement_tuple),
             ("_hints", InternalTraversal.dp_table_hint_list),
         ]
         + HasPrefixes._has_prefixes_traverse_internals
index 6be271e460331f19797ebdc6cc2d1d8c2b2f3e40..9eaf1765a3138353be3a855d6504e1006cfcf02a 100644 (file)
@@ -701,7 +701,8 @@ class UpdateDeleteTest(fixtures.MappedTest):
             list(zip([15, 27, 19, 27])),
         )
 
-    def test_update_future_lambda(self):
+    @testing.variation("values_first", [True, False])
+    def test_update_future_lambda(self, values_first):
         User, users = self.classes.User, self.tables.users
 
         sess = Session(testing.db, future=True)
@@ -710,14 +711,22 @@ class UpdateDeleteTest(fixtures.MappedTest):
             sess.execute(select(User).order_by(User.id)).scalars().all()
         )
 
-        sess.execute(
-            lambda_stmt(
+        new_value = 10
+
+        if values_first:
+            stmt = lambda_stmt(lambda: update(User))
+            stmt += lambda s: s.values({"age": User.age - new_value})
+            stmt += lambda s: s.where(User.age > 29).execution_options(
+                synchronize_session="evaluate"
+            )
+        else:
+            stmt = lambda_stmt(
                 lambda: update(User)
                 .where(User.age > 29)
-                .values({"age": User.age - 10})
+                .values({"age": User.age - new_value})
                 .execution_options(synchronize_session="evaluate")
-            ),
-        )
+            )
+        sess.execute(stmt)
 
         eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27])
         eq_(
@@ -725,14 +734,21 @@ class UpdateDeleteTest(fixtures.MappedTest):
             list(zip([25, 37, 29, 27])),
         )
 
-        sess.execute(
-            lambda_stmt(
+        if values_first:
+            stmt = lambda_stmt(lambda: update(User))
+            stmt += lambda s: s.values({"age": User.age - new_value})
+            stmt += lambda s: s.where(User.age > 29).execution_options(
+                synchronize_session="evaluate"
+            )
+        else:
+            stmt = lambda_stmt(
                 lambda: update(User)
                 .where(User.age > 29)
                 .values({User.age: User.age - 10})
                 .execution_options(synchronize_session="evaluate")
             )
-        )
+
+        sess.execute(stmt)
         eq_([john.age, jack.age, jill.age, jane.age], [25, 27, 29, 27])
         eq_(
             sess.query(User.age).order_by(User.id).all(),
index 37363273b20213cdddd36a3802ff50d47c6d364d..7a058bfcdae8123631b6d561ff59966cb029b94a 100644 (file)
@@ -2693,7 +2693,7 @@ class ValuesBaseTest(fixtures.TestBase, AssertsCompiledSQL):
 
     """Tests the generative capability of Insert, Update"""
 
-    __dialect__ = "default"
+    __dialect__ = "default_enhanced"
 
     # fixme: consolidate converage from elsewhere here and expand
 
@@ -2935,3 +2935,41 @@ class ValuesBaseTest(fixtures.TestBase, AssertsCompiledSQL):
             "UPDATE construct does not support multiple parameter sets.",
             stmt.compile,
         )
+
+    @testing.variation("stmt_type", ["update", "delete"])
+    def test_whereclause_returning_adapted(self, stmt_type):
+        """test #9033"""
+
+        if stmt_type.update:
+            stmt = (
+                t1.update()
+                .where(t1.c.col1 == 10)
+                .values(col1=15)
+                .returning(t1.c.col1)
+            )
+        elif stmt_type.delete:
+            stmt = t1.delete().where(t1.c.col1 == 10).returning(t1.c.col1)
+        else:
+            stmt_type.fail()
+
+        stmt = visitors.replacement_traverse(stmt, {}, lambda elem: None)
+
+        assert isinstance(stmt._where_criteria, tuple)
+        assert isinstance(stmt._returning, tuple)
+
+        stmt = stmt.where(t1.c.col2 == 5).returning(t1.c.col2)
+
+        if stmt_type.update:
+            self.assert_compile(
+                stmt,
+                "UPDATE table1 SET col1=:col1 WHERE table1.col1 = :col1_1 "
+                "AND table1.col2 = :col2_1 RETURNING table1.col1, table1.col2",
+            )
+        elif stmt_type.delete:
+            self.assert_compile(
+                stmt,
+                "DELETE FROM table1 WHERE table1.col1 = :col1_1 "
+                "AND table1.col2 = :col2_1 RETURNING table1.col1, table1.col2",
+            )
+        else:
+            stmt_type.fail()