]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- state.commit() and state.commit_all() now reconcile the current dict against expire...
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 1 Mar 2008 22:30:02 +0000 (22:30 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 1 Mar 2008 22:30:02 +0000 (22:30 +0000)
and unset the expired flag for those attributes.  This is partially so that attributes are not
needlessly marked as expired after a two-phase inheritance load.
- fixed bug which was introduced in 0.4.3, whereby loading an
already-persistent instance mapped with joined table inheritance
would trigger a useless "secondary" load from its joined
table, when using the default "select" polymorphic_fetch.
This was due to attributes being marked as expired
during its first load and not getting unmarked from the
previous "secondary" load.  Attributes are now unexpired
based on presence in __dict__ after any load or commit
operation succeeds.

CHANGES
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/mapper.py
test/orm/expire.py
test/orm/inheritance/polymorph.py

diff --git a/CHANGES b/CHANGES
index 3e5d36c20fba6ddb765f3253d25806f09adcadd3..4faa8ee8665f902280f5e94e7984501de730bb6f 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -45,6 +45,16 @@ CHANGES
     - Fixed potential generative bug when the same Query was used
       to generate multiple Query objects using join().
 
+    - fixed bug which was introduced in 0.4.3, whereby loading an
+      already-persistent instance mapped with joined table inheritance
+      would trigger a useless "secondary" load from its joined 
+      table, when using the default "select" polymorphic_fetch.  
+      This was due to attributes being marked as expired
+      during its first load and not getting unmarked from the 
+      previous "secondary" load.  Attributes are now unexpired
+      based on presence in __dict__ after any load or commit
+      operation succeeds.
+      
     - deprecated Query methods apply_sum(), apply_max(), apply_min(),
       apply_avg().  Better methodologies are coming....
       
index 298a7f51193e848c4d47dfc35662fa987c35c8d7..5c5781d4e6bce08088e652cc1ef7c3d927c6e610 100644 (file)
@@ -257,9 +257,6 @@ class AttributeImpl(object):
         """set an attribute value on the given instance and 'commit' it."""
 
         state.commit_attr(self, value)
-        # remove per-instance callable, if any
-        state.callables.pop(self.key, None)
-        state.dict[self.key] = value
         return value
 
 class ScalarAttributeImpl(AttributeImpl):
@@ -672,6 +669,9 @@ class ClassState(object):
         self.attrs = {}
         self.has_mutable_scalars = False
 
+import sets
+_empty_set = sets.ImmutableSet()
+
 class InstanceState(object):
     """tracks state information at the instance level."""
 
@@ -687,6 +687,7 @@ class InstanceState(object):
         self.appenders = {}
         self.instance_dict = None
         self.runid = None
+        self.expired_attributes = _empty_set
 
     def __cleanup(self, ref):
         # tiptoe around Python GC unpredictableness
@@ -751,7 +752,7 @@ class InstanceState(object):
             return None
 
     def __getstate__(self):
-        return {'committed_state':self.committed_state, 'pending':self.pending, 'parents':self.parents, 'modified':self.modified, 'instance':self.obj(), 'expired_attributes':getattr(self, 'expired_attributes', None), 'callables':self.callables}
+        return {'committed_state':self.committed_state, 'pending':self.pending, 'parents':self.parents, 'modified':self.modified, 'instance':self.obj(), 'expired_attributes':self.expired_attributes, 'callables':self.callables}
 
     def __setstate__(self, state):
         self.committed_state = state['committed_state']
@@ -764,8 +765,7 @@ class InstanceState(object):
         self.callables = state['callables']
         self.runid = None
         self.appenders = {}
-        if state['expired_attributes'] is not None:
-            self.expire_attributes(state['expired_attributes'])
+        self.expired_attributes = state['expired_attributes']
 
     def initialize(self, key):
         getattr(self.class_, key).impl.initialize(self)
@@ -780,7 +780,6 @@ class InstanceState(object):
         serializable.
         """
         instance = self.obj()
-        
         unmodified = self.unmodified
         self.class_._class_state.deferred_scalar_loader(instance, [
             attr.impl.key for attr in _managed_attributes(self.class_) if 
@@ -804,8 +803,7 @@ class InstanceState(object):
     unmodified = property(unmodified)
 
     def expire_attributes(self, attribute_names):
-        if not hasattr(self, 'expired_attributes'):
-            self.expired_attributes = util.Set()
+        self.expired_attributes = util.Set(self.expired_attributes)
 
         if attribute_names is None:
             for attr in _managed_attributes(self.class_):
@@ -829,18 +827,29 @@ class InstanceState(object):
         self.callables.pop(key, None)
 
     def commit_attr(self, attr, value):
+        """set the value of an attribute and mark it 'committed'."""
+
         if hasattr(attr, 'commit_to_state'):
             attr.commit_to_state(self, value)
         else:
             self.committed_state.pop(attr.key, None)
+        self.dict[attr.key] = value
         self.pending.pop(attr.key, None)
         self.appenders.pop(attr.key, None)
+        
+        # we have a value so we can also unexpire it
+        self.callables.pop(attr.key, None)
+        if attr.key in self.expired_attributes:
+            self.expired_attributes.remove(attr.key)
 
     def commit(self, keys):
         """commit all attributes named in the given list of key names.
 
         This is used by a partial-attribute load operation to mark committed those attributes
         which were refreshed from the database.
+
+        Attributes marked as "expired" can potentially remain "expired" after this step
+        if a value was not populated in state.dict.
         """
 
         if self.class_._class_state.has_mutable_scalars:
@@ -857,12 +866,22 @@ class InstanceState(object):
                 self.committed_state.pop(key, None)
                 self.pending.pop(key, None)
                 self.appenders.pop(key, None)
-
+                
+        # unexpire attributes which have loaded
+        for key in self.expired_attributes.intersection(keys):
+            if key in self.dict:
+                self.expired_attributes.remove(key)
+                self.callables.pop(key, None)
+                    
+                
     def commit_all(self):
         """commit all attributes unconditionally.
 
         This is used after a flush() or a regular instance load or refresh operation
         to mark committed all populated attributes.
+        
+        Attributes marked as "expired" can potentially remain "expired" after this step
+        if a value was not populated in state.dict.
         """
 
         self.committed_state = {}
@@ -870,6 +889,12 @@ class InstanceState(object):
         self.pending = {}
         self.appenders = {}
 
+        # unexpire attributes which have loaded
+        for key in list(self.expired_attributes):
+            if key in self.dict:
+                self.expired_attributes.remove(key)
+                self.callables.pop(key, None)
+
         if self.class_._class_state.has_mutable_scalars:
             for attr in _managed_attributes(self.class_):
                 if hasattr(attr.impl, 'commit_to_state') and attr.impl.key in self.dict:
index 297d222466b334b420407fa7edbbefd405255799..f89830c02c65f4e7832ecdc128421af883ac6202 100644 (file)
@@ -1373,10 +1373,9 @@ class Mapper(object):
                 self.populate_instance(context, instance, row, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew)
         
         else:
-            attrs = getattr(state, 'expired_attributes', None)
             # populate attributes on non-loading instances which have been expired
             # TODO: also support deferred attributes here [ticket:870]
-            if attrs: 
+            if state.expired_attributes: 
                 if state in context.partials:
                     isnew = False
                     attrs = context.partials[state]
@@ -1483,7 +1482,7 @@ class Mapper(object):
                     self.__log_debug("Post query loading instance " + instance_str(instance))
 
                 identitykey = self.identity_key_from_instance(instance)
-
+                
                 only_load_props = flags.get('only_load_props', None)
 
                 params = {}
@@ -1563,7 +1562,6 @@ object_session = None
 
 def _load_scalar_attributes(instance, attribute_names):
     mapper = object_mapper(instance)
-
     global object_session
     if not object_session:
         from sqlalchemy.orm.session import object_session
index 545f01234d5c831dcefac37d9a10fd25ed7be74d..dca56dfb8fa592dfd962f397432c0aaa2ec0cc34 100644 (file)
@@ -39,10 +39,12 @@ class ExpireTest(FixtureTest):
         sess.expire(u)
         # object isnt refreshed yet, using dict to bypass trigger
         assert u.__dict__.get('name') != 'jack'
+        assert 'name' in u._state.expired_attributes
 
         sess.query(User).all()
         # test that it refreshed
         assert u.__dict__['name'] == 'jack'
+        assert 'name' not in u._state.expired_attributes
 
         def go():
             assert u.name == 'jack'
index faee633601b3784676fb7ed8418d52935a275182..4b468e227c020be8b87ed7b6558e7632ef2d483a 100644 (file)
@@ -268,7 +268,7 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co
         c.employees.append(Engineer(status='CGG', engineer_name='engineer2', primary_language='python', **{person_attribute_name:'wally'}))
         c.employees.append(Manager(status='ABA', manager_name='manager2', **{person_attribute_name:'jsmith'}))
         session.save(c)
-        print session.new
+
         session.flush()
         session.clear()
 
@@ -284,7 +284,6 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co
         def go():
             c = session.query(Company).get(id)
             for e in c.employees:
-                print e, e._instance_key, e.company
                 assert e._instance_key[0] == Person
             if include_base:
                 assert sets.Set([(e.get_name(), getattr(e, 'status', None)) for e in c.employees]) == sets.Set([('pointy haired boss', 'AAB'), ('dilbert', 'BBA'), ('joesmith', None), ('wally', 'CGG'), ('jsmith', 'ABA')])
@@ -307,25 +306,31 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co
         # test selecting from the query, using the base mapped table (people) as the selection criterion.
         # in the case of the polymorphic Person query, the "people" selectable should be adapted to be "person_join"
         dilbert = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first()
-        dilbert2 = session.query(Engineer).filter(getattr(Person, person_attribute_name)=='dilbert').first()
-        assert dilbert is dilbert2
+        assert dilbert is session.query(Engineer).filter(getattr(Person, person_attribute_name)=='dilbert').first()
 
         # test selecting from the query, joining against an alias of the base "people" table.  test that
         # the "palias" alias does *not* get sucked up into the "person_join" conversion.
         palias = people.alias("palias")
-        session.query(Person).filter((palias.c.name=='dilbert') & (palias.c.person_id==Person.person_id)).first()
-        dilbert2 = session.query(Engineer).filter((palias.c.name=='dilbert') & (palias.c.person_id==Person.person_id)).first()
-        assert dilbert is dilbert2
-
-        session.query(Person).filter((Engineer.engineer_name=="engineer1") & (Engineer.person_id==people.c.person_id)).first()
-
-        dilbert2 = session.query(Engineer).filter(Engineer.engineer_name=="engineer1")[0]
-        assert dilbert is dilbert2
-
+        assert dilbert is session.query(Person).filter((palias.c.name=='dilbert') & (palias.c.person_id==Person.person_id)).first()
+        assert dilbert is session.query(Engineer).filter((palias.c.name=='dilbert') & (palias.c.person_id==Person.person_id)).first()
+        assert dilbert is session.query(Person).filter((Engineer.engineer_name=="engineer1") & (engineers.c.person_id==people.c.person_id)).first()
+        assert dilbert is session.query(Engineer).filter(Engineer.engineer_name=="engineer1")[0]
+        
         dilbert.engineer_name = 'hes dibert!'
 
         session.flush()
         session.clear()
+        
+        if polymorphic_fetch == 'select':
+            def go():
+                session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first()
+            self.assert_sql_count(testing.db, go, 2)
+            session.clear()
+            dilbert = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first()
+            def go():
+                # assert that only primary table is queried for already-present-in-session
+                d = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first()
+            self.assert_sql_count(testing.db, go, 1)
 
         # save/load some managers/bosses
         b = Boss(status='BBB', manager_name='boss', golf_swing='fore', **{person_attribute_name:'daboss'})