From: Mike Bayer Date: Mon, 7 Aug 2023 14:47:11 +0000 (-0400) Subject: implement RETURNING * for ORM DML X-Git-Tag: rel_2_0_20~18 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=fa081f36d2a3e859b2f71ed4c8dfa31ce17fecfd;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git implement RETURNING * for ORM DML Implemented the "RETURNING '*'" use case for ORM enabled DML statements. This will render in as many cases as possible and return the unfiltered result set, however is not supported for multi-parameter "ORM bulk INSERT" statements that have specific column rendering requirements. Fixes: #10192 Change-Id: I04297d08eacb9ad1d5fd6d9dd21afefb8e9dc0b1 --- diff --git a/doc/build/changelog/unreleased_20/10192.rst b/doc/build/changelog/unreleased_20/10192.rst new file mode 100644 index 0000000000..1d59861698 --- /dev/null +++ b/doc/build/changelog/unreleased_20/10192.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: usecase, orm + :tickets: 10192 + + Implemented the "RETURNING '*'" use case for ORM enabled DML statements. + This will render in as many cases as possible and return the unfiltered + result set, however is not supported for multi-parameter "ORM bulk INSERT" + statements that have specific column rendering requirements. + diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py index 063a2cbe48..d38dfa9ce1 100644 --- a/lib/sqlalchemy/orm/bulk_persistence.py +++ b/lib/sqlalchemy/orm/bulk_persistence.py @@ -578,7 +578,10 @@ class ORMDMLState(AbstractORMCompileState): execution_context = result.context compile_state = execution_context.compiled.compile_state - if compile_state.from_statement_ctx: + if ( + compile_state.from_statement_ctx + and not compile_state.from_statement_ctx.compile_options._is_star + ): load_options = execution_options.get( "_sa_orm_load_options", QueryContext.default_load_options ) @@ -1374,6 +1377,16 @@ class BulkORMInsert(ORMDMLState, InsertDMLState): use_supplemental_cols=True, ) + if ( + self.from_statement_ctx is not None + and self.from_statement_ctx.compile_options._is_star + ): + raise sa_exc.CompileError( + "Can't use RETURNING * with bulk ORM INSERT. " + "Please use a different INSERT form, such as INSERT..VALUES " + "or INSERT with a Core Connection" + ) + self.statement = statement diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index e961e59347..63c4e86c63 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -838,6 +838,13 @@ class ORMFromStatementCompileState(ORMCompileState): "plugin_subject", None ) adapter = DMLReturningColFilter(target_mapper, dml_mapper) + + if self.compile_options._is_star and (len(self._entities) != 1): + raise sa_exc.CompileError( + "Can't generate ORM query that includes multiple expressions " + "at the same time as '*'; query for '*' alone if present" + ) + for entity in self._entities: entity.setup_dml_returning_compile_state(self, adapter) @@ -2969,7 +2976,6 @@ class _ColumnEntity(_QueryEntity): column = compile_state.compound_eager_adapter.columns[column] getter = result._getter(column) - ret = getter, self._label_name, self._extra_entities self._row_processor = ret @@ -3031,6 +3037,13 @@ class _RawColumnEntity(_ColumnEntity): def corresponds_to(self, entity): return False + def setup_dml_returning_compile_state( + self, + compile_state: ORMCompileState, + adapter: DMLReturningColFilter, + ) -> None: + return self.setup_compile_state(compile_state) + def setup_compile_state(self, compile_state): current_adapter = compile_state._get_current_adapter() if current_adapter: diff --git a/test/orm/dml/test_bulk_statements.py b/test/orm/dml/test_bulk_statements.py index 2888aeaf9e..d9c91a5707 100644 --- a/test/orm/dml/test_bulk_statements.py +++ b/test/orm/dml/test_bulk_statements.py @@ -192,6 +192,85 @@ class InsertStmtTest(testing.AssertsExecutionResults, fixtures.TestBase): ), ) + @testing.requires.insert_returning + @testing.requires.returning_star + @testing.variation( + "insert_type", + ["bulk", ("values", testing.requires.multivalues_inserts), "single"], + ) + def test_insert_returning_star(self, decl_base, insert_type): + """test #10192""" + + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(Identity(), primary_key=True) + + name: Mapped[str] = mapped_column() + other_thing: Mapped[Optional[str]] + server_thing: Mapped[str] = mapped_column(server_default="thing") + + decl_base.metadata.create_all(testing.db) + insert_stmt = insert(User).returning(literal_column("*")) + + s = fixture_session() + + if insert_type.bulk or insert_type.single: + with expect_raises_message( + exc.CompileError, + r"Can't use RETURNING \* with bulk ORM INSERT.", + ): + if insert_type.bulk: + s.execute( + insert_stmt, + [ + {"name": "some name 1"}, + {"name": "some name 2"}, + {"name": "some name 3"}, + ], + ) + else: + s.execute( + insert_stmt, + {"name": "some name 1"}, + ) + return + elif insert_type.values: + with self.sql_execution_asserter() as asserter: + result = s.execute( + insert_stmt.values( + [ + {"name": "some name 1"}, + {"name": "some name 2"}, + {"name": "some name 3"}, + ], + ) + ) + + eq_( + result.all(), + [ + (1, "some name 1", None, "thing"), + (2, "some name 2", None, "thing"), + (3, "some name 3", None, "thing"), + ], + ) + asserter.assert_( + CompiledSQL( + "INSERT INTO users (name) VALUES (:name_m0), " + "(:name_m1), (:name_m2) RETURNING *", + [ + { + "name_m0": "some name 1", + "name_m1": "some name 2", + "name_m2": "some name 3", + } + ], + ), + ) + else: + insert_type.fail() + @testing.requires.insert_returning @testing.skip_if( "oracle", "oracle doesn't like the no-FROM SELECT inside of an INSERT" @@ -587,6 +666,60 @@ class UpdateStmtTest(testing.AssertsExecutionResults, fixtures.TestBase): CompiledSQL("UPDATE a SET x=:x, y=:y", [{"x": 5, "y": 9}]), ) + @testing.variation("multi_row", ["multirow", "singlerow", "listwsingle"]) + @testing.requires.update_returning + @testing.requires.returning_star + def test_bulk_update_returning_star(self, decl_base, multi_row): + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column( + primary_key=True, autoincrement=False + ) + + x: Mapped[int] + y: Mapped[int] + + decl_base.metadata.create_all(testing.db) + + s = fixture_session() + + s.add_all( + [A(id=1, x=1, y=1), A(id=2, x=2, y=2), A(id=3, x=3, y=3)], + ) + s.commit() + + stmt = update(A).returning(literal_column("*")) + + if multi_row.multirow: + data = [ + {"x": 3, "y": 8}, + {"x": 5, "y": 9}, + {"x": 12, "y": 15}, + ] + + stmt = stmt.execution_options(synchronize_session=None) + elif multi_row.listwsingle: + data = [ + {"x": 5, "y": 9}, + ] + + stmt = stmt.execution_options(synchronize_session=None) + elif multi_row.singlerow: + data = {"x": 5, "y": 9} + else: + multi_row.fail() + + if multi_row.multirow or multi_row.listwsingle: + with expect_raises_message( + exc.InvalidRequestError, "No primary key value supplied" + ): + s.execute(stmt, data) + return + else: + result = s.execute(stmt, data) + eq_(result.all(), [(1, 5, 9), (2, 5, 9), (3, 5, 9)]) + def test_bulk_update_w_where_one(self, decl_base): """test use case in #9595""" diff --git a/test/orm/dml/test_update_delete_where.py b/test/orm/dml/test_update_delete_where.py index a524ddd14a..7f76d735d3 100644 --- a/test/orm/dml/test_update_delete_where.py +++ b/test/orm/dml/test_update_delete_where.py @@ -10,6 +10,7 @@ from sqlalchemy import insert from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import lambda_stmt +from sqlalchemy import literal_column from sqlalchemy import MetaData from sqlalchemy import or_ from sqlalchemy import select @@ -1094,6 +1095,31 @@ class UpdateDeleteTest(fixtures.MappedTest): list(zip([25, 37, 29, 27])), ) + @testing.requires.update_returning + @testing.requires.returning_star + def test_update_returning_star(self): + User = self.classes.User + + sess = fixture_session() + + john, jack, jill, jane = sess.query(User).order_by(User.id).all() + + stmt = ( + update(User) + .where(User.age > 29) + .values({"age": User.age - 10}) + .returning(literal_column("*")) + ) + + result = sess.execute(stmt) + eq_(result.all(), [(2, "jack", 37), (4, "jane", 27)]) + + eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27]) + eq_( + sess.query(User.age).order_by(User.id).all(), + list(zip([25, 37, 29, 27])), + ) + @testing.combinations(True, False, argnames="implicit_returning") def test_update_fetch_returning(self, implicit_returning): if implicit_returning: @@ -1255,6 +1281,28 @@ class UpdateDeleteTest(fixtures.MappedTest): # to point to the class, so you can test eq with sets eq_(set(result.all()), expected) + @testing.requires.delete_returning + @testing.requires.returning_star + def test_delete_returning_star(self): + User = self.classes.User + + sess = fixture_session() + + john, jack, jill, jane = sess.query(User).order_by(User.id).all() + + in_(john, sess) + in_(jack, sess) + + stmt = delete(User).where(User.age > 29).returning(literal_column("*")) + + result = sess.execute(stmt) + eq_(result.all(), [(2, "jack", 47), (4, "jane", 37)]) + + in_(john, sess) + not_in(jack, sess) + in_(jill, sess) + not_in(jane, sess) + @testing.combinations(True, False, argnames="implicit_returning") def test_delete_fetch_returning(self, implicit_returning): if implicit_returning: diff --git a/test/orm/test_core_compilation.py b/test/orm/test_core_compilation.py index a8d38de3e3..06482562b9 100644 --- a/test/orm/test_core_compilation.py +++ b/test/orm/test_core_compilation.py @@ -2834,8 +2834,6 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): ) def test_update_from_entity(self): - from sqlalchemy.sql import update - User = self.classes.User self.assert_compile( update(User), "UPDATE users SET id=:id, name=:name" @@ -2854,8 +2852,6 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): ) def test_delete_from_entity(self): - from sqlalchemy.sql import delete - User = self.classes.User self.assert_compile(delete(User), "DELETE FROM users") @@ -2866,8 +2862,6 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): ) def test_insert_from_entity(self): - from sqlalchemy.sql import insert - User = self.classes.User self.assert_compile( insert(User), "INSERT INTO users (id, name) VALUES (:id, :name)" @@ -2879,6 +2873,27 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): checkparams={"name": "ed"}, ) + def test_update_returning_star(self): + User = self.classes.User + self.assert_compile( + update(User).returning(literal_column("*")), + "UPDATE users SET id=:id, name=:name RETURNING *", + ) + + def test_delete_returning_star(self): + User = self.classes.User + self.assert_compile( + delete(User).returning(literal_column("*")), + "DELETE FROM users RETURNING *", + ) + + def test_insert_returning_star(self): + User = self.classes.User + self.assert_compile( + insert(User).returning(literal_column("*")), + "INSERT INTO users (id, name) VALUES (:id, :name) RETURNING *", + ) + def test_col_prop_builtin_function(self): class Foo: pass diff --git a/test/orm/test_loading.py b/test/orm/test_loading.py index d0b5c9d8f9..8be0f3168d 100644 --- a/test/orm/test_loading.py +++ b/test/orm/test_loading.py @@ -1,9 +1,12 @@ +from sqlalchemy import delete from sqlalchemy import exc +from sqlalchemy import insert from sqlalchemy import literal from sqlalchemy import literal_column from sqlalchemy import select from sqlalchemy import testing from sqlalchemy import text +from sqlalchemy import update from sqlalchemy.orm import loading from sqlalchemy.orm import relationship from sqlalchemy.testing import is_true @@ -77,7 +80,8 @@ class SelectStarTest(_fixtures.FixtureTest): lambda User, star: (star, text("some text")), argnames="testcase", ) - def test_no_star_orm_combinations(self, exprtype, testcase): + @testing.variation("stmt_type", ["select", "update", "insert", "delete"]) + def test_no_star_orm_combinations(self, exprtype, testcase, stmt_type): """test for #8235""" User = self.classes.User @@ -91,7 +95,17 @@ class SelectStarTest(_fixtures.FixtureTest): assert False args = testing.resolve_lambda(testcase, User=User, star=star) - stmt = select(*args).select_from(User) + + if stmt_type.select: + stmt = select(*args).select_from(User) + elif stmt_type.insert: + stmt = insert(User).returning(*args) + elif stmt_type.update: + stmt = update(User).values({"data": "foo"}).returning(*args) + elif stmt_type.delete: + stmt = delete(User).returning(*args) + else: + stmt_type.fail() s = fixture_session() diff --git a/test/requirements.py b/test/requirements.py index 61cb139338..e0941da1b9 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -461,6 +461,12 @@ class DefaultRequirements(SuiteRequirements): def computed_columns_on_update_returning(self): return self.computed_columns + skip_if("oracle") + @property + def returning_star(self): + """backend supports RETURNING *""" + + return skip_if(["oracle", "mssql"]) + @property def correlated_outer_joins(self): """Target must support an outer join to a subquery which