]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added expire_all() method to Session. Calls expire()
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 11 Feb 2008 19:22:34 +0000 (19:22 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 11 Feb 2008 19:22:34 +0000 (19:22 +0000)
for all persistent instances.  This is handy in conjunction
with .....

- instances which have been partially or fully expired
will have their expired attributes populated during a regular
Query operation which affects those objects, preventing
a needless second SQL statement for each instance.

CHANGES
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/session.py
test/orm/expire.py

diff --git a/CHANGES b/CHANGES
index e94826405a15de84511f154a30b1bda36ee4c8d5..e28563201ef44b73580e203bc18258b747b20b33 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -105,7 +105,16 @@ CHANGES
 
     - The proper error message is raised when trying to access
       expired instance attributes with no session present
-
+    
+    - added expire_all() method to Session.  Calls expire()
+      for all persistent instances.  This is handy in conjunction
+      with .....
+      
+    - instances which have been partially or fully expired
+      will have their expired attributes populated during a regular
+      Query operation which affects those objects, preventing
+      a needless second SQL statement for each instance.
+      
     - Dynamic relations, when referenced, create a strong
       reference to the parent object so that the query still has a
       parent to call against even if the parent is only created
index e08a1a0c2ca36c20b14c6a916bb1d3a4086df8d2..5ae79e4323b3b2cffd18994be5a03d8839997d18 100644 (file)
@@ -775,7 +775,14 @@ class InstanceState(object):
         serializable.
         """
         instance = self.obj()
-        self.class_._class_state.deferred_scalar_loader(instance, [k for k in self.expired_attributes if k in self.unmodified])
+        
+        unmodified = self.unmodified
+        self.class_._class_state.deferred_scalar_loader(instance, [
+            attr.impl.key for attr in _managed_attributes(self.class_) if 
+                attr.impl.accepts_scalar_loader and 
+                attr.impl.key in self.expired_attributes and 
+                attr.impl.key in unmodified
+            ])
         for k in self.expired_attributes:
             self.callables.pop(k, None)
         self.expired_attributes.clear()
@@ -798,20 +805,18 @@ class InstanceState(object):
         if attribute_names is None:
             for attr in _managed_attributes(self.class_):
                 self.dict.pop(attr.impl.key, None)
-
+                self.expired_attributes.add(attr.impl.key)
                 if attr.impl.accepts_scalar_loader:
                     self.callables[attr.impl.key] = self
-                    self.expired_attributes.add(attr.impl.key)
 
             self.committed_state = {}
         else:
             for key in attribute_names:
                 self.dict.pop(key, None)
                 self.committed_state.pop(key, None)
-
+                self.expired_attributes.add(key)
                 if getattr(self.class_, key).impl.accepts_scalar_loader:
                     self.callables[key] = self
-                    self.expired_attributes.add(key)
 
     def reset(self, key):
         """remove the given attribute and any callables associated with it."""
index d78973e9424d92339da28d79f08e62cd8f340e37..85aec2f4476fc9a4011361d7483ba339b88e4e81 100644 (file)
@@ -1371,13 +1371,22 @@ class Mapper(object):
 
             if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
                 self.populate_instance(context, instance, row, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew)
+        
+        else:
+            attrs = getattr(state, 'expired_attributes', None)
+            # populate attributes on non-loading instances which have been expired
+            # TODO: also support deferred attributes here [ticket:870]
+            if attrs: 
+                if state in context.partials:
+                    isnew = False
+                    attrs = context.partials[state]
+                else:
+                    isnew = True
+                    attrs = state.expired_attributes.intersection(state.unmodified)
+                    context.partials[state] = attrs  #<-- allow query.instances to commit the subset of attrs
 
-#       NOTYET: populate attributes on non-loading instances which have been expired, deferred, etc.
-#        elif getattr(state, 'expired_attributes', None):   # TODO: base off total set of unloaded attributes, not just exp
-#            attrs = state.expired_attributes.intersection(state.unmodified)
-#            if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=attrs, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
-#                self.populate_instance(context, instance, row, only_load_props=attrs, instancekey=identitykey, isnew=isnew)
-#            context.partials.add((state, attrs))  <-- allow query.instances to commit the subset of attrs
+                if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=attrs, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
+                    self.populate_instance(context, instance, row, only_load_props=attrs, instancekey=identitykey, isnew=isnew)
 
         if result is not None and ('append_result' not in extension.methods or extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE):
             result.append(instance)
@@ -1448,10 +1457,10 @@ class Mapper(object):
         if self.non_primary:
             selectcontext.attributes[('populating_mapper', instance._state)] = self
 
-    def _post_instance(self, selectcontext, state):
+    def _post_instance(self, selectcontext, state, **kwargs):
         post_processors = selectcontext.attributes[('post_processors', self, None)]
         for p in post_processors:
-            p(state.obj())
+            p(state.obj(), **kwargs)
 
     def _get_poly_select_loader(self, selectcontext, row):
         """set up attribute loaders for 'select' and 'deferred' polymorphic loading.
@@ -1475,11 +1484,13 @@ class Mapper(object):
 
                 identitykey = self.identity_key_from_instance(instance)
 
+                only_load_props = flags.get('only_load_props', None)
+
                 params = {}
                 for c, bind in param_names:
                     params[bind] = self._get_attr_by_column(instance, c)
                 row = selectcontext.session.connection(self).execute(statement, params).fetchone()
-                self.populate_instance(selectcontext, instance, row, isnew=False, instancekey=identitykey, ispostselect=True)
+                self.populate_instance(selectcontext, instance, row, isnew=False, instancekey=identitykey, ispostselect=True, only_load_props=only_load_props)
             return post_execute
         elif hosted_mapper.polymorphic_fetch == 'deferred':
             from sqlalchemy.orm.strategies import DeferredColumnLoader
@@ -1494,6 +1505,12 @@ class Mapper(object):
 
                 props = [prop for prop in [self._get_col_to_prop(col) for col in statement.inner_columns] if prop.key not in instance.__dict__]
                 keys = [p.key for p in props]
+                
+                only_load_props = flags.get('only_load_props', None)
+                if only_load_props:
+                    keys = util.Set(keys).difference(only_load_props)
+                    props = [p for p in props if p.key in only_load_props]
+                    
                 for prop in props:
                     strategy = prop._get_strategy(DeferredColumnLoader)
                     instance._state.set_callable(prop.key, strategy.setup_loader(instance, props=keys, create_statement=create_statement))
index 35f632f1e711baa41138b3adcb26a4f551e2b846..2bb87ea715a39028cf580ddf7f301ddd3ccf57eb 100644 (file)
@@ -903,6 +903,7 @@ class Query(object):
 
         while True:
             context.progress = util.Set()
+            context.partials = {}
 
             if self._yield_per:
                 fetch = cursor.fetchmany(self._yield_per)
@@ -927,7 +928,11 @@ class Query(object):
             for ii in context.progress:
                 context.attributes.get(('populating_mapper', ii), _state_mapper(ii))._post_instance(context, ii)
                 ii.commit_all()
-
+                
+            for ii, attrs in context.partials.items():
+                context.attributes.get(('populating_mapper', ii), _state_mapper(ii))._post_instance(context, ii, only_load_props=attrs)
+                ii.commit(attrs)
+                
             for row in rows:
                 yield row
 
index c75b786644cdcb80796742c2f22305f7e10de0da..8f85a496c4fc63543d45dabdb1fe6d9a86e12200 100644 (file)
@@ -820,7 +820,14 @@ class Session(object):
 
         if self.query(_object_mapper(instance))._get(instance._instance_key, refresh_instance=instance._state, only_load_props=attribute_names) is None:
             raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(instance))
-
+    
+    def expire_all(self):
+        """Expires all persistent instances within this Session.  
+        
+        """
+        for state in self.identity_map.all_states():
+            _expire_state(state, None)
+        
     def expire(self, instance, attribute_names=None):
         """Expire the attributes on the given instance.
 
@@ -829,13 +836,6 @@ class Session(object):
         to the database which will refresh all attributes with their
         current value.
 
-        Lazy-loaded relational attributes will remain lazily loaded, so that
-        triggering one will incur the instance-wide refresh operation, followed
-        immediately by the lazy load of that attribute.
-
-        Eagerly-loaded relational attributes will eagerly load within the
-        single refresh operation.
-
         The ``attribute_names`` argument is an iterable collection
         of attribute names indicating a subset of attributes to be
         expired.
index 3394c751bd53b90484ce4d1c1ca6bb26d35a3cd7..545f01234d5c831dcefac37d9a10fd25ed7be74d 100644 (file)
@@ -6,6 +6,7 @@ from sqlalchemy import exceptions
 from sqlalchemy.orm import *
 from testlib import *
 from testlib.fixtures import *
+import gc
 
 class ExpireTest(FixtureTest):
     keep_mappers = False
@@ -39,16 +40,13 @@ class ExpireTest(FixtureTest):
         # object isnt refreshed yet, using dict to bypass trigger
         assert u.__dict__.get('name') != 'jack'
 
-        if False:
-            # NOTYET: need to implement unconditional population
-            # of expired attriutes in mapper._instances()
-            sess.query(User).all()
-            # test that it refreshed
-            assert u.__dict__['name'] == 'jack'
+        sess.query(User).all()
+        # test that it refreshed
+        assert u.__dict__['name'] == 'jack'
 
-            def go():
-                assert u.name == 'jack'
-            self.assert_sql_count(testing.db, go, 0)
+        def go():
+            assert u.name == 'jack'
+        self.assert_sql_count(testing.db, go, 0)
 
     def test_expire_doesntload_on_set(self):
         mapper(User, users)
@@ -122,16 +120,21 @@ class ExpireTest(FixtureTest):
         assert o.isopen == 1
         assert o.description == 'some new description'
 
-        if False:
-            # NOTYET: need to implement unconditional population
-            # of expired attriutes in mapper._instances()
-            sess.expire(o, ['isopen', 'description'])
-            sess.query(Order).all()
-            del o.isopen
-            def go():
-                assert o.isopen is None
-            self.assert_sql_count(testing.db, go, 0)
+        sess.expire(o, ['isopen', 'description'])
+        sess.query(Order).all()
+        del o.isopen
+        def go():
+            assert o.isopen is None
+        self.assert_sql_count(testing.db, go, 0)
 
+        o.isopen=14
+        sess.expire(o)
+        o.description = 'another new description'
+        sess.query(Order).all()
+        assert o.isopen == 1
+        assert o.description == 'another new description'
+        
+        
     def test_expire_committed(self):
         """test that the committed state of the attribute receives the most recent DB data"""
         mapper(Order, orders)
@@ -200,7 +203,7 @@ class ExpireTest(FixtureTest):
             assert u.addresses[0].email_address == 'jack@bean.com'
             assert u.name == 'jack'
         # two loads, since relation() + scalar are
-        # separate right now
+        # separate right now on per-attribute load
         self.assert_sql_count(testing.db, go, 2)
         assert 'name' in u.__dict__
         assert 'addresses' in u.__dict__
@@ -209,6 +212,50 @@ class ExpireTest(FixtureTest):
         assert 'name' not in u.__dict__
         assert 'addresses' not in u.__dict__
     
+        def go():
+            sess.query(User).filter_by(id=7).one()
+            assert u.addresses[0].email_address == 'jack@bean.com'
+            assert u.name == 'jack'
+        # one load, since relation() + scalar are
+        # together when eager load used with Query
+        self.assert_sql_count(testing.db, go, 1)
+            
+    def test_relation_changes_preserved(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, backref='user', lazy=False),
+            })
+        mapper(Address, addresses)
+        sess = create_session()
+        u = sess.query(User).get(8)
+        sess.expire(u, ['name', 'addresses'])
+        u.addresses
+        assert 'name' not in u.__dict__
+        del u.addresses[1]
+        u.name
+        assert 'name' in u.__dict__
+        assert len(u.addresses) == 2
+
+    def test_eagerload_props_dontload(self):
+        # relations currently have to load separately from scalar instances.  the use case is:
+        # expire "addresses".  then access it.  lazy load fires off to load "addresses", but needs
+        # foreign key or primary key attributes in order to lazy load; hits those attributes,
+        # such as below it hits "u.id".  "u.id" triggers full unexpire operation, eagerloads
+        # addresses since lazy=False.  this is all wihtin lazy load which fires unconditionally;
+        # so an unnecessary eagerload (or lazyload) was issued.  would prefer not to complicate
+        # lazyloading to "figure out" that the operation should be aborted right now.
+        
+        mapper(User, users, properties={
+            'addresses':relation(Address, backref='user', lazy=False),
+            })
+        mapper(Address, addresses)
+        sess = create_session()
+        u = sess.query(User).get(8)
+        sess.expire(u)
+        u.id
+        assert 'addresses' not in u.__dict__
+        u.addresses
+        assert 'addresses' in u.__dict__
+        
     def test_expire_synonym(self):
         mapper(User, users, properties={
             'uname':synonym('name')
@@ -361,6 +408,25 @@ class ExpireTest(FixtureTest):
         # doing it that way right now
         #self.assert_sql_count(testing.db, go, 0)
 
+    def test_relations_load_on_query(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, backref='user'),
+            })
+        mapper(Address, addresses)
+
+        sess = create_session()
+        u = sess.query(User).get(8)
+        assert 'name' in u.__dict__
+        u.addresses
+        assert 'addresses' in u.__dict__
+
+        sess.expire(u, ['name', 'addresses'])
+        assert 'name' not in u.__dict__
+        assert 'addresses' not in u.__dict__
+        sess.query(User).options(eagerload('addresses')).filter_by(id=8).all()
+        assert 'name' in u.__dict__
+        assert 'addresses' in u.__dict__
+        
     def test_partial_expire_deferred(self):
         mapper(Order, orders, properties={
             'description':deferred(orders.c.description)
@@ -426,8 +492,149 @@ class ExpireTest(FixtureTest):
             assert o.description == 'order 3'
             assert o.isopen == 1
         self.assert_sql_count(testing.db, go, 1)
+    
+    def test_eagerload_query_refreshes(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, backref='user', lazy=False),
+            })
+        mapper(Address, addresses)
+
+        sess = create_session()
+        u = sess.query(User).get(8)
+        assert len(u.addresses) == 3
+        sess.expire(u)
+        assert 'addresses' not in u.__dict__
+        print "-------------------------------------------"
+        sess.query(User).filter_by(id=8).all()
+        assert 'addresses' in u.__dict__
+        assert len(u.addresses) == 3
+        
+    def test_expire_all(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, backref='user', lazy=False),
+            })
+        mapper(Address, addresses)
+
+        sess = create_session()
+        userlist = sess.query(User).all()
+        assert fixtures.user_address_result == userlist
+        assert len(list(sess)) == 9
+        sess.expire_all()
+        gc.collect()
+        assert len(list(sess)) == 4 # since addresses were gc'ed
+        
+        userlist = sess.query(User).all()
+        u = userlist[1]
+        assert fixtures.user_address_result == userlist
+        assert len(list(sess)) == 9
+        
+class PolymorphicExpireTest(ORMTest):
+    keep_data = True
+    
+    def define_tables(self, metadata):
+        global people, engineers, Person, Engineer
+
+        people = Table('people', metadata,
+           Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
+           Column('name', String(50)),
+           Column('type', String(30)))
+
+        engineers = Table('engineers', metadata,
+           Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True),
+           Column('status', String(30)),
+          )
+        
+        class Person(Base):
+            pass
+        class Engineer(Person):
+            pass
+            
+    def insert_data(self):
+        people.insert().execute(
+            {'person_id':1, 'name':'person1', 'type':'person'},
+            {'person_id':2, 'name':'engineer1', 'type':'engineer'},
+            {'person_id':3, 'name':'engineer2', 'type':'engineer'},
+        )
+        engineers.insert().execute(
+            {'person_id':2, 'status':'new engineer'},
+            {'person_id':3, 'status':'old engineer'},
+        )
+
+    def test_poly_select(self):
+        mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person')
+        mapper(Engineer, engineers, inherits=Person, polymorphic_identity='engineer')
+        
+        sess = create_session()
+        [p1, e1, e2] = sess.query(Person).order_by(people.c.person_id).all()
+        
+        sess.expire(p1)
+        sess.expire(e1, ['status'])
+        sess.expire(e2)
+        
+        for p in [p1, e2]:
+            assert 'name' not in p.__dict__
+        
+        assert 'name' in e1.__dict__
+        assert 'status' not in e2.__dict__
+        assert 'status' not in e1.__dict__
+        
+        e1.name = 'new engineer name'
+        
+        def go():
+            sess.query(Person).all()
+        self.assert_sql_count(testing.db, go, 3)
+        
+        for p in [p1, e1, e2]:
+            assert 'name' in p.__dict__
+        
+        assert 'status' in e2.__dict__
+        assert 'status' in e1.__dict__
+        def go():
+            assert e1.name == 'new engineer name'
+            assert e2.name == 'engineer2'
+            assert e1.status == 'new engineer'
+        self.assert_sql_count(testing.db, go, 0)
+        self.assertEquals(Engineer.name.get_history(e1), (['new engineer name'], [], ['engineer1']))
+        
+    def test_poly_deferred(self):
+        mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person', polymorphic_fetch='deferred')
+        mapper(Engineer, engineers, inherits=Person, polymorphic_identity='engineer')
+
+        sess = create_session()
+        [p1, e1, e2] = sess.query(Person).order_by(people.c.person_id).all()
+
+        sess.expire(p1)
+        sess.expire(e1, ['status'])
+        sess.expire(e2)
+
+        for p in [p1, e2]:
+            assert 'name' not in p.__dict__
+
+        assert 'name' in e1.__dict__
+        assert 'status' not in e2.__dict__
+        assert 'status' not in e1.__dict__
+
+        e1.name = 'new engineer name'
+
+        def go():
+            sess.query(Person).all()
+        self.assert_sql_count(testing.db, go, 1)
+        
+        for p in [p1, e1, e2]:
+            assert 'name' in p.__dict__
 
+        assert 'status' not in e2.__dict__
+        assert 'status' not in e1.__dict__
 
+        def go():
+            assert e1.name == 'new engineer name'
+            assert e2.name == 'engineer2'
+            assert e1.status == 'new engineer'
+            assert e2.status == 'old engineer'
+        self.assert_sql_count(testing.db, go, 2)
+        self.assertEquals(Engineer.name.get_history(e1), (['new engineer name'], [], ['engineer1']))
+        
+    
 class RefreshTest(FixtureTest):
     keep_mappers = False
     refresh_data = True