]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
ensure Bundle / DML RETURNING has test support, full impl
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 19 Dec 2023 14:00:03 +0000 (09:00 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 19 Dec 2023 15:19:06 +0000 (10:19 -0500)
Ensured the use case of :class:`.Bundle` objects used in the
``returning()`` portion of ORM-enabled INSERT, UPDATE and DELETE statements
is tested and works fully.   This was never explicitly implemented or
tested previously and did not work correctly in the 1.4 series; in the 2.0
series, ORM UPDATE/DELETE with WHERE criteria was missing an implementation
method preventing :class:`.Bundle` objects from working.

Fixes: #10776
Change-Id: I32298e65ac590a12b47dd6ba00b7d56038b8a450

doc/build/changelog/unreleased_20/10776.rst [new file with mode: 0644]
lib/sqlalchemy/orm/context.py
test/orm/dml/test_bulk_statements.py
test/orm/dml/test_update_delete_where.py

diff --git a/doc/build/changelog/unreleased_20/10776.rst b/doc/build/changelog/unreleased_20/10776.rst
new file mode 100644 (file)
index 0000000..4a6889f
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 10776
+
+    Ensured the use case of :class:`.Bundle` objects used in the
+    ``returning()`` portion of ORM-enabled INSERT, UPDATE and DELETE statements
+    is tested and works fully.   This was never explicitly implemented or
+    tested previously and did not work correctly in the 1.4 series; in the 2.0
+    series, ORM UPDATE/DELETE with WHERE criteria was missing an implementation
+    method preventing :class:`.Bundle` objects from working.
index 2f5e4ce8b7ba581db6b10a31a715b46181dd176d..3e73d80e71681bd148244989f2a45e8354f751ff 100644 (file)
@@ -2865,6 +2865,13 @@ class _BundleEntity(_QueryEntity):
         for ent in self._entities:
             ent.setup_compile_state(compile_state)
 
+    def setup_dml_returning_compile_state(
+        self,
+        compile_state: ORMCompileState,
+        adapter: DMLReturningColFilter,
+    ) -> None:
+        return self.setup_compile_state(compile_state)
+
     def row_processor(self, context, result):
         procs, labels, extra = zip(
             *[ent.row_processor(context, result) for ent in self._entities]
index 7af47de818681927e1383e946d1c57049a47e026..1e5c17c9de42cfe26bd42d2f249a990f2f3a2fc2 100644 (file)
@@ -23,6 +23,7 @@ from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy import update
 from sqlalchemy.orm import aliased
+from sqlalchemy.orm import Bundle
 from sqlalchemy.orm import column_property
 from sqlalchemy.orm import load_only
 from sqlalchemy.orm import Mapped
@@ -381,6 +382,68 @@ class InsertStmtTest(testing.AssertsExecutionResults, fixtures.TestBase):
 
         eq_(result.all(), [User(id=1, name="John", age=30)])
 
+    @testing.requires.insert_returning
+    @testing.variation(
+        "insert_type",
+        ["bulk", ("values", testing.requires.multivalues_inserts), "single"],
+    )
+    def test_insert_returning_bundle(self, decl_base, insert_type):
+        """test #10776"""
+
+        class User(decl_base):
+            __tablename__ = "users"
+
+            id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+
+            name: Mapped[str] = mapped_column()
+            x: Mapped[int]
+            y: Mapped[int]
+
+        decl_base.metadata.create_all(testing.db)
+        insert_stmt = insert(User).returning(
+            User.name, Bundle("mybundle", User.id, User.x, User.y)
+        )
+
+        s = fixture_session()
+
+        if insert_type.bulk:
+            result = s.execute(
+                insert_stmt,
+                [
+                    {"name": "some name 1", "x": 1, "y": 2},
+                    {"name": "some name 2", "x": 2, "y": 3},
+                    {"name": "some name 3", "x": 3, "y": 4},
+                ],
+            )
+        elif insert_type.values:
+            result = s.execute(
+                insert_stmt.values(
+                    [
+                        {"name": "some name 1", "x": 1, "y": 2},
+                        {"name": "some name 2", "x": 2, "y": 3},
+                        {"name": "some name 3", "x": 3, "y": 4},
+                    ],
+                )
+            )
+        elif insert_type.single:
+            result = s.execute(
+                insert_stmt, {"name": "some name 1", "x": 1, "y": 2}
+            )
+        else:
+            insert_type.fail()
+
+        if insert_type.single:
+            eq_(result.all(), [("some name 1", (1, 1, 2))])
+        else:
+            eq_(
+                result.all(),
+                [
+                    ("some name 1", (1, 1, 2)),
+                    ("some name 2", (2, 2, 3)),
+                    ("some name 3", (3, 3, 4)),
+                ],
+            )
+
     @testing.variation(
         "use_returning", [(True, testing.requires.insert_returning), False]
     )
@@ -794,6 +857,34 @@ class UpdateStmtTest(testing.AssertsExecutionResults, fixtures.TestBase):
             result = s.execute(stmt, data)
             eq_(result.all(), [(1, 5, 9), (2, 5, 9), (3, 5, 9)])
 
+    @testing.requires.update_returning
+    def test_bulk_update_returning_bundle(self, decl_base):
+        class A(decl_base):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(
+                primary_key=True, autoincrement=False
+            )
+
+            x: Mapped[int]
+            y: Mapped[int]
+
+        decl_base.metadata.create_all(testing.db)
+
+        s = fixture_session()
+
+        s.add_all(
+            [A(id=1, x=1, y=1), A(id=2, x=2, y=2), A(id=3, x=3, y=3)],
+        )
+        s.commit()
+
+        stmt = update(A).returning(Bundle("mybundle", A.id, A.x), A.y)
+
+        data = {"x": 5, "y": 9}
+
+        result = s.execute(stmt, data)
+        eq_(result.all(), [((1, 5), 9), ((2, 5), 9), ((3, 5), 9)])
+
     def test_bulk_update_w_where_one(self, decl_base):
         """test use case in #9595"""
 
index 03468972d56976f29ce0d3f043889d45ed634225..cbf27d018b7770a5e45d7d3f0f0b8fde19f738ac 100644 (file)
@@ -21,6 +21,7 @@ from sqlalchemy import update
 from sqlalchemy import values
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import backref
+from sqlalchemy.orm import Bundle
 from sqlalchemy.orm import exc as orm_exc
 from sqlalchemy.orm import immediateload
 from sqlalchemy.orm import joinedload
@@ -1351,6 +1352,45 @@ class UpdateDeleteTest(fixtures.MappedTest):
         # to point to the class, so you can test eq with sets
         eq_(set(result.all()), expected)
 
+    @testing.requires.update_returning
+    @testing.variation("crud_type", ["update", "delete"])
+    @testing.combinations(
+        "auto",
+        "evaluate",
+        "fetch",
+        False,
+        argnames="synchronize_session",
+    )
+    def test_crud_returning_bundle(self, crud_type, synchronize_session):
+        """test #10776"""
+        User = self.classes.User
+
+        sess = fixture_session()
+
+        if crud_type.update:
+            stmt = (
+                update(User)
+                .filter(User.age > 29)
+                .values({"age": User.age - 10})
+                .execution_options(synchronize_session=synchronize_session)
+                .returning(Bundle("mybundle", User.id, User.age), User.name)
+            )
+            expected = {((4, 27), "jane"), ((2, 37), "jack")}
+        elif crud_type.delete:
+            stmt = (
+                delete(User)
+                .filter(User.age > 29)
+                .execution_options(synchronize_session=synchronize_session)
+                .returning(Bundle("mybundle", User.id, User.age), User.name)
+            )
+            expected = {((2, 47), "jack"), ((4, 37), "jane")}
+        else:
+            crud_type.fail()
+
+        result = sess.execute(stmt)
+
+        eq_(set(result.all()), expected)
+
     @testing.requires.delete_returning
     @testing.requires.returning_star
     def test_delete_returning_star(self):