From: Mike Bayer Date: Sun, 22 Sep 2024 15:34:48 +0000 (-0400) Subject: propagate populate_existing for ORM bulk update X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=64c1299180c2d944142d54bea741355d474bcbde;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git propagate populate_existing for ORM bulk update Similar to #9742 Fixed bug in ORM bulk update/delete where using RETURNING with bulk update/delete in combination with populate existing would fail to accommodate the populate_existing option. Fixes: #11912 Change-Id: Ib9ef659512a1d1ae438eab67332a691941c06f43 --- diff --git a/doc/build/changelog/unreleased_20/11912.rst b/doc/build/changelog/unreleased_20/11912.rst new file mode 100644 index 0000000000..c0814b6cba --- /dev/null +++ b/doc/build/changelog/unreleased_20/11912.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, orm + :tickets: 11912 + + Fixed bug in ORM bulk update/delete where using RETURNING with bulk + update/delete in combination with populate existing would fail to + accommodate the populate_existing option. diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py index 9a14a7ecfc..5e565f717f 100644 --- a/lib/sqlalchemy/orm/bulk_persistence.py +++ b/lib/sqlalchemy/orm/bulk_persistence.py @@ -646,6 +646,7 @@ class BulkUDCompileState(ORMDMLState): _eval_condition = None _matched_rows = None _identity_token = None + _populate_existing: bool = False @classmethod def can_use_returning( @@ -678,6 +679,7 @@ class BulkUDCompileState(ORMDMLState): { "synchronize_session", "autoflush", + "populate_existing", "identity_token", "is_delete_using", "is_update_from", @@ -1592,10 +1594,20 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): bind_arguments: _BindArguments, conn: Connection, ) -> _result.Result: + update_options = execution_options.get( "_sa_orm_update_options", cls.default_update_options ) + if update_options._populate_existing: + load_options = execution_options.get( + "_sa_orm_load_options", QueryContext.default_load_options + ) + load_options += {"_populate_existing": True} + execution_options = execution_options.union( + {"_sa_orm_load_options": load_options} + ) + if update_options._dml_strategy not in ( "orm", "auto", diff --git a/test/orm/dml/test_bulk_statements.py b/test/orm/dml/test_bulk_statements.py index 431eb3076f..3943a9ab6c 100644 --- a/test/orm/dml/test_bulk_statements.py +++ b/test/orm/dml/test_bulk_statements.py @@ -602,6 +602,79 @@ class InsertStmtTest(testing.AssertsExecutionResults, fixtures.TestBase): class UpdateStmtTest(testing.AssertsExecutionResults, fixtures.TestBase): __backend__ = True + @testing.variation("populate_existing", [True, False]) + @testing.requires.update_returning + def test_update_populate_existing(self, decl_base, populate_existing): + """test #11912""" + + class Employee(ComparableEntity, decl_base): + __tablename__ = "employee" + + uuid: Mapped[uuid.UUID] = mapped_column(primary_key=True) + user_name: Mapped[str] = mapped_column(nullable=False) + some_server_value: Mapped[str] + + decl_base.metadata.create_all(testing.db) + s = fixture_session() + + uuid1 = uuid.uuid4() + e1 = Employee( + uuid=uuid1, user_name="e1 old name", some_server_value="value 1" + ) + s.add(e1) + s.flush() + + stmt = ( + update(Employee) + .values(user_name="e1 new name") + .where(Employee.uuid == uuid1) + .returning(Employee) + ) + # perform out of band UPDATE on server value to simulate + # a computed col + s.connection().execute( + update(Employee.__table__).values(some_server_value="value 2") + ) + if populate_existing: + rows = s.scalars( + stmt, execution_options={"populate_existing": True} + ) + # SPECIAL: before we actually receive the returning rows, + # the existing objects have not been updated yet + eq_(e1.some_server_value, "value 1") + + eq_( + set(rows), + { + Employee( + uuid=uuid1, + user_name="e1 new name", + some_server_value="value 2", + ), + }, + ) + + # now they are updated + eq_(e1.some_server_value, "value 2") + else: + # no populate existing + rows = s.scalars(stmt) + eq_(e1.some_server_value, "value 1") + eq_( + set(rows), + { + Employee( + uuid=uuid1, + user_name="e1 new name", + some_server_value="value 1", + ), + }, + ) + eq_(e1.some_server_value, "value 1") + s.commit() + s.expire_all() + eq_(e1.some_server_value, "value 2") + @testing.variation( "returning_executemany", [