From: Mike Bayer Date: Tue, 21 Dec 2010 16:08:12 +0000 (-0500) Subject: - refactor expire_attributes into two simpler methods X-Git-Tag: rel_0_7b1~130 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=29e5a23aee38e4d832b6952a9009bb4a8ada4df8;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - refactor expire_attributes into two simpler methods --- diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 8c1db3f1ae..ca23edf3c5 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1891,7 +1891,7 @@ class Mapper(object): [p.key for p in mapper._readonly_props] ) if readonly: - sessionlib._expire_state(state, state.dict, readonly) + state.expire_attributes(state.dict, readonly) # if eager_defaults option is enabled, # refresh whatever has been expired. @@ -1921,7 +1921,7 @@ class Mapper(object): self._set_state_attr_by_column(state, dict_, c, params[c.key]) if postfetch_cols: - sessionlib._expire_state(state, state.dict, + state.expire_attributes(state.dict, [self._columntoproperty[c].key for c in postfetch_cols] ) diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 88a8a8ea66..325a94643a 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -307,16 +307,14 @@ class SessionTransaction(object): assert not self.session._deleted for s in self.session.identity_map.all_states(): - _expire_state(s, s.dict, None, - instance_dict=self.session.identity_map) + s.expire(s.dict, self.session.identity_map._modified) def _remove_snapshot(self): assert self._is_transaction_boundary if not self.nested and self.session.expire_on_commit: for s in self.session.identity_map.all_states(): - _expire_state(s, s.dict, None, - instance_dict=self.session.identity_map) + s.expire(s.dict, self.session.identity_map._modified) def _connection_for_bind(self, bind): self._assert_is_active() @@ -937,7 +935,7 @@ class Session(object): """ for state in self.identity_map.all_states(): - _expire_state(state, state.dict, None, instance_dict=self.identity_map) + state.expire(state.dict, self.identity_map._modified) def expire(self, instance, attribute_names=None): """Expire the attributes on an instance. @@ -975,9 +973,7 @@ class Session(object): def _expire_state(self, state, attribute_names): self._validate_persistent(state) if attribute_names: - _expire_state(state, state.dict, - attribute_names=attribute_names, - instance_dict=self.identity_map) + state.expire_attributes(state.dict, attribute_names) else: # pre-fetch the full cascade since the expire is going to # remove associations @@ -991,8 +987,7 @@ class Session(object): """Expire a state if persistent, else expunge if pending""" if state.key: - _expire_state(state, state.dict, None, - instance_dict=self.identity_map) + state.expire(state.dict, self.identity_map._modified) elif state in self._new: self._new.pop(state) state.detach() @@ -1603,8 +1598,6 @@ class Session(object): return util.IdentitySet(self._new.values()) -_expire_state = state.InstanceState.expire_attributes - _sessions = weakref.WeakValueDictionary() def make_transient(instance): diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index fd55cfcb87..909977cc47 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -224,43 +224,35 @@ class InstanceState(object): dict_.pop(key, None) self.callables[key] = callable_ - def expire_attributes(self, dict_, attribute_names, instance_dict=None): - """Expire all or a group of attributes. - - If all attributes are expired, the "expired" flag is set to True. - - """ - # we would like to assert that 'self.key is not None' here, - # but there are many cases where the mapper will expire - # a newly persisted instance within the flush, before the - # key is assigned, and even cases where the attribute refresh - # occurs fully, within the flush(), before this key is assigned. - # the key is assigned late within the flush() to assist in - # "key switch" bookkeeping scenarios. - - if attribute_names is None: - attribute_names = self.manager.keys() - self.expired = True - if self.modified: - if not instance_dict: - instance_dict = self._instance_dict() - if instance_dict: - instance_dict._modified.discard(self) - else: - instance_dict._modified.discard(self) + def expire(self, dict_, modified_set): + self.expired = True + if self.modified: + modified_set.discard(self) - self.modified = False - filter_deferred = True - else: - filter_deferred = False + self.modified = False pending = self.__dict__.get('pending', None) mutable_dict = self.mutable_dict + self.committed_state.clear() + if mutable_dict: + mutable_dict.clear() + if pending: + pending.clear() - for key in attribute_names: + for key in self.manager: impl = self.manager[key].impl if impl.accepts_scalar_loader and \ - (not filter_deferred or impl.expire_missing or key in dict_): + (impl.expire_missing or key in dict_): + self.callables[key] = self + dict_.pop(key, None) + + def expire_attributes(self, dict_, attribute_names): + pending = self.__dict__.get('pending', None) + mutable_dict = self.mutable_dict + + for key in attribute_names: + impl = self.manager[key].impl + if impl.accepts_scalar_loader: self.callables[key] = self dict_.pop(key, None) diff --git a/test/orm/test_attributes.py b/test/orm/test_attributes.py index 91f61c05c0..9bbfa98eb9 100644 --- a/test/orm/test_attributes.py +++ b/test/orm/test_attributes.py @@ -155,21 +155,21 @@ class AttributesTest(_base.ORMTest): attributes.register_attribute(Foo, 'b', uselist=False, useobject=False) f = Foo() - attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None) + attributes.instance_state(f).expire(attributes.instance_dict(f), set()) eq_(f.a, "this is a") eq_(f.b, 12) f.a = "this is some new a" - attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None) + attributes.instance_state(f).expire(attributes.instance_dict(f), set()) eq_(f.a, "this is a") eq_(f.b, 12) - attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None) + attributes.instance_state(f).expire(attributes.instance_dict(f), set()) f.a = "this is another new a" eq_(f.a, "this is another new a") eq_(f.b, 12) - attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None) + attributes.instance_state(f).expire(attributes.instance_dict(f), set()) eq_(f.a, "this is a") eq_(f.b, 12) @@ -177,7 +177,7 @@ class AttributesTest(_base.ORMTest): eq_(f.a, None) eq_(f.b, 12) - attributes.instance_state(f).commit_all(attributes.instance_dict(f)) + attributes.instance_state(f).commit_all(attributes.instance_dict(f), set()) eq_(f.a, None) eq_(f.b, 12) @@ -195,7 +195,7 @@ class AttributesTest(_base.ORMTest): attributes.register_attribute(MyTest, 'b', uselist=False, useobject=False) m = MyTest() - attributes.instance_state(m).expire_attributes(attributes.instance_dict(m), None) + attributes.instance_state(m).expire(attributes.instance_dict(m), set()) assert 'a' not in m.__dict__ m2 = pickle.loads(pickle.dumps(m)) assert 'a' not in m2.__dict__ diff --git a/test/orm/test_extendedattr.py b/test/orm/test_extendedattr.py index 2eca1ac387..c20cad0da7 100644 --- a/test/orm/test_extendedattr.py +++ b/test/orm/test_extendedattr.py @@ -161,21 +161,21 @@ class UserDefinedExtensionTest(_base.ORMTest): assert Foo in instrumentation.instrumentation_registry._state_finders f = Foo() - attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None) + attributes.instance_state(f).expire(attributes.instance_dict(f), set()) eq_(f.a, "this is a") eq_(f.b, 12) f.a = "this is some new a" - attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None) + attributes.instance_state(f).expire(attributes.instance_dict(f), set()) eq_(f.a, "this is a") eq_(f.b, 12) - attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None) + attributes.instance_state(f).expire(attributes.instance_dict(f), set()) f.a = "this is another new a" eq_(f.a, "this is another new a") eq_(f.b, 12) - attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None) + attributes.instance_state(f).expire(attributes.instance_dict(f), set()) eq_(f.a, "this is a") eq_(f.b, 12)