]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fixed a truncation error when re-assigning a subset of a collection
authorJason Kirtland <jek@discorporate.us>
Wed, 31 Oct 2007 09:21:22 +0000 (09:21 +0000)
committerJason Kirtland <jek@discorporate.us>
Wed, 31 Oct 2007 09:21:22 +0000 (09:21 +0000)
(obj.relation = obj.relation[1:]) [ticket:834]

CHANGES
lib/sqlalchemy/orm/attributes.py
test/orm/cascade.py
test/orm/collection.py
test/orm/query.py

diff --git a/CHANGES b/CHANGES
index 16bb5aeef9074ad44863413984adc77e8eaa621b..d4f49e5d84d42cc336f3c2e80bf20fd4b61675b3 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -56,6 +56,9 @@ CHANGES
 - Fixed __hash__ for association proxy- these collections are unhashable,
   just like their mutable Python counterparts.
 
+- Fixed a truncation error when re-assigning a subset of a collection
+  (obj.relation = obj.relation[1:]) [ticket:834]
+
 0.4.0
 -----
 
index 8d035d568298d3c3b5be3e1f4c89bbf83b346a29..6d9c092a6b42d653933e760d95d877b065862651 100644 (file)
@@ -453,15 +453,25 @@ class CollectionAttributeImpl(AttributeImpl):
         old_collection = self.get_collection(state, old)
 
         new_collection, user_data = self._build_collection(state)
-        self._load_collection(state, value or [], emit_events=True,
-                              collection=new_collection)
+
+        idset = util.IdentitySet
+        constants = idset(old_collection or []).intersection(value or [])
+        additions = idset(value or []).difference(constants)
+        removals  = idset(old_collection or []).difference(constants)
+
+        for member in value or []:
+            if member in additions:
+                new_collection.append_with_event(member)
+            elif member in constants:
+                new_collection.append_without_event(member)
 
         state.dict[self.key] = user_data
         state.modified = True
 
-        # mark all the old elements as detached from the parent
+        # mark all the orphaned elements as detached from the parent
         if old_collection:
-            old_collection.clear_with_event()
+            for member in removals:
+                old_collection.remove_with_event(member)
             old_collection.unlink(old)
 
     def set_committed_value(self, state, value):
@@ -494,7 +504,7 @@ class CollectionAttributeImpl(AttributeImpl):
         else:
             for item in values:
                 collection.append_without_event(item)
-            
+
     def get_collection(self, state, user_data=None):
         if user_data is None:
             user_data = self.get(state)
index 8ab27c2b2098371b09a5e561b5637f20ab814b49..e24fbbdbab95655b7b3b168101ddd5fc9c308ca4 100644 (file)
@@ -113,7 +113,7 @@ class O2MCascadeTest(AssertMixin):
         
     def testdelete(self):
         sess = create_session()
-        l = sess.query(tables.User).select()
+        l = sess.query(tables.User).all()
         for u in l:
             print repr(u.orders)
         self.assert_result(l, data[0], *data[1:])
@@ -161,7 +161,7 @@ class O2MCascadeTest(AssertMixin):
 
     def testorphan(self):
         sess = create_session()
-        l = sess.query(tables.User).select()
+        l = sess.query(tables.User).all()
         jack = l[1]
         jack.orders[:] = []
 
@@ -525,6 +525,60 @@ class DoubleParentOrphanTest(AssertMixin):
         except exceptions.FlushError, e:
             assert True
 
+class CollectionAssignmentOrphanTest(AssertMixin):
+    def setUpAll(self):
+        global metadata, table_a, table_b
+
+        metadata = MetaData(testbase.db)
+        table_a = Table('a', metadata,
+                        Column('id', Integer, primary_key=True),
+                        Column('foo', String(30)))
+        table_b = Table('b', metadata,
+                        Column('id', Integer, primary_key=True),
+                        Column('foo', String(30)),
+                        Column('a_id', Integer, ForeignKey('a.id')))
+        metadata.create_all()
+
+    def tearDown(self):
+        clear_mappers()
+    def tearDownAll(self):
+        metadata.drop_all()
+
+    def test_basic(self):
+        class A(object):
+            def __init__(self, foo):
+                self.foo = foo
+        class B(object):
+            def __init__(self, foo):
+                self.foo = foo
+
+        mapper(A, table_a, properties={
+            'bs':relation(B, cascade="all, delete-orphan")
+            })
+        mapper(B, table_b)
+
+        a1 = A('a1')
+        a1.bs.append(B('b1'))
+        a1.bs.append(B('b2'))
+        a1.bs.append(B('b3'))
+
+        sess = create_session()
+        sess.save(a1)
+        sess.flush()
+
+        assert table_b.count(table_b.c.a_id == None).scalar() == 0
+
+        assert table_b.count().scalar() == 3
+
+        a1 = sess.query(A).get(a1.id)
+        assert len(a1.bs) == 3
+        a1.bs = list(a1.bs)
+        assert not class_mapper(B)._is_orphan(a1.bs[0])
+        a1.bs[0].foo='b2modified'
+        a1.bs[1].foo='b3modified'
+        sess.flush()
+
+        assert table_b.count().scalar() == 3
 
 if __name__ == "__main__":
-    testbase.main()        
+    testbase.main()
index d421952b53b49bfbe0e48a97123f1e2d543b1a82..504a4d0cb770582fc6b5bcb1c6e4feea49e0a92f 100644 (file)
@@ -257,7 +257,7 @@ class CollectionsTest(PersistTest):
         self.assert_(set(obj.attr) == set([e2]))
         self.assert_(e1 in canary.removed)
         self.assert_(e2 in canary.added)
+
         e3 = creator()
         real_list = [e3]
         obj.attr = real_list
@@ -265,7 +265,7 @@ class CollectionsTest(PersistTest):
         self.assert_(set(obj.attr) == set([e3]))
         self.assert_(e2 in canary.removed)
         self.assert_(e3 in canary.added)
-       
+
         e4 = creator()
         try:
             obj.attr = set([e4])
@@ -274,6 +274,21 @@ class CollectionsTest(PersistTest):
             self.assert_(e4 not in canary.data)
             self.assert_(e3 in canary.data)
 
+        e5 = creator()
+        e6 = creator()
+        e7 = creator()
+        obj.attr = [e5, e6, e7]
+        self.assert_(e5 in canary.added)
+        self.assert_(e6 in canary.added)
+        self.assert_(e7 in canary.added)
+
+        obj.attr = [e6, e7]
+        self.assert_(e5 in canary.removed)
+        self.assert_(e6 in canary.added)
+        self.assert_(e7 in canary.added)
+        self.assert_(e6 not in canary.removed)
+        self.assert_(e7 not in canary.removed)
+
     def test_list(self):
         self._test_adapter(list)
         self._test_list(list)
index 775f7357e0e4a176840ec275d0daaf908eb8874d..e2362d6c18ea00cf1404a24ade9bc4f88c7b828c 100644 (file)
@@ -871,7 +871,7 @@ class ExternalColumnsTest(QueryTest):
         })    
 
         sess = create_session()
-        l = sess.query(User).select()
+        l = sess.query(User).all()
         assert [
             User(id=7, concat=14, count=1),
             User(id=8, concat=16, count=3),