From: Mike Bayer Date: Mon, 6 Dec 2010 18:29:13 +0000 (-0500) Subject: - shave about a millisecond off of moderately complex save casades. X-Git-Tag: rel_0_7b1~194 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=8423dbcf62284e669c65afc258b0b993f8a66b6e;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - shave about a millisecond off of moderately complex save casades. --- diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 232f0737c2..482f2be506 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -75,7 +75,7 @@ class QueryableAttribute(interfaces.PropComparator): def get_history(self, instance, **kwargs): return self.impl.get_history(instance_state(instance), instance_dict(instance), **kwargs) - + def __selectable__(self): # TODO: conditionally attach this method based on clause_element ? return self @@ -306,7 +306,10 @@ class AttributeImpl(object): def get_history(self, state, dict_, passive=PASSIVE_OFF): raise NotImplementedError() - + + def get_all_pending(self, state, dict_): + raise NotImplementedError() + def _get_callable(self, state): if self.key in state.callables: return state.callables[self.key] @@ -533,6 +536,20 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): else: return History.from_attribute(self, state, current) + def get_all_pending(self, state, dict_): + if self.key in dict_: + current = dict_[self.key] + + if self.key in state.committed_state: + original = state.committed_state[self.key] + if original not in (NEVER_SET, None) and \ + original is not current: + return [current, original] + + return [current] + else: + return [] + def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF): """Set a value on the given InstanceState. @@ -622,6 +639,34 @@ class CollectionAttributeImpl(AttributeImpl): else: return History.from_attribute(self, state, current) + def get_all_pending(self, state, dict_): + # this is basically an inline + # of self.get_history().sum() + + if self.key not in dict_: + return [] + else: + current = dict_[self.key] + + current = self.get_collection(state, dict_, current) + + if self.key not in state.committed_state: + return list(current) + + original = state.committed_state[self.key] + + if original is NO_VALUE: + return list(current) + else: + current_set = util.IdentitySet(current) + original_set = util.IdentitySet(original) + + # ensure ordering is maintained + return \ + [x for x in current if x not in original_set] + \ + [x for x in current if x in original_set] + \ + [x for x in original if x not in current_set] + def fire_append_event(self, state, dict_, value, initiator): for fn in self.dispatch.on_append: value = fn(state, value, initiator or self) @@ -995,6 +1040,22 @@ def get_history(obj, key, **kwargs): def get_state_history(state, key, **kwargs): return state.get_history(key, **kwargs) +def get_all_pending(state, dict_, key): + """Return a list of all objects currently in memory + involving the given key on the given state. + + This should be equivalent to:: + + get_state_history( + state, + key, + passive=PASSIVE_NO_INITIALIZE).sum() + + """ + + return state.manager.get_impl(key).get_all_pending(state, dict_) + + def has_parent(cls, obj, key, optimistic=False): """TODO""" manager = manager_of_class(cls) diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index 95a58ee847..710b710a91 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -141,7 +141,13 @@ class DynamicAttributeImpl(attributes.AttributeImpl): c = self._get_collection_history(state, passive) return attributes.History(c.added_items, c.unchanged_items, c.deleted_items) - + + def get_all_pending(self, state, dict_): + c = self._get_collection_history(state, True) + return (c.added_items or []) +\ + (c.unchanged_items or []) +\ + (c.deleted_items or []) + def _get_collection_history(self, state, passive=False): if self.key in state.committed_state: c = state.committed_state[self.key] @@ -304,4 +310,4 @@ class CollectionHistory(object): self.deleted_items = [] self.added_items = [] self.unchanged_items = [] - + diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 2be2495942..5bd4e0f41a 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1369,10 +1369,10 @@ class Mapper(object): visited_instances = util.IdentitySet() prp, mpp = object(), object() - visitables = [(deque(self._props.values()), prp, state)] + visitables = [(deque(self._props.values()), prp, state, state.dict)] while visitables: - iterator, item_type, parent_state = visitables[-1] + iterator, item_type, parent_state, parent_dict = visitables[-1] if not iterator: visitables.pop() continue @@ -1382,15 +1382,15 @@ class Mapper(object): if type_ not in prop.cascade: continue queue = deque(prop.cascade_iterator(type_, parent_state, - visited_instances, halt_on)) + parent_dict, visited_instances, halt_on)) if queue: - visitables.append((queue,mpp, None)) + visitables.append((queue,mpp, None, None)) elif item_type is mpp: - instance, instance_mapper, corresponding_state = \ - iterator.popleft() + instance, instance_mapper, corresponding_state, \ + corresponding_dict = iterator.popleft() yield (instance, instance_mapper) visitables.append((deque(instance_mapper._props.values()), - prp, corresponding_state)) + prp, corresponding_state, corresponding_dict)) @_memoized_configured_property def _compiled_cache(self): diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 81ac9262ce..953974af37 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -834,7 +834,7 @@ class RelationshipProperty(StrategizedProperty): dest_state.get_impl(self.key).set(dest_state, dest_dict, obj, None) - def cascade_iterator(self, type_, state, visited_instances, halt_on=None): + def cascade_iterator(self, type_, state, dict_, visited_instances, halt_on=None): if not type_ in self.cascade: return @@ -845,10 +845,10 @@ class RelationshipProperty(StrategizedProperty): passive = attributes.PASSIVE_OFF if type_ == 'save-update': - instances = attributes.get_state_history(state, self.key, - passive=passive).sum() + instances = attributes.get_all_pending(state, dict_, self.key) + else: - instances = state.value_as_iterable(self.key, + instances = state.value_as_iterable(dict_, self.key, passive=passive) skip_pending = type_ == 'refresh-expire' and 'delete-orphan' \ not in self.cascade @@ -857,28 +857,33 @@ class RelationshipProperty(StrategizedProperty): for c in instances: if c is not None and \ c is not attributes.PASSIVE_NO_RESULT and \ - c not in visited_instances and \ - (halt_on is None or not halt_on(c)): + c not in visited_instances: + + instance_state = attributes.instance_state(c) + instance_dict = attributes.instance_dict(c) + + if halt_on and halt_on(instance_state): + continue - if not isinstance(c, self.mapper.class_): + if skip_pending and not instance_state.key: + continue + + instance_mapper = instance_state.manager.mapper + + if not instance_mapper.isa(self.mapper.class_manager.mapper): raise AssertionError("Attribute '%s' on class '%s' " "doesn't handle objects " "of type '%s'" % ( self.key, - str(self.parent.class_), - str(c.__class__) + self.parent.class_, + c.__class__ )) - instance_state = attributes.instance_state(c) - - if skip_pending and not instance_state.key: - continue - + visited_instances.add(c) # cascade using the mapper local to this # object, so that its individual properties are located - instance_mapper = instance_state.manager.mapper - yield c, instance_mapper, instance_state + yield c, instance_mapper, instance_state, instance_dict def _add_reverse_property(self, key): diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index f0e534bfcd..3517eab2b3 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -1108,7 +1108,8 @@ class Session(object): def _cascade_save_or_update(self, state): for state, mapper in _cascade_unknown_state_iterator( - 'save-update', state, halt_on=self.__contains__): + 'save-update', state, + halt_on=self._contains_state): self._save_or_update_impl(state) def delete(self, instance): diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index 059788bac1..668c04f508 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -125,7 +125,7 @@ class InstanceState(object): self.pending[key] = PendingCollection() return self.pending[key] - def value_as_iterable(self, key, passive=PASSIVE_OFF): + def value_as_iterable(self, dict_, key, passive=PASSIVE_OFF): """return an InstanceState attribute as a list, regardless of it being a scalar or collection-based attribute. @@ -135,7 +135,6 @@ class InstanceState(object): """ impl = self.get_impl(key) - dict_ = self.dict x = impl.get(self, dict_, passive=passive) if x is PASSIVE_NO_RESULT: return None diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 23154bd4e1..c59dbed692 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -21,7 +21,7 @@ all_cascades = frozenset(("delete", "delete-orphan", "all", "merge", _INSTRUMENTOR = ('mapper', 'instrumentor') -class CascadeOptions(object): +class CascadeOptions(dict): """Keeps track of the options sent to relationship().cascade""" def __init__(self, arg=""): @@ -29,13 +29,17 @@ class CascadeOptions(object): values = set() else: values = set(c.strip() for c in arg.split(',')) + + for name in ['save-update', 'delete', 'refresh-expire', + 'merge', 'expunge']: + boolean = name in values or 'all' in values + setattr(self, name.replace('-', '_'), boolean) + if boolean: + self[name] = True self.delete_orphan = "delete-orphan" in values - self.delete = "delete" in values or "all" in values - self.save_update = "save-update" in values or "all" in values - self.merge = "merge" in values or "all" in values - self.expunge = "expunge" in values or "all" in values - self.refresh_expire = "refresh-expire" in values or "all" in values - + if self.delete_orphan: + self['delete-orphan'] = True + if self.delete_orphan and not self.delete: util.warn("The 'delete-orphan' cascade option requires " "'delete'. This will raise an error in 0.6.") @@ -44,9 +48,6 @@ class CascadeOptions(object): if x not in all_cascades: raise sa_exc.ArgumentError("Invalid cascade option '%s'" % x) - def __contains__(self, item): - return getattr(self, item.replace("-", "_"), False) - def __repr__(self): return "CascadeOptions(%s)" % repr(",".join( [x for x in ['delete', 'save_update', 'merge', 'expunge', diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py index 093674d48e..bea66c0723 100644 --- a/test/aaa_profiling/test_orm.py +++ b/test/aaa_profiling/test_orm.py @@ -53,8 +53,8 @@ class MergeTest(_base.MappedTest): # down from 185 on this this is a small slice of a usually # bigger operation so using a small variance - @profiling.function_call_count(97, variance=0.05, - versions={'2.4': 73, '3': 96}) + @profiling.function_call_count(91, variance=0.05, + versions={'2.4': 68, '3': 89}) def go(): return sess2.merge(p1, load=False) p2 = go()