]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
apply correct pre-fetch params to post updated rows
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 5 Aug 2025 21:11:50 +0000 (17:11 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 5 Aug 2025 21:11:50 +0000 (17:11 -0400)
Fixed issue where using the ``post_update`` feature would apply incorrect
"pre-fetched" values to the ORM objects after a multi-row UPDATE process
completed.  These "pre-fetched" values would come from any column that had
an :paramref:`.Column.onupdate` callable or a version id generator used by
:paramref:`.orm.Mapper.version_id_generator`; for a version id generator
that delivered random identifiers like timestamps or UUIDs, this incorrect
data would lead to a DELETE statement against those same rows to fail in
the next step.

Fixes: #12748
Change-Id: Id12c7973f168604533762dfc01afbb9155b693a6

doc/build/changelog/unreleased_20/12748.rst [new file with mode: 0644]
lib/sqlalchemy/orm/persistence.py
test/orm/test_cycles.py
test/orm/test_versioning.py

diff --git a/doc/build/changelog/unreleased_20/12748.rst b/doc/build/changelog/unreleased_20/12748.rst
new file mode 100644 (file)
index 0000000..6891632
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 12748
+
+    Fixed issue where using the ``post_update`` feature would apply incorrect
+    "pre-fetched" values to the ORM objects after a multi-row UPDATE process
+    completed.  These "pre-fetched" values would come from any column that had
+    an :paramref:`.Column.onupdate` callable or a version id generator used by
+    :paramref:`.orm.Mapper.version_id_generator`; for a version id generator
+    that delivered random identifiers like timestamps or UUIDs, this incorrect
+    data would lead to a DELETE statement against those same rows to fail in
+    the next step.
+
index 1d6b4abf665af6bf6fc55a40718b8cedfb7012f2..f720f90951a1ea9dd4fced660bdf0528f24ff6a6 100644 (file)
@@ -1379,7 +1379,13 @@ def _emit_post_update_statements(
             )
 
             rows += c.rowcount
-            for state, state_dict, mapper_rec, connection, params in records:
+            for i, (
+                state,
+                state_dict,
+                mapper_rec,
+                connection,
+                params,
+            ) in enumerate(records):
                 _postfetch_post_update(
                     mapper_rec,
                     uowtransaction,
@@ -1387,7 +1393,7 @@ def _emit_post_update_statements(
                     state,
                     state_dict,
                     c,
-                    c.context.compiled_parameters[0],
+                    c.context.compiled_parameters[i],
                 )
 
         if check_rowcount:
index fb37185f53e3c3437389fa8c57b099480b9e69b1..b4ddd26e775b09358f5e6a3076194c3bebd2d5cc 100644 (file)
@@ -15,7 +15,9 @@ from sqlalchemy import Integer
 from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy.orm import backref
+from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import relationship
+from sqlalchemy.orm import Session
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
@@ -1586,6 +1588,46 @@ class SelfReferentialPostUpdateTest3(fixtures.MappedTest):
         session.flush()
 
 
+class PostUpdatePrefetchTest(fixtures.DeclarativeMappedTest):
+    """test #12748"""
+
+    run_setup_classes = "each"
+
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        count = 0
+
+        def _counter():
+            nonlocal count
+            count += 1
+            return count
+
+        class Parent(Base):
+            __tablename__ = "parent"
+            id = mapped_column(Integer, primary_key=True)
+
+            related = relationship("Related", post_update=True)
+
+        class Related(Base):
+            __tablename__ = "related"
+
+            id = mapped_column(Integer, primary_key=True)
+            parent_id = mapped_column(ForeignKey("parent.id"))
+            counter = mapped_column(Integer, onupdate=_counter)
+
+    def test_update_counter(self, connection):
+        Parent, Related = self.classes("Parent", "Related")
+
+        p1 = Parent(related=[Related(), Related(), Related()])
+        with Session(connection, expire_on_commit=False) as sess:
+            sess.add(p1)
+            sess.commit()
+
+        eq_([rel.counter for rel in p1.related], [1, 2, 3])
+
+
 class PostUpdateBatchingTest(fixtures.MappedTest):
     """test that lots of post update cols batch together into a single
     UPDATE."""
index 46821fe0558f2fc2f4f97b9eca49982baa685e28..06fb1b2a5fc8482d5352cd7578e829ef1f1dec0f 100644 (file)
@@ -15,8 +15,10 @@ from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy import TypeDecorator
 from sqlalchemy import util
+from sqlalchemy import Uuid
 from sqlalchemy.orm import configure_mappers
 from sqlalchemy.orm import exc as orm_exc
+from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
 from sqlalchemy.testing import assert_raises
@@ -757,6 +759,48 @@ class VersionOnPostUpdateTest(fixtures.MappedTest):
         )
 
 
+class PostUpdatePrefetchTest(fixtures.DeclarativeMappedTest):
+    """test #12748"""
+
+    run_setup_classes = "each"
+
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        class Parent(Base):
+            __tablename__ = "parent"
+            id = mapped_column(Integer, primary_key=True)
+
+            related = relationship(
+                "Related", post_update=True, cascade="all, delete-orphan"
+            )
+
+        class Related(Base):
+            __tablename__ = "related"
+
+            id = mapped_column(Integer, primary_key=True)
+            parent_id = mapped_column(ForeignKey("parent.id"))
+            version = mapped_column(Uuid)
+
+            __mapper_args__ = {
+                "version_id_col": version,
+                "version_id_generator": lambda v: uuid.uuid4(),
+            }
+
+    def test_random_versionids(self, connection):
+        Parent, Related = self.classes("Parent", "Related")
+
+        p1 = Parent(related=[Related(), Related(), Related()])
+        with Session(connection, expire_on_commit=False) as sess:
+            sess.add(p1)
+            sess.commit()
+
+        with Session(connection, expire_on_commit=False) as sess:
+            sess.delete(p1)
+            sess.commit()
+
+
 class NoBumpOnRelationshipTest(fixtures.MappedTest):
     __backend__ = True