]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
propagate populate_existing for ORM bulk update
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 22 Sep 2024 15:34:48 +0000 (11:34 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 22 Sep 2024 15:34:48 +0000 (11:34 -0400)
Similar to #9742

Fixed bug in ORM bulk update/delete where using RETURNING with bulk
update/delete in combination with populate existing would fail to
accommodate the populate_existing option.

Fixes: #11912
Change-Id: Ib9ef659512a1d1ae438eab67332a691941c06f43

doc/build/changelog/unreleased_20/11912.rst [new file with mode: 0644]
lib/sqlalchemy/orm/bulk_persistence.py
test/orm/dml/test_bulk_statements.py

diff --git a/doc/build/changelog/unreleased_20/11912.rst b/doc/build/changelog/unreleased_20/11912.rst
new file mode 100644 (file)
index 0000000..c0814b6
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 11912
+
+    Fixed bug in ORM bulk update/delete where using RETURNING with bulk
+    update/delete in combination with populate existing would fail to
+    accommodate the populate_existing option.
index 9a14a7ecfcf8f641fe5097f14bbbbe4aeabadb3c..5e565f717f5cb0e000ab99eb783930bd4b0a54e7 100644 (file)
@@ -646,6 +646,7 @@ class BulkUDCompileState(ORMDMLState):
         _eval_condition = None
         _matched_rows = None
         _identity_token = None
+        _populate_existing: bool = False
 
     @classmethod
     def can_use_returning(
@@ -678,6 +679,7 @@ class BulkUDCompileState(ORMDMLState):
             {
                 "synchronize_session",
                 "autoflush",
+                "populate_existing",
                 "identity_token",
                 "is_delete_using",
                 "is_update_from",
@@ -1592,10 +1594,20 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
         bind_arguments: _BindArguments,
         conn: Connection,
     ) -> _result.Result:
+
         update_options = execution_options.get(
             "_sa_orm_update_options", cls.default_update_options
         )
 
+        if update_options._populate_existing:
+            load_options = execution_options.get(
+                "_sa_orm_load_options", QueryContext.default_load_options
+            )
+            load_options += {"_populate_existing": True}
+            execution_options = execution_options.union(
+                {"_sa_orm_load_options": load_options}
+            )
+
         if update_options._dml_strategy not in (
             "orm",
             "auto",
index 431eb3076fcc74e87b52b217a362372888af9ff4..3943a9ab6cc441749de246e9f32bd98002407382 100644 (file)
@@ -602,6 +602,79 @@ class InsertStmtTest(testing.AssertsExecutionResults, fixtures.TestBase):
 class UpdateStmtTest(testing.AssertsExecutionResults, fixtures.TestBase):
     __backend__ = True
 
+    @testing.variation("populate_existing", [True, False])
+    @testing.requires.update_returning
+    def test_update_populate_existing(self, decl_base, populate_existing):
+        """test #11912"""
+
+        class Employee(ComparableEntity, decl_base):
+            __tablename__ = "employee"
+
+            uuid: Mapped[uuid.UUID] = mapped_column(primary_key=True)
+            user_name: Mapped[str] = mapped_column(nullable=False)
+            some_server_value: Mapped[str]
+
+        decl_base.metadata.create_all(testing.db)
+        s = fixture_session()
+
+        uuid1 = uuid.uuid4()
+        e1 = Employee(
+            uuid=uuid1, user_name="e1 old name", some_server_value="value 1"
+        )
+        s.add(e1)
+        s.flush()
+
+        stmt = (
+            update(Employee)
+            .values(user_name="e1 new name")
+            .where(Employee.uuid == uuid1)
+            .returning(Employee)
+        )
+        # perform out of band UPDATE on server value to simulate
+        # a computed col
+        s.connection().execute(
+            update(Employee.__table__).values(some_server_value="value 2")
+        )
+        if populate_existing:
+            rows = s.scalars(
+                stmt, execution_options={"populate_existing": True}
+            )
+            # SPECIAL: before we actually receive the returning rows,
+            # the existing objects have not been updated yet
+            eq_(e1.some_server_value, "value 1")
+
+            eq_(
+                set(rows),
+                {
+                    Employee(
+                        uuid=uuid1,
+                        user_name="e1 new name",
+                        some_server_value="value 2",
+                    ),
+                },
+            )
+
+            # now they are updated
+            eq_(e1.some_server_value, "value 2")
+        else:
+            # no populate existing
+            rows = s.scalars(stmt)
+            eq_(e1.some_server_value, "value 1")
+            eq_(
+                set(rows),
+                {
+                    Employee(
+                        uuid=uuid1,
+                        user_name="e1 new name",
+                        some_server_value="value 1",
+                    ),
+                },
+            )
+            eq_(e1.some_server_value, "value 1")
+        s.commit()
+        s.expire_all()
+        eq_(e1.some_server_value, "value 2")
+
     @testing.variation(
         "returning_executemany",
         [