]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- reduced a bit of overhead in attribute expiration, particularly
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 13 Feb 2010 21:16:08 +0000 (21:16 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 13 Feb 2010 21:16:08 +0000 (21:16 +0000)
the version called by column loaders on an incomplete row (i.e.
joined table inheritance).  there are more dramatic changes
that can be made here but this one is conservative so far
as far as how much we're altering how InstanceState tracks
"expired" attributes.

lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/state.py
lib/sqlalchemy/orm/strategies.py
test/orm/test_attributes.py
test/orm/test_extendedattr.py
test/perf/objselectspeed.py

index 5a6e24dfaae59804aff65dcf250ce657aa4d14c2..a8c525657f09d67680b13820ce7836bebddf8284 100644 (file)
@@ -1484,7 +1484,7 @@ class Mapper(object):
                 )
 
                 if readonly:
-                    _expire_state(state, readonly)
+                    _expire_state(state, state.dict, readonly)
 
                 # if specified, eagerly refresh whatever has
                 # been expired.
@@ -1524,7 +1524,7 @@ class Mapper(object):
         deferred_props = [prop.key for prop in [self._columntoproperty[c] for c in postfetch_cols]]
 
         if deferred_props:
-            _expire_state(state, deferred_props)
+            _expire_state(state, state.dict, deferred_props)
 
         # synchronize newly inserted ids from one table to the next
         # TODO: this still goes a little too often.  would be nice to
index 8dbc6b3db7881776a93b058d05b7d331e47d2d9b..bb92f39e497e435864abee60c74e969fc10f6dd8 100644 (file)
@@ -115,7 +115,7 @@ class ColumnProperty(StrategizedProperty):
                 impl = dest_state.get_impl(self.key)
                 impl.set(dest_state, dest_dict, value, None)
         else:
-            dest_state.expire_attributes([self.key])
+            dest_state.expire_attributes(dest_dict, [self.key])
 
     def get_col_value(self, column, value):
         return value
@@ -636,7 +636,7 @@ class RelationProperty(StrategizedProperty):
                     return
 
         if not "merge" in self.cascade:
-            dest_state.expire_attributes([self.key])
+            dest_state.expire_attribute(dest_dict, [self.key])
             return
 
         if self.key not in source_dict:
index ed55351514fe885bbb59ab92fd09bee1c3741fc2..456f1f19c56f0b40a8a95bfa53485ee81f9c3878 100644 (file)
@@ -1873,8 +1873,9 @@ class Query(object):
 
                     state.commit(dict_, list(to_evaluate))
 
-                    # expire attributes with pending changes (there was no autoflush, so they are overwritten)
-                    state.expire_attributes(set(evaluated_keys).difference(to_evaluate))
+                    # expire attributes with pending changes 
+                    # (there was no autoflush, so they are overwritten)
+                    state.expire_attributes(dict_, set(evaluated_keys).difference(to_evaluate))
 
         elif synchronize_session == 'fetch':
             target_mapper = self._mapper_zero()
index 0e5e939b1386cb9ad42344cf45dce261a96ce5de..d5246bee0ab66cb4197d00facae4f10878e96724 100644 (file)
@@ -289,14 +289,14 @@ class SessionTransaction(object):
         assert not self.session._deleted
 
         for s in self.session.identity_map.all_states():
-            _expire_state(s, None, instance_dict=self.session.identity_map)
+            _expire_state(s, s.dict, None, instance_dict=self.session.identity_map)
 
     def _remove_snapshot(self):
         assert self._is_transaction_boundary
 
         if not self.nested and self.session.expire_on_commit:
             for s in self.session.identity_map.all_states():
-                _expire_state(s, None, instance_dict=self.session.identity_map)
+                _expire_state(s, s.dict, None, instance_dict=self.session.identity_map)
 
     def _connection_for_bind(self, bind):
         self._assert_is_active()
@@ -915,7 +915,7 @@ class Session(object):
         """Expires all persistent instances within this Session."""
 
         for state in self.identity_map.all_states():
-            _expire_state(state, None, instance_dict=self.identity_map)
+            _expire_state(state, state.dict, None, instance_dict=self.identity_map)
 
     def expire(self, instance, attribute_names=None):
         """Expire the attributes on an instance.
@@ -936,14 +936,15 @@ class Session(object):
             raise exc.UnmappedInstanceError(instance)
         self._validate_persistent(state)
         if attribute_names:
-            _expire_state(state, attribute_names=attribute_names, instance_dict=self.identity_map)
+            _expire_state(state, state.dict, 
+                                attribute_names=attribute_names, instance_dict=self.identity_map)
         else:
             # pre-fetch the full cascade since the expire is going to
             # remove associations
             cascaded = list(_cascade_state_iterator('refresh-expire', state))
-            _expire_state(state, None, instance_dict=self.identity_map)
+            _expire_state(state, state.dict, None, instance_dict=self.identity_map)
             for (state, m, o) in cascaded:
-                _expire_state(state, None, instance_dict=self.identity_map)
+                _expire_state(state, state.dict, None, instance_dict=self.identity_map)
 
     def prune(self):
         """Remove unreferenced instances cached in the identity map.
index 4bb9219f4eae88fc6deff96c046c01a2b29a9312..a9494a50e125358d0dbdf38fcd46346606e58b91 100644 (file)
@@ -239,9 +239,30 @@ class InstanceState(object):
         return set(
             key for key in self.manager.iterkeys()
             if key not in self.committed_state and key not in self.dict)
-
-    def expire_attributes(self, attribute_names, instance_dict=None):
-        self.expired_attributes = set(self.expired_attributes)
+    
+    def expire_attribute_pre_commit(self, dict_, key):
+        """a fast expire that can be called by column loaders during a load.
+        
+        The additional bookkeeping is finished up in commit_all().
+        
+        This method is actually called a lot with joined-table
+        loading, when the second table isn't present in the result.
+        
+        """
+        # TODO: yes, this is still a little too busy.
+        # need to more cleanly separate out handling 
+        # for the various AttributeImpls and the contracts 
+        # they wish to maintain with their strategies
+        if not self.expired_attributes:
+            self.expired_attributes = set(self.expired_attributes)
+            
+        dict_.pop(key, None)
+        self.callables[key] = self
+        self.expired_attributes.add(key)
+        
+    def expire_attributes(self, dict_, attribute_names, instance_dict=None):
+        if not self.expired_attributes:
+            self.expired_attributes = set(self.expired_attributes)
 
         if attribute_names is None:
             attribute_names = self.manager.keys()
@@ -258,7 +279,6 @@ class InstanceState(object):
             filter_deferred = True
         else:
             filter_deferred = False
-        dict_ = self.dict
         
         for key in attribute_names:
             impl = self.manager[key].impl
@@ -354,8 +374,7 @@ class InstanceState(object):
         
         self.committed_state = {}
         self.pending = {}
-        
-        # unexpire attributes which have loaded
+            
         if self.expired_attributes:
             for key in self.expired_attributes.intersection(dict_):
                 self.callables.pop(key, None)
index 5e81d33ca6f3e5f483e2fc23deae8db3929f311c..4d5ec3da4fb17ed64ebc818dcad81b1d2f0ca9c9 100644 (file)
@@ -121,7 +121,7 @@ class ColumnLoader(LoaderStrategy):
         else:
             def new_execute(state, dict_, row, isnew):
                 if isnew:
-                    state.expire_attributes([key])
+                    state.expire_attribute_pre_commit(dict_, key)
         return new_execute, None
 
 log.class_logger(ColumnLoader)
@@ -168,7 +168,7 @@ class CompositeColumnLoader(ColumnLoader):
             if c not in row:
                 def new_execute(state, dict_, row, isnew):
                     if isnew:
-                        state.expire_attributes([key])
+                        state.expire_attribute_pre_commit(dict_, key)
                 break
         else:
             def new_execute(state, dict_, row, isnew):
index c69021aa3d7fa3db1707512cd4df0158da650bb8..e6041d5663988dfc63d0b588916de2f11f2259a4 100644 (file)
@@ -142,21 +142,21 @@ class AttributesTest(_base.ORMTest):
         attributes.register_attribute(Foo, 'b', uselist=False, useobject=False)
 
         f = Foo()
-        attributes.instance_state(f).expire_attributes(None)
+        attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None)
         eq_(f.a, "this is a")
         eq_(f.b, 12)
 
         f.a = "this is some new a"
-        attributes.instance_state(f).expire_attributes(None)
+        attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None)
         eq_(f.a, "this is a")
         eq_(f.b, 12)
 
-        attributes.instance_state(f).expire_attributes(None)
+        attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None)
         f.a = "this is another new a"
         eq_(f.a, "this is another new a")
         eq_(f.b, 12)
 
-        attributes.instance_state(f).expire_attributes(None)
+        attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None)
         eq_(f.a, "this is a")
         eq_(f.b, 12)
 
@@ -182,7 +182,7 @@ class AttributesTest(_base.ORMTest):
         attributes.register_attribute(MyTest, 'b', uselist=False, useobject=False)
 
         m = MyTest()
-        attributes.instance_state(m).expire_attributes(None)
+        attributes.instance_state(m).expire_attributes(attributes.instance_dict(m), None)
         assert 'a' not in m.__dict__
         m2 = pickle.loads(pickle.dumps(m))
         assert 'a' not in m2.__dict__
@@ -355,7 +355,7 @@ class AttributesTest(_base.ORMTest):
         x.bars
         b = Bar(id=4)
         b.foos.append(x)
-        attributes.instance_state(x).expire_attributes(['bars'])
+        attributes.instance_state(x).expire_attributes(attributes.instance_dict(x), ['bars'])
         assert_raises(AssertionError, b.foos.remove, x)
         
         
@@ -1294,7 +1294,7 @@ class HistoryTest(_base.ORMTest):
         eq_(attributes.get_state_history(attributes.instance_state(f), 'bars'), ([bar4], [], []))
 
         lazy_load = [bar1, bar2, bar3]
-        attributes.instance_state(f).expire_attributes(['bars'])
+        attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), ['bars'])
         eq_(attributes.get_state_history(attributes.instance_state(f), 'bars'), ((), [bar1, bar2, bar3], ()))
 
     def test_collections_via_lazyload(self):
index 685be3a5f64fc874daad7b889ac514dd84429c6d..4374b9ecb97eeb0733f5fd5085f56206b6d13016 100644 (file)
@@ -161,21 +161,21 @@ class UserDefinedExtensionTest(_base.ORMTest):
             
             assert Foo in attributes.instrumentation_registry._state_finders
             f = Foo()
-            attributes.instance_state(f).expire_attributes(None)
+            attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None)
             eq_(f.a, "this is a")
             eq_(f.b, 12)
 
             f.a = "this is some new a"
-            attributes.instance_state(f).expire_attributes(None)
+            attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None)
             eq_(f.a, "this is a")
             eq_(f.b, 12)
 
-            attributes.instance_state(f).expire_attributes(None)
+            attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None)
             f.a = "this is another new a"
             eq_(f.a, "this is another new a")
             eq_(f.b, 12)
 
-            attributes.instance_state(f).expire_attributes(None)
+            attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None)
             eq_(f.a, "this is a")
             eq_(f.b, 12)
 
index e04ef4efbedd3c1fd1cff798d14d8f81423783bb..d3fd34046687e8b6927f29e4a302224601fe054f 100644 (file)
@@ -8,21 +8,44 @@ db = create_engine('sqlite://')
 metadata = MetaData(db)
 Person_table = Table('Person', metadata,
                      Column('id', Integer, primary_key=True),
+                     Column('type', String(10)),
                      Column('name', String(40)),
                      Column('sex', Integer),
                      Column('age', Integer))
 
+
+Employee_table = Table('Employee', metadata,
+                  Column('id', Integer, ForeignKey('Person.id'), primary_key=True),
+                  Column('foo', String(40)),
+                  Column('bar', Integer),
+                  Column('bat', Integer))
+
 class RawPerson(object): pass
 class Person(object): pass
 mapper(Person, Person_table)
+
+class JoinedPerson(object):pass
+class Employee(JoinedPerson):pass
+mapper(JoinedPerson, Person_table, \
+                polymorphic_on=Person_table.c.type, polymorphic_identity='person')
+mapper(Employee, Employee_table, \
+                inherits=JoinedPerson, polymorphic_identity='employee')
 compile_mappers()
 
 def setup():
     metadata.create_all()
     i = Person_table.insert()
-    data = [{'name':'John Doe','sex':1,'age':35}] * 100
+    data = [{'name':'John Doe','sex':1,'age':35, 'type':'employee'}] * 100
     for j in xrange(500):
         i.execute(data)
+        
+    # note we arent fetching from employee_table,
+    # so we can leave it empty even though its "incorrect"
+    #i = Employee_table.insert()
+    #data = [{'foo':'foo', 'bar':'bar':'bat':'bat'}] * 100
+    #for j in xrange(500):
+    #    i.execute(data)
+        
     print "Inserted 50,000 rows"
 
 def sqlite_select(entity_cls):
@@ -55,6 +78,11 @@ def orm_select():
     session = create_session()
     people = session.query(Person).all()
 
+#@profiling.profiled(report=True, always=True)
+def joined_orm_select():
+    session = create_session()
+    people = session.query(JoinedPerson).all()
+
 def all():
     setup()
     try:
@@ -103,6 +131,13 @@ def all():
         orm_select()
         t2 = time.clock()
         usage('sqlalchemy.orm fetch')
+
+        gc_collect()
+        usage.snap()
+        t = time.clock()
+        joined_orm_select()
+        t2 = time.clock()
+        usage('sqlalchemy.orm "joined" fetch')
     finally:
         metadata.drop_all()