def _instrument_membership_mutator(method, before, argument, after):
"""Route method args and/or return value through the collection adapter."""
- if type(argument) is int:
- def wrapper(*args, **kw):
- if before and len(args) < argument:
- raise exceptions.ArgumentError(
- 'Missing argument %i' % argument)
- initiator = kw.pop('_sa_initiator', None)
- if initiator is False:
- executor = None
+ # This isn't smart enough to handle @adds(1) for 'def fn(self, (a, b))'
+ if before:
+ fn_args = list(sautil.flatten_iterator(inspect.getargspec(method)[0]))
+ if type(argument) is int:
+ pos_arg = argument
+ named_arg = len(fn_args) > argument and fn_args[argument] or None
+ else:
+ if argument in fn_args:
+ pos_arg = fn_args.index(argument)
else:
- executor = getattr(args[0], '_sa_adapter', None)
-
- if before and executor:
- getattr(executor, before)(args[argument], initiator)
+ pos_arg = None
+ named_arg = argument
+ del fn_args
- if not after or not executor:
- return method(*args, **kw)
+ def wrapper(*args, **kw):
+ if before:
+ if pos_arg is None:
+ if named_arg not in kw:
+ raise exceptions.ArgumentError(
+ "Missing argument %s" % argument)
+ value = kw[named_arg]
else:
- res = method(*args, **kw)
- if res is not None:
- getattr(executor, after)(res, initiator)
- return res
- else:
- def wrapper(*args, **kw):
- if before:
- vals = inspect.getargvalues(inspect.currentframe())
- if argument in kw:
- value = kw[argument]
+ if len(args) > pos_arg:
+ value = args[pos_arg]
+ elif named_arg in kw:
+ value = kw[named_arg]
else:
- positional = inspect.getargspec(method)[0]
- pos = positional.index(argument)
- if pos == -1:
- raise exceptions.ArgumentError('Missing argument %s' %
- argument)
- else:
- value = args[pos]
+ raise exceptions.ArgumentError(
+ "Missing argument %s" % argument)
- initiator = kw.pop('_sa_initiator', None)
- if initiator is False:
- executor = None
- else:
- executor = getattr(args[0], '_sa_adapter', None)
-
- if before and executor:
- getattr(executor, before)(value, initiator)
+ initiator = kw.pop('_sa_initiator', None)
+ if initiator is False:
+ executor = None
+ else:
+ executor = getattr(args[0], '_sa_adapter', None)
+
+ if before and executor:
+ getattr(executor, before)(value, initiator)
- if not after or not executor:
- return method(*args, **kw)
- else:
- res = method(*args, **kw)
- if res is not None:
- getattr(executor, after)(res, initiator)
- return res
+ if not after or not executor:
+ return method(*args, **kw)
+ else:
+ res = method(*args, **kw)
+ if res is not None:
+ getattr(executor, after)(res, initiator)
+ return res
try:
wrapper._sa_instrumented = True
wrapper.__name__ = method.__name__
assert_eq()
if reduce(and_, [hasattr(direct, a) for a in
- ('__delitem', 'insert', '__len__')], True):
+ ('__delitem__', 'insert', '__len__')], True):
values = [creator(), creator(), creator(), creator()]
direct[slice(0,1)] = values
control[slice(0,1)] = values
self.assert_(getattr(MyCollection2, '_sa_instrumented') ==
id(MyCollection2))
+ def test_recipes(self):
+ class Custom(object):
+ def __init__(self):
+ self.data = []
+ @collection.appender
+ @collection.adds('entity')
+ def put(self, entity):
+ self.data.append(entity)
+
+ @collection.remover
+ @collection.removes(1)
+ def remove(self, entity):
+ self.data.remove(entity)
+
+ @collection.adds(1)
+ def push(self, *args):
+ self.data.append(args[0])
+
+ @collection.removes('entity')
+ def yank(self, entity, arg):
+ self.data.remove(entity)
+
+ @collection.replaces(2)
+ def replace(self, arg, entity, **kw):
+ self.data.insert(0, entity)
+ return self.data.pop()
+
+ @collection.removes_return()
+ def pop(self, key):
+ return self.data.pop()
+
+ @collection.iterator
+ def __iter__(self):
+ return iter(self.data)
+
+ class Foo(object):
+ pass
+ canary = Canary()
+ manager.register_attribute(Foo, 'attr', True, extension=canary,
+ typecallable=Custom)
+
+ obj = Foo()
+ adapter = collections.collection_adapter(obj.attr)
+ direct = obj.attr
+ control = list()
+ def assert_eq():
+ self.assert_(set(direct) == canary.data)
+ self.assert_(set(adapter) == canary.data)
+ self.assert_(list(direct) == control)
+ creator = entity_maker
+
+ e1 = creator()
+ direct.put(e1)
+ control.append(e1)
+ assert_eq()
+
+ e2 = creator()
+ direct.put(entity=e2)
+ control.append(e2)
+ assert_eq()
+
+ direct.remove(e2)
+ control.remove(e2)
+ assert_eq()
+
+ direct.remove(entity=e1)
+ control.remove(e1)
+ assert_eq()
+
+ e3 = creator()
+ direct.push(e3)
+ control.append(e3)
+ assert_eq()
+
+ direct.yank(e3, 'blah')
+ control.remove(e3)
+ assert_eq()
+
+ e4, e5, e6, e7 = creator(), creator(), creator(), creator()
+ direct.put(e4)
+ direct.put(e5)
+ control.append(e4)
+ control.append(e5)
+
+ dr1 = direct.replace('foo', e6, bar='baz')
+ control.insert(0, e6)
+ cr1 = control.pop()
+ assert_eq()
+ self.assert_(dr1 is cr1)
+
+ dr2 = direct.replace(arg=1, entity=e7)
+ control.insert(0, e7)
+ cr2 = control.pop()
+ assert_eq()
+ self.assert_(dr2 is cr2)
+
+ dr3 = direct.pop('blah')
+ cr3 = control.pop()
+ assert_eq()
+ self.assert_(dr3 is cr3)
+
def test_lifecycle(self):
class Foo(object):
pass