From: Jason Kirtland Date: Tue, 3 Jul 2007 02:41:12 +0000 (+0000) Subject: - Add coverage for set collections, added missing clear() decorator X-Git-Tag: rel_0_4_6~134 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b716af2c4247e0ce3aa0a35fb7c519b87832d585;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Add coverage for set collections, added missing clear() decorator - Try not to be such an idiot when testing lists --- diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 0e5a787f11..b9022ef369 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -894,6 +894,13 @@ def _set_decorators(): _tidy(pop) return pop + def clear(fn): + def clear(self): + for item in list(self): + self.remove(item) + _tidy(clear) + return clear + def update(fn): def update(self, value): for item in value: diff --git a/test/orm/collection.py b/test/orm/collection.py index c55c542f75..be81e7a4d6 100644 --- a/test/orm/collection.py +++ b/test/orm/collection.py @@ -43,14 +43,14 @@ def entity_maker(): return Entity(_id) class CollectionsTest(testbase.PersistTest): - def _test_adapter(self, collection_class, creator=entity_maker, + def _test_adapter(self, typecallable, creator=entity_maker, to_set=None): class Foo(object): pass canary = Canary() manager.register_attribute(Foo, 'attr', True, extension=canary, - typecallable=collection_class) + typecallable=typecallable) obj = Foo() adapter = collections.collection_adapter(obj.attr) @@ -59,9 +59,9 @@ class CollectionsTest(testbase.PersistTest): to_set = lambda col: set(col) def assert_eq(): - self.assert_(to_set(direct) == set(canary.data)) - self.assert_(set(adapter) == set(canary.data)) - assert_ne = lambda: self.assert_(set(obj.attr) != set(canary.data)) + self.assert_(to_set(direct) == canary.data) + self.assert_(set(adapter) == canary.data) + assert_ne = lambda: self.assert_(set(obj.attr) != canary.data) e1, e2 = creator(), creator() @@ -81,13 +81,13 @@ class CollectionsTest(testbase.PersistTest): adapter.remove_with_event(e1) assert_eq() - def _test_list(self, collection_class, creator=entity_maker): + def _test_list(self, typecallable, creator=entity_maker): class Foo(object): pass canary = Canary() manager.register_attribute(Foo, 'attr', True, extension=canary, - collection_class=collection_class) + typecallable=typecallable) obj = Foo() adapter = collections.collection_adapter(obj.attr) @@ -95,8 +95,8 @@ class CollectionsTest(testbase.PersistTest): control = list() def assert_eq(): - self.assert_(set(direct) == set(canary.data)) - self.assert_(set(adapter) == set(canary.data)) + self.assert_(set(direct) == canary.data) + self.assert_(set(adapter) == canary.data) self.assert_(direct == control) # assume append() is available for list tests @@ -229,11 +229,15 @@ class CollectionsTest(testbase.PersistTest): def insert(self, index, item): self.data.insert(index, item) def pop(self, index=-1): - self.data.pop(index) + return self.data.pop(index) def extend(self): assert False def __iter__(self): return iter(self.data) + def __eq__(self, other): + return self.data == other + def __repr__(self): + return 'ListLike(%s)' % repr(self.data) self._test_adapter(ListLike) self._test_list(ListLike) @@ -251,19 +255,210 @@ class CollectionsTest(testbase.PersistTest): def insert(self, index, item): self.data.insert(index, item) def pop(self, index=-1): - self.data.pop(index) + return self.data.pop(index) def extend(self): assert False def __iter__(self): return iter(self.data) + def __eq__(self, other): + return self.data == other + def __repr__(self): + return 'ListIsh(%s)' % repr(self.data) self._test_adapter(ListIsh) self._test_list(ListIsh) self.assert_(getattr(ListIsh, '_sa_instrumented') == id(ListIsh)) + def _test_set(self, typecallable, creator=entity_maker): + class Foo(object): + pass + + canary = Canary() + manager.register_attribute(Foo, 'attr', True, extension=canary, + typecallable=typecallable) + + obj = Foo() + adapter = collections.collection_adapter(obj.attr) + direct = obj.attr + control = set() + + def assert_eq(): + self.assert_(set(direct) == canary.data) + self.assert_(set(adapter) == canary.data) + self.assert_(direct == control) + + def addall(*values): + for item in values: + direct.add(item) + control.add(item) + assert_eq() + def zap(): + for item in list(direct): + direct.remove(item) + control.clear() + + # assume add() is available for list tests + addall(creator()) + + if hasattr(direct, 'pop'): + direct.pop() + control.pop() + assert_eq() + + if hasattr(direct, 'remove'): + e = creator() + addall(e) + + direct.remove(e) + control.remove(e) + assert_eq() + + e = creator() + try: + direct.remove(e) + except KeyError: + assert_eq() + self.assert_(e not in canary.removed) + else: + self.assert_(False) + + if hasattr(direct, 'discard'): + e = creator() + addall(e) + + direct.discard(e) + control.discard(e) + assert_eq() + + e = creator() + direct.discard(e) + self.assert_(e not in canary.removed) + assert_eq() + + if hasattr(direct, 'update'): + e = creator() + addall(e) + + values = set([e, creator(), creator()]) + + direct.update(values) + control.update(values) + assert_eq() + + if hasattr(direct, 'clear'): + addall(creator(), creator()) + direct.clear() + control.clear() + assert_eq() + + if hasattr(direct, 'difference_update'): + zap() + addall(creator(), creator()) + values = set([creator()]) + + direct.difference_update(values) + control.difference_update(values) + assert_eq() + values.update(set([e, creator()])) + direct.difference_update(values) + control.difference_update(values) + assert_eq() + + if hasattr(direct, 'intersection_update'): + zap() + e = creator() + addall(e, creator(), creator()) + values = set(control) + + direct.intersection_update(values) + control.intersection_update(values) + assert_eq() + + values.update(set([e, creator()])) + direct.intersection_update(values) + control.intersection_update(values) + assert_eq() + + if hasattr(direct, 'symmetric_difference_update'): + zap() + e = creator() + addall(e, creator(), creator()) + + values = set([e, creator()]) + direct.symmetric_difference_update(values) + control.symmetric_difference_update(values) + assert_eq() + + e = creator() + addall(e) + values = set([e]) + direct.symmetric_difference_update(values) + control.symmetric_difference_update(values) + assert_eq() + + values = set() + direct.symmetric_difference_update(values) + control.symmetric_difference_update(values) + assert_eq() + def test_set(self): self._test_adapter(set) + self._test_set(set) + + def test_set_subclass(self): + class MySet(set): + pass + self._test_adapter(MySet) + self._test_set(MySet) + self.assert_(getattr(MySet, '_sa_instrumented') == id(MySet)) + + def test_set_duck(self): + class SetLike(object): + def __init__(self): + self.data = set() + def add(self, item): + self.data.add(item) + def remove(self, item): + self.data.remove(item) + def discard(self, item): + self.data.discard(item) + def pop(self): + return self.data.pop() + def update(self, other): + self.data.update(other) + def __iter__(self): + return iter(self.data) + def __eq__(self, other): + return self.data == other + + self._test_adapter(SetLike) + self._test_set(SetLike) + self.assert_(getattr(SetLike, '_sa_instrumented') == id(SetLike)) + + def test_set_emulates(self): + class SetIsh(object): + __emulates__ = set + def __init__(self): + self.data = set() + def add(self, item): + self.data.add(item) + def remove(self, item): + self.data.remove(item) + def discard(self, item): + self.data.discard(item) + def pop(self): + return self.data.pop() + def update(self, other): + self.data.update(other) + def __iter__(self): + return iter(self.data) + def __eq__(self, other): + return self.data == other + self._test_adapter(SetIsh) + self._test_set(SetIsh) + self.assert_(getattr(SetIsh, '_sa_instrumented') == id(SetIsh)) + def test_dict(self): def dictable_entity(a=None, b=None, c=None): global _id