From: Mike Bayer Date: Tue, 19 Dec 2023 14:00:03 +0000 (-0500) Subject: ensure Bundle / DML RETURNING has test support, full impl X-Git-Tag: rel_2_0_24~6 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=df6f3a232393a647052bf6b52d73e4529f7d69e9;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git ensure Bundle / DML RETURNING has test support, full impl Ensured the use case of :class:`.Bundle` objects used in the ``returning()`` portion of ORM-enabled INSERT, UPDATE and DELETE statements is tested and works fully. This was never explicitly implemented or tested previously and did not work correctly in the 1.4 series; in the 2.0 series, ORM UPDATE/DELETE with WHERE criteria was missing an implementation method preventing :class:`.Bundle` objects from working. Fixes: #10776 Change-Id: I32298e65ac590a12b47dd6ba00b7d56038b8a450 (cherry picked from commit 6e089c3dbf7e7348da84dfc62cc1c6100a257fd4) --- diff --git a/doc/build/changelog/unreleased_20/10776.rst b/doc/build/changelog/unreleased_20/10776.rst new file mode 100644 index 0000000000..4a6889fdb7 --- /dev/null +++ b/doc/build/changelog/unreleased_20/10776.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, orm + :tickets: 10776 + + Ensured the use case of :class:`.Bundle` objects used in the + ``returning()`` portion of ORM-enabled INSERT, UPDATE and DELETE statements + is tested and works fully. This was never explicitly implemented or + tested previously and did not work correctly in the 1.4 series; in the 2.0 + series, ORM UPDATE/DELETE with WHERE criteria was missing an implementation + method preventing :class:`.Bundle` objects from working. diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 2f5e4ce8b7..3e73d80e71 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -2865,6 +2865,13 @@ class _BundleEntity(_QueryEntity): for ent in self._entities: ent.setup_compile_state(compile_state) + def setup_dml_returning_compile_state( + self, + compile_state: ORMCompileState, + adapter: DMLReturningColFilter, + ) -> None: + return self.setup_compile_state(compile_state) + def row_processor(self, context, result): procs, labels, extra = zip( *[ent.row_processor(context, result) for ent in self._entities] diff --git a/test/orm/dml/test_bulk_statements.py b/test/orm/dml/test_bulk_statements.py index 7af47de818..1e5c17c9de 100644 --- a/test/orm/dml/test_bulk_statements.py +++ b/test/orm/dml/test_bulk_statements.py @@ -23,6 +23,7 @@ from sqlalchemy import String from sqlalchemy import testing from sqlalchemy import update from sqlalchemy.orm import aliased +from sqlalchemy.orm import Bundle from sqlalchemy.orm import column_property from sqlalchemy.orm import load_only from sqlalchemy.orm import Mapped @@ -381,6 +382,68 @@ class InsertStmtTest(testing.AssertsExecutionResults, fixtures.TestBase): eq_(result.all(), [User(id=1, name="John", age=30)]) + @testing.requires.insert_returning + @testing.variation( + "insert_type", + ["bulk", ("values", testing.requires.multivalues_inserts), "single"], + ) + def test_insert_returning_bundle(self, decl_base, insert_type): + """test #10776""" + + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(Identity(), primary_key=True) + + name: Mapped[str] = mapped_column() + x: Mapped[int] + y: Mapped[int] + + decl_base.metadata.create_all(testing.db) + insert_stmt = insert(User).returning( + User.name, Bundle("mybundle", User.id, User.x, User.y) + ) + + s = fixture_session() + + if insert_type.bulk: + result = s.execute( + insert_stmt, + [ + {"name": "some name 1", "x": 1, "y": 2}, + {"name": "some name 2", "x": 2, "y": 3}, + {"name": "some name 3", "x": 3, "y": 4}, + ], + ) + elif insert_type.values: + result = s.execute( + insert_stmt.values( + [ + {"name": "some name 1", "x": 1, "y": 2}, + {"name": "some name 2", "x": 2, "y": 3}, + {"name": "some name 3", "x": 3, "y": 4}, + ], + ) + ) + elif insert_type.single: + result = s.execute( + insert_stmt, {"name": "some name 1", "x": 1, "y": 2} + ) + else: + insert_type.fail() + + if insert_type.single: + eq_(result.all(), [("some name 1", (1, 1, 2))]) + else: + eq_( + result.all(), + [ + ("some name 1", (1, 1, 2)), + ("some name 2", (2, 2, 3)), + ("some name 3", (3, 3, 4)), + ], + ) + @testing.variation( "use_returning", [(True, testing.requires.insert_returning), False] ) @@ -794,6 +857,34 @@ class UpdateStmtTest(testing.AssertsExecutionResults, fixtures.TestBase): result = s.execute(stmt, data) eq_(result.all(), [(1, 5, 9), (2, 5, 9), (3, 5, 9)]) + @testing.requires.update_returning + def test_bulk_update_returning_bundle(self, decl_base): + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column( + primary_key=True, autoincrement=False + ) + + x: Mapped[int] + y: Mapped[int] + + decl_base.metadata.create_all(testing.db) + + s = fixture_session() + + s.add_all( + [A(id=1, x=1, y=1), A(id=2, x=2, y=2), A(id=3, x=3, y=3)], + ) + s.commit() + + stmt = update(A).returning(Bundle("mybundle", A.id, A.x), A.y) + + data = {"x": 5, "y": 9} + + result = s.execute(stmt, data) + eq_(result.all(), [((1, 5), 9), ((2, 5), 9), ((3, 5), 9)]) + def test_bulk_update_w_where_one(self, decl_base): """test use case in #9595""" diff --git a/test/orm/dml/test_update_delete_where.py b/test/orm/dml/test_update_delete_where.py index 03468972d5..cbf27d018b 100644 --- a/test/orm/dml/test_update_delete_where.py +++ b/test/orm/dml/test_update_delete_where.py @@ -21,6 +21,7 @@ from sqlalchemy import update from sqlalchemy import values from sqlalchemy.orm import aliased from sqlalchemy.orm import backref +from sqlalchemy.orm import Bundle from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import immediateload from sqlalchemy.orm import joinedload @@ -1351,6 +1352,45 @@ class UpdateDeleteTest(fixtures.MappedTest): # to point to the class, so you can test eq with sets eq_(set(result.all()), expected) + @testing.requires.update_returning + @testing.variation("crud_type", ["update", "delete"]) + @testing.combinations( + "auto", + "evaluate", + "fetch", + False, + argnames="synchronize_session", + ) + def test_crud_returning_bundle(self, crud_type, synchronize_session): + """test #10776""" + User = self.classes.User + + sess = fixture_session() + + if crud_type.update: + stmt = ( + update(User) + .filter(User.age > 29) + .values({"age": User.age - 10}) + .execution_options(synchronize_session=synchronize_session) + .returning(Bundle("mybundle", User.id, User.age), User.name) + ) + expected = {((4, 27), "jane"), ((2, 37), "jack")} + elif crud_type.delete: + stmt = ( + delete(User) + .filter(User.age > 29) + .execution_options(synchronize_session=synchronize_session) + .returning(Bundle("mybundle", User.id, User.age), User.name) + ) + expected = {((2, 47), "jack"), ((4, 37), "jane")} + else: + crud_type.fail() + + result = sess.execute(stmt) + + eq_(set(result.all()), expected) + @testing.requires.delete_returning @testing.requires.returning_star def test_delete_returning_star(self):