]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
skip ORM loading setups for non-toplevel DML
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 11 May 2023 14:06:10 +0000 (10:06 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 11 May 2023 14:06:10 +0000 (10:06 -0400)
Fixed regression where use of :func:`_dml.update` or :func:`_dml_delete`
within a :class:`_sql.CTE` construct, then used in a :func:`_sql.select`,
would raise a :class:`.CompileError` as a result of ORM related rules for
performing ORM-level update/delete statements.

Fixes: #9767
Change-Id: I4eae9af86752b2e5fd64f7998f8a68754c349e4c

doc/build/changelog/unreleased_20/9767.rst [new file with mode: 0644]
lib/sqlalchemy/orm/bulk_persistence.py
test/orm/test_core_compilation.py

diff --git a/doc/build/changelog/unreleased_20/9767.rst b/doc/build/changelog/unreleased_20/9767.rst
new file mode 100644 (file)
index 0000000..857d349
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, orm, regression
+    :tickets: 9767
+
+    Fixed regression where use of :func:`_dml.update` or :func:`_dml_delete`
+    within a :class:`_sql.CTE` construct, then used in a :func:`_sql.select`,
+    would raise a :class:`.CompileError` as a result of ORM related rules for
+    performing ORM-level update/delete statements.
index 257d71db4020f13725cf9859f791be510d2f851b..b75285ebdea36887b69c3f43bc9a91e8ced5342c 100644 (file)
@@ -1397,9 +1397,11 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
             "dml_strategy", "unspecified"
         )
 
-        if dml_strategy == "bulk":
+        toplevel = not compiler.stack
+
+        if toplevel and dml_strategy == "bulk":
             self._setup_for_bulk_update(statement, compiler)
-        elif dml_strategy in ("orm", "unspecified"):
+        elif not toplevel or dml_strategy in ("orm", "unspecified"):
             self._setup_for_orm_update(statement, compiler)
 
         return self
@@ -1407,6 +1409,8 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
     def _setup_for_orm_update(self, statement, compiler, **kw):
         orm_level_statement = statement
 
+        toplevel = not compiler.stack
+
         ext_info = statement.table._annotations["parententity"]
 
         self.mapper = mapper = ext_info.mapper
@@ -1416,8 +1420,8 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
         self._init_global_attributes(
             statement,
             compiler,
-            toplevel=True,
-            process_criteria_for_toplevel=True,
+            toplevel=toplevel,
+            process_criteria_for_toplevel=toplevel,
         )
 
         if statement._values:
@@ -1451,9 +1455,12 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
 
         use_supplemental_cols = False
 
-        synchronize_session = compiler._annotations.get(
-            "synchronize_session", None
-        )
+        if not toplevel:
+            synchronize_session = None
+        else:
+            synchronize_session = compiler._annotations.get(
+                "synchronize_session", None
+            )
         can_use_returning = compiler._annotations.get(
             "can_use_returning", None
         )
@@ -1486,13 +1493,14 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
                 *(list(mapper.local_table.primary_key))
             )
 
-        new_stmt = self._setup_orm_returning(
-            compiler,
-            orm_level_statement,
-            new_stmt,
-            dml_mapper=mapper,
-            use_supplemental_cols=use_supplemental_cols,
-        )
+        if toplevel:
+            new_stmt = self._setup_orm_returning(
+                compiler,
+                orm_level_statement,
+                new_stmt,
+                dml_mapper=mapper,
+                use_supplemental_cols=use_supplemental_cols,
+            )
 
         self.statement = new_stmt
 
@@ -1814,6 +1822,8 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
     def create_for_statement(cls, statement, compiler, **kw):
         self = cls.__new__(cls)
 
+        toplevel = not compiler.stack
+
         orm_level_statement = statement
 
         ext_info = statement.table._annotations["parententity"]
@@ -1822,8 +1832,8 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
         self._init_global_attributes(
             statement,
             compiler,
-            toplevel=True,
-            process_criteria_for_toplevel=True,
+            toplevel=toplevel,
+            process_criteria_for_toplevel=toplevel,
         )
 
         new_stmt = statement._clone()
@@ -1841,9 +1851,12 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
 
         use_supplemental_cols = False
 
-        synchronize_session = compiler._annotations.get(
-            "synchronize_session", None
-        )
+        if not toplevel:
+            synchronize_session = None
+        else:
+            synchronize_session = compiler._annotations.get(
+                "synchronize_session", None
+            )
         can_use_returning = compiler._annotations.get(
             "can_use_returning", None
         )
@@ -1870,13 +1883,14 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
 
             new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key)
 
-        new_stmt = self._setup_orm_returning(
-            compiler,
-            orm_level_statement,
-            new_stmt,
-            dml_mapper=mapper,
-            use_supplemental_cols=use_supplemental_cols,
-        )
+        if toplevel:
+            new_stmt = self._setup_orm_returning(
+                compiler,
+                orm_level_statement,
+                new_stmt,
+                dml_mapper=mapper,
+                use_supplemental_cols=use_supplemental_cols,
+            )
 
         self.statement = new_stmt
 
index 6736d55895b9e2043a9ac454cdd5dd71b4325f9b..8b28de591d3d9cce9082a7fb546d7bf192a45328 100644 (file)
@@ -360,6 +360,40 @@ class SelectableTest(QueryTest, AssertsCompiledSQL):
         )
 
 
+class DMLTest(QueryTest, AssertsCompiledSQL):
+    __dialect__ = "default"
+
+    @testing.variation("stmt_type", ["update", "delete"])
+    def test_dml_ctes(self, stmt_type: testing.Variation):
+        User = self.classes.User
+
+        if stmt_type.update:
+            fn = update
+        elif stmt_type.delete:
+            fn = delete
+        else:
+            stmt_type.fail()
+
+        inner_cte = fn(User).returning(User.id).cte("uid")
+
+        stmt = select(inner_cte)
+
+        if stmt_type.update:
+            self.assert_compile(
+                stmt,
+                "WITH uid AS (UPDATE users SET id=:id, name=:name "
+                "RETURNING users.id) SELECT uid.id FROM uid",
+            )
+        elif stmt_type.delete:
+            self.assert_compile(
+                stmt,
+                "WITH uid AS (DELETE FROM users "
+                "RETURNING users.id) SELECT uid.id FROM uid",
+            )
+        else:
+            stmt_type.fail()
+
+
 class ColumnsClauseFromsTest(QueryTest, AssertsCompiledSQL):
     __dialect__ = "default"