"dml_strategy", "unspecified"
)
- if dml_strategy == "bulk":
+ toplevel = not compiler.stack
+
+ if toplevel and dml_strategy == "bulk":
self._setup_for_bulk_update(statement, compiler)
- elif dml_strategy in ("orm", "unspecified"):
+ elif not toplevel or dml_strategy in ("orm", "unspecified"):
self._setup_for_orm_update(statement, compiler)
return self
def _setup_for_orm_update(self, statement, compiler, **kw):
orm_level_statement = statement
+ toplevel = not compiler.stack
+
ext_info = statement.table._annotations["parententity"]
self.mapper = mapper = ext_info.mapper
self._init_global_attributes(
statement,
compiler,
- toplevel=True,
- process_criteria_for_toplevel=True,
+ toplevel=toplevel,
+ process_criteria_for_toplevel=toplevel,
)
if statement._values:
use_supplemental_cols = False
- synchronize_session = compiler._annotations.get(
- "synchronize_session", None
- )
+ if not toplevel:
+ synchronize_session = None
+ else:
+ synchronize_session = compiler._annotations.get(
+ "synchronize_session", None
+ )
can_use_returning = compiler._annotations.get(
"can_use_returning", None
)
*(list(mapper.local_table.primary_key))
)
- new_stmt = self._setup_orm_returning(
- compiler,
- orm_level_statement,
- new_stmt,
- dml_mapper=mapper,
- use_supplemental_cols=use_supplemental_cols,
- )
+ if toplevel:
+ new_stmt = self._setup_orm_returning(
+ compiler,
+ orm_level_statement,
+ new_stmt,
+ dml_mapper=mapper,
+ use_supplemental_cols=use_supplemental_cols,
+ )
self.statement = new_stmt
def create_for_statement(cls, statement, compiler, **kw):
self = cls.__new__(cls)
+ toplevel = not compiler.stack
+
orm_level_statement = statement
ext_info = statement.table._annotations["parententity"]
self._init_global_attributes(
statement,
compiler,
- toplevel=True,
- process_criteria_for_toplevel=True,
+ toplevel=toplevel,
+ process_criteria_for_toplevel=toplevel,
)
new_stmt = statement._clone()
use_supplemental_cols = False
- synchronize_session = compiler._annotations.get(
- "synchronize_session", None
- )
+ if not toplevel:
+ synchronize_session = None
+ else:
+ synchronize_session = compiler._annotations.get(
+ "synchronize_session", None
+ )
can_use_returning = compiler._annotations.get(
"can_use_returning", None
)
new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key)
- new_stmt = self._setup_orm_returning(
- compiler,
- orm_level_statement,
- new_stmt,
- dml_mapper=mapper,
- use_supplemental_cols=use_supplemental_cols,
- )
+ if toplevel:
+ new_stmt = self._setup_orm_returning(
+ compiler,
+ orm_level_statement,
+ new_stmt,
+ dml_mapper=mapper,
+ use_supplemental_cols=use_supplemental_cols,
+ )
self.statement = new_stmt
)
+class DMLTest(QueryTest, AssertsCompiledSQL):
+ __dialect__ = "default"
+
+ @testing.variation("stmt_type", ["update", "delete"])
+ def test_dml_ctes(self, stmt_type: testing.Variation):
+ User = self.classes.User
+
+ if stmt_type.update:
+ fn = update
+ elif stmt_type.delete:
+ fn = delete
+ else:
+ stmt_type.fail()
+
+ inner_cte = fn(User).returning(User.id).cte("uid")
+
+ stmt = select(inner_cte)
+
+ if stmt_type.update:
+ self.assert_compile(
+ stmt,
+ "WITH uid AS (UPDATE users SET id=:id, name=:name "
+ "RETURNING users.id) SELECT uid.id FROM uid",
+ )
+ elif stmt_type.delete:
+ self.assert_compile(
+ stmt,
+ "WITH uid AS (DELETE FROM users "
+ "RETURNING users.id) SELECT uid.id FROM uid",
+ )
+ else:
+ stmt_type.fail()
+
+
class ColumnsClauseFromsTest(QueryTest, AssertsCompiledSQL):
__dialect__ = "default"