]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Add coverage for set collections, added missing clear() decorator
authorJason Kirtland <jek@discorporate.us>
Tue, 3 Jul 2007 02:41:12 +0000 (02:41 +0000)
committerJason Kirtland <jek@discorporate.us>
Tue, 3 Jul 2007 02:41:12 +0000 (02:41 +0000)
- Try not to be such an idiot when testing lists

lib/sqlalchemy/orm/collections.py
test/orm/collection.py

index 0e5a787f117f968f4eb5137a6fd19b20417d4c68..b9022ef369d341273c65c85bff07d3f78bdbb1c6 100644 (file)
@@ -894,6 +894,13 @@ def _set_decorators():
         _tidy(pop)
         return pop
 
+    def clear(fn):
+        def clear(self):
+            for item in list(self):
+                self.remove(item)
+        _tidy(clear)
+        return clear
+
     def update(fn):
         def update(self, value):
             for item in value:
index c55c542f75c2c55a5224a2416d0b1c82d19d5e07..be81e7a4d689375c62efc39289c73a634f181e5d 100644 (file)
@@ -43,14 +43,14 @@ def entity_maker():
     return Entity(_id)
 
 class CollectionsTest(testbase.PersistTest):
-    def _test_adapter(self, collection_class, creator=entity_maker,
+    def _test_adapter(self, typecallable, creator=entity_maker,
                       to_set=None):
         class Foo(object):
             pass
 
         canary = Canary()
         manager.register_attribute(Foo, 'attr', True, extension=canary,
-                                   typecallable=collection_class)
+                                   typecallable=typecallable)
 
         obj = Foo()
         adapter = collections.collection_adapter(obj.attr)
@@ -59,9 +59,9 @@ class CollectionsTest(testbase.PersistTest):
             to_set = lambda col: set(col)
 
         def assert_eq():
-            self.assert_(to_set(direct) == set(canary.data))
-            self.assert_(set(adapter) == set(canary.data))            
-        assert_ne = lambda: self.assert_(set(obj.attr) != set(canary.data))
+            self.assert_(to_set(direct) == canary.data)
+            self.assert_(set(adapter) == canary.data)
+        assert_ne = lambda: self.assert_(set(obj.attr) != canary.data)
 
         e1, e2 = creator(), creator()
 
@@ -81,13 +81,13 @@ class CollectionsTest(testbase.PersistTest):
         adapter.remove_with_event(e1)
         assert_eq()
 
-    def _test_list(self, collection_class, creator=entity_maker):
+    def _test_list(self, typecallable, creator=entity_maker):
         class Foo(object):
             pass
         
         canary = Canary()
         manager.register_attribute(Foo, 'attr', True, extension=canary,
-                                   collection_class=collection_class)
+                                   typecallable=typecallable)
 
         obj = Foo()
         adapter = collections.collection_adapter(obj.attr)
@@ -95,8 +95,8 @@ class CollectionsTest(testbase.PersistTest):
         control = list()
 
         def assert_eq():
-            self.assert_(set(direct) == set(canary.data))
-            self.assert_(set(adapter) == set(canary.data))
+            self.assert_(set(direct) == canary.data)
+            self.assert_(set(adapter) == canary.data)
             self.assert_(direct == control)
         
         # assume append() is available for list tests
@@ -229,11 +229,15 @@ class CollectionsTest(testbase.PersistTest):
             def insert(self, index, item):
                 self.data.insert(index, item)
             def pop(self, index=-1):
-                self.data.pop(index)
+                return self.data.pop(index)
             def extend(self):
                 assert False
             def __iter__(self):
                 return iter(self.data)
+            def __eq__(self, other):
+                return self.data == other
+            def __repr__(self):
+                return 'ListLike(%s)' % repr(self.data)
             
         self._test_adapter(ListLike)
         self._test_list(ListLike)
@@ -251,19 +255,210 @@ class CollectionsTest(testbase.PersistTest):
             def insert(self, index, item):
                 self.data.insert(index, item)
             def pop(self, index=-1):
-                self.data.pop(index)
+                return self.data.pop(index)
             def extend(self):
                 assert False
             def __iter__(self):
                 return iter(self.data)
+            def __eq__(self, other):
+                return self.data == other
+            def __repr__(self):
+                return 'ListIsh(%s)' % repr(self.data)
             
         self._test_adapter(ListIsh)
         self._test_list(ListIsh)
         self.assert_(getattr(ListIsh, '_sa_instrumented') == id(ListIsh))
 
+    def _test_set(self, typecallable, creator=entity_maker):
+        class Foo(object):
+            pass
+        
+        canary = Canary()
+        manager.register_attribute(Foo, 'attr', True, extension=canary,
+                                   typecallable=typecallable)
+
+        obj = Foo()
+        adapter = collections.collection_adapter(obj.attr)
+        direct = obj.attr
+        control = set()
+
+        def assert_eq():
+            self.assert_(set(direct) == canary.data)
+            self.assert_(set(adapter) == canary.data)
+            self.assert_(direct == control)
+
+        def addall(*values):
+            for item in values:
+                direct.add(item)
+                control.add(item)
+            assert_eq()
+        def zap():
+            for item in list(direct):
+                direct.remove(item)
+            control.clear()
+        
+        # assume add() is available for list tests
+        addall(creator())
+
+        if hasattr(direct, 'pop'):
+            direct.pop()
+            control.pop()
+            assert_eq()
+
+        if hasattr(direct, 'remove'):
+            e = creator()
+            addall(e)
+
+            direct.remove(e)
+            control.remove(e)
+            assert_eq()
+
+            e = creator()
+            try:
+                direct.remove(e)
+            except KeyError:
+                assert_eq()
+                self.assert_(e not in canary.removed)
+            else:
+                self.assert_(False)
+
+        if hasattr(direct, 'discard'):
+            e = creator()
+            addall(e)
+
+            direct.discard(e)
+            control.discard(e)
+            assert_eq()
+
+            e = creator()
+            direct.discard(e)
+            self.assert_(e not in canary.removed)
+            assert_eq()
+            
+        if hasattr(direct, 'update'):
+            e = creator()
+            addall(e)
+            
+            values = set([e, creator(), creator()])
+
+            direct.update(values)
+            control.update(values)
+            assert_eq()
+
+        if hasattr(direct, 'clear'):
+            addall(creator(), creator())
+            direct.clear()
+            control.clear()
+            assert_eq()
+
+        if hasattr(direct, 'difference_update'):
+            zap()
+            addall(creator(), creator())
+            values = set([creator()])
+
+            direct.difference_update(values)
+            control.difference_update(values)
+            assert_eq()
+            values.update(set([e, creator()]))
+            direct.difference_update(values)
+            control.difference_update(values)
+            assert_eq()
+
+        if hasattr(direct, 'intersection_update'):
+            zap()
+            e = creator()
+            addall(e, creator(), creator())
+            values = set(control)
+
+            direct.intersection_update(values)
+            control.intersection_update(values)
+            assert_eq()
+
+            values.update(set([e, creator()]))
+            direct.intersection_update(values)
+            control.intersection_update(values)
+            assert_eq()
+
+        if hasattr(direct, 'symmetric_difference_update'):
+            zap()
+            e = creator()
+            addall(e, creator(), creator())
+
+            values = set([e, creator()])
+            direct.symmetric_difference_update(values)
+            control.symmetric_difference_update(values)
+            assert_eq()
+
+            e = creator()
+            addall(e)
+            values = set([e])
+            direct.symmetric_difference_update(values)
+            control.symmetric_difference_update(values)
+            assert_eq()
+
+            values = set()
+            direct.symmetric_difference_update(values)
+            control.symmetric_difference_update(values)
+            assert_eq()
+
     def test_set(self):
         self._test_adapter(set)
+        self._test_set(set)
+
+    def test_set_subclass(self):
+        class MySet(set):
+            pass
+        self._test_adapter(MySet)
+        self._test_set(MySet) 
+        self.assert_(getattr(MySet, '_sa_instrumented') == id(MySet))
+
+    def test_set_duck(self):
+        class SetLike(object):
+            def __init__(self):
+                self.data = set()
+            def add(self, item):
+                self.data.add(item)
+            def remove(self, item):
+                self.data.remove(item)
+            def discard(self, item):
+                self.data.discard(item)
+            def pop(self):
+                return self.data.pop()
+            def update(self, other):
+                self.data.update(other)
+            def __iter__(self):
+                return iter(self.data)
+            def __eq__(self, other):
+                return self.data == other
+
+        self._test_adapter(SetLike)
+        self._test_set(SetLike)
+        self.assert_(getattr(SetLike, '_sa_instrumented') == id(SetLike))
+
+    def test_set_emulates(self):
+        class SetIsh(object):
+            __emulates__ = set
+            def __init__(self):
+                self.data = set()
+            def add(self, item):
+                self.data.add(item)
+            def remove(self, item):
+                self.data.remove(item)
+            def discard(self, item):
+                self.data.discard(item)
+            def pop(self):
+                return self.data.pop()
+            def update(self, other):
+                self.data.update(other)
+            def __iter__(self):
+                return iter(self.data)
+            def __eq__(self, other):
+                return self.data == other
 
+        self._test_adapter(SetIsh)
+        self._test_set(SetIsh)
+        self.assert_(getattr(SetIsh, '_sa_instrumented') == id(SetIsh))
+        
     def test_dict(self):
         def dictable_entity(a=None, b=None, c=None):
             global _id