]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- A critical fix to dynamic relations allows the
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 19 Jul 2008 21:40:34 +0000 (21:40 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 19 Jul 2008 21:40:34 +0000 (21:40 +0000)
"modified" history to be properly cleared after
a flush() (backported from 0.5).

CHANGES
lib/sqlalchemy/orm/dynamic.py
test/orm/dynamic.py

diff --git a/CHANGES b/CHANGES
index 102e71c94f14dbae401f8f9449e4e692aa772bdf..4973436de0f6591dcb3c68f9c1eec02a920e966b 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -10,6 +10,10 @@ CHANGES
       that multiple contains() calls will not conflict
       with each other [ticket:1058]
 
+    - A critical fix to dynamic relations allows the 
+      "modified" history to be properly cleared after
+      a flush() (backported from 0.5).
+
     - fixed bug preventing merge() from functioning in 
       conjunction with a comparable_property()
 
index 133ad99c897912be7c37e5b654901a604f9c7a4e..19bdeab1ddf7776eed16dcaa963327d41ffce9dd 100644 (file)
@@ -5,7 +5,7 @@ from sqlalchemy import exceptions, util, logging
 from sqlalchemy.orm import attributes, object_session, util as mapperutil, strategies
 from sqlalchemy.orm.query import Query
 from sqlalchemy.orm.mapper import has_identity, object_mapper
-
+from sqlalchemy.orm.util import _state_has_identity
 
 class DynaLoader(strategies.AbstractRelationLoader):
     def init_class_attribute(self):
@@ -38,7 +38,8 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
             return history.added_items + history.unchanged_items
 
     def fire_append_event(self, state, value, initiator):
-        state.modified = True
+        collection_history = self._modified_event(state)
+        collection_history.added_items.append(value)
 
         if self.trackparent and value is not None:
             self.sethasparent(value._state, True)
@@ -47,7 +48,8 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
             ext.append(instance, value, initiator or self)
 
     def fire_remove_event(self, state, value, initiator):
-        state.modified = True
+        collection_history = self._modified_event(state)
+        collection_history.deleted_items.append(value)
 
         if self.trackparent and value is not None:
             self.sethasparent(value._state, False)
@@ -55,15 +57,28 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
         instance = state.obj()
         for ext in self.extensions:
             ext.remove(instance, value, initiator or self)
+    
+    def _modified_event(self, state):
+        state.modified = True
+        if self.key not in state.committed_state:
+            state.committed_state[self.key] = CollectionHistory(self, state)
+
+        # this is a hack to allow the _base.ComparableEntity fixture
+        # to work
+        state.dict[self.key] = True
+        
+        return state.committed_state[self.key]
         
     def set(self, state, value, initiator):
         if initiator is self:
             return
-
-        old_collection = self.get(state).assign(value)
-
-        # TODO: emit events ???
-        state.modified = True
+        
+        collection_history = self._modified_event(state)
+        if _state_has_identity(state):
+            old_collection = list(self.get(state))
+        else:
+            old_collection = []
+        collection_history.replace(old_collection, value)
 
     def delete(self, *args, **kwargs):
         raise NotImplementedError()
@@ -73,11 +88,11 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
         return (c.added_items, c.unchanged_items, c.deleted_items)
         
     def _get_collection_history(self, state, passive=False):
-        try:
-            c = state.dict[self.key]
-        except KeyError:
-            state.dict[self.key] = c = CollectionHistory(self, state)
-
+        if self.key in state.committed_state:
+            c = state.committed_state[self.key]
+        else:
+            c = CollectionHistory(self, state)
+            
         if not passive:
             return CollectionHistory(self, state, apply_to=c)
         else:
@@ -85,15 +100,13 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
         
     def append(self, state, value, initiator, passive=False):
         if initiator is not self:
-            self._get_collection_history(state, passive=True).added_items.append(value)
             self.fire_append_event(state, value, initiator)
     
     def remove(self, state, value, initiator, passive=False):
         if initiator is not self:
-            self._get_collection_history(state, passive=True).deleted_items.append(value)
             self.fire_remove_event(state, value, initiator)
 
-            
+        
 class AppenderQuery(Query):
     def __init__(self, attr, state):
         super(AppenderQuery, self).__init__(attr.target_mapper, None)
@@ -152,15 +165,6 @@ class AppenderQuery(Query):
             q = q.order_by(self.attr.order_by)
         return q
 
-    def assign(self, collection):
-        instance = self.instance
-        if has_identity(instance):
-            oldlist = list(self)
-        else:
-            oldlist = []
-        self.attr._get_collection_history(self.instance._state, passive=True).replace(oldlist, collection)
-        return oldlist
-        
     def append(self, item):
         self.attr.append(self.instance._state, item, None)
 
index c38b27823806422ec30bbb5e5c9c9de9f7fa8c34..3a851495e4e5c399ddebea812dafee238eb545ac 100644 (file)
@@ -118,7 +118,10 @@ class FlushTest(FixtureTest):
         sess.save(u1)
         sess.save(u2)
         sess.flush()
-
+        
+        from sqlalchemy.orm import attributes
+        self.assertEquals(attributes.get_history(u1._state, 'addresses'), ([], [Address(email_address='lala@hoho.com')], []))
+        
         sess.clear()
 
         # test the test fixture a little bit
@@ -129,6 +132,31 @@ class FlushTest(FixtureTest):
             User(name='jack', addresses=[Address(email_address='lala@hoho.com')]),
             User(name='ed', addresses=[Address(email_address='foo@bar.com')])
         ] == sess.query(User).all()
+    
+    def test_hasattr(self):
+        mapper(User, users, properties={
+            'addresses':dynamic_loader(mapper(Address, addresses))
+        })
+        u1 = User(name='jack')
+        
+        assert 'addresses' not in u1.__dict__.keys()
+        u1.addresses = [Address(email_address='test')]
+        assert 'addresses' in dir(u1)
+        
+    def test_rollback(self):
+        mapper(User, users, properties={
+            'addresses':dynamic_loader(mapper(Address, addresses))
+        })
+        sess = create_session(transactional=True, autoflush=True)
+        u1 = User(name='jack')
+        u1.addresses.append(Address(email_address='lala@hoho.com'))
+        sess.save(u1)
+        sess.flush()
+        sess.commit()
+        u1.addresses.append(Address(email_address='foo@bar.com'))
+        self.assertEquals(u1.addresses.all(), [Address(email_address='lala@hoho.com'), Address(email_address='foo@bar.com')])
+        sess.rollback()
+        self.assertEquals(u1.addresses.all(), [Address(email_address='lala@hoho.com')])
 
     @testing.fails_on('maxdb')
     def test_delete_nocascade(self):