]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Session.merge() is performance optimized, using half the
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 7 Jan 2010 22:09:17 +0000 (22:09 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 7 Jan 2010 22:09:17 +0000 (22:09 +0000)
call counts for "load=False" mode compared to 0.5 and
significantly fewer SQL queries in the case of collections
for "load=True" mode.

CHANGES
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/session.py
test/aaa_profiling/test_orm.py [new file with mode: 0644]
test/orm/test_merge.py
test/orm/test_query.py

diff --git a/CHANGES b/CHANGES
index a24c5c333722eee038d0e04fa5667a23dbaf3dc3..e98855d0ee03a15023c967964986a12489479659 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -121,6 +121,11 @@ CHANGES
     
   - the "dont_load=True" flag on Session.merge() is deprecated
     and is now "load=False".
+    
+  - Session.merge() is performance optimized, using half the
+    call counts for "load=False" mode compared to 0.5 and
+    significantly fewer SQL queries in the case of collections
+    for "load=True" mode.
 
   - `expression.null()` is fully understood the same way
     None is when comparing an object/collection-referencing
index 1bb8504881b157bf168868d855d2418895779b99..e63c2b8677e296cd3b5665fa6c0b95d34b6964ce 100644 (file)
@@ -105,13 +105,17 @@ class ColumnProperty(StrategizedProperty):
     def setattr(self, state, value, column):
         state.get_impl(self.key).set(state, state.dict, value, None)
 
-    def merge(self, session, source, dest, load, _recursive):
-        value = attributes.instance_state(source).value_as_iterable(
-            self.key, passive=True)
-        if value:
-            setattr(dest, self.key, value[0])
+    def merge(self, session, source_state, source_dict, dest_state, dest_dict, load, _recursive):
+        if self.key in source_dict:
+            value = source_dict[self.key]
+        
+            if not load:
+                dest_dict[self.key] = value
+            else:
+                impl = dest_state.get_impl(self.key)
+                impl.set(dest_state, dest_dict, value, None)
         else:
-            attributes.instance_state(dest).expire_attributes([self.key])
+            dest_state.expire_attributes([self.key])
 
     def get_col_value(self, column, value):
         return value
@@ -301,7 +305,7 @@ class SynonymProperty(MapperProperty):
             proxy_property=self.descriptor
             )
 
-    def merge(self, session, source, dest, load, _recursive):
+    def merge(self, session, source_state, source_dict, dest_state, dest_dict, load, _recursive):
         pass
         
 log.class_logger(SynonymProperty)
@@ -334,7 +338,7 @@ class ComparableProperty(MapperProperty):
     def create_row_processor(self, selectcontext, path, mapper, row, adapter):
         return (None, None)
 
-    def merge(self, session, source, dest, load, _recursive):
+    def merge(self, session, source_state, source_dict, dest_state, dest_dict, load, _recursive):
         pass
 
 
@@ -624,50 +628,61 @@ class RelationProperty(StrategizedProperty):
     def __str__(self):
         return str(self.parent.class_.__name__) + "." + self.key
 
-    def merge(self, session, source, dest, load, _recursive):
+    def merge(self, session, source_state, source_dict, dest_state, dest_dict, load, _recursive):
         if load:
             # TODO: no test coverage for recursive check
             for r in self._reverse_property:
-                if (source, r) in _recursive:
+                if (source_state, r) in _recursive:
                     return
 
-        source_state = attributes.instance_state(source)
-        dest_state, dest_dict = attributes.instance_state(dest), attributes.instance_dict(dest)
-
         if not "merge" in self.cascade:
             dest_state.expire_attributes([self.key])
             return
 
-        instances = source_state.value_as_iterable(self.key, passive=True)
-
-        if not instances:
+        if self.key not in source_dict:
             return
 
         if self.uselist:
+            instances = source_state.get_impl(self.key).\
+                            get(source_state, source_dict)
+            
+            if load:
+                # for a full merge, pre-load the destination collection,
+                # so that individual _merge of each item pulls from identity
+                # map for those already present.  
+                # also assumes CollectionAttrbiuteImpl behavior of loading
+                # "old" list in any case
+                dest_state.get_impl(self.key).get(dest_state, dest_dict)
+                
             dest_list = []
             for current in instances:
-                _recursive[(current, self)] = True
-                obj = session._merge(current, load=load, _recursive=_recursive)
+                current_state = attributes.instance_state(current)
+                current_dict = attributes.instance_dict(current)
+                _recursive[(current_state, self)] = True
+                obj = session._merge(current_state, current_dict, load=load, _recursive=_recursive)
                 if obj is not None:
                     dest_list.append(obj)
+                    
             if not load:
                 coll = attributes.init_state_collection(dest_state, dest_dict, self.key)
                 for c in dest_list:
                     coll.append_without_event(c)
             else:
-                getattr(dest.__class__, self.key).impl._set_iterable(dest_state, dest_dict, dest_list)
+                dest_state.get_impl(self.key)._set_iterable(dest_state, dest_dict, dest_list)
         else:
-            current = instances[0]
+            current = source_dict[self.key]
             if current is not None:
-                _recursive[(current, self)] = True
-                obj = session._merge(current, load=load, _recursive=_recursive)
+                current_state = attributes.instance_state(current)
+                current_dict = attributes.instance_dict(current)
+                _recursive[(current_state, self)] = True
+                obj = session._merge(current_state, current_dict, load=load, _recursive=_recursive)
             else:
                 obj = None
-            
+
             if not load:
-                dest_state.dict[self.key] = obj
+                dest_dict[self.key] = obj
             else:
-                setattr(dest, self.key, obj)
+                dest_state.get_impl(self.key).set(dest_state, dest_dict, obj, None)
 
     def cascade_iterator(self, type_, state, visited_instances, halt_on=None):
         if not type_ in self.cascade:
index f171622d4aab74d3da642f4734455dcf0c2b9808..ee4286c67ba20a77bde0a7b013c14f064e75925b 100644 (file)
@@ -1103,24 +1103,25 @@ class Session(object):
             load = not kw['dont_load']
             util.warn_deprecated("dont_load=True has been renamed to load=False.")
         
-        # TODO: this should be an IdentityDict for instances, but will
-        # need a separate dict for PropertyLoader tuples
         _recursive = {}
         self._autoflush()
+        _object_mapper(instance) # verify mapped
         autoflush = self.autoflush
         try:
             self.autoflush = False
-            return self._merge(instance, load=load, _recursive=_recursive)
+            return self._merge(
+                            attributes.instance_state(instance), 
+                            attributes.instance_dict(instance), 
+                            load=load, _recursive=_recursive)
         finally:
             self.autoflush = autoflush
         
-    def _merge(self, instance, load=True, _recursive=None):
-        mapper = _object_mapper(instance)
-        if instance in _recursive:
-            return _recursive[instance]
+    def _merge(self, state, state_dict, load=True, _recursive=None):
+        mapper = _state_mapper(state)
+        if state in _recursive:
+            return _recursive[state]
 
         new_instance = False
-        state = attributes.instance_state(instance)
         key = state.key
         
         if key is None:
@@ -1134,6 +1135,7 @@ class Session(object):
 
         if key in self.identity_map:
             merged = self.identity_map[key]
+            
         elif not load:
             if state.modified:
                 raise sa_exc.InvalidRequestError(
@@ -1154,16 +1156,21 @@ class Session(object):
         if merged is None:
             merged = mapper.class_manager.new_instance()
             merged_state = attributes.instance_state(merged)
+            merged_dict = attributes.instance_dict(merged)
             new_instance = True
-            self.add(merged)
-
-        _recursive[instance] = merged
+            self._save_or_update_state(merged_state)
+        else:
+            merged_state = attributes.instance_state(merged)
+            merged_dict = attributes.instance_dict(merged)
+            
+        _recursive[state] = merged
 
         for prop in mapper.iterate_properties:
-            prop.merge(self, instance, merged, load, _recursive)
+            prop.merge(self, state, state_dict, merged_state, merged_dict, load, _recursive)
 
         if not load:
-            attributes.instance_state(merged).commit_all(attributes.instance_dict(merged), self.identity_map)  # remove any history
+            # remove any history
+            merged_state.commit_all(merged_dict, self.identity_map)  
 
         if new_instance:
             merged_state._run_on_load(merged)
diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py
new file mode 100644 (file)
index 0000000..d88e73c
--- /dev/null
@@ -0,0 +1,85 @@
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
+from sqlalchemy import exc as sa_exc, util, Integer, String, ForeignKey
+from sqlalchemy.orm import exc as orm_exc, mapper, relation, sessionmaker
+
+from sqlalchemy.test import testing, profiling
+from test.orm import _base
+from sqlalchemy.test.schema import Table, Column
+
+
+class MergeTest(_base.MappedTest):
+    @classmethod
+    def define_tables(cls, metadata):
+        parent = Table('parent', metadata,
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('data', String(20))
+        )
+
+        child = Table('child', metadata,
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('data', String(20)),
+            Column('parent_id', Integer, ForeignKey('parent.id'), nullable=False)
+        )
+
+
+    @classmethod
+    def setup_classes(cls):
+        class Parent(_base.BasicEntity):
+            pass
+        class Child(_base.BasicEntity):
+            pass
+
+    @classmethod
+    @testing.resolve_artifact_names
+    def setup_mappers(cls):
+        mapper(Parent, parent, properties={
+            'children':relation(Child, backref='parent')
+        })
+        mapper(Child, child)
+
+    @classmethod
+    @testing.resolve_artifact_names
+    def insert_data(cls):
+        parent.insert().execute(
+            {'id':1, 'data':'p1'},
+        )
+        child.insert().execute(
+            {'id':1, 'data':'p1c1', 'parent_id':1},
+        )
+    
+    @testing.resolve_artifact_names
+    def test_merge_no_load(self):
+        sess = sessionmaker()()
+        sess2 = sessionmaker()()
+        
+        p1 = sess.query(Parent).get(1)
+        p1.children
+        
+        # down from 185 on this
+        # this is a small slice of a usually bigger
+        # operation so using a small variance
+        @profiling.function_call_count(106, variance=0.001)
+        def go():
+            p2 = sess2.merge(p1, load=False)
+            
+        go()
+
+    @testing.resolve_artifact_names
+    def test_merge_load(self):
+        sess = sessionmaker()()
+        sess2 = sessionmaker()()
+
+        p1 = sess.query(Parent).get(1)
+        p1.children
+
+        # preloading of collection took this down from 1728
+        # to 1192 using sqlite3
+        @profiling.function_call_count(1192)
+        def go():
+            p2 = sess2.merge(p1)
+        go()
+        
+        # one more time, count the SQL
+        sess2 = sessionmaker()()
+        self.assert_sql_count(testing.db, go, 2)
+            
index 533c3ea5d12db5d3a0a17950d7bafb3789edbf56..c3b28386d37ab4cf77b812e0e026afb399989626 100644 (file)
@@ -4,7 +4,8 @@ from sqlalchemy import Integer, PickleType
 import operator
 from sqlalchemy.test import testing
 from sqlalchemy.util import OrderedSet
-from sqlalchemy.orm import mapper, relation, create_session, PropComparator, synonym, comparable_property, sessionmaker
+from sqlalchemy.orm import mapper, relation, create_session, PropComparator, \
+                            synonym, comparable_property, sessionmaker, attributes
 from sqlalchemy.test.testing import eq_, ne_
 from test.orm import _base, _fixtures
 from sqlalchemy.test.schema import Table, Column
@@ -378,6 +379,43 @@ class MergeTest(_fixtures.FixtureTest):
         eq_(on_load.called, 6)
         eq_(u3.name, 'also fred')
 
+    @testing.resolve_artifact_names
+    def test_many_to_one_cascade(self):
+        mapper(Address, addresses, properties={
+            'user':relation(User)
+        })
+        mapper(User, users)
+        
+        u1 = User(id=1, name="u1")
+        a1 =Address(id=1, email_address="a1", user=u1)
+        u2 = User(id=2, name="u2")
+        
+        sess = create_session()
+        sess.add_all([a1, u2])
+        sess.flush()
+        
+        a1.user = u2
+        
+        sess2 = create_session()
+        a2 = sess2.merge(a1)
+        eq_(
+            attributes.get_history(a2, 'user'), 
+            ([u2], (), [attributes.PASSIVE_NO_RESULT])
+        )
+        assert a2 in sess2.dirty
+        
+        sess.refresh(a1)
+        
+        sess2 = create_session()
+        a2 = sess2.merge(a1, load=False)
+        eq_(
+            attributes.get_history(a2, 'user'), 
+            ((), [u1], ())
+        )
+        assert a2 not in sess2.dirty
+        
+        
+        
     @testing.resolve_artifact_names
     def test_many_to_many_cascade(self):
 
index 967d1ac6cbb51c93be036e1d34a132141f17d5bc..bc3b9e26d8fa1fe554eb8f8c16ed859d72c96028 100644 (file)
@@ -929,6 +929,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL):
             sess.query(User.id).from_self().\
                 add_column(func.count().label('foo')).\
                 group_by(User.id).\
+                order_by(User.id).\
                 from_self().all(),
             [
                 (7,1), (8, 1), (9, 1), (10, 1)