]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
honor prefetch_cols and postfetch_cols in ORM update w/ WHERE criteria
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 25 Sep 2024 18:19:02 +0000 (14:19 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 9 Oct 2024 12:12:53 +0000 (08:12 -0400)
Continuing from :ticket:`11912`, columns marked with
:paramref:`.mapped_column.onupdate`,
:paramref:`.mapped_column.server_onupdate`, or :class:`.Computed` are now
refreshed in ORM instances when running an ORM enabled UPDATE with WHERE
criteria, even if the statement does not use RETURNING or
populate_existing.

this moves the test we added in #11912 to be in
test_update_delete_where, since this behavior is not related to bulk
statements.    For bulk statements, we're building onto the "many rows
fast" use case and we as yet intentionally don't do any "bookkeeping",
which means none of the expiration or any of that. would need to rethink
"bulk update" a bit to get onupdates to refresh.

Fixes: #11917
Change-Id: I9601be7afed523b356ce47a6daf98cc6584f4ad3
(cherry picked from commit bd1c17f11318d0b581f59c8c6521979246abc9b8)

doc/build/changelog/unreleased_20/11917.rst [new file with mode: 0644]
lib/sqlalchemy/orm/bulk_persistence.py
test/orm/dml/test_bulk_statements.py
test/orm/dml/test_update_delete_where.py

diff --git a/doc/build/changelog/unreleased_20/11917.rst b/doc/build/changelog/unreleased_20/11917.rst
new file mode 100644 (file)
index 0000000..951b191
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 11917
+
+    Continuing from :ticket:`11912`, columns marked with
+    :paramref:`.mapped_column.onupdate`,
+    :paramref:`.mapped_column.server_onupdate`, or :class:`.Computed` are now
+    refreshed in ORM instances when running an ORM enabled UPDATE with WHERE
+    criteria, even if the statement does not use RETURNING or
+    populate_existing.
index 155de56dbe1e50d4880e9e33a0be3e1bfb8cb82f..01a39049b070d5b4ebaa3eb450d2a5ad7e5ce1a8 100644 (file)
@@ -1761,7 +1761,10 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
             session,
             update_options,
             statement,
+            result.context.compiled_parameters[0],
             [(obj, state, dict_) for obj, state, dict_, _ in matched_objects],
+            result.prefetch_cols(),
+            result.postfetch_cols(),
         )
 
     @classmethod
@@ -1806,6 +1809,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
             session,
             update_options,
             statement,
+            result.context.compiled_parameters[0],
             [
                 (
                     obj,
@@ -1814,16 +1818,26 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
                 )
                 for obj in objs
             ],
+            result.prefetch_cols(),
+            result.postfetch_cols(),
         )
 
     @classmethod
     def _apply_update_set_values_to_objects(
-        cls, session, update_options, statement, matched_objects
+        cls,
+        session,
+        update_options,
+        statement,
+        effective_params,
+        matched_objects,
+        prefetch_cols,
+        postfetch_cols,
     ):
         """apply values to objects derived from an update statement, e.g.
         UPDATE..SET <values>
 
         """
+
         mapper = update_options._subject_mapper
         target_cls = mapper.class_
         evaluator_compiler = evaluator._EvaluatorCompiler(target_cls)
@@ -1846,7 +1860,35 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
         attrib = {k for k, v in resolved_keys_as_propnames}
 
         states = set()
+
+        to_prefetch = {
+            c
+            for c in prefetch_cols
+            if c.key in effective_params
+            and c in mapper._columntoproperty
+            and c.key not in evaluated_keys
+        }
+        to_expire = {
+            mapper._columntoproperty[c].key
+            for c in postfetch_cols
+            if c in mapper._columntoproperty
+        }.difference(evaluated_keys)
+
+        prefetch_transfer = [
+            (mapper._columntoproperty[c].key, c.key) for c in to_prefetch
+        ]
+
         for obj, state, dict_ in matched_objects:
+
+            dict_.update(
+                {
+                    col_to_prop: effective_params[c_key]
+                    for col_to_prop, c_key in prefetch_transfer
+                }
+            )
+
+            state._expire_attributes(state.dict, to_expire)
+
             to_evaluate = state.unmodified.intersection(evaluated_keys)
 
             for key in to_evaluate:
index 3943a9ab6cc441749de246e9f32bd98002407382..992a18947b793b1583b3ca02d8d0833fca7321c0 100644 (file)
@@ -8,8 +8,10 @@ from typing import Set
 import uuid
 
 from sqlalchemy import bindparam
+from sqlalchemy import Computed
 from sqlalchemy import event
 from sqlalchemy import exc
+from sqlalchemy import FetchedValue
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import Identity
@@ -602,78 +604,102 @@ class InsertStmtTest(testing.AssertsExecutionResults, fixtures.TestBase):
 class UpdateStmtTest(testing.AssertsExecutionResults, fixtures.TestBase):
     __backend__ = True
 
-    @testing.variation("populate_existing", [True, False])
-    @testing.requires.update_returning
-    def test_update_populate_existing(self, decl_base, populate_existing):
-        """test #11912"""
+    @testing.variation(
+        "use_onupdate",
+        [
+            "none",
+            "server",
+            "callable",
+            "clientsql",
+            ("computed", testing.requires.computed_columns),
+        ],
+    )
+    def test_bulk_update_onupdates(
+        self,
+        decl_base,
+        use_onupdate,
+    ):
+        """assert that for now, bulk ORM update by primary key does not
+        expire or refresh onupdates."""
 
         class Employee(ComparableEntity, decl_base):
             __tablename__ = "employee"
 
             uuid: Mapped[uuid.UUID] = mapped_column(primary_key=True)
-            user_name: Mapped[str] = mapped_column(nullable=False)
-            some_server_value: Mapped[str]
+            user_name: Mapped[str] = mapped_column(String(200), nullable=False)
+
+            if use_onupdate.server:
+                some_server_value: Mapped[str] = mapped_column(
+                    server_onupdate=FetchedValue()
+                )
+            elif use_onupdate.callable:
+                some_server_value: Mapped[str] = mapped_column(
+                    onupdate=lambda: "value 2"
+                )
+            elif use_onupdate.clientsql:
+                some_server_value: Mapped[str] = mapped_column(
+                    onupdate=literal("value 2")
+                )
+            elif use_onupdate.computed:
+                some_server_value: Mapped[str] = mapped_column(
+                    String(255),
+                    Computed(user_name + " computed value"),
+                    nullable=True,
+                )
+            else:
+                some_server_value: Mapped[str]
 
         decl_base.metadata.create_all(testing.db)
         s = fixture_session()
 
         uuid1 = uuid.uuid4()
-        e1 = Employee(
-            uuid=uuid1, user_name="e1 old name", some_server_value="value 1"
-        )
+
+        if use_onupdate.computed:
+            server_old_value, server_new_value = (
+                "e1 old name computed value",
+                "e1 new name computed value",
+            )
+            e1 = Employee(uuid=uuid1, user_name="e1 old name")
+        else:
+            server_old_value, server_new_value = ("value 1", "value 2")
+            e1 = Employee(
+                uuid=uuid1,
+                user_name="e1 old name",
+                some_server_value="value 1",
+            )
         s.add(e1)
         s.flush()
 
-        stmt = (
-            update(Employee)
-            .values(user_name="e1 new name")
-            .where(Employee.uuid == uuid1)
-            .returning(Employee)
-        )
+        # for computed col, make sure e1.some_server_value is loaded.
+        # this will already be the case for all RETURNING backends, so this
+        # suits just MySQL.
+        if use_onupdate.computed:
+            e1.some_server_value
+
+        stmt = update(Employee)
+
         # perform out of band UPDATE on server value to simulate
         # a computed col
-        s.connection().execute(
-            update(Employee.__table__).values(some_server_value="value 2")
-        )
-        if populate_existing:
-            rows = s.scalars(
-                stmt, execution_options={"populate_existing": True}
+        if use_onupdate.none or use_onupdate.server:
+            s.connection().execute(
+                update(Employee.__table__).values(some_server_value="value 2")
             )
-            # SPECIAL: before we actually receive the returning rows,
-            # the existing objects have not been updated yet
-            eq_(e1.some_server_value, "value 1")
 
-            eq_(
-                set(rows),
-                {
-                    Employee(
-                        uuid=uuid1,
-                        user_name="e1 new name",
-                        some_server_value="value 2",
-                    ),
-                },
-            )
+        execution_options = {}
 
-            # now they are updated
-            eq_(e1.some_server_value, "value 2")
-        else:
-            # no populate existing
-            rows = s.scalars(stmt)
-            eq_(e1.some_server_value, "value 1")
-            eq_(
-                set(rows),
-                {
-                    Employee(
-                        uuid=uuid1,
-                        user_name="e1 new name",
-                        some_server_value="value 1",
-                    ),
-                },
-            )
-            eq_(e1.some_server_value, "value 1")
+        s.execute(
+            stmt,
+            execution_options=execution_options,
+            params=[{"uuid": uuid1, "user_name": "e1 new name"}],
+        )
+
+        assert "some_server_value" in e1.__dict__
+        eq_(e1.some_server_value, server_old_value)
+
+        # do a full expire, now the new value is definitely there
         s.commit()
         s.expire_all()
-        eq_(e1.some_server_value, "value 2")
+        eq_(e1.some_server_value, server_new_value)
 
     @testing.variation(
         "returning_executemany",
@@ -2393,18 +2419,24 @@ class EagerLoadTest(
 
         class A(Base):
             __tablename__ = "a"
-            id: Mapped[int] = mapped_column(Integer, primary_key=True)
+            id: Mapped[int] = mapped_column(
+                Integer, Identity(), primary_key=True
+            )
             cs = relationship("C")
 
         class B(Base):
             __tablename__ = "b"
-            id: Mapped[int] = mapped_column(Integer, primary_key=True)
+            id: Mapped[int] = mapped_column(
+                Integer, Identity(), primary_key=True
+            )
             a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
             a = relationship("A")
 
         class C(Base):
             __tablename__ = "c"
-            id: Mapped[int] = mapped_column(Integer, primary_key=True)
+            id: Mapped[int] = mapped_column(
+                Integer, Identity(), primary_key=True
+            )
             a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
 
     @classmethod
index 3f7b08b470ce510ed8e8b6970ecfd3cf8d0f2570..8d9feaf63c271a4602cbdbc7b695fb52fe0b7dce 100644 (file)
@@ -1,15 +1,22 @@
+from __future__ import annotations
+
+import uuid
+
 from sqlalchemy import Boolean
 from sqlalchemy import case
 from sqlalchemy import column
+from sqlalchemy import Computed
 from sqlalchemy import delete
 from sqlalchemy import event
 from sqlalchemy import exc
+from sqlalchemy import FetchedValue
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import insert
 from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import lambda_stmt
+from sqlalchemy import literal
 from sqlalchemy import literal_column
 from sqlalchemy import MetaData
 from sqlalchemy import or_
@@ -25,6 +32,8 @@ from sqlalchemy.orm import Bundle
 from sqlalchemy.orm import exc as orm_exc
 from sqlalchemy.orm import immediateload
 from sqlalchemy.orm import joinedload
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import selectinload
 from sqlalchemy.orm import Session
@@ -44,6 +53,7 @@ from sqlalchemy.testing import in_
 from sqlalchemy.testing import not_in
 from sqlalchemy.testing.assertions import expect_raises_message
 from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -3296,6 +3306,219 @@ class LoadFromReturningTest(fixtures.MappedTest):
             # TODO: state of above objects should be "deleted"
 
 
+class OnUpdatePopulationTest(fixtures.TestBase):
+    __backend__ = True
+
+    @testing.variation("populate_existing", [True, False])
+    @testing.variation(
+        "use_onupdate",
+        [
+            "none",
+            "server",
+            "callable",
+            "clientsql",
+            ("computed", testing.requires.computed_columns),
+        ],
+    )
+    @testing.variation(
+        "use_returning",
+        [
+            ("returning", testing.requires.update_returning),
+            ("defaults", testing.requires.update_returning),
+            "none",
+        ],
+    )
+    @testing.variation("synchronize", ["auto", "fetch", "evaluate"])
+    def test_update_populate_existing(
+        self,
+        decl_base,
+        populate_existing,
+        use_onupdate,
+        use_returning,
+        synchronize,
+    ):
+        """test #11912 and #11917"""
+
+        class Employee(ComparableEntity, decl_base):
+            __tablename__ = "employee"
+
+            uuid: Mapped[uuid.UUID] = mapped_column(primary_key=True)
+            user_name: Mapped[str] = mapped_column(String(200), nullable=False)
+
+            if use_onupdate.server:
+                some_server_value: Mapped[str] = mapped_column(
+                    server_onupdate=FetchedValue()
+                )
+            elif use_onupdate.callable:
+                some_server_value: Mapped[str] = mapped_column(
+                    onupdate=lambda: "value 2"
+                )
+            elif use_onupdate.clientsql:
+                some_server_value: Mapped[str] = mapped_column(
+                    onupdate=literal("value 2")
+                )
+            elif use_onupdate.computed:
+                some_server_value: Mapped[str] = mapped_column(
+                    String(255),
+                    Computed(user_name + " computed value"),
+                    nullable=True,
+                )
+            else:
+                some_server_value: Mapped[str]
+
+        decl_base.metadata.create_all(testing.db)
+        s = fixture_session()
+
+        uuid1 = uuid.uuid4()
+
+        if use_onupdate.computed:
+            server_old_value, server_new_value = (
+                "e1 old name computed value",
+                "e1 new name computed value",
+            )
+            e1 = Employee(uuid=uuid1, user_name="e1 old name")
+        else:
+            server_old_value, server_new_value = ("value 1", "value 2")
+            e1 = Employee(
+                uuid=uuid1,
+                user_name="e1 old name",
+                some_server_value="value 1",
+            )
+        s.add(e1)
+        s.flush()
+
+        stmt = (
+            update(Employee)
+            .values(user_name="e1 new name")
+            .where(Employee.uuid == uuid1)
+        )
+
+        if use_returning.returning:
+            stmt = stmt.returning(Employee)
+        elif use_returning.defaults:
+            # NOTE: the return_defaults case here has not been analyzed for
+            # #11912 or #11917.   future enhancements may change its behavior
+            stmt = stmt.return_defaults()
+
+        # perform out of band UPDATE on server value to simulate
+        # a computed col
+        if use_onupdate.none or use_onupdate.server:
+            s.connection().execute(
+                update(Employee.__table__).values(some_server_value="value 2")
+            )
+
+        execution_options = {}
+
+        if populate_existing:
+            execution_options["populate_existing"] = True
+
+        if synchronize.evaluate:
+            execution_options["synchronize_session"] = "evaluate"
+        if synchronize.fetch:
+            execution_options["synchronize_session"] = "fetch"
+
+        if use_returning.returning:
+            rows = s.scalars(stmt, execution_options=execution_options)
+        else:
+            s.execute(stmt, execution_options=execution_options)
+
+        if (
+            use_onupdate.clientsql
+            or use_onupdate.server
+            or use_onupdate.computed
+        ):
+            if not use_returning.defaults:
+                # if server-side onupdate was generated, the col should have
+                # been expired
+                assert "some_server_value" not in e1.__dict__
+
+                # and refreshes when called.  this is even if we have RETURNING
+                # rows we didn't fetch yet.
+                eq_(e1.some_server_value, server_new_value)
+            else:
+                # using return defaults here is not expiring.   have not
+                # researched why, it may be because the explicit
+                # return_defaults interferes with the ORMs call
+                assert "some_server_value" in e1.__dict__
+                eq_(e1.some_server_value, server_old_value)
+
+        elif use_onupdate.callable:
+            if not use_returning.defaults or not synchronize.fetch:
+                # for python-side onupdate, col is populated with local value
+                assert "some_server_value" in e1.__dict__
+
+                # and is refreshed
+                eq_(e1.some_server_value, server_new_value)
+            else:
+                assert "some_server_value" in e1.__dict__
+
+                # and is not refreshed
+                eq_(e1.some_server_value, server_old_value)
+
+        else:
+            # no onupdate, then the value was not touched yet,
+            # even if we used RETURNING with populate_existing, because
+            # we did not fetch the rows yet
+            assert "some_server_value" in e1.__dict__
+            eq_(e1.some_server_value, server_old_value)
+
+        # now see if we can fetch rows
+        if use_returning.returning:
+
+            if populate_existing or not use_onupdate.none:
+                eq_(
+                    set(rows),
+                    {
+                        Employee(
+                            uuid=uuid1,
+                            user_name="e1 new name",
+                            some_server_value=server_new_value,
+                        ),
+                    },
+                )
+
+            else:
+                # if no populate existing and no server default, that column
+                # is not touched at all
+                eq_(
+                    set(rows),
+                    {
+                        Employee(
+                            uuid=uuid1,
+                            user_name="e1 new name",
+                            some_server_value=server_old_value,
+                        ),
+                    },
+                )
+
+        if use_returning.defaults:
+            # as mentioned above, the return_defaults() case here remains
+            # unanalyzed.
+            if synchronize.fetch or (
+                use_onupdate.clientsql
+                or use_onupdate.server
+                or use_onupdate.computed
+                or use_onupdate.none
+            ):
+                eq_(e1.some_server_value, server_old_value)
+            else:
+                eq_(e1.some_server_value, server_new_value)
+
+        elif (
+            populate_existing and use_returning.returning
+        ) or not use_onupdate.none:
+            eq_(e1.some_server_value, server_new_value)
+        else:
+            # no onupdate specified, and no populate existing with returning,
+            # the attribute is not refreshed
+            eq_(e1.some_server_value, server_old_value)
+
+        # do a full expire, now the new value is definitely there
+        s.commit()
+        s.expire_all()
+        eq_(e1.some_server_value, server_new_value)
+
+
 class PGIssue11849Test(fixtures.DeclarativeMappedTest):
     __backend__ = True
     __only_on__ = ("postgresql",)