From: Jason Kirtland Date: Sun, 5 Aug 2007 19:21:32 +0000 (+0000) Subject: Added test coverage for freeform collection decorators X-Git-Tag: rel_0_4beta1~68 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a4efa31dc029350ddb70637d8c5326df2e02a815;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Added test coverage for freeform collection decorators Decorators with positional arg specs can be called with named args too... --- diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 7ade882f5c..194cf09388 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -696,58 +696,52 @@ def _instrument_class(cls): def _instrument_membership_mutator(method, before, argument, after): """Route method args and/or return value through the collection adapter.""" - if type(argument) is int: - def wrapper(*args, **kw): - if before and len(args) < argument: - raise exceptions.ArgumentError( - 'Missing argument %i' % argument) - initiator = kw.pop('_sa_initiator', None) - if initiator is False: - executor = None + # This isn't smart enough to handle @adds(1) for 'def fn(self, (a, b))' + if before: + fn_args = list(sautil.flatten_iterator(inspect.getargspec(method)[0])) + if type(argument) is int: + pos_arg = argument + named_arg = len(fn_args) > argument and fn_args[argument] or None + else: + if argument in fn_args: + pos_arg = fn_args.index(argument) else: - executor = getattr(args[0], '_sa_adapter', None) - - if before and executor: - getattr(executor, before)(args[argument], initiator) + pos_arg = None + named_arg = argument + del fn_args - if not after or not executor: - return method(*args, **kw) + def wrapper(*args, **kw): + if before: + if pos_arg is None: + if named_arg not in kw: + raise exceptions.ArgumentError( + "Missing argument %s" % argument) + value = kw[named_arg] else: - res = method(*args, **kw) - if res is not None: - getattr(executor, after)(res, initiator) - return res - else: - def wrapper(*args, **kw): - if before: - vals = inspect.getargvalues(inspect.currentframe()) - if argument in kw: - value = kw[argument] + if len(args) > pos_arg: + value = args[pos_arg] + elif named_arg in kw: + value = kw[named_arg] else: - positional = inspect.getargspec(method)[0] - pos = positional.index(argument) - if pos == -1: - raise exceptions.ArgumentError('Missing argument %s' % - argument) - else: - value = args[pos] + raise exceptions.ArgumentError( + "Missing argument %s" % argument) - initiator = kw.pop('_sa_initiator', None) - if initiator is False: - executor = None - else: - executor = getattr(args[0], '_sa_adapter', None) - - if before and executor: - getattr(executor, before)(value, initiator) + initiator = kw.pop('_sa_initiator', None) + if initiator is False: + executor = None + else: + executor = getattr(args[0], '_sa_adapter', None) + + if before and executor: + getattr(executor, before)(value, initiator) - if not after or not executor: - return method(*args, **kw) - else: - res = method(*args, **kw) - if res is not None: - getattr(executor, after)(res, initiator) - return res + if not after or not executor: + return method(*args, **kw) + else: + res = method(*args, **kw) + if res is not None: + getattr(executor, after)(res, initiator) + return res try: wrapper._sa_instrumented = True wrapper.__name__ = method.__name__ diff --git a/test/orm/collection.py b/test/orm/collection.py index 432528a83e..0cc8cf7e06 100644 --- a/test/orm/collection.py +++ b/test/orm/collection.py @@ -127,7 +127,7 @@ class CollectionsTest(PersistTest): assert_eq() if reduce(and_, [hasattr(direct, a) for a in - ('__delitem', 'insert', '__len__')], True): + ('__delitem__', 'insert', '__len__')], True): values = [creator(), creator(), creator(), creator()] direct[slice(0,1)] = values control[slice(0,1)] = values @@ -963,6 +963,107 @@ class CollectionsTest(PersistTest): self.assert_(getattr(MyCollection2, '_sa_instrumented') == id(MyCollection2)) + def test_recipes(self): + class Custom(object): + def __init__(self): + self.data = [] + @collection.appender + @collection.adds('entity') + def put(self, entity): + self.data.append(entity) + + @collection.remover + @collection.removes(1) + def remove(self, entity): + self.data.remove(entity) + + @collection.adds(1) + def push(self, *args): + self.data.append(args[0]) + + @collection.removes('entity') + def yank(self, entity, arg): + self.data.remove(entity) + + @collection.replaces(2) + def replace(self, arg, entity, **kw): + self.data.insert(0, entity) + return self.data.pop() + + @collection.removes_return() + def pop(self, key): + return self.data.pop() + + @collection.iterator + def __iter__(self): + return iter(self.data) + + class Foo(object): + pass + canary = Canary() + manager.register_attribute(Foo, 'attr', True, extension=canary, + typecallable=Custom) + + obj = Foo() + adapter = collections.collection_adapter(obj.attr) + direct = obj.attr + control = list() + def assert_eq(): + self.assert_(set(direct) == canary.data) + self.assert_(set(adapter) == canary.data) + self.assert_(list(direct) == control) + creator = entity_maker + + e1 = creator() + direct.put(e1) + control.append(e1) + assert_eq() + + e2 = creator() + direct.put(entity=e2) + control.append(e2) + assert_eq() + + direct.remove(e2) + control.remove(e2) + assert_eq() + + direct.remove(entity=e1) + control.remove(e1) + assert_eq() + + e3 = creator() + direct.push(e3) + control.append(e3) + assert_eq() + + direct.yank(e3, 'blah') + control.remove(e3) + assert_eq() + + e4, e5, e6, e7 = creator(), creator(), creator(), creator() + direct.put(e4) + direct.put(e5) + control.append(e4) + control.append(e5) + + dr1 = direct.replace('foo', e6, bar='baz') + control.insert(0, e6) + cr1 = control.pop() + assert_eq() + self.assert_(dr1 is cr1) + + dr2 = direct.replace(arg=1, entity=e7) + control.insert(0, e7) + cr2 = control.pop() + assert_eq() + self.assert_(dr2 is cr2) + + dr3 = direct.pop('blah') + cr3 = control.pop() + assert_eq() + self.assert_(dr3 is cr3) + def test_lifecycle(self): class Foo(object): pass