op = method._sa_instrument_after
assert op in ('fire_append_event', 'fire_remove_event')
after = op
- if before or after:
+ if before:
methods[name] = before[0], before[1], after
+ elif after:
+ methods[name] = None, None, after
# apply ABC auto-decoration to methods that need it
for method, decorator in decorators.items():
def __setitem__(self, key, value, _sa_initiator=None):
if key in self:
__del(self, self[key], _sa_initiator)
- __set(self, value)
+ __set(self, value, _sa_initiator)
fn(self, key, value)
_tidy(__setitem__)
return __setitem__
def setdefault(fn):
def setdefault(self, key, default=None):
- if key not in self and default is not None:
- __set(self, default)
- return fn(self, key, default)
+ if key not in self:
+ self.__setitem__(key, default)
+ return default
+ else:
+ return self.__getitem__(key)
_tidy(setdefault)
return setdefault
def update(fn):
def update(self, other):
for key in other.keys():
- self[key] = other[key]
+ if not self.has_key(key) or self[key] is not other[key]:
+ self[key] = other[key]
_tidy(update)
return update
else:
if __other is not Unspecified:
if hasattr(__other, 'keys'):
for key in __other.keys():
- self[key] = __other[key]
+ if key not in self or self[key] is not __other[key]:
+ self[key] = __other[key]
else:
for key, value in __other:
- self[key] = value
+ if key not in self or self[key] is not value:
+ self[key] = value
for key in kw:
- self[key] = kw[key]
+ if key not in self or self[key] is not kw[key]:
+ self[key] = kw[key]
_tidy(update)
return update
'iterator': '__iter__',
'_decorators': _set_decorators(), },
# < 0.4 compatible naming (almost), deprecated- use decorators instead.
- dict: { 'appender': 'append',
+ dict: { 'appender': 'set',
'remover': 'remove',
'iterator': 'itervalues',
'_decorators': _dict_decorators(), },
"""A basic dictionary-based collection class.
Extends dict with the minimal bag semantics that collection classes require.
- "append" and "remove" are implemented in terms of a keying function: any
+ "set" and "remove" are implemented in terms of a keying function: any
callable that takes an object and returns an object for use as a dictionary
key.
"""
def __init__(self, keyfunc):
self.keyfunc = keyfunc
- def append(self, value, _sa_initiator=None):
+ def set(self, value, _sa_initiator=None):
key = self.keyfunc(value)
self.__setitem__(key, value, _sa_initiator)
- append = collection.internally_instrumented(append)
- append = collection.appender(append)
+ set = collection.internally_instrumented(set)
+ set = collection.appender(set)
def remove(self, value, _sa_initiator=None):
key = self.keyfunc(value)
else:
self.update(d, **kwargs)
- def keys(self):
- return list(self._list)
-
def clear(self):
self._list = []
dict.clear(self)
- def update(self, d=None, **kwargs):
- # d can be a dict or sequence of keys/values
- if d:
- if hasattr(d, 'iteritems'):
- seq = d.iteritems()
+ def update(self, ____sequence=None, **kwargs):
+ if ____sequence is not None:
+ if hasattr(____sequence, 'keys'):
+ for key in ____sequence.keys():
+ self.__setitem__(key, ____sequence[key])
else:
- seq = d
- for key, value in seq:
- self.__setitem__(key, value)
+ for key, value in ____sequence:
+ self[key] = value
if kwargs:
self.update(kwargs)
else:
return self.__getitem__(key)
- def values(self):
- return [self[key] for key in self._list]
-
def __iter__(self):
return iter(self._list)
+ def values(self):
+ return [self[key] for key in self._list]
+
def itervalues(self):
- return iter([self[key] for key in self._list])
+ return iter(self.values())
+
+ def keys(self):
+ return list(self._list)
def iterkeys(self):
- return self.__iter__()
+ return iter(self.keys())
- def iteritems(self):
- return iter([(key, self[key]) for key in self.keys()])
+ def items(self):
+ return [(key, self[key]) for key in self.keys()]
- def __delitem__(self, key):
- try:
- del self._list[self._list.index(key)]
- except ValueError:
- raise KeyError(key)
- dict.__delitem__(self, key)
+ def iteritems(self):
+ return iter(self.items())
def __setitem__(self, key, object):
if not self.has_key(key):
self._list.append(key)
dict.__setitem__(self, key, object)
+ def __delitem__(self, key):
+ dict.__delitem__(self, key)
+ self._list.remove(key)
+
+ def pop(self, key):
+ value = dict.pop(self, key)
+ self._list.remove(key)
+ return value
+
+ def popitem(self):
+ item = dict.popitem(self)
+ self._list.remove(item[0])
+ return item
+
class ThreadLocal(object):
"""An object in which attribute access occurs only within the context of the current thread."""
import testbase
from sqlalchemy import *
+import sqlalchemy.exceptions as exceptions
from sqlalchemy.orm import create_session, mapper, relation, \
interfaces, attributes
import sqlalchemy.orm.collections as collections
global _id
_id += 1
return Entity(_id)
+def dictable_entity(a=None, b=None, c=None):
+ global _id
+ _id += 1
+ return Entity(a or str(_id), b or 'value %s' % _id, c)
+
class CollectionsTest(testbase.PersistTest):
def _test_adapter(self, typecallable, creator=entity_maker,
def assert_eq():
self.assert_(to_set(direct) == canary.data)
self.assert_(set(adapter) == canary.data)
- assert_ne = lambda: self.assert_(set(obj.attr) != canary.data)
+ assert_ne = lambda: self.assert_(to_set(direct) != canary.data)
e1, e2 = creator(), creator()
def _test_set(self, typecallable, creator=entity_maker):
class Foo(object):
pass
-
+
canary = Canary()
manager.register_attribute(Foo, 'attr', True, extension=canary,
typecallable=typecallable)
self._test_adapter(SetIsh)
self._test_set(SetIsh)
self.assert_(getattr(SetIsh, '_sa_instrumented') == id(SetIsh))
+
+ def _test_dict(self, typecallable, creator=dictable_entity):
+ class Foo(object):
+ pass
+
+ canary = Canary()
+ manager.register_attribute(Foo, 'attr', True, extension=canary,
+ typecallable=typecallable)
+
+ obj = Foo()
+ adapter = collections.collection_adapter(obj.attr)
+ direct = obj.attr
+ control = dict()
+
+ def assert_eq():
+ self.assert_(set(direct.values()) == canary.data)
+ self.assert_(set(adapter) == canary.data)
+ self.assert_(direct == control)
+
+ def addall(*values):
+ for item in values:
+ direct.set(item)
+ control[item.a] = item
+ assert_eq()
+ def zap():
+ for item in list(adapter):
+ direct.remove(item)
+ control.clear()
+ # assume an 'set' method is available for tests
+ addall(creator())
+
+ if hasattr(direct, '__setitem__'):
+ e = creator()
+ direct[e.a] = e
+ control[e.a] = e
+ assert_eq()
+
+ e = creator(e.a, e.b)
+ direct[e.a] = e
+ control[e.a] = e
+ assert_eq()
+
+ if hasattr(direct, '__delitem__'):
+ e = creator()
+ addall(e)
+
+ del direct[e.a]
+ del control[e.a]
+ assert_eq()
+
+ e = creator()
+ try:
+ del direct[e.a]
+ except KeyError:
+ self.assert_(e not in canary.removed)
+
+ if hasattr(direct, 'clear'):
+ addall(creator(), creator(), creator())
+
+ direct.clear()
+ control.clear()
+ assert_eq()
+
+ direct.clear()
+ control.clear()
+ assert_eq()
+
+ if hasattr(direct, 'pop'):
+ e = creator()
+ addall(e)
+
+ direct.pop(e.a)
+ control.pop(e.a)
+ assert_eq()
+
+ e = creator()
+ try:
+ direct.pop(e.a)
+ except KeyError:
+ self.assert_(e not in canary.removed)
+
+ if hasattr(direct, 'popitem'):
+ zap()
+ e = creator()
+ addall(e)
+
+ direct.popitem()
+ control.popitem()
+ assert_eq()
+
+ if hasattr(direct, 'setdefault'):
+ e = creator()
+
+ val_a = direct.setdefault(e.a, e)
+ val_b = control.setdefault(e.a, e)
+ assert_eq()
+ self.assert_(val_a is val_b)
+
+ val_a = direct.setdefault(e.a, e)
+ val_b = control.setdefault(e.a, e)
+ assert_eq()
+ self.assert_(val_a is val_b)
+
+ if hasattr(direct, 'update'):
+ e = creator()
+ d = dict([(ee.a, ee) for ee in [e, creator(), creator()]])
+ addall(e, creator())
+
+ direct.update(d)
+ control.update(d)
+ assert_eq()
+
+ kw = dict([(ee.a, ee) for ee in [e, creator()]])
+ direct.update(**kw)
+ control.update(**kw)
+ assert_eq()
+
def test_dict(self):
- def dictable_entity(a=None, b=None, c=None):
- global _id
- _id += 1
- return Entity(a or str(_id), b or 'value %s' % _id, c)
+ try:
+ self._test_adapter(dict, dictable_entity,
+ to_set=lambda c: set(c.values()))
+ self.assert_(False)
+ except exceptions.ArgumentError, e:
+ self.assert_(e.args[0] == 'Type InstrumentedDict must elect an appender method to be a collection class')
+
+ try:
+ self._test_dict(dict)
+ self.assert_(False)
+ except exceptions.ArgumentError, e:
+ self.assert_(e.args[0] == 'Type InstrumentedDict must elect an appender method to be a collection class')
+
+ def test_dict_subclass(self):
+ class MyDict(dict):
+ @collection.appender
+ @collection.internally_instrumented
+ def set(self, item, _sa_initiator=None):
+ self.__setitem__(item.a, item, _sa_initiator=_sa_initiator)
+ @collection.remover
+ @collection.internally_instrumented
+ def _remove(self, item, _sa_initiator=None):
+ self.__delitem__(item.a, _sa_initiator=_sa_initiator)
+
+ self._test_adapter(MyDict, dictable_entity,
+ to_set=lambda c: set(c.values()))
+ self._test_dict(MyDict)
+ self.assert_(getattr(MyDict, '_sa_instrumented') == id(MyDict))
+
+ def test_dict_subclass2(self):
+ class MyEasyDict(collections.MappedCollection):
+ def __init__(self):
+ super(MyEasyDict, self).__init__(lambda e: e.a)
+
+ self._test_adapter(MyEasyDict, dictable_entity,
+ to_set=lambda c: set(c.values()))
+ self._test_dict(MyEasyDict)
+ self.assert_(getattr(MyEasyDict, '_sa_instrumented') == id(MyEasyDict))
+
+ def test_dict_subclass3(self):
+ class MyOrdered(util.OrderedDict, collections.MappedCollection):
+ def __init__(self):
+ collections.MappedCollection.__init__(self, lambda e: e.a)
+ util.OrderedDict.__init__(self)
+
+ self._test_adapter(MyOrdered, dictable_entity,
+ to_set=lambda c: set(c.values()))
+ self._test_dict(MyOrdered)
+ self.assert_(getattr(MyOrdered, '_sa_instrumented') == id(MyOrdered))
+
+
+ def test_dict_duck(self):
+ class DictLike(object):
+ def __init__(self):
+ self.data = dict()
+
+ @collection.appender
+ @collection.replaces(1)
+ def set(self, item):
+ current = self.data.get(item.a, None)
+ self.data[item.a] = item
+ return current
+ @collection.remover
+ def _remove(self, item):
+ del self.data[item.a]
+ def __setitem__(self, key, value):
+ self.data[key] = value
+ def __getitem__(self, key):
+ return self.data[key]
+ def __delitem__(self, key):
+ del self.data[key]
+ def values(self):
+ return self.data.values()
+ def __contains__(self, key):
+ return key in self.data
+ @collection.iterator
+ def itervalues(self):
+ return self.data.itervalues()
+ def __eq__(self, other):
+ return self.data == other
+ def __repr__(self):
+ return 'DictLike(%s)' % repr(self.data)
+
+ self._test_adapter(DictLike, dictable_entity,
+ to_set=lambda c: set(c.itervalues()))
+ self._test_dict(DictLike)
+ self.assert_(getattr(DictLike, '_sa_instrumented') == id(DictLike))
+
+ def test_dict_emulates(self):
+ class DictIsh(object):
+ __emulates__ = dict
+ def __init__(self):
+ self.data = dict()
+
+ @collection.appender
+ @collection.replaces(1)
+ def set(self, item):
+ current = self.data.get(item.a, None)
+ self.data[item.a] = item
+ return current
+ @collection.remover
+ def _remove(self, item):
+ del self.data[item.a]
+ def __setitem__(self, key, value):
+ self.data[key] = value
+ def __getitem__(self, key):
+ return self.data[key]
+ def __delitem__(self, key):
+ del self.data[key]
+ def values(self):
+ return self.data.values()
+ def __contains__(self, key):
+ return key in self.data
+ @collection.iterator
+ def itervalues(self):
+ return self.data.itervalues()
+ def __eq__(self, other):
+ return self.data == other
+ def __repr__(self):
+ return 'DictIsh(%s)' % repr(self.data)
+
+ self._test_adapter(DictIsh, dictable_entity,
+ to_set=lambda c: set(c.itervalues()))
+ self._test_dict(DictIsh)
+ self.assert_(getattr(DictIsh, '_sa_instrumented') == id(DictIsh))
+
+ def _test_object(self, typecallable, creator=entity_maker):
+ class Foo(object):
+ pass
+
+ canary = Canary()
+ manager.register_attribute(Foo, 'attr', True, extension=canary,
+ typecallable=typecallable)
+
+ obj = Foo()
+ adapter = collections.collection_adapter(obj.attr)
+ direct = obj.attr
+ control = set()
+
+ def assert_eq():
+ self.assert_(set(direct) == canary.data)
+ self.assert_(set(adapter) == canary.data)
+ self.assert_(direct == control)
+
+ # There is no API for object collections. We'll make one up
+ # for the purposes of the test.
+ e = creator()
+ direct.push(e)
+ control.add(e)
+ assert_eq()
+
+ direct.zark(e)
+ control.remove(e)
+ assert_eq()
- self._test_adapter(collections.attribute_mapped_collection('a'),
- dictable_entity, to_set=lambda c: set(c.values()))
+ e = creator()
+ direct.maybe_zark(e)
+ control.discard(e)
+ assert_eq()
+
+ e = creator()
+ direct.push(e)
+ control.add(e)
+ assert_eq()
+
+ e = creator()
+ direct.maybe_zark(e)
+ control.discard(e)
+ assert_eq()
+
+ def test_object_duck(self):
+ class MyCollection(object):
+ def __init__(self):
+ self.data = set()
+ @collection.appender
+ def push(self, item):
+ self.data.add(item)
+ @collection.remover
+ def zark(self, item):
+ self.data.remove(item)
+ @collection.removes_return()
+ def maybe_zark(self, item):
+ if item in self.data:
+ self.data.remove(item)
+ return item
+ @collection.iterator
+ def __iter__(self):
+ return iter(self.data)
+ def __eq__(self, other):
+ return self.data == other
+
+ self._test_adapter(MyCollection)
+ self._test_object(MyCollection)
+ self.assert_(getattr(MyCollection, '_sa_instrumented') ==
+ id(MyCollection))
+
+ def test_object_emulates(self):
+ class MyCollection2(object):
+ __emulates__ = None
+ def __init__(self):
+ self.data = set()
+ # looks like a list
+ def append(self, item):
+ assert False
+ @collection.appender
+ def push(self, item):
+ self.data.add(item)
+ @collection.remover
+ def zark(self, item):
+ self.data.remove(item)
+ @collection.removes_return()
+ def maybe_zark(self, item):
+ if item in self.data:
+ self.data.remove(item)
+ return item
+ @collection.iterator
+ def __iter__(self):
+ return iter(self.data)
+ def __eq__(self, other):
+ return self.data == other
+
+ self._test_adapter(MyCollection2)
+ self._test_object(MyCollection2)
+ self.assert_(getattr(MyCollection2, '_sa_instrumented') ==
+ id(MyCollection2))
+
class DictHelpersTest(testbase.ORMTest):
def define_tables(self, metadata):