]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implement _postfetch_post_update to expire/refresh onupdates in post_update
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 12 Apr 2016 19:56:02 +0000 (15:56 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 5 Jun 2017 22:39:08 +0000 (18:39 -0400)
Fixed bug involving the :paramref:`.relationship.post_update` feature
where a column "onupdate" value would not result in expiration or
refresh of the corresponding object attribute, if the UPDATE for the
row were a result of the "post update" feature.  Additionally, the
:meth:`.SessionEvents.refresh_flush` event is now emitted for these
attributes when refreshed within the flush.

Fixes: #3472
Change-Id: I5ee2d715e773a306ab1e8143e4382c228991ac78

doc/build/changelog/changelog_12.rst
lib/sqlalchemy/orm/persistence.py
test/orm/test_cycles.py

index ee6eb7f65ef01fd77b0cef7fe6503333e6d12c94..ba28d4832660b3f0b7cd098f6ae866dec0f2240e 100644 (file)
 
             :ref:`change_3948`
 
+    .. change:: 3472
+        :tags: bug, orm
+        :tickets: 3472
+
+        Fixed bug involving the :paramref:`.relationship.post_update` feature
+        where a column "onupdate" value would not result in expiration or
+        refresh of the corresponding object attribute, if the UPDATE for the
+        row were a result of the "post update" feature.  Additionally, the
+        :meth:`.SessionEvents.refresh_flush` event is now emitted for these
+        attributes when refreshed within the flush.
+
     .. change:: 3996
         :tags: bug, orm
         :tickets: 3996
index e8a7e4c33e028276ebc7f38591039d7db00094c4..5fa9701badf9398c0d6ed556fa645f52ff9be321 100644 (file)
@@ -601,7 +601,7 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table,
                     params[col.key] = value
                     hasdata = True
         if hasdata:
-            yield params, connection
+            yield state, state_dict, mapper, connection, params
 
 
 def _collect_delete_commands(base_mapper, uowtransaction, table,
@@ -887,15 +887,22 @@ def _emit_post_update_statements(base_mapper, uowtransaction,
     # to support executemany().
     for key, grouper in groupby(
         update, lambda rec: (
-            rec[1],  # connection
-            set(rec[0])  # parameter keys
+            rec[3],  # connection
+            set(rec[4]),  # parameter keys
         )
     ):
+        grouper = list(grouper)
         connection = key[0]
-        multiparams = [params for params, conn in grouper]
-        cached_connections[connection].\
+        multiparams = [
+            params for state, state_dict, mapper_rec, conn, params in grouper]
+        c = cached_connections[connection].\
             execute(statement, multiparams)
 
+        for state, state_dict, mapper_rec, connection, params in grouper:
+            _postfetch_post_update(
+                mapper, uowtransaction, state, state_dict,
+                c, c.context.compiled_parameters[0])
+
 
 def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
                             mapper, table, delete):
@@ -1038,6 +1045,33 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
                     "Instance does not contain a non-NULL version value")
 
 
+def _postfetch_post_update(mapper, uowtransaction,
+                           state, dict_, result, params):
+    prefetch_cols = result.context.compiled.prefetch
+    postfetch_cols = result.context.compiled.postfetch
+
+    refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
+    if refresh_flush:
+        load_evt_attrs = []
+
+    for c in prefetch_cols:
+        if c.key in params and c in mapper._columntoproperty:
+            dict_[mapper._columntoproperty[c].key] = params[c.key]
+            if refresh_flush:
+                load_evt_attrs.append(mapper._columntoproperty[c].key)
+
+    if refresh_flush and load_evt_attrs:
+        mapper.class_manager.dispatch.refresh_flush(
+            state, uowtransaction, load_evt_attrs)
+
+    if postfetch_cols:
+        state._expire_attributes(state.dict,
+                                 [mapper._columntoproperty[c].key
+                                  for c in postfetch_cols if c in
+                                  mapper._columntoproperty]
+                                 )
+
+
 def _postfetch(mapper, uowtransaction, table,
                state, dict_, result, params, value_params):
     """Expire attributes in need of newly persisted database state,
index f3a7dd141c55b9894c04539002a2f90925546734..a1be28d3f77219c533b0472638dcb197255be80d 100644 (file)
@@ -6,13 +6,16 @@ T1/T2.
 
 """
 from sqlalchemy import testing
+from sqlalchemy import event
+from sqlalchemy.testing import mock
 from sqlalchemy import Integer, String, ForeignKey
 from sqlalchemy.testing.schema import Table, Column
 from sqlalchemy.orm import mapper, relationship, backref, \
-    create_session, sessionmaker
+    create_session, sessionmaker, Session
 from sqlalchemy.testing import eq_, is_
 from sqlalchemy.testing.assertsql import RegexSQL, CompiledSQL, AllOf
 from sqlalchemy.testing import fixtures
+from itertools import count
 
 
 class SelfReferentialTest(fixtures.MappedTest):
@@ -1291,3 +1294,67 @@ class PostUpdateBatchingTest(fixtures.MappedTest):
                              'c1_id': None, 'c3_id': None}
             )
         )
+
+
+class PostUpdateOnUpdateTest(fixtures.DeclarativeMappedTest):
+
+    counter = count()
+
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        class A(Base):
+            __tablename__ = 'a'
+            id = Column(Integer, primary_key=True)
+            favorite_b_id = Column(ForeignKey('b.id', name="favorite_b_fk"))
+            bs = relationship("B", primaryjoin="A.id == B.a_id")
+            favorite_b = relationship(
+                "B", primaryjoin="A.favorite_b_id == B.id", post_update=True)
+            updated = Column(Integer, onupdate=lambda: next(cls.counter))
+
+        class B(Base):
+            __tablename__ = 'b'
+            id = Column(Integer, primary_key=True)
+            a_id = Column(ForeignKey('a.id', name="a_fk"))
+
+    def setup(self):
+        super(PostUpdateOnUpdateTest, self).setup()
+        PostUpdateOnUpdateTest.counter = count()
+
+    def test_update_defaults(self):
+        A, B = self.classes("A", "B")
+
+        s = Session()
+        a1 = A()
+        b1 = B()
+
+        a1.bs.append(b1)
+        a1.favorite_b = b1
+        s.add(a1)
+        s.flush()
+
+        eq_(a1.updated, 0)
+
+    def test_update_defaults_refresh_flush_event(self):
+        A, B = self.classes("A", "B")
+
+        canary = mock.Mock()
+        event.listen(A, "refresh_flush", canary)
+
+        s = Session()
+        a1 = A()
+        b1 = B()
+
+        a1.bs.append(b1)
+        a1.favorite_b = b1
+        s.add(a1)
+        s.flush()
+
+        eq_(a1.updated, 0)
+        eq_(
+            canary.mock_calls,
+            [
+                mock.call(a1, mock.ANY, ['updated'])
+            ]
+        )