From 7db1ced9e33d33da89f934107eeabe9ac337ae5b Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 11 May 2023 10:06:10 -0400 Subject: [PATCH] skip ORM loading setups for non-toplevel DML Fixed regression where use of :func:`_dml.update` or :func:`_dml_delete` within a :class:`_sql.CTE` construct, then used in a :func:`_sql.select`, would raise a :class:`.CompileError` as a result of ORM related rules for performing ORM-level update/delete statements. Fixes: #9767 Change-Id: I4eae9af86752b2e5fd64f7998f8a68754c349e4c --- doc/build/changelog/unreleased_20/9767.rst | 8 +++ lib/sqlalchemy/orm/bulk_persistence.py | 66 +++++++++++++--------- test/orm/test_core_compilation.py | 34 +++++++++++ 3 files changed, 82 insertions(+), 26 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/9767.rst diff --git a/doc/build/changelog/unreleased_20/9767.rst b/doc/build/changelog/unreleased_20/9767.rst new file mode 100644 index 0000000000..857d349879 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9767.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, orm, regression + :tickets: 9767 + + Fixed regression where use of :func:`_dml.update` or :func:`_dml_delete` + within a :class:`_sql.CTE` construct, then used in a :func:`_sql.select`, + would raise a :class:`.CompileError` as a result of ORM related rules for + performing ORM-level update/delete statements. diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py index 257d71db40..b75285ebde 100644 --- a/lib/sqlalchemy/orm/bulk_persistence.py +++ b/lib/sqlalchemy/orm/bulk_persistence.py @@ -1397,9 +1397,11 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): "dml_strategy", "unspecified" ) - if dml_strategy == "bulk": + toplevel = not compiler.stack + + if toplevel and dml_strategy == "bulk": self._setup_for_bulk_update(statement, compiler) - elif dml_strategy in ("orm", "unspecified"): + elif not toplevel or dml_strategy in ("orm", "unspecified"): self._setup_for_orm_update(statement, compiler) return self @@ -1407,6 +1409,8 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): def _setup_for_orm_update(self, statement, compiler, **kw): orm_level_statement = statement + toplevel = not compiler.stack + ext_info = statement.table._annotations["parententity"] self.mapper = mapper = ext_info.mapper @@ -1416,8 +1420,8 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): self._init_global_attributes( statement, compiler, - toplevel=True, - process_criteria_for_toplevel=True, + toplevel=toplevel, + process_criteria_for_toplevel=toplevel, ) if statement._values: @@ -1451,9 +1455,12 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): use_supplemental_cols = False - synchronize_session = compiler._annotations.get( - "synchronize_session", None - ) + if not toplevel: + synchronize_session = None + else: + synchronize_session = compiler._annotations.get( + "synchronize_session", None + ) can_use_returning = compiler._annotations.get( "can_use_returning", None ) @@ -1486,13 +1493,14 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): *(list(mapper.local_table.primary_key)) ) - new_stmt = self._setup_orm_returning( - compiler, - orm_level_statement, - new_stmt, - dml_mapper=mapper, - use_supplemental_cols=use_supplemental_cols, - ) + if toplevel: + new_stmt = self._setup_orm_returning( + compiler, + orm_level_statement, + new_stmt, + dml_mapper=mapper, + use_supplemental_cols=use_supplemental_cols, + ) self.statement = new_stmt @@ -1814,6 +1822,8 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): def create_for_statement(cls, statement, compiler, **kw): self = cls.__new__(cls) + toplevel = not compiler.stack + orm_level_statement = statement ext_info = statement.table._annotations["parententity"] @@ -1822,8 +1832,8 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): self._init_global_attributes( statement, compiler, - toplevel=True, - process_criteria_for_toplevel=True, + toplevel=toplevel, + process_criteria_for_toplevel=toplevel, ) new_stmt = statement._clone() @@ -1841,9 +1851,12 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): use_supplemental_cols = False - synchronize_session = compiler._annotations.get( - "synchronize_session", None - ) + if not toplevel: + synchronize_session = None + else: + synchronize_session = compiler._annotations.get( + "synchronize_session", None + ) can_use_returning = compiler._annotations.get( "can_use_returning", None ) @@ -1870,13 +1883,14 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key) - new_stmt = self._setup_orm_returning( - compiler, - orm_level_statement, - new_stmt, - dml_mapper=mapper, - use_supplemental_cols=use_supplemental_cols, - ) + if toplevel: + new_stmt = self._setup_orm_returning( + compiler, + orm_level_statement, + new_stmt, + dml_mapper=mapper, + use_supplemental_cols=use_supplemental_cols, + ) self.statement = new_stmt diff --git a/test/orm/test_core_compilation.py b/test/orm/test_core_compilation.py index 6736d55895..8b28de591d 100644 --- a/test/orm/test_core_compilation.py +++ b/test/orm/test_core_compilation.py @@ -360,6 +360,40 @@ class SelectableTest(QueryTest, AssertsCompiledSQL): ) +class DMLTest(QueryTest, AssertsCompiledSQL): + __dialect__ = "default" + + @testing.variation("stmt_type", ["update", "delete"]) + def test_dml_ctes(self, stmt_type: testing.Variation): + User = self.classes.User + + if stmt_type.update: + fn = update + elif stmt_type.delete: + fn = delete + else: + stmt_type.fail() + + inner_cte = fn(User).returning(User.id).cte("uid") + + stmt = select(inner_cte) + + if stmt_type.update: + self.assert_compile( + stmt, + "WITH uid AS (UPDATE users SET id=:id, name=:name " + "RETURNING users.id) SELECT uid.id FROM uid", + ) + elif stmt_type.delete: + self.assert_compile( + stmt, + "WITH uid AS (DELETE FROM users " + "RETURNING users.id) SELECT uid.id FROM uid", + ) + else: + stmt_type.fail() + + class ColumnsClauseFromsTest(QueryTest, AssertsCompiledSQL): __dialect__ = "default" -- 2.47.3