]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Removed all* O(N) scanning behavior from the flush() process,
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 May 2009 21:51:40 +0000 (21:51 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 May 2009 21:51:40 +0000 (21:51 +0000)
i.e. operations that were scanning the full session,
including an extremely expensive one that was erroneously
assuming primary key values were changing when this
was not the case.

* one edge case remains which may invoke a full scan,
  if an existing primary key attribute is modified
  to a new value.

CHANGES
lib/sqlalchemy/__init__.py
lib/sqlalchemy/orm/dependency.py
lib/sqlalchemy/orm/identity.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/state.py
lib/sqlalchemy/orm/sync.py
lib/sqlalchemy/orm/unitofwork.py
test/orm/naturalpks.py
test/orm/session.py

diff --git a/CHANGES b/CHANGES
index 5130e886fee54b0a7e406393af02159e6d6e3935..ae3a6969bca827b028dc778956460ef5187b841f 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -12,6 +12,16 @@ CHANGES
       in conjunction with large mapper graphs, large numbers of 
       objects:
       
+      - Removed all* O(N) scanning behavior from the flush() process,
+        i.e. operations that were scanning the full session, 
+        including an extremely expensive one that was erroneously
+        assuming primary key values were changing when this 
+        was not the case.
+        
+        * one edge case remains which may invoke a full scan,
+          if an existing primary key attribute is modified
+          to a new value.
+      
       - The Session's "weak referencing" behavior is now *full* -
         no strong references whatsoever are made to a mapped object
         or related items/collections in its __dict__.  Backrefs and 
index 2dea27a0410d90b6e61bf1dc91f77a23e2f12bf8..b28de9bc8f562366ec585eb25364fd73c01fcbee 100644 (file)
@@ -107,6 +107,6 @@ from sqlalchemy.engine import create_engine, engine_from_config
 __all__ = sorted(name for name, obj in locals().items()
                  if not (name.startswith('_') or inspect.ismodule(obj)))
                  
-__version__ = '0.5.3'
+__version__ = '0.5.4'
 
 del inspect, sys
index 151c557d712420bef0b3e66e837e76bf43147a5d..f3820eb7cdae0b113584aa00f584d227ef1d2488 100644 (file)
@@ -265,12 +265,13 @@ class OneToManyDP(DependencyProcessor):
                                 uowcommit.register_object(
                                     attributes.instance_state(c),
                                     isdelete=True)
-                if not self.passive_updates and self._pks_changed(uowcommit, state):
+                if self._pks_changed(uowcommit, state):
                     if not history:
-                        history = uowcommit.get_attribute_history(state, self.key, passive=False)
-                    for child in history.unchanged:
-                        if child is not None:
-                            uowcommit.register_object(child)
+                        history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_updates)
+                    if history:
+                        for child in history.unchanged:
+                            if child is not None:
+                                uowcommit.register_object(child)
 
     def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
         source = state
@@ -284,7 +285,7 @@ class OneToManyDP(DependencyProcessor):
             sync.populate(source, self.parent, dest, self.mapper, self.prop.synchronize_pairs)
 
     def _pks_changed(self, uowcommit, state):
-        return sync.source_changes(uowcommit, state, self.parent, self.prop.synchronize_pairs)
+        return sync.source_modified(uowcommit, state, self.parent, self.prop.synchronize_pairs)
 
 class DetectKeySwitch(DependencyProcessor):
     """a special DP that works for many-to-one relations, fires off for
@@ -326,11 +327,11 @@ class DetectKeySwitch(DependencyProcessor):
                     elem.dict[self.key] is not None and 
                     attributes.instance_state(elem.dict[self.key]) in switchers
                 ]:
-                uowcommit.register_object(s, listonly=self.passive_updates)
+                uowcommit.register_object(s)
                 sync.populate(attributes.instance_state(s.dict[self.key]), self.mapper, s, self.parent, self.prop.synchronize_pairs)
 
     def _pks_changed(self, uowcommit, state):
-        return sync.source_changes(uowcommit, state, self.mapper, self.prop.synchronize_pairs)
+        return sync.source_modified(uowcommit, state, self.mapper, self.prop.synchronize_pairs)
 
 class ManyToOneDP(DependencyProcessor):
     def __init__(self, prop):
@@ -519,7 +520,7 @@ class ManyToManyDP(DependencyProcessor):
         sync.populate_dict(child, self.mapper, associationrow, self.prop.secondary_synchronize_pairs)
 
     def _pks_changed(self, uowcommit, state):
-        return sync.source_changes(uowcommit, state, self.parent, self.prop.synchronize_pairs)
+        return sync.source_modified(uowcommit, state, self.parent, self.prop.synchronize_pairs)
 
 class MapperStub(object):
     """Represent a many-to-many dependency within a flush 
index aa041a5855547bdae65c5d82769557cc684eaf42..dc219e1eb897d405148b1e3f23030c7338803899 100644 (file)
@@ -12,8 +12,8 @@ from sqlalchemy.orm import attributes
 
 class IdentityMap(dict):
     def __init__(self):
-        self._mutable_attrs = {}
-        self.modified = False
+        self._mutable_attrs = set()
+        self._modified = set()
         self._wr = weakref.ref(self)
 
     def replace(self, state):
@@ -34,28 +34,29 @@ class IdentityMap(dict):
     def _manage_incoming_state(self, state):
         state._instance_dict = self._wr
         
-        if state.modified:  
-            self.modified = True
+        if state.modified:
+            self._modified.add(state)  
         if state.manager.mutable_attributes:
-            self._mutable_attrs[state] = True
+            self._mutable_attrs.add(state)
     
     def _manage_removed_state(self, state):
         del state._instance_dict
+        self._mutable_attrs.discard(state)
+        self._modified.discard(state)
+    
+    def _dirty_states(self):
+        return self._modified.union(s for s in self._mutable_attrs if s.modified)
         
-        if state in self._mutable_attrs:
-            del self._mutable_attrs[state]
-            
     def check_modified(self):
         """return True if any InstanceStates present have been marked as 'modified'."""
         
-        if not self.modified:
-            for state in list(self._mutable_attrs):
-                if state.check_modified():
-                    return True
-            else:
-                return False
-        else:
+        if self._modified:
             return True
+        else:
+            for state in self._mutable_attrs:
+                if state.modified:
+                    return True
+        return False
             
     def has_key(self, key):
         return key in self
index 00a7d55e5ecdc32c72744d7f2bf97cafd1eb6ecb..cbfb0c1d643a3a2d670b3a29b4968a7dbc38c7c8 100644 (file)
@@ -299,14 +299,14 @@ class SessionTransaction(object):
             self.session._expunge_state(s)
 
         for s in self.session.identity_map.all_states():
-            _expire_state(s, None)
+            _expire_state(s, None, instance_dict=self.session.identity_map)
 
     def _remove_snapshot(self):
         assert self._is_transaction_boundary
 
         if not self.nested and self.session.expire_on_commit:
             for s in self.session.identity_map.all_states():
-                _expire_state(s, None)
+                _expire_state(s, None, instance_dict=self.session.identity_map)
 
     def _connection_for_bind(self, bind):
         self._assert_is_active()
@@ -900,7 +900,7 @@ class Session(object):
 
     def _finalize_loaded(self, states):
         for state, dict_ in states.items():
-            state.commit_all(dict_)
+            state.commit_all(dict_, self.identity_map)
 
     def refresh(self, instance, attribute_names=None):
         """Refresh the attributes on the given instance.
@@ -935,7 +935,7 @@ class Session(object):
         """Expires all persistent instances within this Session."""
 
         for state in self.identity_map.all_states():
-            _expire_state(state, None)
+            _expire_state(state, None, instance_dict=self.identity_map)
 
     def expire(self, instance, attribute_names=None):
         """Expire the attributes on an instance.
@@ -956,14 +956,14 @@ class Session(object):
             raise exc.UnmappedInstanceError(instance)
         self._validate_persistent(state)
         if attribute_names:
-            _expire_state(state, attribute_names=attribute_names)
+            _expire_state(state, attribute_names=attribute_names, instance_dict=self.identity_map)
         else:
             # pre-fetch the full cascade since the expire is going to
             # remove associations
             cascaded = list(_cascade_state_iterator('refresh-expire', state))
-            _expire_state(state, None)
+            _expire_state(state, None, instance_dict=self.identity_map)
             for (state, m, o) in cascaded:
-                _expire_state(state, None)
+                _expire_state(state, None, instance_dict=self.identity_map)
 
     def prune(self):
         """Remove unreferenced instances cached in the identity map.
@@ -1022,8 +1022,8 @@ class Session(object):
                 state.key = instance_key
             
             self.identity_map.replace(state)
-            state.commit_all(state.dict)
-
+            state.commit_all(state.dict, self.identity_map)
+            
         # remove from new last, might be the last strong ref
         if state in self._new:
             if self._enable_transaction_accounting and self.transaction:
@@ -1211,7 +1211,7 @@ class Session(object):
             prop.merge(self, instance, merged, dont_load, _recursive)
 
         if dont_load:
-            attributes.instance_state(merged).commit_all(attributes.instance_dict(merged))  # remove any history
+            attributes.instance_state(merged).commit_all(attributes.instance_dict(merged), self.identity_map)  # remove any history
 
         if new_instance:
             merged_state._run_on_load(merged)
@@ -1360,10 +1360,9 @@ class Session(object):
             not self._deleted and not self._new):
             return
 
-        
         dirty = self._dirty_states
         if not dirty and not self._deleted and not self._new:
-            self.identity_map.modified = False
+            self.identity_map._modified.clear()
             return
 
         flush_context = UOWTransaction(self)
@@ -1389,15 +1388,19 @@ class Session(object):
                     raise exc.UnmappedInstanceError(o)
                 objset.add(state)
         else:
-            # or just everything
-            objset = set(self.identity_map.all_states()).union(new)
+            objset = None
 
         # store objects whose fate has been decided
         processed = set()
 
         # put all saves/updates into the flush context.  detect top-level
         # orphans and throw them into deleted.
-        for state in new.union(dirty).intersection(objset).difference(deleted):
+        if objset:
+            proc = new.union(dirty).intersection(objset).difference(deleted)
+        else:
+            proc = new.union(dirty).difference(deleted)
+            
+        for state in proc:
             is_orphan = _state_mapper(state)._is_orphan(state)
             if is_orphan and not _state_has_identity(state):
                 path = ", nor ".join(
@@ -1413,7 +1416,11 @@ class Session(object):
             processed.add(state)
 
         # put all remaining deletes into the flush context.
-        for state in deleted.intersection(objset).difference(processed):
+        if objset:
+            proc = deleted.intersection(objset).difference(processed)
+        else:
+            proc = deleted.difference(processed)
+        for state in proc:
             flush_context.register_object(state, isdelete=True)
 
         if len(flush_context.tasks) == 0:
@@ -1433,9 +1440,13 @@ class Session(object):
         
         flush_context.finalize_flush_changes()
 
-        if not objects:
-            self.identity_map.modified = False
-
+        # useful assertions:
+        #if not objects:
+        #    assert not self.identity_map._modified
+        #else:
+        #    assert self.identity_map._modified == self.identity_map._modified.difference(objects)
+        #self.identity_map._modified.clear()
+        
         for ext in self.extensions:
             ext.after_flush_postexec(self, flush_context)
 
@@ -1484,10 +1495,7 @@ class Session(object):
         those that were possibly deleted.
 
         """
-        return util.IdentitySet(
-            [state
-             for state in self.identity_map.all_states()
-             if state.modified])
+        return self.identity_map._dirty_states()
 
     @property
     def dirty(self):
index c99dfe73c7f7dba71d474624ba749dec2b415971..1b73a1bb62bb76a5a455feebfcb9cfdb3e7a7a86 100644 (file)
@@ -193,12 +193,20 @@ class InstanceState(object):
             key for key in self.manager.iterkeys()
             if key not in self.committed_state and key not in self.dict)
 
-    def expire_attributes(self, attribute_names):
+    def expire_attributes(self, attribute_names, instance_dict=None):
         self.expired_attributes = set(self.expired_attributes)
 
         if attribute_names is None:
             attribute_names = self.manager.keys()
             self.expired = True
+            if self.modified:
+                if not instance_dict:
+                    instance_dict = self._instance_dict()
+                    if instance_dict:
+                        instance_dict._modified.discard(self)
+                else:
+                    instance_dict._modified.discard(self)
+                    
             self.modified = False
             filter_deferred = True
         else:
@@ -248,13 +256,14 @@ class InstanceState(object):
             if needs_committed:
                 self.committed_state[attr.key] = previous
 
+        if not self.modified:
+            instance_dict = self._instance_dict()
+            if instance_dict:
+                instance_dict._modified.add(self)
+
         self.modified = True
         self._strong_obj = self.obj()
 
-        instance_dict = self._instance_dict()
-        if instance_dict:
-            instance_dict.modified = True
-        
     def commit(self, dict_, keys):
         """Commit attributes.
 
@@ -279,7 +288,7 @@ class InstanceState(object):
                 self.expired_attributes.remove(key)
                 self.callables.pop(key, None)
 
-    def commit_all(self, dict_):
+    def commit_all(self, dict_, instance_dict=None):
         """commit all attributes unconditionally.
 
         This is used after a flush() or a full load/refresh
@@ -308,6 +317,9 @@ class InstanceState(object):
             if key in dict_:
                 self.manager[key].impl.commit_to_state(self, dict_, self.committed_state)
 
+        if instance_dict and self.modified:
+            instance_dict._modified.discard(self)
+
         self.modified = self.expired = False
         self._strong_obj = None
 
index dd979e1a808e7bc698cf903a217ba47131abb1e2..c12f17aff5e9c5f971ae7e33d490e702904fa137 100644 (file)
@@ -50,26 +50,18 @@ def populate_dict(source, source_mapper, dict_, synchronize_pairs):
 
         dict_[r.key] = value
 
-def source_changes(uowcommit, source, source_mapper, synchronize_pairs):
+def source_modified(uowcommit, source, source_mapper, synchronize_pairs):
+    """return true if the source object has changes from an old to a new value on the given
+    synchronize pairs
+    
+    """
     for l, r in synchronize_pairs:
         try:
             prop = source_mapper._get_col_to_prop(l)
         except exc.UnmappedColumnError:
             _raise_col_to_prop(False, source_mapper, l, None, r)
         history = uowcommit.get_attribute_history(source, prop.key, passive=True)
-        if history.has_changes():
-            return True
-    else:
-        return False
-
-def dest_changes(uowcommit, dest, dest_mapper, synchronize_pairs):
-    for l, r in synchronize_pairs:
-        try:
-            prop = dest_mapper._get_col_to_prop(r)
-        except exc.UnmappedColumnError:
-            _raise_col_to_prop(True, None, l, dest_mapper, r)
-        history = uowcommit.get_attribute_history(dest, prop.key, passive=True)
-        if history.has_changes():
+        if len(history.deleted):
             return True
     else:
         return False
index 407b702a8bc09f468407f41aa07b11de466c9bea..da26c8d7b38f464a5017eb09e790167b9a7cd6e2 100644 (file)
@@ -121,6 +121,7 @@ class UOWTransaction(object):
             return history.as_state()
 
     def register_object(self, state, isdelete=False, listonly=False, postupdate=False, post_update_cols=None):
+        
         # if object is not in the overall session, do nothing
         if not self.session._contains_state(state):
             if self._should_log_debug:
index 57c0757720f010f5ac42efd401e981dbe1170681..8efce660c37a777f8e4828409bd850fb856526ec 100644 (file)
@@ -220,12 +220,13 @@ class NaturalPKTest(_base.MappedTest):
         u1.address = a1
         sess.add(a1)
         sess.flush()
-        
+
         u1.username = 'ed'
 
         def go():
             sess.flush()
         if passive_updates:
+            sess.expire(u1, ['address'])
             self.assert_sql_count(testing.db, go, 1)
         else:
             self.assert_sql_count(testing.db, go, 2)
@@ -234,7 +235,6 @@ class NaturalPKTest(_base.MappedTest):
             sess.flush()
         self.assert_sql_count(testing.db, go, 0)
 
-        assert a1.username == 'ed'
         sess.expunge_all()
         self.assertEquals([Address(username='ed')], sess.query(Address).all())
         
@@ -269,6 +269,7 @@ class NaturalPKTest(_base.MappedTest):
         def go():
             sess.flush()
         if passive_updates:
+            sess.expire(u1, ['addresses'])
             self.assert_sql_count(testing.db, go, 1)
         else:
             self.assert_sql_count(testing.db, go, 3)
@@ -279,11 +280,11 @@ class NaturalPKTest(_base.MappedTest):
         u1 = sess.query(User).get('ed')
         assert len(u1.addresses) == 2    # load addresses
         u1.username = 'fred'
-        print "--------------------------------"
         def go():
             sess.flush()
         # check that the passive_updates is on on the other side
         if passive_updates:
+            sess.expire(u1, ['addresses'])
             self.assert_sql_count(testing.db, go, 1)
         else:
             self.assert_sql_count(testing.db, go, 3)
index 6ae05c77b05e0b057d1f6be212e346b116f59fd7..1729354077d58715e860f8cc9046cd2c4e308204 100644 (file)
@@ -852,9 +852,9 @@ class SessionTest(_fixtures.FixtureTest):
         assert len(s.identity_map) == 1
 
         user = s.query(User).one()
-        assert not s.identity_map.modified
+        assert not s.identity_map._modified
         user.name = 'u2'
-        assert s.identity_map.modified
+        assert s.identity_map._modified
         s.flush()
         eq_(users.select().execute().fetchall(), [(user.id, 'u2')])