]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Added test coverage for freeform collection decorators
authorJason Kirtland <jek@discorporate.us>
Sun, 5 Aug 2007 19:21:32 +0000 (19:21 +0000)
committerJason Kirtland <jek@discorporate.us>
Sun, 5 Aug 2007 19:21:32 +0000 (19:21 +0000)
Decorators with positional arg specs can be called with named args too...

lib/sqlalchemy/orm/collections.py
test/orm/collection.py

index 7ade882f5cde139255de37a71530512970e02585..194cf0938885c91e68696d9f975bfeca859b74d8 100644 (file)
@@ -696,58 +696,52 @@ def _instrument_class(cls):
 def _instrument_membership_mutator(method, before, argument, after):
     """Route method args and/or return value through the collection adapter."""
 
-    if type(argument) is int:
-        def wrapper(*args, **kw):
-            if before and len(args) < argument:
-                raise exceptions.ArgumentError(
-                    'Missing argument %i' % argument)
-            initiator = kw.pop('_sa_initiator', None)
-            if initiator is False:
-                executor = None
+    # This isn't smart enough to handle @adds(1) for 'def fn(self, (a, b))'
+    if before:
+        fn_args = list(sautil.flatten_iterator(inspect.getargspec(method)[0]))
+        if type(argument) is int:
+            pos_arg = argument
+            named_arg = len(fn_args) > argument and fn_args[argument] or None
+        else:
+            if argument in fn_args:
+                pos_arg = fn_args.index(argument)
             else:
-                executor = getattr(args[0], '_sa_adapter', None)
-            
-            if before and executor:
-                getattr(executor, before)(args[argument], initiator)
+                pos_arg = None
+            named_arg = argument
+        del fn_args
 
-            if not after or not executor:
-                return method(*args, **kw)
+    def wrapper(*args, **kw):
+        if before:
+            if pos_arg is None:
+                if named_arg not in kw:
+                    raise exceptions.ArgumentError(
+                        "Missing argument %s" % argument)
+                value = kw[named_arg]
             else:
-                res = method(*args, **kw)
-                if res is not None:
-                    getattr(executor, after)(res, initiator)
-                return res
-    else:
-        def wrapper(*args, **kw):
-            if before:
-                vals = inspect.getargvalues(inspect.currentframe())
-                if argument in kw:
-                    value = kw[argument]
+                if len(args) > pos_arg:
+                    value = args[pos_arg]
+                elif named_arg in kw:
+                    value = kw[named_arg]
                 else:
-                    positional = inspect.getargspec(method)[0]
-                    pos = positional.index(argument)
-                    if pos == -1:
-                        raise exceptions.ArgumentError('Missing argument %s' %
-                                                       argument)
-                    else:
-                        value = args[pos]
+                    raise exceptions.ArgumentError(
+                        "Missing argument %s" % argument)
 
-            initiator = kw.pop('_sa_initiator', None)
-            if initiator is False:
-                executor = None
-            else:
-                executor = getattr(args[0], '_sa_adapter', None)
-
-            if before and executor:
-                getattr(executor, before)(value, initiator)
+        initiator = kw.pop('_sa_initiator', None)
+        if initiator is False:
+            executor = None
+        else:
+            executor = getattr(args[0], '_sa_adapter', None)
+            
+        if before and executor:
+            getattr(executor, before)(value, initiator)
 
-            if not after or not executor:
-                return method(*args, **kw)
-            else:
-                res = method(*args, **kw)
-                if res is not None:
-                    getattr(executor, after)(res, initiator)
-                return res
+        if not after or not executor:
+            return method(*args, **kw)
+        else:
+            res = method(*args, **kw)
+            if res is not None:
+                getattr(executor, after)(res, initiator)
+            return res
     try:
         wrapper._sa_instrumented = True
         wrapper.__name__ = method.__name__
index 432528a83eb2ceba472a0fafe4b3411f568848b8..0cc8cf7e06ae5802636da3b020bea1e6d8b78011 100644 (file)
@@ -127,7 +127,7 @@ class CollectionsTest(PersistTest):
             assert_eq()
 
             if reduce(and_, [hasattr(direct, a) for a in
-                             ('__delitem', 'insert', '__len__')], True):
+                             ('__delitem__', 'insert', '__len__')], True):
                 values = [creator(), creator(), creator(), creator()]
                 direct[slice(0,1)] = values
                 control[slice(0,1)] = values
@@ -963,6 +963,107 @@ class CollectionsTest(PersistTest):
         self.assert_(getattr(MyCollection2, '_sa_instrumented') ==
                      id(MyCollection2))
 
+    def test_recipes(self):
+        class Custom(object):
+            def __init__(self):
+                self.data = []
+            @collection.appender
+            @collection.adds('entity')
+            def put(self, entity):
+                self.data.append(entity)
+
+            @collection.remover
+            @collection.removes(1)
+            def remove(self, entity):
+                self.data.remove(entity)
+
+            @collection.adds(1)
+            def push(self, *args):
+                self.data.append(args[0])
+
+            @collection.removes('entity')
+            def yank(self, entity, arg):
+                self.data.remove(entity)
+
+            @collection.replaces(2)
+            def replace(self, arg, entity, **kw):
+                self.data.insert(0, entity)
+                return self.data.pop()
+
+            @collection.removes_return()
+            def pop(self, key):
+                return self.data.pop()
+            
+            @collection.iterator
+            def __iter__(self):
+                return iter(self.data)
+
+        class Foo(object):
+            pass
+        canary = Canary()
+        manager.register_attribute(Foo, 'attr', True, extension=canary,
+                                   typecallable=Custom)
+
+        obj = Foo()
+        adapter = collections.collection_adapter(obj.attr)
+        direct = obj.attr
+        control = list()
+        def assert_eq():
+            self.assert_(set(direct) == canary.data)
+            self.assert_(set(adapter) == canary.data)
+            self.assert_(list(direct) == control)
+        creator = entity_maker
+
+        e1 = creator()
+        direct.put(e1)
+        control.append(e1)
+        assert_eq()
+
+        e2 = creator()
+        direct.put(entity=e2)
+        control.append(e2)
+        assert_eq()
+
+        direct.remove(e2)
+        control.remove(e2)
+        assert_eq()
+
+        direct.remove(entity=e1)
+        control.remove(e1)
+        assert_eq()
+
+        e3 = creator()
+        direct.push(e3)
+        control.append(e3)
+        assert_eq()
+
+        direct.yank(e3, 'blah')
+        control.remove(e3)
+        assert_eq()
+
+        e4, e5, e6, e7 = creator(), creator(), creator(), creator()
+        direct.put(e4)
+        direct.put(e5)
+        control.append(e4)
+        control.append(e5)
+
+        dr1 = direct.replace('foo', e6, bar='baz')
+        control.insert(0, e6)
+        cr1 = control.pop()
+        assert_eq()
+        self.assert_(dr1 is cr1)
+
+        dr2 = direct.replace(arg=1, entity=e7)
+        control.insert(0, e7)
+        cr2 = control.pop()
+        assert_eq()
+        self.assert_(dr2 is cr2)
+
+        dr3 = direct.pop('blah')
+        cr3 = control.pop()
+        assert_eq()
+        self.assert_(dr3 is cr3)
+
     def test_lifecycle(self):
         class Foo(object):
             pass