From: Mike Bayer Date: Tue, 25 Nov 2008 04:43:04 +0000 (+0000) Subject: - Duplicate items in a list-based collection will X-Git-Tag: rel_0_5_0~156 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=332f5ee2662835ed1ca008043d40c37d7cddc270;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Duplicate items in a list-based collection will be maintained when issuing INSERTs to a "secondary" table in a many-to-many relation. Assuming the m2m table has a unique or primary key constraint on it, this will raise the expected constraint violation instead of silently dropping the duplicate entries. Note that the old behavior remains for a one-to-many relation since collection entries in that case don't result in INSERT statements and SQLA doesn't manually police collections. [ticket:1232] --- diff --git a/CHANGES b/CHANGES index 5041ee9326..9a0c1868d3 100644 --- a/CHANGES +++ b/CHANGES @@ -12,6 +12,18 @@ CHANGES that the given argument is a FromClause, or Text/Select/Union, respectively. + - Duplicate items in a list-based collection will + be maintained when issuing INSERTs to + a "secondary" table in a many-to-many relation. + Assuming the m2m table has a unique or primary key + constraint on it, this will raise the expected + constraint violation instead of silently + dropping the duplicate entries. Note that the + old behavior remains for a one-to-many relation + since collection entries in that case + don't result in INSERT statements and SQLA doesn't + manually police collections. [ticket:1232] + - Query.add_column() can accept FromClause objects in the same manner as session.query() can. diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 662ea05d3c..79be76c3ad 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -1427,11 +1427,15 @@ class History(tuple): elif original is NEVER_SET: return cls((), list(current), ()) else: - collection = util.OrderedIdentitySet(current) - s = util.OrderedIdentitySet(original) - return cls(list(collection.difference(s)), - list(collection.intersection(s)), - list(s.difference(collection))) + current_set = util.IdentitySet(current) + original_set = util.IdentitySet(original) + + # ensure duplicates are maintained + return cls( + [x for x in current if x not in original_set], + [x for x in current if x in original_set], + [x for x in original if x not in current_set] + ) else: if current is NO_VALUE: if original not in [None, NEVER_SET, NO_VALUE]: diff --git a/test/orm/attributes.py b/test/orm/attributes.py index e2484d17d4..074c236f70 100644 --- a/test/orm/attributes.py +++ b/test/orm/attributes.py @@ -1047,6 +1047,19 @@ class HistoryTest(_base.ORMTest): f.someattr = [there] eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [], [])) + # case 4. ensure duplicates show up, order is maintained + f = Foo() + f.someattr.append(hi) + f.someattr.append(there) + f.someattr.append(hi) + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi, there, hi], [], [])) + + attributes.instance_state(f).commit_all() + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi, there, hi], ())) + + f.someattr = [] + eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], [hi, there, hi])) + def test_collections_via_backref(self): class Foo(_base.BasicEntity): pass diff --git a/test/orm/manytomany.py b/test/orm/manytomany.py index 7a60f01c6f..61409c7169 100644 --- a/test/orm/manytomany.py +++ b/test/orm/manytomany.py @@ -63,7 +63,7 @@ class M2MTest(_base.MappedTest): repr(self.outputs))) @testing.resolve_artifact_names - def testerror(self): + def test_error(self): mapper(Place, place, properties={ 'transitions':relation(Transition, secondary=place_input, backref='places') }) @@ -74,8 +74,8 @@ class M2MTest(_base.MappedTest): sa.orm.compile_mappers) @testing.resolve_artifact_names - def testcircular(self): - """tests a many-to-many relationship from a table to itself.""" + def test_circular(self): + """test a many-to-many relationship from a table to itself.""" Place.mapper = mapper(Place, place) @@ -124,8 +124,8 @@ class M2MTest(_base.MappedTest): sess.flush() @testing.resolve_artifact_names - def testdouble(self): - """tests that a mapper can have two eager relations to the same table, via + def test_double(self): + """test that a mapper can have two eager relations to the same table, via two different association tables. aliases are required.""" Place.mapper = mapper(Place, place, properties = { @@ -155,7 +155,7 @@ class M2MTest(_base.MappedTest): }) @testing.resolve_artifact_names - def testbidirectional(self): + def test_bidirectional(self): """tests a many-to-many backrefs""" Place.mapper = mapper(Place, place) Transition.mapper = mapper(Transition, transition, properties = dict( @@ -200,18 +200,20 @@ class M2MTest2(_base.MappedTest): Column('course_id', String(20), ForeignKey('course.name'), primary_key=True)) - @testing.resolve_artifact_names - def testcircular(self): - class Student(object): + def setup_classes(self): + class Student(_base.BasicEntity): def __init__(self, name=''): self.name = name - class Course(object): + class Course(_base.BasicEntity): def __init__(self, name=''): self.name = name + @testing.resolve_artifact_names + def test_circular(self): + mapper(Student, student) mapper(Course, course, properties={ - 'students': relation(Student, enroll, lazy=True, backref='courses')}) + 'students': relation(Student, enroll, backref='courses')}) sess = create_session() s1 = Student('Student1') @@ -232,15 +234,25 @@ class M2MTest2(_base.MappedTest): del s.courses[1] self.assert_(len(s.courses) == 2) + @testing.resolve_artifact_names + def test_dupliates_raise(self): + """test constraint error is raised for dupe entries in a list""" + + mapper(Student, student) + mapper(Course, course, properties={ + 'students': relation(Student, enroll, backref='courses')}) + + sess = create_session() + s1 = Student("s1") + c1 = Course('c1') + s1.courses.append(c1) + s1.courses.append(c1) + sess.add(s1) + self.assertRaises(sa.exc.DBAPIError, sess.flush) + @testing.resolve_artifact_names def test_delete(self): """A many-to-many table gets cleared out with deletion from the backref side""" - class Student(object): - def __init__(self, name=''): - self.name = name - class Course(object): - def __init__(self, name=''): - self.name = name mapper(Student, student) mapper(Course, course, properties = { @@ -286,7 +298,7 @@ class M2MTest3(_base.MappedTest): Column('b2', sa.Boolean)) @testing.resolve_artifact_names - def testbasic(self): + def test_basic(self): class C(object):pass class A(object):pass class B(object):pass