]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Coverage of list collections, and matching fixes in slice mutation
authorJason Kirtland <jek@discorporate.us>
Tue, 3 Jul 2007 01:34:53 +0000 (01:34 +0000)
committerJason Kirtland <jek@discorporate.us>
Tue, 3 Jul 2007 01:34:53 +0000 (01:34 +0000)
lib/sqlalchemy/orm/collections.py
test/orm/collection.py

index 35442748a3511a7251a397d828df3dcd91cce947..0e5a787f117f968f4eb5137a6fd19b20417d4c68 100644 (file)
@@ -665,26 +665,59 @@ def _list_decorators():
         _tidy(remove)
         return remove
 
+    def insert(fn):
+        def insert(self, index, value):
+            __set(self, value)
+            fn(self, index, value)
+        _tidy(insert)
+        return insert
+
     def __setitem__(fn):
         def __setitem__(self, index, value):
             if not isinstance(index, slice):
+                existing = self[index]
+                if existing is not None:
+                    __del(self, existing)
                 __set(self, value)
                 fn(self, index, value)
             else:
-                rng = range(slice.start or 0, slice.stop or 0, slice.step or 1)
-                if len(value) != len(rng):
-                    raise ValueError
-                for i in rng:
-                    __set(self, value[i])
-                    fn(self, i, value[i])
+                # slice assignment requires __delitem__, insert, __len__
+                if index.stop is None:
+                    stop = 0
+                elif index.stop < 0:
+                    stop = len(self) + index.stop
+                else:
+                    stop = index.stop
+                step = index.step or 1
+                rng = range(index.start or 0, stop, step)
+                if step == 1:
+                    for i in rng:
+                        del self[index.start]
+                    i = index.start
+                    for item in value:
+                        self.insert(i, item)
+                        i += 1
+                else:
+                    if len(value) != len(rng):
+                        raise ValueError
+                    for i, item in zip(rng, value):
+                        self.__setitem__(i, item)
         _tidy(__setitem__)
         return __setitem__
 
     def __delitem__(fn):
         def __delitem__(self, index):
-            item = self[index]
-            __del(self, item)
-            fn(self, index)
+            if not isinstance(index, slice):
+                item = self[index]
+                __del(self, item)
+                fn(self, index)
+            else:
+                # slice deletion requires __getslice__ and a slice-groking
+                # __getitem__ for stepped deletion
+                # note: not breaking this into atomic dels
+                for item in self[index]:
+                    __del(self, item)
+                fn(self, index)
         _tidy(__delitem__)
         return __delitem__
 
index 783193c96166edfca32e844bd84d585dcb0be8dd..c55c542f75c2c55a5224a2416d0b1c82d19d5e07 100644 (file)
 import testbase
 from sqlalchemy import *
-from sqlalchemy.orm import create_session, mapper, relation
+from sqlalchemy.orm import create_session, mapper, relation, \
+    interfaces, attributes
 import sqlalchemy.orm.collections as collections
 from sqlalchemy.orm.collections import collection
 from sqlalchemy import util
+from operator import and_
 
 
+class Canary(interfaces.AttributeExtension):
+    def __init__(self):
+        self.data = set()
+        self.added = set()
+        self.removed = set()
+    def append(self, obj, value, initiator):
+        assert value not in self.added
+        self.data.add(value)
+        self.added.add(value)
+    def remove(self, obj, value, initiator):
+        assert value not in self.removed
+        self.data.remove(value)
+        self.removed.add(value)
+    def set(self, obj, value, oldvalue, initiator):
+        if oldvalue is not None:
+            self.remove(obj, oldvalue, None)
+        self.append(obj, value, None)
+
+class Entity(object):
+    def __init__(self, a=None, b=None, c=None):
+        self.a = a
+        self.b = b
+        self.c = c
+    def __repr__(self):
+        return str((id(self), self.a, self.b, self.c))
+
+manager = attributes.AttributeManager()
+
+_id = 1
+def entity_maker():
+    global _id
+    _id += 1
+    return Entity(_id)
+
 class CollectionsTest(testbase.PersistTest):
-    # FIXME: ...
-    pass
+    def _test_adapter(self, collection_class, creator=entity_maker,
+                      to_set=None):
+        class Foo(object):
+            pass
+
+        canary = Canary()
+        manager.register_attribute(Foo, 'attr', True, extension=canary,
+                                   typecallable=collection_class)
+
+        obj = Foo()
+        adapter = collections.collection_adapter(obj.attr)
+        direct = obj.attr
+        if to_set is None:
+            to_set = lambda col: set(col)
+
+        def assert_eq():
+            self.assert_(to_set(direct) == set(canary.data))
+            self.assert_(set(adapter) == set(canary.data))            
+        assert_ne = lambda: self.assert_(set(obj.attr) != set(canary.data))
+
+        e1, e2 = creator(), creator()
+
+        adapter.append_with_event(e1)
+        assert_eq()
+        
+        adapter.append_without_event(e2)
+        assert_ne()
+        canary.data.add(e2)
+        assert_eq()
+        
+        adapter.remove_without_event(e2)
+        assert_ne()
+        canary.data.remove(e2)
+        assert_eq()
+
+        adapter.remove_with_event(e1)
+        assert_eq()
+
+    def _test_list(self, collection_class, creator=entity_maker):
+        class Foo(object):
+            pass
+        
+        canary = Canary()
+        manager.register_attribute(Foo, 'attr', True, extension=canary,
+                                   collection_class=collection_class)
+
+        obj = Foo()
+        adapter = collections.collection_adapter(obj.attr)
+        direct = obj.attr
+        control = list()
+
+        def assert_eq():
+            self.assert_(set(direct) == set(canary.data))
+            self.assert_(set(adapter) == set(canary.data))
+            self.assert_(direct == control)
+        
+        # assume append() is available for list tests
+        e = creator()
+        direct.append(e)
+        control.append(e)
+        assert_eq()
+
+        if hasattr(direct, 'pop'):
+            direct.pop()
+            control.pop()
+            assert_eq()
+
+        if hasattr(direct, '__setitem__'):
+            e = creator()
+            direct.append(e)
+            control.append(e)
+            
+            e = creator()
+            direct[0] = e
+            control[0] = e
+            assert_eq()
+
+            if reduce(and_, [hasattr(direct, a) for a in
+                             ('__delitem', 'insert', '__len__')], True):
+                values = [creator(), creator(), creator(), creator()]
+                direct[slice(0,1)] = values
+                control[slice(0,1)] = values
+                assert_eq()
+
+                values = [creator(), creator()]
+                direct[slice(0,-1,2)] = values
+                control[slice(0,-1,2)] = values
+                assert_eq()
+
+                values = [creator()]
+                direct[slice(0,-1)] = values
+                control[slice(0,-1)] = values
+                assert_eq()
+
+        if hasattr(direct, '__delitem__'):
+            e = creator()
+            direct.append(e)
+            control.append(e)
+            del direct[-1]
+            del control[-1]
+            assert_eq()
 
-class DictsTest(testbase.ORMTest):
+            if hasattr(direct, '__getslice__'):
+                for e in [creator(), creator(), creator(), creator()]:
+                    direct.append(e)
+                    control.append(e)
+
+                del direct[:-3]
+                del control[:-3]
+                assert_eq()
+
+                del direct[0:1]
+                del control[0:1]
+                assert_eq()
+
+                del direct[::2]
+                del control[::2]
+                assert_eq()
+
+        if hasattr(direct, 'remove'):
+            e = creator()
+            direct.append(e)
+            control.append(e)
+            
+            direct.remove(e)
+            control.remove(e)
+            assert_eq()
+
+        if hasattr(direct, '__setslice__'):
+            values = [creator(), creator()]
+            direct[0:1] = values
+            control[0:1] = values
+            assert_eq()
+
+            values = [creator()]
+            direct[0:] = values
+            control[0:] = values
+            assert_eq()
+        
+        if hasattr(direct, '__delslice__'):
+            for i in range(1, 4):
+                e = creator()
+                direct.append(e)
+                control.append(e)
+
+            del direct[-1:]
+            del control[-1:] 
+            assert_eq()
+
+            del direct[1:2]
+            del control[1:2]
+            assert_eq()
+
+            del direct[:]
+            del control[:]
+            assert_eq()
+
+        if hasattr(direct, 'extend'):
+            values = [creator(), creator(), creator()]
+
+            direct.extend(values)
+            control.extend(values)
+            assert_eq()
+                    
+    def test_list(self):
+        self._test_adapter(list)
+        self._test_list(list)
+
+    def test_list_subclass(self):
+        class MyList(list):
+            pass
+        self._test_adapter(MyList)
+        self._test_list(MyList)
+        self.assert_(getattr(MyList, '_sa_instrumented') == id(MyList))
+
+    def test_list_duck(self):
+        class ListLike(object):
+            def __init__(self):
+                self.data = list()
+            def append(self, item):
+                self.data.append(item)
+            def remove(self, item):
+                self.data.remove(item)
+            def insert(self, index, item):
+                self.data.insert(index, item)
+            def pop(self, index=-1):
+                self.data.pop(index)
+            def extend(self):
+                assert False
+            def __iter__(self):
+                return iter(self.data)
+            
+        self._test_adapter(ListLike)
+        self._test_list(ListLike)
+        self.assert_(getattr(ListLike, '_sa_instrumented') == id(ListLike))
+
+    def test_list_emulates(self):
+        class ListIsh(object):
+            __emulates__ = list
+            def __init__(self):
+                self.data = list()
+            def append(self, item):
+                self.data.append(item)
+            def remove(self, item):
+                self.data.remove(item)
+            def insert(self, index, item):
+                self.data.insert(index, item)
+            def pop(self, index=-1):
+                self.data.pop(index)
+            def extend(self):
+                assert False
+            def __iter__(self):
+                return iter(self.data)
+            
+        self._test_adapter(ListIsh)
+        self._test_list(ListIsh)
+        self.assert_(getattr(ListIsh, '_sa_instrumented') == id(ListIsh))
+
+    def test_set(self):
+        self._test_adapter(set)
+
+    def test_dict(self):
+        def dictable_entity(a=None, b=None, c=None):
+            global _id
+            _id += 1
+            return Entity(a or str(_id), b or 'value %s' % _id, c)
+        
+        self._test_adapter(collections.attribute_mapped_collection('a'),
+                           dictable_entity, to_set=lambda c: set(c.values()))
+
+class DictHelpersTest(testbase.ORMTest):
     def define_tables(self, metadata):
         global parents, children, Parent, Child
         
@@ -19,7 +282,8 @@ class DictsTest(testbase.ORMTest):
                         Column('label', String))
         children = Table('children', metadata,
                          Column('id', Integer, primary_key=True),
-                         Column('parent_id', Integer, ForeignKey('parents.id'), nullable=False),
+                         Column('parent_id', Integer, ForeignKey('parents.id'),
+                                nullable=False),
                          Column('a', String),
                          Column('b', String),
                          Column('c', String))
@@ -51,7 +315,7 @@ class DictsTest(testbase.ORMTest):
 
         p = session.query(Parent).get(pid)
 
-        assert set(p.children.keys()) == set(['foo', 'bar'])
+        self.assert_(set(p.children.keys()) == set(['foo', 'bar']))
         cid = p.children['foo'].id
 
         collections.collection_adapter(p.children).append_with_event(
@@ -63,33 +327,33 @@ class DictsTest(testbase.ORMTest):
         
         p = session.query(Parent).get(pid)
         
-        assert set(p.children.keys()) == set(['foo', 'bar'])
-        assert p.children['foo'].id != cid
+        self.assert_(set(p.children.keys()) == set(['foo', 'bar']))
+        self.assert_(p.children['foo'].id != cid)
         
-        assert(len(list(collections.collection_adapter(p.children))) == 2)
+        self.assert_(len(list(collections.collection_adapter(p.children))) == 2)
         session.flush()
         session.clear()
 
         p = session.query(Parent).get(pid)
-        assert(len(list(collections.collection_adapter(p.children))) == 2)
+        self.assert_(len(list(collections.collection_adapter(p.children))) == 2)
 
         collections.collection_adapter(p.children).remove_with_event(
             p.children['foo'])
         
-        assert(len(list(collections.collection_adapter(p.children))) == 1)
+        self.assert_(len(list(collections.collection_adapter(p.children))) == 1)
         session.flush()
         session.clear()
 
         p = session.query(Parent).get(pid)
-        assert(len(list(collections.collection_adapter(p.children))) == 1)
+        self.assert_(len(list(collections.collection_adapter(p.children))) == 1)
 
         del p.children['bar']
-        assert(len(list(collections.collection_adapter(p.children))) == 0)
+        self.assert_(len(list(collections.collection_adapter(p.children))) == 0)
         session.flush()
         session.clear()
 
         p = session.query(Parent).get(pid)
-        assert(len(list(collections.collection_adapter(p.children))) == 0)
+        self.assert_(len(list(collections.collection_adapter(p.children))) == 0)
         
 
     def _test_composite_mapped(self, collection_class):
@@ -111,7 +375,7 @@ class DictsTest(testbase.ORMTest):
         
         p = session.query(Parent).get(pid)
 
-        assert set(p.children.keys()) == set([('foo', '1'), ('foo', '2')])
+        self.assert_(set(p.children.keys()) == set([('foo', '1'), ('foo', '2')]))
         cid = p.children[('foo', '1')].id
 
         collections.collection_adapter(p.children).append_with_event(
@@ -123,28 +387,24 @@ class DictsTest(testbase.ORMTest):
         
         p = session.query(Parent).get(pid)
         
-        assert set(p.children.keys()) == set([('foo', '1'), ('foo', '2')])
-        assert p.children[('foo', '1')].id != cid
+        self.assert_(set(p.children.keys()) == set([('foo', '1'), ('foo', '2')]))
+        self.assert_(p.children[('foo', '1')].id != cid)
         
-        assert(len(list(collections.collection_adapter(p.children))) == 2)
+        self.assert_(len(list(collections.collection_adapter(p.children))) == 2)
         
     def test_mapped_collection(self):
-        return
         collection_class = collections.mapped_collection(lambda c: c.a)
         self._test_scalar_mapped(collection_class)
 
     def test_mapped_collection2(self):
-        return
         collection_class = collections.mapped_collection(lambda c: (c.a, c.b))
         self._test_composite_mapped(collection_class)
 
     def test_attr_mapped_collection(self):
-        return
         collection_class = collections.attribute_mapped_collection('a')
         self._test_scalar_mapped(collection_class)
 
     def test_column_mapped_collection(self):
-        return
         collection_class = collections.column_mapped_collection(children.c.a)
         self._test_scalar_mapped(collection_class)
 
@@ -168,6 +428,6 @@ class DictsTest(testbase.ORMTest):
                 util.OrderedDict.__init__(self)
         collection_class = lambda: Ordered2(lambda v: (v.a, v.b))
         self._test_composite_mapped(collection_class)
-    
+
 if __name__ == "__main__":
     testbase.main()