From: Jason Kirtland Date: Tue, 3 Jul 2007 01:34:53 +0000 (+0000) Subject: - Coverage of list collections, and matching fixes in slice mutation X-Git-Tag: rel_0_4_6~135 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=1cfd46d3b7baa21611256177119264cb1466c848;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Coverage of list collections, and matching fixes in slice mutation --- diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 35442748a3..0e5a787f11 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -665,26 +665,59 @@ def _list_decorators(): _tidy(remove) return remove + def insert(fn): + def insert(self, index, value): + __set(self, value) + fn(self, index, value) + _tidy(insert) + return insert + def __setitem__(fn): def __setitem__(self, index, value): if not isinstance(index, slice): + existing = self[index] + if existing is not None: + __del(self, existing) __set(self, value) fn(self, index, value) else: - rng = range(slice.start or 0, slice.stop or 0, slice.step or 1) - if len(value) != len(rng): - raise ValueError - for i in rng: - __set(self, value[i]) - fn(self, i, value[i]) + # slice assignment requires __delitem__, insert, __len__ + if index.stop is None: + stop = 0 + elif index.stop < 0: + stop = len(self) + index.stop + else: + stop = index.stop + step = index.step or 1 + rng = range(index.start or 0, stop, step) + if step == 1: + for i in rng: + del self[index.start] + i = index.start + for item in value: + self.insert(i, item) + i += 1 + else: + if len(value) != len(rng): + raise ValueError + for i, item in zip(rng, value): + self.__setitem__(i, item) _tidy(__setitem__) return __setitem__ def __delitem__(fn): def __delitem__(self, index): - item = self[index] - __del(self, item) - fn(self, index) + if not isinstance(index, slice): + item = self[index] + __del(self, item) + fn(self, index) + else: + # slice deletion requires __getslice__ and a slice-groking + # __getitem__ for stepped deletion + # note: not breaking this into atomic dels + for item in self[index]: + __del(self, item) + fn(self, index) _tidy(__delitem__) return __delitem__ diff --git a/test/orm/collection.py b/test/orm/collection.py index 783193c961..c55c542f75 100644 --- a/test/orm/collection.py +++ b/test/orm/collection.py @@ -1,16 +1,279 @@ import testbase from sqlalchemy import * -from sqlalchemy.orm import create_session, mapper, relation +from sqlalchemy.orm import create_session, mapper, relation, \ + interfaces, attributes import sqlalchemy.orm.collections as collections from sqlalchemy.orm.collections import collection from sqlalchemy import util +from operator import and_ +class Canary(interfaces.AttributeExtension): + def __init__(self): + self.data = set() + self.added = set() + self.removed = set() + def append(self, obj, value, initiator): + assert value not in self.added + self.data.add(value) + self.added.add(value) + def remove(self, obj, value, initiator): + assert value not in self.removed + self.data.remove(value) + self.removed.add(value) + def set(self, obj, value, oldvalue, initiator): + if oldvalue is not None: + self.remove(obj, oldvalue, None) + self.append(obj, value, None) + +class Entity(object): + def __init__(self, a=None, b=None, c=None): + self.a = a + self.b = b + self.c = c + def __repr__(self): + return str((id(self), self.a, self.b, self.c)) + +manager = attributes.AttributeManager() + +_id = 1 +def entity_maker(): + global _id + _id += 1 + return Entity(_id) + class CollectionsTest(testbase.PersistTest): - # FIXME: ... - pass + def _test_adapter(self, collection_class, creator=entity_maker, + to_set=None): + class Foo(object): + pass + + canary = Canary() + manager.register_attribute(Foo, 'attr', True, extension=canary, + typecallable=collection_class) + + obj = Foo() + adapter = collections.collection_adapter(obj.attr) + direct = obj.attr + if to_set is None: + to_set = lambda col: set(col) + + def assert_eq(): + self.assert_(to_set(direct) == set(canary.data)) + self.assert_(set(adapter) == set(canary.data)) + assert_ne = lambda: self.assert_(set(obj.attr) != set(canary.data)) + + e1, e2 = creator(), creator() + + adapter.append_with_event(e1) + assert_eq() + + adapter.append_without_event(e2) + assert_ne() + canary.data.add(e2) + assert_eq() + + adapter.remove_without_event(e2) + assert_ne() + canary.data.remove(e2) + assert_eq() + + adapter.remove_with_event(e1) + assert_eq() + + def _test_list(self, collection_class, creator=entity_maker): + class Foo(object): + pass + + canary = Canary() + manager.register_attribute(Foo, 'attr', True, extension=canary, + collection_class=collection_class) + + obj = Foo() + adapter = collections.collection_adapter(obj.attr) + direct = obj.attr + control = list() + + def assert_eq(): + self.assert_(set(direct) == set(canary.data)) + self.assert_(set(adapter) == set(canary.data)) + self.assert_(direct == control) + + # assume append() is available for list tests + e = creator() + direct.append(e) + control.append(e) + assert_eq() + + if hasattr(direct, 'pop'): + direct.pop() + control.pop() + assert_eq() + + if hasattr(direct, '__setitem__'): + e = creator() + direct.append(e) + control.append(e) + + e = creator() + direct[0] = e + control[0] = e + assert_eq() + + if reduce(and_, [hasattr(direct, a) for a in + ('__delitem', 'insert', '__len__')], True): + values = [creator(), creator(), creator(), creator()] + direct[slice(0,1)] = values + control[slice(0,1)] = values + assert_eq() + + values = [creator(), creator()] + direct[slice(0,-1,2)] = values + control[slice(0,-1,2)] = values + assert_eq() + + values = [creator()] + direct[slice(0,-1)] = values + control[slice(0,-1)] = values + assert_eq() + + if hasattr(direct, '__delitem__'): + e = creator() + direct.append(e) + control.append(e) + del direct[-1] + del control[-1] + assert_eq() -class DictsTest(testbase.ORMTest): + if hasattr(direct, '__getslice__'): + for e in [creator(), creator(), creator(), creator()]: + direct.append(e) + control.append(e) + + del direct[:-3] + del control[:-3] + assert_eq() + + del direct[0:1] + del control[0:1] + assert_eq() + + del direct[::2] + del control[::2] + assert_eq() + + if hasattr(direct, 'remove'): + e = creator() + direct.append(e) + control.append(e) + + direct.remove(e) + control.remove(e) + assert_eq() + + if hasattr(direct, '__setslice__'): + values = [creator(), creator()] + direct[0:1] = values + control[0:1] = values + assert_eq() + + values = [creator()] + direct[0:] = values + control[0:] = values + assert_eq() + + if hasattr(direct, '__delslice__'): + for i in range(1, 4): + e = creator() + direct.append(e) + control.append(e) + + del direct[-1:] + del control[-1:] + assert_eq() + + del direct[1:2] + del control[1:2] + assert_eq() + + del direct[:] + del control[:] + assert_eq() + + if hasattr(direct, 'extend'): + values = [creator(), creator(), creator()] + + direct.extend(values) + control.extend(values) + assert_eq() + + def test_list(self): + self._test_adapter(list) + self._test_list(list) + + def test_list_subclass(self): + class MyList(list): + pass + self._test_adapter(MyList) + self._test_list(MyList) + self.assert_(getattr(MyList, '_sa_instrumented') == id(MyList)) + + def test_list_duck(self): + class ListLike(object): + def __init__(self): + self.data = list() + def append(self, item): + self.data.append(item) + def remove(self, item): + self.data.remove(item) + def insert(self, index, item): + self.data.insert(index, item) + def pop(self, index=-1): + self.data.pop(index) + def extend(self): + assert False + def __iter__(self): + return iter(self.data) + + self._test_adapter(ListLike) + self._test_list(ListLike) + self.assert_(getattr(ListLike, '_sa_instrumented') == id(ListLike)) + + def test_list_emulates(self): + class ListIsh(object): + __emulates__ = list + def __init__(self): + self.data = list() + def append(self, item): + self.data.append(item) + def remove(self, item): + self.data.remove(item) + def insert(self, index, item): + self.data.insert(index, item) + def pop(self, index=-1): + self.data.pop(index) + def extend(self): + assert False + def __iter__(self): + return iter(self.data) + + self._test_adapter(ListIsh) + self._test_list(ListIsh) + self.assert_(getattr(ListIsh, '_sa_instrumented') == id(ListIsh)) + + def test_set(self): + self._test_adapter(set) + + def test_dict(self): + def dictable_entity(a=None, b=None, c=None): + global _id + _id += 1 + return Entity(a or str(_id), b or 'value %s' % _id, c) + + self._test_adapter(collections.attribute_mapped_collection('a'), + dictable_entity, to_set=lambda c: set(c.values())) + +class DictHelpersTest(testbase.ORMTest): def define_tables(self, metadata): global parents, children, Parent, Child @@ -19,7 +282,8 @@ class DictsTest(testbase.ORMTest): Column('label', String)) children = Table('children', metadata, Column('id', Integer, primary_key=True), - Column('parent_id', Integer, ForeignKey('parents.id'), nullable=False), + Column('parent_id', Integer, ForeignKey('parents.id'), + nullable=False), Column('a', String), Column('b', String), Column('c', String)) @@ -51,7 +315,7 @@ class DictsTest(testbase.ORMTest): p = session.query(Parent).get(pid) - assert set(p.children.keys()) == set(['foo', 'bar']) + self.assert_(set(p.children.keys()) == set(['foo', 'bar'])) cid = p.children['foo'].id collections.collection_adapter(p.children).append_with_event( @@ -63,33 +327,33 @@ class DictsTest(testbase.ORMTest): p = session.query(Parent).get(pid) - assert set(p.children.keys()) == set(['foo', 'bar']) - assert p.children['foo'].id != cid + self.assert_(set(p.children.keys()) == set(['foo', 'bar'])) + self.assert_(p.children['foo'].id != cid) - assert(len(list(collections.collection_adapter(p.children))) == 2) + self.assert_(len(list(collections.collection_adapter(p.children))) == 2) session.flush() session.clear() p = session.query(Parent).get(pid) - assert(len(list(collections.collection_adapter(p.children))) == 2) + self.assert_(len(list(collections.collection_adapter(p.children))) == 2) collections.collection_adapter(p.children).remove_with_event( p.children['foo']) - assert(len(list(collections.collection_adapter(p.children))) == 1) + self.assert_(len(list(collections.collection_adapter(p.children))) == 1) session.flush() session.clear() p = session.query(Parent).get(pid) - assert(len(list(collections.collection_adapter(p.children))) == 1) + self.assert_(len(list(collections.collection_adapter(p.children))) == 1) del p.children['bar'] - assert(len(list(collections.collection_adapter(p.children))) == 0) + self.assert_(len(list(collections.collection_adapter(p.children))) == 0) session.flush() session.clear() p = session.query(Parent).get(pid) - assert(len(list(collections.collection_adapter(p.children))) == 0) + self.assert_(len(list(collections.collection_adapter(p.children))) == 0) def _test_composite_mapped(self, collection_class): @@ -111,7 +375,7 @@ class DictsTest(testbase.ORMTest): p = session.query(Parent).get(pid) - assert set(p.children.keys()) == set([('foo', '1'), ('foo', '2')]) + self.assert_(set(p.children.keys()) == set([('foo', '1'), ('foo', '2')])) cid = p.children[('foo', '1')].id collections.collection_adapter(p.children).append_with_event( @@ -123,28 +387,24 @@ class DictsTest(testbase.ORMTest): p = session.query(Parent).get(pid) - assert set(p.children.keys()) == set([('foo', '1'), ('foo', '2')]) - assert p.children[('foo', '1')].id != cid + self.assert_(set(p.children.keys()) == set([('foo', '1'), ('foo', '2')])) + self.assert_(p.children[('foo', '1')].id != cid) - assert(len(list(collections.collection_adapter(p.children))) == 2) + self.assert_(len(list(collections.collection_adapter(p.children))) == 2) def test_mapped_collection(self): - return collection_class = collections.mapped_collection(lambda c: c.a) self._test_scalar_mapped(collection_class) def test_mapped_collection2(self): - return collection_class = collections.mapped_collection(lambda c: (c.a, c.b)) self._test_composite_mapped(collection_class) def test_attr_mapped_collection(self): - return collection_class = collections.attribute_mapped_collection('a') self._test_scalar_mapped(collection_class) def test_column_mapped_collection(self): - return collection_class = collections.column_mapped_collection(children.c.a) self._test_scalar_mapped(collection_class) @@ -168,6 +428,6 @@ class DictsTest(testbase.ORMTest): util.OrderedDict.__init__(self) collection_class = lambda: Ordered2(lambda v: (v.a, v.b)) self._test_composite_mapped(collection_class) - + if __name__ == "__main__": testbase.main()