]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Factor argument replacement logic out of @return_future
authorBen Darnell <ben@bendarnell.com>
Sun, 17 Feb 2013 17:25:22 +0000 (12:25 -0500)
committerBen Darnell <ben@bendarnell.com>
Sun, 17 Feb 2013 17:25:22 +0000 (12:25 -0500)
tornado/concurrent.py
tornado/test/util_test.py
tornado/util.py

index e9057e84ab79fd88f55c54162af453975f9f274c..d73a59c9e9af174f8add34b7ddf6f0f3db07c119 100644 (file)
 from __future__ import absolute_import, division, print_function, with_statement
 
 import functools
-import inspect
 import sys
 
 from tornado.stack_context import ExceptionStackContext
-from tornado.util import raise_exc_info
+from tornado.util import raise_exc_info, ArgReplacer
 
 try:
     from concurrent import futures
@@ -143,25 +142,14 @@ def return_future(f):
     Note that ``@return_future`` and ``@gen.engine`` can be applied to the
     same function, provided ``@return_future`` appears first.
     """
-    try:
-        callback_pos = inspect.getargspec(f).args.index('callback')
-    except ValueError:
-        # Callback is not accepted as a positional parameter
-        callback_pos = None
+    replacer = ArgReplacer(f, 'callback')
     @functools.wraps(f)
     def wrapper(*args, **kwargs):
         future = Future()
-        if callback_pos is not None and len(args) > callback_pos:
-            # The callback argument is being passed positionally
-            if args[callback_pos] is not None:
-                future.add_done_callback(args[callback_pos])
-            args = list(args)  # *args is normally a tuple
-            args[callback_pos] = future.set_result
-        else:
-            # The callback argument is either omitted or passed by keyword.
-            if kwargs.get('callback') is not None:
-                future.add_done_callback(kwargs.pop('callback'))
-            kwargs['callback'] = future.set_result
+        callback, args, kwargs = replacer.replace(future.set_result,
+                                                  args, kwargs)
+        if callback is not None:
+            future.add_done_callback(callback)
 
         def handle_error(typ, value, tb):
             future.set_exception(value)
index 41ca2110521f7b60fc284029653cc09428fa82ec..038602a8e2a0f8ddb844878bec44796f6ce268b6 100644 (file)
@@ -3,7 +3,7 @@ from __future__ import absolute_import, division, print_function, with_statement
 import sys
 
 from tornado.escape import utf8
-from tornado.util import raise_exc_info, Configurable, u, exec_in
+from tornado.util import raise_exc_info, Configurable, u, exec_in, ArgReplacer
 from tornado.test.util import unittest
 
 try:
@@ -141,3 +141,23 @@ class ExecInTest(unittest.TestCase):
         # ...but the template doesn't
         exec_in('print >> f, "world"', dict(f=f))
         self.assertEqual(f.getvalue(), 'hello\nworld\n')
+
+
+class ArgReplacerTest(unittest.TestCase):
+    def setUp(self):
+        def function(x, y, callback=None, z=None):
+            pass
+        self.replacer = ArgReplacer(function, 'callback')
+
+    def test_omitted(self):
+        self.assertEqual(self.replacer.replace('new', (1, 2), dict()),
+                         (None, (1, 2), dict(callback='new')))
+
+    def test_position(self):
+        self.assertEqual(self.replacer.replace('new', (1, 2, 'old', 3), dict()),
+                         ('old', [1, 2, 'new', 3], dict()))
+
+    def test_keyword(self):
+        self.assertEqual(self.replacer.replace('new', (1,),
+                                               dict(y=2, callback='old', z=3)),
+                         ('old', (1,), dict(y=2, callback='new', z=3)))
index deea41e2f45b186643f908ea0616f3da4a866028..69de2c8e8ee42b0df47ca36ef5d78f6f6d6844a7 100644 (file)
@@ -2,6 +2,7 @@
 
 from __future__ import absolute_import, division, print_function, with_statement
 
+import inspect
 import sys
 import zlib
 
@@ -201,6 +202,45 @@ class Configurable(object):
         base.__impl_kwargs = saved[1]
 
 
+class ArgReplacer(object):
+    """Replaces one value in an ``args, kwargs`` pair.
+
+    Inspects the function signature to find an argument by name
+    whether it is passed by position or keyword.  For use in decorators
+    and similar wrappers.
+    """
+    def __init__(self, func, name):
+        """Create an ArgReplacer for the named argument to the given function.
+        """
+        self.name = name
+        try:
+            self.arg_pos = inspect.getargspec(func).args.index(self.name)
+        except ValueError:
+            # Not a positional parameter
+            self.arg_pos = None
+
+    def replace(self, new_value, args, kwargs):
+        """Replace the named argument in ``args, kwargs`` with ``new_value``.
+
+        Returns ``(old_value, args, kwargs)``.  The returned ``args`` and
+        ``kwargs`` objects may not be the same as the input objects, or
+        the input objects may be mutated.
+
+        If the named argument was not found, ``new_value`` will be added
+        to ``kwargs`` and None will be returned as ``old_value``.
+        """
+        if self.arg_pos is not None and len(args) > self.arg_pos:
+            # The arg to replace is passed positionally
+            old_value = args[self.arg_pos]
+            args = list(args)  # *args is normally a tuple
+            args[self.arg_pos] = new_value
+        else:
+            # The arg to replace is either omitted or passed by keyword.
+            old_value = kwargs.get(self.name)
+            kwargs[self.name] = new_value
+        return old_value, args, kwargs
+
+
 def doctests():
     import doctest
     return doctest.DocTestSuite()