return Entity(_id)
class CollectionsTest(testbase.PersistTest):
- def _test_adapter(self, collection_class, creator=entity_maker,
+ def _test_adapter(self, typecallable, creator=entity_maker,
to_set=None):
class Foo(object):
pass
canary = Canary()
manager.register_attribute(Foo, 'attr', True, extension=canary,
- typecallable=collection_class)
+ typecallable=typecallable)
obj = Foo()
adapter = collections.collection_adapter(obj.attr)
to_set = lambda col: set(col)
def assert_eq():
- self.assert_(to_set(direct) == set(canary.data))
- self.assert_(set(adapter) == set(canary.data))
- assert_ne = lambda: self.assert_(set(obj.attr) != set(canary.data))
+ self.assert_(to_set(direct) == canary.data)
+ self.assert_(set(adapter) == canary.data)
+ assert_ne = lambda: self.assert_(set(obj.attr) != canary.data)
e1, e2 = creator(), creator()
adapter.remove_with_event(e1)
assert_eq()
- def _test_list(self, collection_class, creator=entity_maker):
+ def _test_list(self, typecallable, creator=entity_maker):
class Foo(object):
pass
canary = Canary()
manager.register_attribute(Foo, 'attr', True, extension=canary,
- collection_class=collection_class)
+ typecallable=typecallable)
obj = Foo()
adapter = collections.collection_adapter(obj.attr)
control = list()
def assert_eq():
- self.assert_(set(direct) == set(canary.data))
- self.assert_(set(adapter) == set(canary.data))
+ self.assert_(set(direct) == canary.data)
+ self.assert_(set(adapter) == canary.data)
self.assert_(direct == control)
# assume append() is available for list tests
def insert(self, index, item):
self.data.insert(index, item)
def pop(self, index=-1):
- self.data.pop(index)
+ return self.data.pop(index)
def extend(self):
assert False
def __iter__(self):
return iter(self.data)
+ def __eq__(self, other):
+ return self.data == other
+ def __repr__(self):
+ return 'ListLike(%s)' % repr(self.data)
self._test_adapter(ListLike)
self._test_list(ListLike)
def insert(self, index, item):
self.data.insert(index, item)
def pop(self, index=-1):
- self.data.pop(index)
+ return self.data.pop(index)
def extend(self):
assert False
def __iter__(self):
return iter(self.data)
+ def __eq__(self, other):
+ return self.data == other
+ def __repr__(self):
+ return 'ListIsh(%s)' % repr(self.data)
self._test_adapter(ListIsh)
self._test_list(ListIsh)
self.assert_(getattr(ListIsh, '_sa_instrumented') == id(ListIsh))
+ def _test_set(self, typecallable, creator=entity_maker):
+ class Foo(object):
+ pass
+
+ canary = Canary()
+ manager.register_attribute(Foo, 'attr', True, extension=canary,
+ typecallable=typecallable)
+
+ obj = Foo()
+ adapter = collections.collection_adapter(obj.attr)
+ direct = obj.attr
+ control = set()
+
+ def assert_eq():
+ self.assert_(set(direct) == canary.data)
+ self.assert_(set(adapter) == canary.data)
+ self.assert_(direct == control)
+
+ def addall(*values):
+ for item in values:
+ direct.add(item)
+ control.add(item)
+ assert_eq()
+ def zap():
+ for item in list(direct):
+ direct.remove(item)
+ control.clear()
+
+ # assume add() is available for list tests
+ addall(creator())
+
+ if hasattr(direct, 'pop'):
+ direct.pop()
+ control.pop()
+ assert_eq()
+
+ if hasattr(direct, 'remove'):
+ e = creator()
+ addall(e)
+
+ direct.remove(e)
+ control.remove(e)
+ assert_eq()
+
+ e = creator()
+ try:
+ direct.remove(e)
+ except KeyError:
+ assert_eq()
+ self.assert_(e not in canary.removed)
+ else:
+ self.assert_(False)
+
+ if hasattr(direct, 'discard'):
+ e = creator()
+ addall(e)
+
+ direct.discard(e)
+ control.discard(e)
+ assert_eq()
+
+ e = creator()
+ direct.discard(e)
+ self.assert_(e not in canary.removed)
+ assert_eq()
+
+ if hasattr(direct, 'update'):
+ e = creator()
+ addall(e)
+
+ values = set([e, creator(), creator()])
+
+ direct.update(values)
+ control.update(values)
+ assert_eq()
+
+ if hasattr(direct, 'clear'):
+ addall(creator(), creator())
+ direct.clear()
+ control.clear()
+ assert_eq()
+
+ if hasattr(direct, 'difference_update'):
+ zap()
+ addall(creator(), creator())
+ values = set([creator()])
+
+ direct.difference_update(values)
+ control.difference_update(values)
+ assert_eq()
+ values.update(set([e, creator()]))
+ direct.difference_update(values)
+ control.difference_update(values)
+ assert_eq()
+
+ if hasattr(direct, 'intersection_update'):
+ zap()
+ e = creator()
+ addall(e, creator(), creator())
+ values = set(control)
+
+ direct.intersection_update(values)
+ control.intersection_update(values)
+ assert_eq()
+
+ values.update(set([e, creator()]))
+ direct.intersection_update(values)
+ control.intersection_update(values)
+ assert_eq()
+
+ if hasattr(direct, 'symmetric_difference_update'):
+ zap()
+ e = creator()
+ addall(e, creator(), creator())
+
+ values = set([e, creator()])
+ direct.symmetric_difference_update(values)
+ control.symmetric_difference_update(values)
+ assert_eq()
+
+ e = creator()
+ addall(e)
+ values = set([e])
+ direct.symmetric_difference_update(values)
+ control.symmetric_difference_update(values)
+ assert_eq()
+
+ values = set()
+ direct.symmetric_difference_update(values)
+ control.symmetric_difference_update(values)
+ assert_eq()
+
def test_set(self):
self._test_adapter(set)
+ self._test_set(set)
+
+ def test_set_subclass(self):
+ class MySet(set):
+ pass
+ self._test_adapter(MySet)
+ self._test_set(MySet)
+ self.assert_(getattr(MySet, '_sa_instrumented') == id(MySet))
+
+ def test_set_duck(self):
+ class SetLike(object):
+ def __init__(self):
+ self.data = set()
+ def add(self, item):
+ self.data.add(item)
+ def remove(self, item):
+ self.data.remove(item)
+ def discard(self, item):
+ self.data.discard(item)
+ def pop(self):
+ return self.data.pop()
+ def update(self, other):
+ self.data.update(other)
+ def __iter__(self):
+ return iter(self.data)
+ def __eq__(self, other):
+ return self.data == other
+
+ self._test_adapter(SetLike)
+ self._test_set(SetLike)
+ self.assert_(getattr(SetLike, '_sa_instrumented') == id(SetLike))
+
+ def test_set_emulates(self):
+ class SetIsh(object):
+ __emulates__ = set
+ def __init__(self):
+ self.data = set()
+ def add(self, item):
+ self.data.add(item)
+ def remove(self, item):
+ self.data.remove(item)
+ def discard(self, item):
+ self.data.discard(item)
+ def pop(self):
+ return self.data.pop()
+ def update(self, other):
+ self.data.update(other)
+ def __iter__(self):
+ return iter(self.data)
+ def __eq__(self, other):
+ return self.data == other
+ self._test_adapter(SetIsh)
+ self._test_set(SetIsh)
+ self.assert_(getattr(SetIsh, '_sa_instrumented') == id(SetIsh))
+
def test_dict(self):
def dictable_entity(a=None, b=None, c=None):
global _id