]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- A critical fix to dynamic relations allows the
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 19 Jul 2008 21:33:58 +0000 (21:33 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 19 Jul 2008 21:33:58 +0000 (21:33 +0000)
"modified" history to be properly cleared after
a flush().

CHANGES
lib/sqlalchemy/orm/dynamic.py
test/orm/dynamic.py

diff --git a/CHANGES b/CHANGES
index 1d34eedbea6623b0c8e7b701aef465fb4e27f051..03082e3bd9ffd14e411930bc864cde90cad6a1f8 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -12,6 +12,10 @@ CHANGES
     - The 'cascade' parameter to relation() accepts None
       as a value, which is equivalent to no cascades.
     
+    - A critical fix to dynamic relations allows the 
+      "modified" history to be properly cleared after
+      a flush().
+      
     - Added a new SessionExtension hook called after_attach().
       This is called at the point of attachment for objects
       via add(), add_all(), delete(), and merge().
index 424ef85b7cb3a5637c7dbef888d446bc91453186..3d139dff123e75b4cab714ce59454975f5200073 100644 (file)
@@ -18,7 +18,7 @@ from sqlalchemy.orm import (
     attributes, object_session, util as mapperutil, strategies,
     )
 from sqlalchemy.orm.query import Query
-from sqlalchemy.orm.util import has_identity
+from sqlalchemy.orm.util import _state_has_identity, has_identity
 
 
 class DynaLoader(strategies.AbstractRelationLoader):
@@ -55,7 +55,8 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
             return history.added_items + history.unchanged_items
 
     def fire_append_event(self, state, value, initiator):
-        state.modified = True
+        collection_history = self._modified_event(state)
+        collection_history.added_items.append(value)
 
         if self.trackparent and value is not None:
             self.sethasparent(attributes.instance_state(value), True)
@@ -63,22 +64,36 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
             ext.append(state, value, initiator or self)
 
     def fire_remove_event(self, state, value, initiator):
-        state.modified = True
+        collection_history = self._modified_event(state)
+        collection_history.deleted_items.append(value)
 
         if self.trackparent and value is not None:
             self.sethasparent(attributes.instance_state(value), False)
 
         for ext in self.extensions:
             ext.remove(state, value, initiator or self)
+    
+    def _modified_event(self, state):
+        state.modified = True
+        if self.key not in state.committed_state:
+            state.committed_state[self.key] = CollectionHistory(self, state)
+
+        # this is a hack to allow the _base.ComparableEntity fixture
+        # to work
+        state.dict[self.key] = True
+        
+        return state.committed_state[self.key]
         
     def set(self, state, value, initiator):
         if initiator is self:
             return
-
-        old_collection = self.get(state).assign(value)
-
-        # TODO: emit events ???
-        state.modified = True
+        
+        collection_history = self._modified_event(state)
+        if _state_has_identity(state):
+            old_collection = list(self.get(state))
+        else:
+            old_collection = []
+        collection_history.replace(old_collection, value)
 
     def delete(self, *args, **kwargs):
         raise NotImplementedError()
@@ -88,11 +103,11 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
         return (c.added_items, c.unchanged_items, c.deleted_items)
         
     def _get_collection_history(self, state, passive=False):
-        try:
-            c = state.dict[self.key]
-        except KeyError:
-            state.dict[self.key] = c = CollectionHistory(self, state)
-
+        if self.key in state.committed_state:
+            c = state.committed_state[self.key]
+        else:
+            c = CollectionHistory(self, state)
+            
         if not passive:
             return CollectionHistory(self, state, apply_to=c)
         else:
@@ -100,15 +115,13 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
         
     def append(self, state, value, initiator, passive=False):
         if initiator is not self:
-            self._get_collection_history(state, passive=True).added_items.append(value)
             self.fire_append_event(state, value, initiator)
     
     def remove(self, state, value, initiator, passive=False):
         if initiator is not self:
-            self._get_collection_history(state, passive=True).deleted_items.append(value)
             self.fire_remove_event(state, value, initiator)
 
-            
+        
 class AppenderQuery(Query):
     def __init__(self, attr, state):
         super(AppenderQuery, self).__init__(attr.target_mapper, None)
@@ -170,15 +183,6 @@ class AppenderQuery(Query):
             q = q.order_by(self.attr.order_by)
         return q
 
-    def assign(self, collection):
-        instance = self.instance
-        if has_identity(instance):
-            oldlist = list(self)
-        else:
-            oldlist = []
-        self.attr._get_collection_history(attributes.instance_state(self.instance), passive=True).replace(oldlist, collection)
-        return oldlist
-        
     def append(self, item):
         self.attr.append(attributes.instance_state(self.instance), item, None)
 
index a4d0f396f69cec3253d70a54121881a2034f952c..5d9b54b15cad8aa947c67036034163197cea9b27 100644 (file)
@@ -129,7 +129,10 @@ class FlushTest(_fixtures.FixtureTest):
         u1.addresses.append(Address(email_address='lala@hoho.com'))
         sess.add_all((u1, u2))
         sess.flush()
-
+        
+        from sqlalchemy.orm import attributes
+        self.assertEquals(attributes.get_history(attributes.instance_state(u1), 'addresses'), ([], [Address(email_address='lala@hoho.com')], []))
+        
         sess.clear()
 
         # test the test fixture a little bit
@@ -140,7 +143,18 @@ class FlushTest(_fixtures.FixtureTest):
             User(name='jack', addresses=[Address(email_address='lala@hoho.com')]),
             User(name='ed', addresses=[Address(email_address='foo@bar.com')])
         ] == sess.query(User).all()
-
+    
+    @testing.resolve_artifact_names
+    def test_hasattr(self):
+        mapper(User, users, properties={
+            'addresses':dynamic_loader(mapper(Address, addresses))
+        })
+        u1 = User(name='jack')
+        
+        assert 'addresses' not in u1.__dict__.keys()
+        u1.addresses = [Address(email_address='test')]
+        assert 'addresses' in dir(u1)
+        
     @testing.resolve_artifact_names
     def test_rollback(self):
         mapper(User, users, properties={