]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fixed in-place set mutation operator support [ticket:920]
authorJason Kirtland <jek@discorporate.us>
Fri, 4 Jan 2008 20:17:42 +0000 (20:17 +0000)
committerJason Kirtland <jek@discorporate.us>
Fri, 4 Jan 2008 20:17:42 +0000 (20:17 +0000)
CHANGES
lib/sqlalchemy/ext/associationproxy.py
lib/sqlalchemy/orm/collections.py
test/ext/associationproxy.py
test/orm/collection.py

diff --git a/CHANGES b/CHANGES
index b839ec7df14e36b57773bf061d08e3e46f88e418..4deba3b494cf3123b8333ce43ff622a203448bd0 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -4,22 +4,23 @@ CHANGES
 0.4.3
 -----
 - orm
-    - added very rudimentary yielding iterator behavior to Query.  Call
-      query.yield_per(<number of rows>) and evaluate the Query in an 
+    - Added very rudimentary yielding iterator behavior to Query.  Call
+      query.yield_per(<number of rows>) and evaluate the Query in an
       iterative context; every collection of N rows will be packaged up
-      and yielded.  Use this method with extreme caution since it does 
+      and yielded.  Use this method with extreme caution since it does
       not attempt to reconcile eagerly loaded collections across
       result batch boundaries, nor will it behave nicely if the same
-      instance occurs in more than one batch.  This means that an eagerly 
+      instance occurs in more than one batch.  This means that an eagerly
       loaded collection will get cleared out if it's referenced in more than
       one batch, and in all cases attributes will be overwritten on instances
       that occur in more than one batch.
 
-- dialects
+   - Fixed in-place set mutation operators for set collections and association
+     proxied sets. [ticket:920]
 
-    - PostgreSQL
-       - Fixed the missing call to subtype result processor for the PGArray
-         type. [ticket:913]
+- dialects
+    - Fixed the missing call to subtype result processor for the PGArray
+      type. [ticket:913]
 
 0.4.2
 -----
index 472bd1b2cc790fc6f878f2bab634eb1fede2587f..c5a2b4d073c010a5c9ae4d6d844e516c75a5eb0a 100644 (file)
@@ -176,8 +176,9 @@ class AssociationProxy(object):
                 self._scalar_set(target, values)
         else:
             proxy = self.__get__(obj, None)
-            proxy.clear()
-            self._set(proxy, values)
+            if proxy is not values:
+                proxy.clear()
+                self._set(proxy, values)
 
     def __delete__(self, obj):
         delattr(obj, self.key)
@@ -653,7 +654,12 @@ class _AssociationSet(object):
         for value in other:
             self.add(value)
 
-    __ior__ = update
+    def __ior__(self, other):
+        if util.duck_type_collection(other) is not set:
+            return NotImplemented
+        for value in other:
+            self.add(value)
+        return self
 
     def _set(self):
         return util.Set(iter(self))
@@ -672,7 +678,12 @@ class _AssociationSet(object):
         for value in other:
             self.discard(value)
 
-    __isub__ = difference_update
+    def __isub__(self, other):
+        if util.duck_type_collection(other) is not set:
+            return NotImplemented
+        for value in other:
+            self.discard(value)
+        return self
 
     def intersection(self, other):
         return util.Set(self).intersection(other)
@@ -689,7 +700,18 @@ class _AssociationSet(object):
         for value in add:
             self.add(value)
 
-    __iand__ = intersection_update
+    def __iand__(self, other):
+        if util.duck_type_collection(other) is not set:
+            return NotImplemented
+        want, have = self.intersection(other), util.Set(self)
+
+        remove, add = have - want, want - have
+
+        for value in remove:
+            self.remove(value)
+        for value in add:
+            self.add(value)
+        return self
 
     def symmetric_difference(self, other):
         return util.Set(self).symmetric_difference(other)
@@ -706,7 +728,18 @@ class _AssociationSet(object):
         for value in add:
             self.add(value)
 
-    __ixor__ = symmetric_difference_update
+    def __ixor__(self, other):
+        if util.duck_type_collection(other) is not set:
+            return NotImplemented
+        want, have = self.symmetric_difference(other), util.Set(self)
+
+        remove, add = have - want, want - have
+
+        for value in remove:
+            self.remove(value)
+        for value in add:
+            self.add(value)
+        return self
 
     def issubset(self, other):
         return util.Set(self).issubset(other)
index ddbf6f0051680a46a43c4788eac8607e026e0281..106601640029496459b7f9d2f85f11af5c81e1ef 100644 (file)
@@ -1138,7 +1138,17 @@ def _set_decorators():
                     self.add(item)
         _tidy(update)
         return update
-    __ior__ = update
+
+    def __ior__(fn):
+        def __ior__(self, value):
+            if sautil.duck_type_collection(value) is not set:
+                return NotImplemented
+            for item in value:
+                if item not in self:
+                    self.add(item)
+            return self
+        _tidy(__ior__)
+        return __ior__
 
     def difference_update(fn):
         def difference_update(self, value):
@@ -1146,7 +1156,16 @@ def _set_decorators():
                 self.discard(item)
         _tidy(difference_update)
         return difference_update
-    __isub__ = difference_update
+
+    def __isub__(fn):
+        def __isub__(self, value):
+            if sautil.duck_type_collection(value) is not set:
+                return NotImplemented
+            for item in value:
+                self.discard(item)
+            return self
+        _tidy(__isub__)
+        return __isub__
 
     def intersection_update(fn):
         def intersection_update(self, other):
@@ -1159,7 +1178,21 @@ def _set_decorators():
                 self.add(item)
         _tidy(intersection_update)
         return intersection_update
-    __iand__ = intersection_update
+
+    def __iand__(fn):
+        def __iand__(self, other):
+            if sautil.duck_type_collection(other) is not set:
+                return NotImplemented
+            want, have = self.intersection(other), sautil.Set(self)
+            remove, add = have - want, want - have
+
+            for item in remove:
+                self.remove(item)
+            for item in add:
+                self.add(item)
+            return self
+        _tidy(__iand__)
+        return __iand__
 
     def symmetric_difference_update(fn):
         def symmetric_difference_update(self, other):
@@ -1172,7 +1205,21 @@ def _set_decorators():
                 self.add(item)
         _tidy(symmetric_difference_update)
         return symmetric_difference_update
-    __ixor__ = symmetric_difference_update
+
+    def __ixor__(fn):
+        def __ixor__(self, other):
+            if sautil.duck_type_collection(other) is not set:
+                return NotImplemented
+            want, have = self.symmetric_difference(other), sautil.Set(self)
+            remove, add = have - want, want - have
+
+            for item in remove:
+                self.remove(item)
+            for item in add:
+                self.add(item)
+            return self
+        _tidy(__ixor__)
+        return __ixor__
 
     l = locals().copy()
     l.pop('_tidy')
index fe8b40255c8d59b69ca2ad3fbcb710e078626a34..b3ce69a97d33957db312dea0c7b80c04e3cd831a 100644 (file)
@@ -485,6 +485,38 @@ class SetTest(_CollectionOperations):
                         print 'got', repr(p.children)
                         raise
 
+        # in-place mutations
+        for op in ('|=', '-=', '&=', '^='):
+            for base in (['a', 'b', 'c'], []):
+                for other in (set(['a','b','c']), set(['a','b','c','d']),
+                              set(['a']), set(['a','b']),
+                              set(['c','d']), set(['e', 'f', 'g']),
+                              set()):
+                    p = Parent('p')
+                    p.children = base[:]
+                    control = set(base[:])
+
+                    exec "p.children %s other" % op
+                    exec "control %s other" % op
+
+                    try:
+                        self.assert_(p.children == control)
+                    except:
+                        print 'Test %s %s %s:' % (set(base), op, other)
+                        print 'want', repr(control)
+                        print 'got', repr(p.children)
+                        raise
+
+                    p = self.roundtrip(p)
+
+                    try:
+                        self.assert_(p.children == control)
+                    except:
+                        print 'Test %s %s %s:' % (base, op, other)
+                        print 'want', repr(control)
+                        print 'got', repr(p.children)
+                        raise
+
 
 class CustomSetTest(SetTest):
     def __init__(self, *args, **kw):
index 43b2f41e25314d79deef4faf79954bcdefe79912..6e50a85125fea11db59db98ecbd181f7d0406f5c 100644 (file)
@@ -74,12 +74,12 @@ class CollectionsTest(PersistTest):
 
         adapter.append_with_event(e1)
         assert_eq()
-        
+
         adapter.append_without_event(e2)
         assert_ne()
         canary.data.add(e2)
         assert_eq()
-        
+
         adapter.remove_without_event(e2)
         assert_ne()
         canary.data.remove(e2)
@@ -91,7 +91,7 @@ class CollectionsTest(PersistTest):
     def _test_list(self, typecallable, creator=entity_maker):
         class Foo(object):
             pass
-        
+
         canary = Canary()
         attributes.register_class(Foo)
         attributes.register_attribute(Foo, 'attr', True, extension=canary,
@@ -106,7 +106,7 @@ class CollectionsTest(PersistTest):
             self.assert_(set(direct) == canary.data)
             self.assert_(set(adapter) == canary.data)
             self.assert_(direct == control)
-        
+
         # assume append() is available for list tests
         e = creator()
         direct.append(e)
@@ -122,7 +122,7 @@ class CollectionsTest(PersistTest):
             e = creator()
             direct.append(e)
             control.append(e)
-            
+
             e = creator()
             direct[0] = e
             control[0] = e
@@ -174,7 +174,7 @@ class CollectionsTest(PersistTest):
             e = creator()
             direct.append(e)
             control.append(e)
-            
+
             direct.remove(e)
             control.remove(e)
             assert_eq()
@@ -204,7 +204,7 @@ class CollectionsTest(PersistTest):
             direct[1::2] = values
             control[1::2] = values
             assert_eq()
-            
+
         if hasattr(direct, '__delslice__'):
             for i in range(1, 4):
                 e = creator()
@@ -212,7 +212,7 @@ class CollectionsTest(PersistTest):
                 control.append(e)
 
             del direct[-1:]
-            del control[-1:] 
+            del control[-1:]
             assert_eq()
 
             del direct[1:2]
@@ -321,7 +321,7 @@ class CollectionsTest(PersistTest):
                 return self.data == other
             def __repr__(self):
                 return 'ListLike(%s)' % repr(self.data)
-            
+
         self._test_adapter(ListLike)
         self._test_list(ListLike)
         self._test_list_bulk(ListLike)
@@ -348,7 +348,7 @@ class CollectionsTest(PersistTest):
                 return self.data == other
             def __repr__(self):
                 return 'ListIsh(%s)' % repr(self.data)
-            
+
         self._test_adapter(ListIsh)
         self._test_list(ListIsh)
         self._test_list_bulk(ListIsh)
@@ -382,7 +382,7 @@ class CollectionsTest(PersistTest):
             for item in list(direct):
                 direct.remove(item)
             control.clear()
-        
+
         # assume add() is available for list tests
         addall(creator())
 
@@ -420,17 +420,35 @@ class CollectionsTest(PersistTest):
             direct.discard(e)
             self.assert_(e not in canary.removed)
             assert_eq()
-            
+
         if hasattr(direct, 'update'):
+            zap()
             e = creator()
             addall(e)
-            
+
             values = set([e, creator(), creator()])
 
             direct.update(values)
             control.update(values)
             assert_eq()
 
+        if hasattr(direct, '__ior__'):
+            zap()
+            e = creator()
+            addall(e)
+
+            values = set([e, creator(), creator()])
+
+            direct |= values
+            control |= values
+            assert_eq()
+
+            try:
+                direct |= [e, creator()]
+                assert False
+            except TypeError:
+                assert True
+
         if hasattr(direct, 'clear'):
             addall(creator(), creator())
             direct.clear()
@@ -439,6 +457,7 @@ class CollectionsTest(PersistTest):
 
         if hasattr(direct, 'difference_update'):
             zap()
+            e = creator()
             addall(creator(), creator())
             values = set([creator()])
 
@@ -450,6 +469,26 @@ class CollectionsTest(PersistTest):
             control.difference_update(values)
             assert_eq()
 
+        if hasattr(direct, '__isub__'):
+            zap()
+            e = creator()
+            addall(creator(), creator())
+            values = set([creator()])
+
+            direct -= values
+            control -= values
+            assert_eq()
+            values.update(set([e, creator()]))
+            direct -= values
+            control -= values
+            assert_eq()
+
+            try:
+                direct -= [e, creator()]
+                assert False
+            except TypeError:
+                assert True
+
         if hasattr(direct, 'intersection_update'):
             zap()
             e = creator()
@@ -465,6 +504,27 @@ class CollectionsTest(PersistTest):
             control.intersection_update(values)
             assert_eq()
 
+        if hasattr(direct, '__iand__'):
+            zap()
+            e = creator()
+            addall(e, creator(), creator())
+            values = set(control)
+
+            direct &= values
+            control &= values
+            assert_eq()
+
+            values.update(set([e, creator()]))
+            direct &= values
+            control &= values
+            assert_eq()
+
+            try:
+                direct &= [e, creator()]
+                assert False
+            except TypeError:
+                assert True
+
         if hasattr(direct, 'symmetric_difference_update'):
             zap()
             e = creator()
@@ -487,6 +547,34 @@ class CollectionsTest(PersistTest):
             control.symmetric_difference_update(values)
             assert_eq()
 
+        if hasattr(direct, '__ixor__'):
+            zap()
+            e = creator()
+            addall(e, creator(), creator())
+
+            values = set([e, creator()])
+            direct ^= values
+            control ^= values
+            assert_eq()
+
+            e = creator()
+            addall(e)
+            values = set([e])
+            direct ^= values
+            control ^= values
+            assert_eq()
+
+            values = set()
+            direct ^= values
+            control ^= values
+            assert_eq()
+
+            try:
+                direct ^= [e, creator()]
+                assert False
+            except TypeError:
+                assert True
+
     def _test_set_bulk(self, typecallable, creator=entity_maker):
         class Foo(object):
             pass
@@ -513,7 +601,7 @@ class CollectionsTest(PersistTest):
         self.assert_(obj.attr == set([e2]))
         self.assert_(e1 in canary.removed)
         self.assert_(e2 in canary.added)
+
         e3 = creator()
         real_set = set([e3])
         obj.attr = real_set
@@ -521,7 +609,7 @@ class CollectionsTest(PersistTest):
         self.assert_(obj.attr == set([e3]))
         self.assert_(e2 in canary.removed)
         self.assert_(e3 in canary.added)
-       
+
         e4 = creator()
         try:
             obj.attr = [e4]
@@ -620,7 +708,7 @@ class CollectionsTest(PersistTest):
             for item in list(adapter):
                 direct.remove(item)
             control.clear()
-        
+
         # assume an 'set' method is available for tests
         addall(creator())
 
@@ -655,7 +743,7 @@ class CollectionsTest(PersistTest):
             direct.clear()
             control.clear()
             assert_eq()
-            
+
             direct.clear()
             control.clear()
             assert_eq()
@@ -678,7 +766,7 @@ class CollectionsTest(PersistTest):
             zap()
             e = creator()
             addall(e)
-            
+
             direct.popitem()
             control.popitem()
             assert_eq()
@@ -907,7 +995,7 @@ class CollectionsTest(PersistTest):
     def _test_object(self, typecallable, creator=entity_maker):
         class Foo(object):
             pass
-        
+
         canary = Canary()
         attributes.register_class(Foo)
         attributes.register_attribute(Foo, 'attr', True, extension=canary,
@@ -933,7 +1021,7 @@ class CollectionsTest(PersistTest):
         direct.zark(e)
         control.remove(e)
         assert_eq()
-        
+
         e = creator()
         direct.maybe_zark(e)
         control.discard(e)
@@ -1035,7 +1123,7 @@ class CollectionsTest(PersistTest):
             @collection.removes_return()
             def pop(self, key):
                 return self.data.pop()
-            
+
             @collection.iterator
             def __iter__(self):
                 return iter(self.data)
@@ -1136,14 +1224,14 @@ class CollectionsTest(PersistTest):
         col1.append(e3)
         self.assert_(e3 not in canary.data)
         self.assert_(collections.collection_adapter(col1) is None)
-        
+
         obj.attr[0] = e3
         self.assert_(e3 in canary.data)
 
 class DictHelpersTest(ORMTest):
     def define_tables(self, metadata):
         global parents, children, Parent, Child
-        
+
         parents = Table('parents', metadata,
                         Column('id', Integer, primary_key=True),
                         Column('label', String))
@@ -1170,7 +1258,7 @@ class DictHelpersTest(ORMTest):
             'children': relation(Child, collection_class=collection_class,
                                  cascade="all, delete-orphan")
             })
-        
+
         p = Parent()
         p.children['foo'] = Child('foo', 'value')
         p.children['bar'] = Child('bar', 'value')
@@ -1187,15 +1275,15 @@ class DictHelpersTest(ORMTest):
 
         collections.collection_adapter(p.children).append_with_event(
             Child('foo', 'newvalue'))
-        
+
         session.flush()
         session.clear()
-        
+
         p = session.query(Parent).get(pid)
-        
+
         self.assert_(set(p.children.keys()) == set(['foo', 'bar']))
         self.assert_(p.children['foo'].id != cid)
-        
+
         self.assert_(len(list(collections.collection_adapter(p.children))) == 2)
         session.flush()
         session.clear()
@@ -1205,7 +1293,7 @@ class DictHelpersTest(ORMTest):
 
         collections.collection_adapter(p.children).remove_with_event(
             p.children['foo'])
-        
+
         self.assert_(len(list(collections.collection_adapter(p.children))) == 1)
         session.flush()
         session.clear()
@@ -1220,7 +1308,7 @@ class DictHelpersTest(ORMTest):
 
         p = session.query(Parent).get(pid)
         self.assert_(len(list(collections.collection_adapter(p.children))) == 0)
-        
+
 
     def _test_composite_mapped(self, collection_class):
         mapper(Child, children)
@@ -1228,7 +1316,7 @@ class DictHelpersTest(ORMTest):
             'children': relation(Child, collection_class=collection_class,
                                  cascade="all, delete-orphan")
             })
-        
+
         p = Parent()
         p.children[('foo', '1')] = Child('foo', '1', 'value 1')
         p.children[('foo', '2')] = Child('foo', '2', 'value 2')
@@ -1238,7 +1326,7 @@ class DictHelpersTest(ORMTest):
         session.flush()
         pid = p.id
         session.clear()
-        
+
         p = session.query(Parent).get(pid)
 
         self.assert_(set(p.children.keys()) == set([('foo', '1'), ('foo', '2')]))
@@ -1246,17 +1334,17 @@ class DictHelpersTest(ORMTest):
 
         collections.collection_adapter(p.children).append_with_event(
             Child('foo', '1', 'newvalue'))
-        
+
         session.flush()
         session.clear()
-        
+
         p = session.query(Parent).get(pid)
-        
+
         self.assert_(set(p.children.keys()) == set([('foo', '1'), ('foo', '2')]))
         self.assert_(p.children[('foo', '1')].id != cid)
-        
+
         self.assert_(len(list(collections.collection_adapter(p.children))) == 2)
-        
+
     def test_mapped_collection(self):
         collection_class = collections.mapped_collection(lambda c: c.a)
         self._test_scalar_mapped(collection_class)