From c567f04f797bc5c00368acd3f716cf08d125d995 Mon Sep 17 00:00:00 2001 From: Ben Darnell Date: Sun, 17 Feb 2013 12:25:22 -0500 Subject: [PATCH] Factor argument replacement logic out of @return_future --- tornado/concurrent.py | 24 ++++++----------------- tornado/test/util_test.py | 22 ++++++++++++++++++++- tornado/util.py | 40 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 19 deletions(-) diff --git a/tornado/concurrent.py b/tornado/concurrent.py index e9057e84a..d73a59c9e 100644 --- a/tornado/concurrent.py +++ b/tornado/concurrent.py @@ -16,11 +16,10 @@ from __future__ import absolute_import, division, print_function, with_statement import functools -import inspect import sys from tornado.stack_context import ExceptionStackContext -from tornado.util import raise_exc_info +from tornado.util import raise_exc_info, ArgReplacer try: from concurrent import futures @@ -143,25 +142,14 @@ def return_future(f): Note that ``@return_future`` and ``@gen.engine`` can be applied to the same function, provided ``@return_future`` appears first. """ - try: - callback_pos = inspect.getargspec(f).args.index('callback') - except ValueError: - # Callback is not accepted as a positional parameter - callback_pos = None + replacer = ArgReplacer(f, 'callback') @functools.wraps(f) def wrapper(*args, **kwargs): future = Future() - if callback_pos is not None and len(args) > callback_pos: - # The callback argument is being passed positionally - if args[callback_pos] is not None: - future.add_done_callback(args[callback_pos]) - args = list(args) # *args is normally a tuple - args[callback_pos] = future.set_result - else: - # The callback argument is either omitted or passed by keyword. - if kwargs.get('callback') is not None: - future.add_done_callback(kwargs.pop('callback')) - kwargs['callback'] = future.set_result + callback, args, kwargs = replacer.replace(future.set_result, + args, kwargs) + if callback is not None: + future.add_done_callback(callback) def handle_error(typ, value, tb): future.set_exception(value) diff --git a/tornado/test/util_test.py b/tornado/test/util_test.py index 41ca21105..038602a8e 100644 --- a/tornado/test/util_test.py +++ b/tornado/test/util_test.py @@ -3,7 +3,7 @@ from __future__ import absolute_import, division, print_function, with_statement import sys from tornado.escape import utf8 -from tornado.util import raise_exc_info, Configurable, u, exec_in +from tornado.util import raise_exc_info, Configurable, u, exec_in, ArgReplacer from tornado.test.util import unittest try: @@ -141,3 +141,23 @@ class ExecInTest(unittest.TestCase): # ...but the template doesn't exec_in('print >> f, "world"', dict(f=f)) self.assertEqual(f.getvalue(), 'hello\nworld\n') + + +class ArgReplacerTest(unittest.TestCase): + def setUp(self): + def function(x, y, callback=None, z=None): + pass + self.replacer = ArgReplacer(function, 'callback') + + def test_omitted(self): + self.assertEqual(self.replacer.replace('new', (1, 2), dict()), + (None, (1, 2), dict(callback='new'))) + + def test_position(self): + self.assertEqual(self.replacer.replace('new', (1, 2, 'old', 3), dict()), + ('old', [1, 2, 'new', 3], dict())) + + def test_keyword(self): + self.assertEqual(self.replacer.replace('new', (1,), + dict(y=2, callback='old', z=3)), + ('old', (1,), dict(y=2, callback='new', z=3))) diff --git a/tornado/util.py b/tornado/util.py index deea41e2f..69de2c8e8 100644 --- a/tornado/util.py +++ b/tornado/util.py @@ -2,6 +2,7 @@ from __future__ import absolute_import, division, print_function, with_statement +import inspect import sys import zlib @@ -201,6 +202,45 @@ class Configurable(object): base.__impl_kwargs = saved[1] +class ArgReplacer(object): + """Replaces one value in an ``args, kwargs`` pair. + + Inspects the function signature to find an argument by name + whether it is passed by position or keyword. For use in decorators + and similar wrappers. + """ + def __init__(self, func, name): + """Create an ArgReplacer for the named argument to the given function. + """ + self.name = name + try: + self.arg_pos = inspect.getargspec(func).args.index(self.name) + except ValueError: + # Not a positional parameter + self.arg_pos = None + + def replace(self, new_value, args, kwargs): + """Replace the named argument in ``args, kwargs`` with ``new_value``. + + Returns ``(old_value, args, kwargs)``. The returned ``args`` and + ``kwargs`` objects may not be the same as the input objects, or + the input objects may be mutated. + + If the named argument was not found, ``new_value`` will be added + to ``kwargs`` and None will be returned as ``old_value``. + """ + if self.arg_pos is not None and len(args) > self.arg_pos: + # The arg to replace is passed positionally + old_value = args[self.arg_pos] + args = list(args) # *args is normally a tuple + args[self.arg_pos] = new_value + else: + # The arg to replace is either omitted or passed by keyword. + old_value = kwargs.get(self.name) + kwargs[self.name] = new_value + return old_value, args, kwargs + + def doctests(): import doctest return doctest.DocTestSuite() -- 2.47.2