]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed the "set collection" function on "dynamic" relations
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 27 Mar 2009 19:54:10 +0000 (19:54 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 27 Mar 2009 19:54:10 +0000 (19:54 +0000)
to initiate events correctly.  Previously a collection
could only be assigned to a pending parent instance,
otherwise modified events would not be fired correctly.
Set collection is now compatible with merge(),
fixes [ticket:1352].

CHANGES
lib/sqlalchemy/orm/dynamic.py
test/orm/dynamic.py

diff --git a/CHANGES b/CHANGES
index bef324894cc3f3c11401394ea3c0bb7c73d87ea9..ab09d7902f7b249dc1cfa01c29b504e014e03050 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -3,6 +3,16 @@
 =======
 CHANGES
 =======
+0.5.4
+=====
+- orm
+    - Fixed the "set collection" function on "dynamic" relations
+      to initiate events correctly.  Previously a collection
+      could only be assigned to a pending parent instance,
+      otherwise modified events would not be fired correctly.
+      Set collection is now compatible with merge(), 
+      fixes [ticket:1352].
+      
 0.5.3
 =====
 - orm
index 319364910c280d3f1033d6f605a370f9734f0172..4bc3f58c2e63efe17cb72b233272b2351771e884 100644 (file)
@@ -19,7 +19,7 @@ from sqlalchemy.orm import (
     )
 from sqlalchemy.orm.query import Query
 from sqlalchemy.orm.util import _state_has_identity, has_identity
-from sqlalchemy.orm import attributes
+from sqlalchemy.orm import attributes, collections
 
 class DynaLoader(strategies.AbstractRelationLoader):
     def init_class_attribute(self, mapper):
@@ -102,12 +102,19 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
         if initiator is self:
             return
 
+        self._set_iterable(state, value)
+
+    def _set_iterable(self, state, iterable, adapter=None):
+
         collection_history = self._modified_event(state)
+        new_values = list(iterable)
+        
         if _state_has_identity(state):
             old_collection = list(self.get(state))
         else:
             old_collection = []
-        collection_history.replace(old_collection, value)
+
+        collections.bulk_replace(new_values, DynCollectionAdapter(self, state, old_collection), DynCollectionAdapter(self, state, new_values))
 
     def delete(self, *args, **kwargs):
         raise NotImplementedError()
@@ -135,6 +142,28 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
         if initiator is not self:
             self.fire_remove_event(state, value, initiator)
 
+class DynCollectionAdapter(object):
+    """the dynamic analogue to orm.collections.CollectionAdapter"""
+    
+    def __init__(self, attr, owner_state, data):
+        self.attr = attr
+        self.state = owner_state
+        self.data = data
+    
+    def __iter__(self):
+        return iter(self.data)
+        
+    def append_with_event(self, item, initiator=None):
+        self.attr.append(self.state, item, initiator)
+
+    def remove_with_event(self, item, initiator=None):
+        self.attr.remove(self.state, item, initiator)
+
+    def append_without_event(self, item):
+        pass
+    
+    def remove_without_event(self, item):
+        pass
         
 class AppenderMixin(object):
     query_class = None
@@ -239,8 +268,4 @@ class CollectionHistory(object):
             self.deleted_items = []
             self.added_items = []
             self.unchanged_items = []
-            
-    def replace(self, olditems, newitems):
-        self.added_items = newitems
-        self.deleted_items = olditems
         
index 0a9e57831546e2c0245fb19c0874825afbb81a35..f975f762f8b0fc81a7267cc05e47d8dc4d9161e9 100644 (file)
@@ -3,7 +3,7 @@ import operator
 from sqlalchemy.orm import dynamic_loader, backref
 from testlib import testing
 from testlib.sa import Table, Column, Integer, String, ForeignKey, desc, select, func
-from testlib.sa.orm import mapper, relation, create_session, Query
+from testlib.sa.orm import mapper, relation, create_session, Query, attributes
 from testlib.testing import eq_
 from testlib.compat import _function_named
 from orm import _base, _fixtures
@@ -152,7 +152,7 @@ class DynamicTest(_fixtures.FixtureTest):
         assert type(q).__name__ == 'MyQuery'
 
 
-class FlushTest(_fixtures.FixtureTest):
+class SessionTest(_fixtures.FixtureTest):
     run_inserts = None
 
     @testing.resolve_artifact_names
@@ -193,9 +193,43 @@ class FlushTest(_fixtures.FixtureTest):
             (a2.id, u1.id, 'bar')
         ]
         
+
+    @testing.resolve_artifact_names
+    def test_merge(self):
+        mapper(User, users, properties={
+            'addresses':dynamic_loader(mapper(Address, addresses), order_by=addresses.c.email_address)
+        })
+        sess = create_session()
+        u1 = User(name='jack')
+        a1 = Address(email_address='a1')
+        a2 = Address(email_address='a2')
+        a3 = Address(email_address='a3')
+        
+        u1.addresses.append(a2)
+        u1.addresses.append(a3)
+        
+        sess.add_all([u1, a1])
+        sess.flush()
+        
+        u1 = User(id=u1.id, name='jack')
+        u1.addresses.append(a1)
+        u1.addresses.append(a3)
+        u1 = sess.merge(u1)
+        assert attributes.get_history(u1, 'addresses') == (
+            [a1], 
+            [a3], 
+            [a2]
+        )
+
+        sess.flush()
+        
+        eq_(
+            list(u1.addresses),
+            [a1, a3]
+        )
         
     @testing.resolve_artifact_names
-    def test_basic(self):
+    def test_flush(self):
         mapper(User, users, properties={
             'addresses':dynamic_loader(mapper(Address, addresses))
         })
@@ -231,6 +265,31 @@ class FlushTest(_fixtures.FixtureTest):
         assert 'addresses' not in u1.__dict__.keys()
         u1.addresses = [Address(email_address='test')]
         assert 'addresses' in dir(u1)
+    
+    @testing.resolve_artifact_names
+    def test_collection_set(self):
+        mapper(User, users, properties={
+            'addresses':dynamic_loader(mapper(Address, addresses), order_by=addresses.c.email_address)
+        })
+        sess = create_session(autoflush=True, autocommit=False)
+        u1 = User(name='jack')
+        a1 = Address(email_address='a1')
+        a2 = Address(email_address='a2')
+        a3 = Address(email_address='a3')
+        a4 = Address(email_address='a4')
+        
+        sess.add(u1)
+        u1.addresses = [a1, a3]
+        assert list(u1.addresses) == [a1, a3]
+        u1.addresses = [a1, a2, a4]
+        assert list(u1.addresses) == [a1, a2, a4]
+        u1.addresses = [a2, a3]
+        assert list(u1.addresses) == [a2, a3]
+        u1.addresses = []
+        assert list(u1.addresses) == []
+        
+        
+
         
     @testing.resolve_artifact_names
     def test_rollback(self):
@@ -392,7 +451,7 @@ def create_backref_test(autoflush, saveuser):
     test_backref = _function_named(
         test_backref, "test%s%s" % ((autoflush and "_autoflush" or ""),
                                     (saveuser and "_saveuser" or "_savead")))
-    setattr(FlushTest, test_backref.__name__, test_backref)
+    setattr(SessionTest, test_backref.__name__, test_backref)
 
 for autoflush in (False, True):
     for saveuser in (False, True):