]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Duplicate items in a list-based collection will
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 25 Nov 2008 04:43:04 +0000 (04:43 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 25 Nov 2008 04:43:04 +0000 (04:43 +0000)
be maintained when issuing INSERTs to
a "secondary" table in a many-to-many relation.
Assuming the m2m table has a unique or primary key
constraint on it, this will raise the expected
constraint violation instead of silently
dropping the duplicate entries. Note that the
old behavior remains for a one-to-many relation
since collection entries in that case
don't result in INSERT statements and SQLA doesn't
manually police collections. [ticket:1232]

CHANGES
lib/sqlalchemy/orm/attributes.py
test/orm/attributes.py
test/orm/manytomany.py

diff --git a/CHANGES b/CHANGES
index 5041ee9326ba4c67166e332f82bee88ba336a9a6..9a0c1868d3fdb3b4bc79d4b37a5348df6485d6be 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -12,6 +12,18 @@ CHANGES
       that the given argument is a FromClause,
       or Text/Select/Union, respectively.
 
+    - Duplicate items in a list-based collection will
+      be maintained when issuing INSERTs to
+      a "secondary" table in a many-to-many relation.
+      Assuming the m2m table has a unique or primary key 
+      constraint on it, this will raise the expected 
+      constraint violation instead of silently
+      dropping the duplicate entries. Note that the 
+      old behavior remains for a one-to-many relation
+      since collection entries in that case
+      don't result in INSERT statements and SQLA doesn't
+      manually police collections. [ticket:1232]
+      
     - Query.add_column() can accept FromClause objects
       in the same manner as session.query() can.
 
index 662ea05d3c5b73354fdd5865e00921dda2a7f348..79be76c3adc1f0a894da05df69e825678972b120 100644 (file)
@@ -1427,11 +1427,15 @@ class History(tuple):
             elif original is NEVER_SET:
                 return cls((), list(current), ())
             else:
-                collection = util.OrderedIdentitySet(current)
-                s = util.OrderedIdentitySet(original)
-                return cls(list(collection.difference(s)),
-                           list(collection.intersection(s)),
-                           list(s.difference(collection)))
+                current_set = util.IdentitySet(current)
+                original_set = util.IdentitySet(original)
+
+                # ensure duplicates are maintained
+                return cls(
+                    [x for x in current if x not in original_set],
+                    [x for x in current if x in original_set],
+                    [x for x in original if x not in current_set]
+                )
         else:
             if current is NO_VALUE:
                 if original not in [None, NEVER_SET, NO_VALUE]:
index e2484d17d45df9a325e380137ab67b795277a107..074c236f70933bbcd8e79b8024bdb7e8a8df2871 100644 (file)
@@ -1047,6 +1047,19 @@ class HistoryTest(_base.ORMTest):
         f.someattr = [there]
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [], []))
 
+        # case 4.  ensure duplicates show up, order is maintained
+        f = Foo()
+        f.someattr.append(hi)
+        f.someattr.append(there)
+        f.someattr.append(hi)
+        eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi, there, hi], [], []))
+
+        attributes.instance_state(f).commit_all()
+        eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi, there, hi], ()))
+        
+        f.someattr = []
+        eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], [hi, there, hi]))
+        
     def test_collections_via_backref(self):
         class Foo(_base.BasicEntity):
             pass
index 7a60f01c6f3680df081df38b16167fc5e5fa1bf7..61409c7169a2871c5c948a74c2fb8ee6ddb82bdd 100644 (file)
@@ -63,7 +63,7 @@ class M2MTest(_base.MappedTest):
                                  repr(self.outputs)))
 
     @testing.resolve_artifact_names
-    def testerror(self):
+    def test_error(self):
         mapper(Place, place, properties={
             'transitions':relation(Transition, secondary=place_input, backref='places')
         })
@@ -74,8 +74,8 @@ class M2MTest(_base.MappedTest):
                                  sa.orm.compile_mappers)
 
     @testing.resolve_artifact_names
-    def testcircular(self):
-        """tests a many-to-many relationship from a table to itself."""
+    def test_circular(self):
+        """test a many-to-many relationship from a table to itself."""
 
         Place.mapper = mapper(Place, place)
 
@@ -124,8 +124,8 @@ class M2MTest(_base.MappedTest):
         sess.flush()
 
     @testing.resolve_artifact_names
-    def testdouble(self):
-        """tests that a mapper can have two eager relations to the same table, via
+    def test_double(self):
+        """test that a mapper can have two eager relations to the same table, via
         two different association tables.  aliases are required."""
 
         Place.mapper = mapper(Place, place, properties = {
@@ -155,7 +155,7 @@ class M2MTest(_base.MappedTest):
             })
 
     @testing.resolve_artifact_names
-    def testbidirectional(self):
+    def test_bidirectional(self):
         """tests a many-to-many backrefs"""
         Place.mapper = mapper(Place, place)
         Transition.mapper = mapper(Transition, transition, properties = dict(
@@ -200,18 +200,20 @@ class M2MTest2(_base.MappedTest):
             Column('course_id', String(20), ForeignKey('course.name'),
                    primary_key=True))
 
-    @testing.resolve_artifact_names
-    def testcircular(self):
-        class Student(object):
+    def setup_classes(self):
+        class Student(_base.BasicEntity):
             def __init__(self, name=''):
                 self.name = name
-        class Course(object):
+        class Course(_base.BasicEntity):
             def __init__(self, name=''):
                 self.name = name
 
+    @testing.resolve_artifact_names
+    def test_circular(self):
+
         mapper(Student, student)
         mapper(Course, course, properties={
-            'students': relation(Student, enroll, lazy=True, backref='courses')})
+            'students': relation(Student, enroll, backref='courses')})
 
         sess = create_session()
         s1 = Student('Student1')
@@ -232,15 +234,25 @@ class M2MTest2(_base.MappedTest):
         del s.courses[1]
         self.assert_(len(s.courses) == 2)
 
+    @testing.resolve_artifact_names
+    def test_dupliates_raise(self):
+        """test constraint error is raised for dupe entries in a list"""
+        
+        mapper(Student, student)
+        mapper(Course, course, properties={
+            'students': relation(Student, enroll, backref='courses')})
+
+        sess = create_session()
+        s1 = Student("s1")
+        c1 = Course('c1')
+        s1.courses.append(c1)
+        s1.courses.append(c1)
+        sess.add(s1)
+        self.assertRaises(sa.exc.DBAPIError, sess.flush)
+        
     @testing.resolve_artifact_names
     def test_delete(self):
         """A many-to-many table gets cleared out with deletion from the backref side"""
-        class Student(object):
-            def __init__(self, name=''):
-                self.name = name
-        class Course(object):
-            def __init__(self, name=''):
-                self.name = name
 
         mapper(Student, student)
         mapper(Course, course, properties = {
@@ -286,7 +298,7 @@ class M2MTest3(_base.MappedTest):
             Column('b2', sa.Boolean))
 
     @testing.resolve_artifact_names
-    def testbasic(self):
+    def test_basic(self):
         class C(object):pass
         class A(object):pass
         class B(object):pass