From: Jason Kirtland Date: Wed, 31 Oct 2007 09:21:22 +0000 (+0000) Subject: Fixed a truncation error when re-assigning a subset of a collection X-Git-Tag: rel_0_4_1~91 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=2d3f907ac0a23d410ecc3c74afc6d63bd2abc186;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Fixed a truncation error when re-assigning a subset of a collection (obj.relation = obj.relation[1:]) [ticket:834] --- diff --git a/CHANGES b/CHANGES index 16bb5aeef9..d4f49e5d84 100644 --- a/CHANGES +++ b/CHANGES @@ -56,6 +56,9 @@ CHANGES - Fixed __hash__ for association proxy- these collections are unhashable, just like their mutable Python counterparts. +- Fixed a truncation error when re-assigning a subset of a collection + (obj.relation = obj.relation[1:]) [ticket:834] + 0.4.0 ----- diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 8d035d5682..6d9c092a6b 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -453,15 +453,25 @@ class CollectionAttributeImpl(AttributeImpl): old_collection = self.get_collection(state, old) new_collection, user_data = self._build_collection(state) - self._load_collection(state, value or [], emit_events=True, - collection=new_collection) + + idset = util.IdentitySet + constants = idset(old_collection or []).intersection(value or []) + additions = idset(value or []).difference(constants) + removals = idset(old_collection or []).difference(constants) + + for member in value or []: + if member in additions: + new_collection.append_with_event(member) + elif member in constants: + new_collection.append_without_event(member) state.dict[self.key] = user_data state.modified = True - # mark all the old elements as detached from the parent + # mark all the orphaned elements as detached from the parent if old_collection: - old_collection.clear_with_event() + for member in removals: + old_collection.remove_with_event(member) old_collection.unlink(old) def set_committed_value(self, state, value): @@ -494,7 +504,7 @@ class CollectionAttributeImpl(AttributeImpl): else: for item in values: collection.append_without_event(item) - + def get_collection(self, state, user_data=None): if user_data is None: user_data = self.get(state) diff --git a/test/orm/cascade.py b/test/orm/cascade.py index 8ab27c2b20..e24fbbdbab 100644 --- a/test/orm/cascade.py +++ b/test/orm/cascade.py @@ -113,7 +113,7 @@ class O2MCascadeTest(AssertMixin): def testdelete(self): sess = create_session() - l = sess.query(tables.User).select() + l = sess.query(tables.User).all() for u in l: print repr(u.orders) self.assert_result(l, data[0], *data[1:]) @@ -161,7 +161,7 @@ class O2MCascadeTest(AssertMixin): def testorphan(self): sess = create_session() - l = sess.query(tables.User).select() + l = sess.query(tables.User).all() jack = l[1] jack.orders[:] = [] @@ -525,6 +525,60 @@ class DoubleParentOrphanTest(AssertMixin): except exceptions.FlushError, e: assert True +class CollectionAssignmentOrphanTest(AssertMixin): + def setUpAll(self): + global metadata, table_a, table_b + + metadata = MetaData(testbase.db) + table_a = Table('a', metadata, + Column('id', Integer, primary_key=True), + Column('foo', String(30))) + table_b = Table('b', metadata, + Column('id', Integer, primary_key=True), + Column('foo', String(30)), + Column('a_id', Integer, ForeignKey('a.id'))) + metadata.create_all() + + def tearDown(self): + clear_mappers() + def tearDownAll(self): + metadata.drop_all() + + def test_basic(self): + class A(object): + def __init__(self, foo): + self.foo = foo + class B(object): + def __init__(self, foo): + self.foo = foo + + mapper(A, table_a, properties={ + 'bs':relation(B, cascade="all, delete-orphan") + }) + mapper(B, table_b) + + a1 = A('a1') + a1.bs.append(B('b1')) + a1.bs.append(B('b2')) + a1.bs.append(B('b3')) + + sess = create_session() + sess.save(a1) + sess.flush() + + assert table_b.count(table_b.c.a_id == None).scalar() == 0 + + assert table_b.count().scalar() == 3 + + a1 = sess.query(A).get(a1.id) + assert len(a1.bs) == 3 + a1.bs = list(a1.bs) + assert not class_mapper(B)._is_orphan(a1.bs[0]) + a1.bs[0].foo='b2modified' + a1.bs[1].foo='b3modified' + sess.flush() + + assert table_b.count().scalar() == 3 if __name__ == "__main__": - testbase.main() + testbase.main() diff --git a/test/orm/collection.py b/test/orm/collection.py index d421952b53..504a4d0cb7 100644 --- a/test/orm/collection.py +++ b/test/orm/collection.py @@ -257,7 +257,7 @@ class CollectionsTest(PersistTest): self.assert_(set(obj.attr) == set([e2])) self.assert_(e1 in canary.removed) self.assert_(e2 in canary.added) - + e3 = creator() real_list = [e3] obj.attr = real_list @@ -265,7 +265,7 @@ class CollectionsTest(PersistTest): self.assert_(set(obj.attr) == set([e3])) self.assert_(e2 in canary.removed) self.assert_(e3 in canary.added) - + e4 = creator() try: obj.attr = set([e4]) @@ -274,6 +274,21 @@ class CollectionsTest(PersistTest): self.assert_(e4 not in canary.data) self.assert_(e3 in canary.data) + e5 = creator() + e6 = creator() + e7 = creator() + obj.attr = [e5, e6, e7] + self.assert_(e5 in canary.added) + self.assert_(e6 in canary.added) + self.assert_(e7 in canary.added) + + obj.attr = [e6, e7] + self.assert_(e5 in canary.removed) + self.assert_(e6 in canary.added) + self.assert_(e7 in canary.added) + self.assert_(e6 not in canary.removed) + self.assert_(e7 not in canary.removed) + def test_list(self): self._test_adapter(list) self._test_list(list) diff --git a/test/orm/query.py b/test/orm/query.py index 775f7357e0..e2362d6c18 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -871,7 +871,7 @@ class ExternalColumnsTest(QueryTest): }) sess = create_session() - l = sess.query(User).select() + l = sess.query(User).all() assert [ User(id=7, concat=14, count=1), User(id=8, concat=16, count=3),