From: Mike Bayer Date: Wed, 25 Sep 2024 18:19:02 +0000 (-0400) Subject: honor prefetch_cols and postfetch_cols in ORM update w/ WHERE criteria X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=bd1c17f11318d0b581f59c8c6521979246abc9b8;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git honor prefetch_cols and postfetch_cols in ORM update w/ WHERE criteria Continuing from :ticket:`11912`, columns marked with :paramref:`.mapped_column.onupdate`, :paramref:`.mapped_column.server_onupdate`, or :class:`.Computed` are now refreshed in ORM instances when running an ORM enabled UPDATE with WHERE criteria, even if the statement does not use RETURNING or populate_existing. this moves the test we added in #11912 to be in test_update_delete_where, since this behavior is not related to bulk statements. For bulk statements, we're building onto the "many rows fast" use case and we as yet intentionally don't do any "bookkeeping", which means none of the expiration or any of that. would need to rethink "bulk update" a bit to get onupdates to refresh. Fixes: #11917 Change-Id: I9601be7afed523b356ce47a6daf98cc6584f4ad3 --- diff --git a/doc/build/changelog/unreleased_20/11917.rst b/doc/build/changelog/unreleased_20/11917.rst new file mode 100644 index 0000000000..951b191605 --- /dev/null +++ b/doc/build/changelog/unreleased_20/11917.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, orm + :tickets: 11917 + + Continuing from :ticket:`11912`, columns marked with + :paramref:`.mapped_column.onupdate`, + :paramref:`.mapped_column.server_onupdate`, or :class:`.Computed` are now + refreshed in ORM instances when running an ORM enabled UPDATE with WHERE + criteria, even if the statement does not use RETURNING or + populate_existing. diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py index 5e565f717f..a9408f1cce 100644 --- a/lib/sqlalchemy/orm/bulk_persistence.py +++ b/lib/sqlalchemy/orm/bulk_persistence.py @@ -1763,7 +1763,10 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): session, update_options, statement, + result.context.compiled_parameters[0], [(obj, state, dict_) for obj, state, dict_, _ in matched_objects], + result.prefetch_cols(), + result.postfetch_cols(), ) @classmethod @@ -1808,6 +1811,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): session, update_options, statement, + result.context.compiled_parameters[0], [ ( obj, @@ -1816,16 +1820,26 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): ) for obj in objs ], + result.prefetch_cols(), + result.postfetch_cols(), ) @classmethod def _apply_update_set_values_to_objects( - cls, session, update_options, statement, matched_objects + cls, + session, + update_options, + statement, + effective_params, + matched_objects, + prefetch_cols, + postfetch_cols, ): """apply values to objects derived from an update statement, e.g. UPDATE..SET """ + mapper = update_options._subject_mapper target_cls = mapper.class_ evaluator_compiler = evaluator._EvaluatorCompiler(target_cls) @@ -1848,7 +1862,35 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): attrib = {k for k, v in resolved_keys_as_propnames} states = set() + + to_prefetch = { + c + for c in prefetch_cols + if c.key in effective_params + and c in mapper._columntoproperty + and c.key not in evaluated_keys + } + to_expire = { + mapper._columntoproperty[c].key + for c in postfetch_cols + if c in mapper._columntoproperty + }.difference(evaluated_keys) + + prefetch_transfer = [ + (mapper._columntoproperty[c].key, c.key) for c in to_prefetch + ] + for obj, state, dict_ in matched_objects: + + dict_.update( + { + col_to_prop: effective_params[c_key] + for col_to_prop, c_key in prefetch_transfer + } + ) + + state._expire_attributes(state.dict, to_expire) + to_evaluate = state.unmodified.intersection(evaluated_keys) for key in to_evaluate: diff --git a/test/orm/dml/test_bulk_statements.py b/test/orm/dml/test_bulk_statements.py index 3943a9ab6c..992a18947b 100644 --- a/test/orm/dml/test_bulk_statements.py +++ b/test/orm/dml/test_bulk_statements.py @@ -8,8 +8,10 @@ from typing import Set import uuid from sqlalchemy import bindparam +from sqlalchemy import Computed from sqlalchemy import event from sqlalchemy import exc +from sqlalchemy import FetchedValue from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Identity @@ -602,78 +604,102 @@ class InsertStmtTest(testing.AssertsExecutionResults, fixtures.TestBase): class UpdateStmtTest(testing.AssertsExecutionResults, fixtures.TestBase): __backend__ = True - @testing.variation("populate_existing", [True, False]) - @testing.requires.update_returning - def test_update_populate_existing(self, decl_base, populate_existing): - """test #11912""" + @testing.variation( + "use_onupdate", + [ + "none", + "server", + "callable", + "clientsql", + ("computed", testing.requires.computed_columns), + ], + ) + def test_bulk_update_onupdates( + self, + decl_base, + use_onupdate, + ): + """assert that for now, bulk ORM update by primary key does not + expire or refresh onupdates.""" class Employee(ComparableEntity, decl_base): __tablename__ = "employee" uuid: Mapped[uuid.UUID] = mapped_column(primary_key=True) - user_name: Mapped[str] = mapped_column(nullable=False) - some_server_value: Mapped[str] + user_name: Mapped[str] = mapped_column(String(200), nullable=False) + + if use_onupdate.server: + some_server_value: Mapped[str] = mapped_column( + server_onupdate=FetchedValue() + ) + elif use_onupdate.callable: + some_server_value: Mapped[str] = mapped_column( + onupdate=lambda: "value 2" + ) + elif use_onupdate.clientsql: + some_server_value: Mapped[str] = mapped_column( + onupdate=literal("value 2") + ) + elif use_onupdate.computed: + some_server_value: Mapped[str] = mapped_column( + String(255), + Computed(user_name + " computed value"), + nullable=True, + ) + else: + some_server_value: Mapped[str] decl_base.metadata.create_all(testing.db) s = fixture_session() uuid1 = uuid.uuid4() - e1 = Employee( - uuid=uuid1, user_name="e1 old name", some_server_value="value 1" - ) + + if use_onupdate.computed: + server_old_value, server_new_value = ( + "e1 old name computed value", + "e1 new name computed value", + ) + e1 = Employee(uuid=uuid1, user_name="e1 old name") + else: + server_old_value, server_new_value = ("value 1", "value 2") + e1 = Employee( + uuid=uuid1, + user_name="e1 old name", + some_server_value="value 1", + ) s.add(e1) s.flush() - stmt = ( - update(Employee) - .values(user_name="e1 new name") - .where(Employee.uuid == uuid1) - .returning(Employee) - ) + # for computed col, make sure e1.some_server_value is loaded. + # this will already be the case for all RETURNING backends, so this + # suits just MySQL. + if use_onupdate.computed: + e1.some_server_value + + stmt = update(Employee) + # perform out of band UPDATE on server value to simulate # a computed col - s.connection().execute( - update(Employee.__table__).values(some_server_value="value 2") - ) - if populate_existing: - rows = s.scalars( - stmt, execution_options={"populate_existing": True} + if use_onupdate.none or use_onupdate.server: + s.connection().execute( + update(Employee.__table__).values(some_server_value="value 2") ) - # SPECIAL: before we actually receive the returning rows, - # the existing objects have not been updated yet - eq_(e1.some_server_value, "value 1") - eq_( - set(rows), - { - Employee( - uuid=uuid1, - user_name="e1 new name", - some_server_value="value 2", - ), - }, - ) + execution_options = {} - # now they are updated - eq_(e1.some_server_value, "value 2") - else: - # no populate existing - rows = s.scalars(stmt) - eq_(e1.some_server_value, "value 1") - eq_( - set(rows), - { - Employee( - uuid=uuid1, - user_name="e1 new name", - some_server_value="value 1", - ), - }, - ) - eq_(e1.some_server_value, "value 1") + s.execute( + stmt, + execution_options=execution_options, + params=[{"uuid": uuid1, "user_name": "e1 new name"}], + ) + + assert "some_server_value" in e1.__dict__ + eq_(e1.some_server_value, server_old_value) + + # do a full expire, now the new value is definitely there s.commit() s.expire_all() - eq_(e1.some_server_value, "value 2") + eq_(e1.some_server_value, server_new_value) @testing.variation( "returning_executemany", @@ -2393,18 +2419,24 @@ class EagerLoadTest( class A(Base): __tablename__ = "a" - id: Mapped[int] = mapped_column(Integer, primary_key=True) + id: Mapped[int] = mapped_column( + Integer, Identity(), primary_key=True + ) cs = relationship("C") class B(Base): __tablename__ = "b" - id: Mapped[int] = mapped_column(Integer, primary_key=True) + id: Mapped[int] = mapped_column( + Integer, Identity(), primary_key=True + ) a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) a = relationship("A") class C(Base): __tablename__ = "c" - id: Mapped[int] = mapped_column(Integer, primary_key=True) + id: Mapped[int] = mapped_column( + Integer, Identity(), primary_key=True + ) a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) @classmethod diff --git a/test/orm/dml/test_update_delete_where.py b/test/orm/dml/test_update_delete_where.py index 3f7b08b470..8d9feaf63c 100644 --- a/test/orm/dml/test_update_delete_where.py +++ b/test/orm/dml/test_update_delete_where.py @@ -1,15 +1,22 @@ +from __future__ import annotations + +import uuid + from sqlalchemy import Boolean from sqlalchemy import case from sqlalchemy import column +from sqlalchemy import Computed from sqlalchemy import delete from sqlalchemy import event from sqlalchemy import exc +from sqlalchemy import FetchedValue from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import insert from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import lambda_stmt +from sqlalchemy import literal from sqlalchemy import literal_column from sqlalchemy import MetaData from sqlalchemy import or_ @@ -25,6 +32,8 @@ from sqlalchemy.orm import Bundle from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import immediateload from sqlalchemy.orm import joinedload +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session @@ -44,6 +53,7 @@ from sqlalchemy.testing import in_ from sqlalchemy.testing import not_in from sqlalchemy.testing.assertions import expect_raises_message from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -3296,6 +3306,219 @@ class LoadFromReturningTest(fixtures.MappedTest): # TODO: state of above objects should be "deleted" +class OnUpdatePopulationTest(fixtures.TestBase): + __backend__ = True + + @testing.variation("populate_existing", [True, False]) + @testing.variation( + "use_onupdate", + [ + "none", + "server", + "callable", + "clientsql", + ("computed", testing.requires.computed_columns), + ], + ) + @testing.variation( + "use_returning", + [ + ("returning", testing.requires.update_returning), + ("defaults", testing.requires.update_returning), + "none", + ], + ) + @testing.variation("synchronize", ["auto", "fetch", "evaluate"]) + def test_update_populate_existing( + self, + decl_base, + populate_existing, + use_onupdate, + use_returning, + synchronize, + ): + """test #11912 and #11917""" + + class Employee(ComparableEntity, decl_base): + __tablename__ = "employee" + + uuid: Mapped[uuid.UUID] = mapped_column(primary_key=True) + user_name: Mapped[str] = mapped_column(String(200), nullable=False) + + if use_onupdate.server: + some_server_value: Mapped[str] = mapped_column( + server_onupdate=FetchedValue() + ) + elif use_onupdate.callable: + some_server_value: Mapped[str] = mapped_column( + onupdate=lambda: "value 2" + ) + elif use_onupdate.clientsql: + some_server_value: Mapped[str] = mapped_column( + onupdate=literal("value 2") + ) + elif use_onupdate.computed: + some_server_value: Mapped[str] = mapped_column( + String(255), + Computed(user_name + " computed value"), + nullable=True, + ) + else: + some_server_value: Mapped[str] + + decl_base.metadata.create_all(testing.db) + s = fixture_session() + + uuid1 = uuid.uuid4() + + if use_onupdate.computed: + server_old_value, server_new_value = ( + "e1 old name computed value", + "e1 new name computed value", + ) + e1 = Employee(uuid=uuid1, user_name="e1 old name") + else: + server_old_value, server_new_value = ("value 1", "value 2") + e1 = Employee( + uuid=uuid1, + user_name="e1 old name", + some_server_value="value 1", + ) + s.add(e1) + s.flush() + + stmt = ( + update(Employee) + .values(user_name="e1 new name") + .where(Employee.uuid == uuid1) + ) + + if use_returning.returning: + stmt = stmt.returning(Employee) + elif use_returning.defaults: + # NOTE: the return_defaults case here has not been analyzed for + # #11912 or #11917. future enhancements may change its behavior + stmt = stmt.return_defaults() + + # perform out of band UPDATE on server value to simulate + # a computed col + if use_onupdate.none or use_onupdate.server: + s.connection().execute( + update(Employee.__table__).values(some_server_value="value 2") + ) + + execution_options = {} + + if populate_existing: + execution_options["populate_existing"] = True + + if synchronize.evaluate: + execution_options["synchronize_session"] = "evaluate" + if synchronize.fetch: + execution_options["synchronize_session"] = "fetch" + + if use_returning.returning: + rows = s.scalars(stmt, execution_options=execution_options) + else: + s.execute(stmt, execution_options=execution_options) + + if ( + use_onupdate.clientsql + or use_onupdate.server + or use_onupdate.computed + ): + if not use_returning.defaults: + # if server-side onupdate was generated, the col should have + # been expired + assert "some_server_value" not in e1.__dict__ + + # and refreshes when called. this is even if we have RETURNING + # rows we didn't fetch yet. + eq_(e1.some_server_value, server_new_value) + else: + # using return defaults here is not expiring. have not + # researched why, it may be because the explicit + # return_defaults interferes with the ORMs call + assert "some_server_value" in e1.__dict__ + eq_(e1.some_server_value, server_old_value) + + elif use_onupdate.callable: + if not use_returning.defaults or not synchronize.fetch: + # for python-side onupdate, col is populated with local value + assert "some_server_value" in e1.__dict__ + + # and is refreshed + eq_(e1.some_server_value, server_new_value) + else: + assert "some_server_value" in e1.__dict__ + + # and is not refreshed + eq_(e1.some_server_value, server_old_value) + + else: + # no onupdate, then the value was not touched yet, + # even if we used RETURNING with populate_existing, because + # we did not fetch the rows yet + assert "some_server_value" in e1.__dict__ + eq_(e1.some_server_value, server_old_value) + + # now see if we can fetch rows + if use_returning.returning: + + if populate_existing or not use_onupdate.none: + eq_( + set(rows), + { + Employee( + uuid=uuid1, + user_name="e1 new name", + some_server_value=server_new_value, + ), + }, + ) + + else: + # if no populate existing and no server default, that column + # is not touched at all + eq_( + set(rows), + { + Employee( + uuid=uuid1, + user_name="e1 new name", + some_server_value=server_old_value, + ), + }, + ) + + if use_returning.defaults: + # as mentioned above, the return_defaults() case here remains + # unanalyzed. + if synchronize.fetch or ( + use_onupdate.clientsql + or use_onupdate.server + or use_onupdate.computed + or use_onupdate.none + ): + eq_(e1.some_server_value, server_old_value) + else: + eq_(e1.some_server_value, server_new_value) + + elif ( + populate_existing and use_returning.returning + ) or not use_onupdate.none: + eq_(e1.some_server_value, server_new_value) + else: + # no onupdate specified, and no populate existing with returning, + # the attribute is not refreshed + eq_(e1.some_server_value, server_old_value) + + # do a full expire, now the new value is definitely there + s.commit() + s.expire_all() + eq_(e1.some_server_value, server_new_value) + + class PGIssue11849Test(fixtures.DeclarativeMappedTest): __backend__ = True __only_on__ = ("postgresql",)