From: Mike Bayer Date: Sun, 8 Nov 2009 21:54:56 +0000 (+0000) Subject: moved modified_event() calls below the attribute extension fires. this basically... X-Git-Tag: rel_0_6beta1~181 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=8a282a9b60dabd86bb16b6241055a619d61dc09b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git moved modified_event() calls below the attribute extension fires. this basically has no difference in any case except that where an extension is calling commit() on the attribute - in that case it usually, but not always, maintains the same history. [ticket:1601] --- diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index dfc415d2a9..9fbcf3d20e 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -431,10 +431,9 @@ class ScalarAttributeImpl(AttributeImpl): else: old = dict_.get(self.key, NO_VALUE) - state.modified_event(dict_, self, False, old) - if self.extensions: self.fire_remove_event(state, dict_, old, None) + state.modified_event(dict_, self, False, old) del dict_[self.key] def get_history(self, state, dict_, passive=PASSIVE_OFF): @@ -450,10 +449,9 @@ class ScalarAttributeImpl(AttributeImpl): else: old = dict_.get(self.key, NO_VALUE) - state.modified_event(dict_, self, False, old) - if self.extensions: value = self.fire_replace_event(state, dict_, value, old, initiator) + state.modified_event(dict_, self, False, old) dict_[self.key] = value def fire_replace_event(self, state, dict_, value, previous, initiator): @@ -520,14 +518,12 @@ class MutableScalarAttributeImpl(ScalarAttributeImpl): if initiator is self: return - state.modified_event(dict_, self, True, NEVER_SET) - if self.extensions: old = self.get(state, dict_) value = self.fire_replace_event(state, dict_, value, old, initiator) - dict_[self.key] = value - else: - dict_[self.key] = value + + state.modified_event(dict_, self, True, NEVER_SET) + dict_[self.key] = value state.mutable_dict[self.key] = value @@ -584,17 +580,15 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): dict_[self.key] = value def fire_remove_event(self, state, dict_, value, initiator): - state.modified_event(dict_, self, False, value) - if self.trackparent and value is not None: self.sethasparent(instance_state(value), False) for ext in self.extensions: ext.remove(state, value, initiator or self) - def fire_replace_event(self, state, dict_, value, previous, initiator): - state.modified_event(dict_, self, False, previous) + state.modified_event(dict_, self, False, value) + def fire_replace_event(self, state, dict_, value, previous, initiator): if self.trackparent: if previous is not value and previous not in (None, PASSIVE_NO_RESULT): self.sethasparent(instance_state(previous), False) @@ -602,6 +596,8 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): for ext in self.extensions: value = ext.set(state, value, previous, initiator or self) + state.modified_event(dict_, self, False, previous) + if self.trackparent: if value is not None: self.sethasparent(instance_state(value), True) @@ -649,11 +645,11 @@ class CollectionAttributeImpl(AttributeImpl): return History.from_attribute(self, state, current) def fire_append_event(self, state, dict_, value, initiator): - state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) - for ext in self.extensions: value = ext.append(state, value, initiator or self) + state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) + if self.trackparent and value is not None: self.sethasparent(instance_state(value), True) @@ -663,14 +659,14 @@ class CollectionAttributeImpl(AttributeImpl): state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) def fire_remove_event(self, state, dict_, value, initiator): - state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) - if self.trackparent and value is not None: self.sethasparent(instance_state(value), False) for ext in self.extensions: ext.remove(state, value, initiator or self) + state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) + def delete(self, state, dict_): if self.key not in dict_: return diff --git a/test/orm/test_attributes.py b/test/orm/test_attributes.py index 1aec6a02e0..c312b37aa9 100644 --- a/test/orm/test_attributes.py +++ b/test/orm/test_attributes.py @@ -4,7 +4,7 @@ from sqlalchemy.orm.collections import collection from sqlalchemy.orm.interfaces import AttributeExtension from sqlalchemy import exc as sa_exc from sqlalchemy.test import * -from sqlalchemy.test.testing import eq_, assert_raises +from sqlalchemy.test.testing import eq_, ne_, assert_raises from test.orm import _base from sqlalchemy.test.util import gc_collect from sqlalchemy.util import cmp, jython @@ -221,6 +221,103 @@ class AttributesTest(_base.ORMTest): u.addresses.append(a) self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.addresses[0].email_address == 'lala@123.com' and u.addresses[1].email_address == 'foo@bar.com') + def test_extension_commit_attr(self): + """test that an extension which commits attribute history + maintains the end-result history. + + This won't work in conjunction with some unitofwork extensions. + + """ + + class Foo(_base.ComparableEntity): + pass + class Bar(_base.ComparableEntity): + pass + + class ReceiveEvents(AttributeExtension): + def __init__(self, key): + self.key = key + + def append(self, state, child, initiator): + if commit: + state.commit_all(state.dict) + return child + + def remove(self, state, child, initiator): + if commit: + state.commit_all(state.dict) + return child + + def set(self, state, child, oldchild, initiator): + if commit: + state.commit_all(state.dict) + return child + + attributes.register_class(Foo) + attributes.register_class(Bar) + + b1, b2, b3, b4 = Bar(id='b1'), Bar(id='b2'), Bar(id='b3'), Bar(id='b4') + + def loadcollection(**kw): + if kw.get('passive') is attributes.PASSIVE_NO_FETCH: + return attributes.PASSIVE_NO_RESULT + return [b1, b2] + + def loadscalar(**kw): + if kw.get('passive') is attributes.PASSIVE_NO_FETCH: + return attributes.PASSIVE_NO_RESULT + return b2 + + attributes.register_attribute(Foo, 'bars', + uselist=True, + useobject=True, + callable_=lambda o:loadcollection, + extension=[ReceiveEvents('bars')]) + + attributes.register_attribute(Foo, 'bar', + uselist=False, + useobject=True, + callable_=lambda o:loadscalar, + extension=[ReceiveEvents('bar')]) + + attributes.register_attribute(Foo, 'scalar', + uselist=False, + useobject=False, extension=[ReceiveEvents('scalar')]) + + + def create_hist(): + def hist(key, shouldmatch, fn, *arg): + attributes.instance_state(f1).commit_all(attributes.instance_dict(f1)) + fn(*arg) + histories.append((shouldmatch, attributes.get_history(f1, key))) + + f1 = Foo() + hist('bars', True, f1.bars.append, b3) + hist('bars', True, f1.bars.append, b4) + hist('bars', False, f1.bars.remove, b2) + hist('bar', True, setattr, f1, 'bar', b3) + hist('bar', True, setattr, f1, 'bar', None) + hist('bar', True, setattr, f1, 'bar', b4) + hist('scalar', True, setattr, f1, 'scalar', 5) + hist('scalar', True, setattr, f1, 'scalar', None) + hist('scalar', True, setattr, f1, 'scalar', 4) + + histories = [] + commit = False + create_hist() + without_commit = list(histories) + histories[:] = [] + commit = True + create_hist() + with_commit = histories + for without, with_ in zip(without_commit, with_commit): + shouldmatch, woc = without + shouldmatch, wic = with_ + if shouldmatch: + eq_(woc, wic) + else: + ne_(woc, wic) + def test_extension_lazyload_assertion(self): class Foo(_base.BasicEntity): pass