]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- shave about a millisecond off of moderately complex save casades.
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 6 Dec 2010 18:29:13 +0000 (13:29 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 6 Dec 2010 18:29:13 +0000 (13:29 -0500)
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/dynamic.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/state.py
lib/sqlalchemy/orm/util.py
test/aaa_profiling/test_orm.py

index 232f0737c226cbb58ad07f5aacf6e65e85970572..482f2be50638dd4bccee895852ab615f9dd5d64c 100644 (file)
@@ -75,7 +75,7 @@ class QueryableAttribute(interfaces.PropComparator):
     def get_history(self, instance, **kwargs):
         return self.impl.get_history(instance_state(instance),
                                         instance_dict(instance), **kwargs)
-
+    
     def __selectable__(self):
         # TODO: conditionally attach this method based on clause_element ?
         return self
@@ -306,7 +306,10 @@ class AttributeImpl(object):
 
     def get_history(self, state, dict_, passive=PASSIVE_OFF):
         raise NotImplementedError()
-
+    
+    def get_all_pending(self, state, dict_):
+        raise NotImplementedError()
+        
     def _get_callable(self, state):
         if self.key in state.callables:
             return state.callables[self.key]
@@ -533,6 +536,20 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
             else:
                 return History.from_attribute(self, state, current)
 
+    def get_all_pending(self, state, dict_):
+        if self.key in dict_:
+            current = dict_[self.key]
+            
+            if self.key in state.committed_state:
+                original = state.committed_state[self.key]
+                if original not in (NEVER_SET, None) and \
+                    original is not current:
+                    return [current, original]
+                    
+            return [current]
+        else:
+            return []
+
     def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
         """Set a value on the given InstanceState.
 
@@ -622,6 +639,34 @@ class CollectionAttributeImpl(AttributeImpl):
         else:
             return History.from_attribute(self, state, current)
 
+    def get_all_pending(self, state, dict_):
+        # this is basically an inline 
+        # of self.get_history().sum()
+        
+        if self.key not in dict_:
+            return []
+        else:
+            current = dict_[self.key]
+            
+        current = self.get_collection(state, dict_, current)
+
+        if self.key not in state.committed_state:
+            return list(current)
+
+        original = state.committed_state[self.key]
+        
+        if original is NO_VALUE:
+            return list(current)
+        else:
+            current_set = util.IdentitySet(current)
+            original_set = util.IdentitySet(original)
+
+            # ensure ordering is maintained
+            return \
+                [x for x in current if x not in original_set] + \
+                [x for x in current if x in original_set] + \
+                [x for x in original if x not in current_set]
+        
     def fire_append_event(self, state, dict_, value, initiator):
         for fn in self.dispatch.on_append:
             value = fn(state, value, initiator or self)
@@ -995,6 +1040,22 @@ def get_history(obj, key, **kwargs):
 def get_state_history(state, key, **kwargs):
     return state.get_history(key, **kwargs)
 
+def get_all_pending(state, dict_, key):
+    """Return a list of all objects currently in memory 
+    involving the given key on the given state.
+    
+    This should be equivalent to::
+    
+        get_state_history(
+                    state, 
+                    key, 
+                    passive=PASSIVE_NO_INITIALIZE).sum()
+    
+    """
+    
+    return state.manager.get_impl(key).get_all_pending(state, dict_)
+    
+    
 def has_parent(cls, obj, key, optimistic=False):
     """TODO"""
     manager = manager_of_class(cls)
index 95a58ee8473ac061010e0436d15b189aa0cac429..710b710a9124de548a8c17e64e849256d2a0c95e 100644 (file)
@@ -141,7 +141,13 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
         c = self._get_collection_history(state, passive)
         return attributes.History(c.added_items, c.unchanged_items,
                                   c.deleted_items)
-
+    
+    def get_all_pending(self, state, dict_):
+        c = self._get_collection_history(state, True)
+        return (c.added_items or []) +\
+                (c.unchanged_items or []) +\
+                (c.deleted_items or [])
+        
     def _get_collection_history(self, state, passive=False):
         if self.key in state.committed_state:
             c = state.committed_state[self.key]
@@ -304,4 +310,4 @@ class CollectionHistory(object):
             self.deleted_items = []
             self.added_items = []
             self.unchanged_items = []
-
+    
index 2be24959424ffb7b70f357f1fa3856dc9d17cdf9..5bd4e0f41ace2bf9c5963c790c4b970378cc33c1 100644 (file)
@@ -1369,10 +1369,10 @@ class Mapper(object):
         visited_instances = util.IdentitySet()
         prp, mpp = object(), object()
 
-        visitables = [(deque(self._props.values()), prp, state)]
+        visitables = [(deque(self._props.values()), prp, state, state.dict)]
 
         while visitables:
-            iterator, item_type, parent_state = visitables[-1]
+            iterator, item_type, parent_state, parent_dict = visitables[-1]
             if not iterator:
                 visitables.pop()
                 continue
@@ -1382,15 +1382,15 @@ class Mapper(object):
                 if type_ not in prop.cascade:
                     continue
                 queue = deque(prop.cascade_iterator(type_, parent_state, 
-                            visited_instances, halt_on))
+                            parent_dict, visited_instances, halt_on))
                 if queue:
-                    visitables.append((queue,mpp, None))
+                    visitables.append((queue,mpp, None, None))
             elif item_type is mpp:
-                instance, instance_mapper, corresponding_state  = \
-                                iterator.popleft()
+                instance, instance_mapper, corresponding_state, \
+                                corresponding_dict = iterator.popleft()
                 yield (instance, instance_mapper)
                 visitables.append((deque(instance_mapper._props.values()), 
-                                        prp, corresponding_state))
+                                        prp, corresponding_state, corresponding_dict))
 
     @_memoized_configured_property
     def _compiled_cache(self):
index 81ac9262ce3d8aa09dc78599b2e2baf3101aa363..953974af3713f99c3d1684b4af796089acb57eb1 100644 (file)
@@ -834,7 +834,7 @@ class RelationshipProperty(StrategizedProperty):
                 dest_state.get_impl(self.key).set(dest_state,
                         dest_dict, obj, None)
 
-    def cascade_iterator(self, type_, state, visited_instances, halt_on=None):
+    def cascade_iterator(self, type_, state, dict_, visited_instances, halt_on=None):
         if not type_ in self.cascade:
             return
 
@@ -845,10 +845,10 @@ class RelationshipProperty(StrategizedProperty):
             passive = attributes.PASSIVE_OFF
 
         if type_ == 'save-update':
-            instances = attributes.get_state_history(state, self.key,
-                    passive=passive).sum()
+            instances = attributes.get_all_pending(state, dict_, self.key)
+            
         else:
-            instances = state.value_as_iterable(self.key,
+            instances = state.value_as_iterable(dict_, self.key,
                     passive=passive)
         skip_pending = type_ == 'refresh-expire' and 'delete-orphan' \
             not in self.cascade
@@ -857,28 +857,33 @@ class RelationshipProperty(StrategizedProperty):
             for c in instances:
                 if c is not None and \
                     c is not attributes.PASSIVE_NO_RESULT and \
-                    c not in visited_instances and \
-                    (halt_on is None or not halt_on(c)):
+                    c not in visited_instances:
+                    
+                    instance_state = attributes.instance_state(c)
+                    instance_dict = attributes.instance_dict(c)
+                    
+                    if halt_on and halt_on(instance_state):
+                        continue
                     
-                    if not isinstance(c, self.mapper.class_):
+                    if skip_pending and not instance_state.key:
+                        continue
+                    
+                    instance_mapper = instance_state.manager.mapper
+                    
+                    if not instance_mapper.isa(self.mapper.class_manager.mapper):
                         raise AssertionError("Attribute '%s' on class '%s' "
                                             "doesn't handle objects "
                                             "of type '%s'" % (
                                                 self.key, 
-                                                str(self.parent.class_)
-                                                str(c.__class__)
+                                                self.parent.class_
+                                                c.__class__
                                             ))
-                    instance_state = attributes.instance_state(c)
-                    
-                    if skip_pending and not instance_state.key:
-                        continue
-                        
+
                     visited_instances.add(c)
 
                     # cascade using the mapper local to this 
                     # object, so that its individual properties are located
-                    instance_mapper = instance_state.manager.mapper
-                    yield c, instance_mapper, instance_state
+                    yield c, instance_mapper, instance_state, instance_dict
             
 
     def _add_reverse_property(self, key):
index f0e534bfcd9d1516e439eee53bd147ef8174e7a7..3517eab2b3c0f2760c14e3dc678d630236d08e60 100644 (file)
@@ -1108,7 +1108,8 @@ class Session(object):
 
     def _cascade_save_or_update(self, state):
         for state, mapper in _cascade_unknown_state_iterator(
-                                    'save-update', state, halt_on=self.__contains__):
+                                    'save-update', state, 
+                                    halt_on=self._contains_state):
             self._save_or_update_impl(state)
 
     def delete(self, instance):
index 059788bac1ad02a99d30a3d04f2e214ec4eb9a83..668c04f508d1f0c0d21c7f7c60f218f8da16fa01 100644 (file)
@@ -125,7 +125,7 @@ class InstanceState(object):
             self.pending[key] = PendingCollection()
         return self.pending[key]
 
-    def value_as_iterable(self, key, passive=PASSIVE_OFF):
+    def value_as_iterable(self, dict_, key, passive=PASSIVE_OFF):
         """return an InstanceState attribute as a list,
         regardless of it being a scalar or collection-based
         attribute.
@@ -135,7 +135,6 @@ class InstanceState(object):
         """
 
         impl = self.get_impl(key)
-        dict_ = self.dict
         x = impl.get(self, dict_, passive=passive)
         if x is PASSIVE_NO_RESULT:
             return None
index 23154bd4e18bb67c099d63dbcadc09c6a1e3991c..c59dbed692efb8469eded27dfcfbab72fabde24d 100644 (file)
@@ -21,7 +21,7 @@ all_cascades = frozenset(("delete", "delete-orphan", "all", "merge",
 
 _INSTRUMENTOR = ('mapper', 'instrumentor')
 
-class CascadeOptions(object):
+class CascadeOptions(dict):
     """Keeps track of the options sent to relationship().cascade"""
 
     def __init__(self, arg=""):
@@ -29,13 +29,17 @@ class CascadeOptions(object):
             values = set()
         else:
             values = set(c.strip() for c in arg.split(','))
+            
+        for name in ['save-update', 'delete', 'refresh-expire', 
+                            'merge', 'expunge']:
+            boolean = name in values or 'all' in values
+            setattr(self, name.replace('-', '_'), boolean)
+            if boolean:
+                self[name] = True
         self.delete_orphan = "delete-orphan" in values
-        self.delete = "delete" in values or "all" in values
-        self.save_update = "save-update" in values or "all" in values
-        self.merge = "merge" in values or "all" in values
-        self.expunge = "expunge" in values or "all" in values
-        self.refresh_expire = "refresh-expire" in values or "all" in values
-
+        if self.delete_orphan:
+            self['delete-orphan'] = True
+        
         if self.delete_orphan and not self.delete:
             util.warn("The 'delete-orphan' cascade option requires "
                         "'delete'.  This will raise an error in 0.6.")
@@ -44,9 +48,6 @@ class CascadeOptions(object):
             if x not in all_cascades:
                 raise sa_exc.ArgumentError("Invalid cascade option '%s'" % x)
 
-    def __contains__(self, item):
-        return getattr(self, item.replace("-", "_"), False)
-
     def __repr__(self):
         return "CascadeOptions(%s)" % repr(",".join(
             [x for x in ['delete', 'save_update', 'merge', 'expunge',
index 093674d48eba8b263bd2e610dfd644ee19f4b6ad..bea66c07235bace4e0d9961967a252973761b22b 100644 (file)
@@ -53,8 +53,8 @@ class MergeTest(_base.MappedTest):
         # down from 185 on this this is a small slice of a usually
         # bigger operation so using a small variance
 
-        @profiling.function_call_count(97, variance=0.05,
-                versions={'2.4': 73, '3': 96})
+        @profiling.function_call_count(91, variance=0.05,
+                versions={'2.4': 68, '3': 89})
         def go():
             return sess2.merge(p1, load=False)
         p2 = go()