From: Mike Bayer Date: Tue, 12 Apr 2016 19:56:02 +0000 (-0400) Subject: Implement _postfetch_post_update to expire/refresh onupdates in post_update X-Git-Tag: rel_1_2_0b1~34^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9dee44ae2f8b113d23f8a6e192f77fb4e3837894;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Implement _postfetch_post_update to expire/refresh onupdates in post_update Fixed bug involving the :paramref:`.relationship.post_update` feature where a column "onupdate" value would not result in expiration or refresh of the corresponding object attribute, if the UPDATE for the row were a result of the "post update" feature. Additionally, the :meth:`.SessionEvents.refresh_flush` event is now emitted for these attributes when refreshed within the flush. Fixes: #3472 Change-Id: I5ee2d715e773a306ab1e8143e4382c228991ac78 --- diff --git a/doc/build/changelog/changelog_12.rst b/doc/build/changelog/changelog_12.rst index ee6eb7f65e..ba28d48326 100644 --- a/doc/build/changelog/changelog_12.rst +++ b/doc/build/changelog/changelog_12.rst @@ -43,6 +43,17 @@ :ref:`change_3948` + .. change:: 3472 + :tags: bug, orm + :tickets: 3472 + + Fixed bug involving the :paramref:`.relationship.post_update` feature + where a column "onupdate" value would not result in expiration or + refresh of the corresponding object attribute, if the UPDATE for the + row were a result of the "post update" feature. Additionally, the + :meth:`.SessionEvents.refresh_flush` event is now emitted for these + attributes when refreshed within the flush. + .. change:: 3996 :tags: bug, orm :tickets: 3996 diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index e8a7e4c33e..5fa9701bad 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -601,7 +601,7 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table, params[col.key] = value hasdata = True if hasdata: - yield params, connection + yield state, state_dict, mapper, connection, params def _collect_delete_commands(base_mapper, uowtransaction, table, @@ -887,15 +887,22 @@ def _emit_post_update_statements(base_mapper, uowtransaction, # to support executemany(). for key, grouper in groupby( update, lambda rec: ( - rec[1], # connection - set(rec[0]) # parameter keys + rec[3], # connection + set(rec[4]), # parameter keys ) ): + grouper = list(grouper) connection = key[0] - multiparams = [params for params, conn in grouper] - cached_connections[connection].\ + multiparams = [ + params for state, state_dict, mapper_rec, conn, params in grouper] + c = cached_connections[connection].\ execute(statement, multiparams) + for state, state_dict, mapper_rec, connection, params in grouper: + _postfetch_post_update( + mapper, uowtransaction, state, state_dict, + c, c.context.compiled_parameters[0]) + def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, mapper, table, delete): @@ -1038,6 +1045,33 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states): "Instance does not contain a non-NULL version value") +def _postfetch_post_update(mapper, uowtransaction, + state, dict_, result, params): + prefetch_cols = result.context.compiled.prefetch + postfetch_cols = result.context.compiled.postfetch + + refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush) + if refresh_flush: + load_evt_attrs = [] + + for c in prefetch_cols: + if c.key in params and c in mapper._columntoproperty: + dict_[mapper._columntoproperty[c].key] = params[c.key] + if refresh_flush: + load_evt_attrs.append(mapper._columntoproperty[c].key) + + if refresh_flush and load_evt_attrs: + mapper.class_manager.dispatch.refresh_flush( + state, uowtransaction, load_evt_attrs) + + if postfetch_cols: + state._expire_attributes(state.dict, + [mapper._columntoproperty[c].key + for c in postfetch_cols if c in + mapper._columntoproperty] + ) + + def _postfetch(mapper, uowtransaction, table, state, dict_, result, params, value_params): """Expire attributes in need of newly persisted database state, diff --git a/test/orm/test_cycles.py b/test/orm/test_cycles.py index f3a7dd141c..a1be28d3f7 100644 --- a/test/orm/test_cycles.py +++ b/test/orm/test_cycles.py @@ -6,13 +6,16 @@ T1/T2. """ from sqlalchemy import testing +from sqlalchemy import event +from sqlalchemy.testing import mock from sqlalchemy import Integer, String, ForeignKey from sqlalchemy.testing.schema import Table, Column from sqlalchemy.orm import mapper, relationship, backref, \ - create_session, sessionmaker + create_session, sessionmaker, Session from sqlalchemy.testing import eq_, is_ from sqlalchemy.testing.assertsql import RegexSQL, CompiledSQL, AllOf from sqlalchemy.testing import fixtures +from itertools import count class SelfReferentialTest(fixtures.MappedTest): @@ -1291,3 +1294,67 @@ class PostUpdateBatchingTest(fixtures.MappedTest): 'c1_id': None, 'c3_id': None} ) ) + + +class PostUpdateOnUpdateTest(fixtures.DeclarativeMappedTest): + + counter = count() + + @classmethod + def setup_classes(cls): + Base = cls.DeclarativeBasic + + class A(Base): + __tablename__ = 'a' + id = Column(Integer, primary_key=True) + favorite_b_id = Column(ForeignKey('b.id', name="favorite_b_fk")) + bs = relationship("B", primaryjoin="A.id == B.a_id") + favorite_b = relationship( + "B", primaryjoin="A.favorite_b_id == B.id", post_update=True) + updated = Column(Integer, onupdate=lambda: next(cls.counter)) + + class B(Base): + __tablename__ = 'b' + id = Column(Integer, primary_key=True) + a_id = Column(ForeignKey('a.id', name="a_fk")) + + def setup(self): + super(PostUpdateOnUpdateTest, self).setup() + PostUpdateOnUpdateTest.counter = count() + + def test_update_defaults(self): + A, B = self.classes("A", "B") + + s = Session() + a1 = A() + b1 = B() + + a1.bs.append(b1) + a1.favorite_b = b1 + s.add(a1) + s.flush() + + eq_(a1.updated, 0) + + def test_update_defaults_refresh_flush_event(self): + A, B = self.classes("A", "B") + + canary = mock.Mock() + event.listen(A, "refresh_flush", canary) + + s = Session() + a1 = A() + b1 = B() + + a1.bs.append(b1) + a1.favorite_b = b1 + s.add(a1) + s.flush() + + eq_(a1.updated, 0) + eq_( + canary.mock_calls, + [ + mock.call(a1, mock.ANY, ['updated']) + ] + )