from __future__ import absolute_import, division, print_function, with_statement
import functools
+import inspect
+import sys
from tornado.stack_context import ExceptionStackContext
+from tornado.util import raise_exc_info
try:
from concurrent import futures
return future
return wrapper
-# TODO: this needs a better name
-
-def future_wrap(f):
+def return_future(f):
+ """Decorator to make a function that returns via callback return a `Future`.
+
+ The wrapped function should take a ``callback`` keyword argument
+ and invoke it with one argument when it has finished. To signal failure,
+ the function can simply raise an exception (which will be
+ captured by the `stack_context` and passed along to the `Future`).
+
+ From the caller's perspective, the callback argument is optional.
+ If one is given, it will be invoked when the function is complete
+ with the `Future` as an argument. If no callback is given, the caller
+ should use the `Future` to wait for the function to complete
+ (perhaps by yielding it in a `gen.engine` function, or passing it
+ to `IOLoop.add_future`).
+
+ Usage::
+ @return_future
+ def future_func(arg1, arg2, callback):
+ # Do stuff (possibly asynchronous)
+ callback(result)
+
+ @gen.engine
+ def caller(callback):
+ yield future_func(arg1, arg2)
+ callback()
+
+ Note that ``@return_future`` and ``@gen.engine`` can be applied to the
+ same function, provided ``@return_future`` appears first.
+ """
+ try:
+ callback_pos = inspect.getargspec(f).args.index('callback')
+ except ValueError:
+ # Callback is not accepted as a positional parameter
+ callback_pos = None
@functools.wraps(f)
def wrapper(*args, **kwargs):
future = Future()
- if kwargs.get('callback') is not None:
- future.add_done_callback(kwargs.pop('callback'))
- kwargs['callback'] = future.set_result
+ if callback_pos is not None and len(args) > callback_pos:
+ # The callback argument is being passed positionally
+ if args[callback_pos] is not None:
+ future.add_done_callback(args[callback_pos])
+ args = list(args) # *args is normally a tuple
+ args[callback_pos] = future.set_result
+ else:
+ # The callback argument is either omitted or passed by keyword.
+ if kwargs.get('callback') is not None:
+ future.add_done_callback(kwargs.pop('callback'))
+ kwargs['callback'] = future.set_result
def handle_error(typ, value, tb):
future.set_exception(value)
return True
+ exc_info = None
with ExceptionStackContext(handle_error):
- f(*args, **kwargs)
+ try:
+ result = f(*args, **kwargs)
+ except:
+ exc_info = sys.exc_info()
+ assert result is None, ("@return_future should not be used with "
+ "functions that return values")
+ if exc_info is not None:
+ # If the initial synchronous part of f() raised an exception,
+ # go ahead and raise it to the caller directly without waiting
+ # for them to inspect the Future.
+ raise_exc_info(exc_info)
return future
return wrapper
import re
import socket
-from tornado.concurrent import Future, future_wrap
+from tornado.concurrent import Future, return_future
from tornado.escape import utf8, to_unicode
from tornado import gen
from tornado.iostream import IOStream
from tornado.testing import AsyncTestCase, LogTrapTestCase, get_unused_port
+class ReturnFutureTest(AsyncTestCase):
+ @return_future
+ def sync_future(self, callback):
+ callback(42)
+
+ @return_future
+ def async_future(self, callback):
+ self.io_loop.add_callback(callback, 42)
+
+ @return_future
+ def immediate_failure(self, callback):
+ 1 / 0
+
+ @return_future
+ def delayed_failure(self, callback):
+ self.io_loop.add_callback(lambda: 1 / 0)
+
+ def test_immediate_failure(self):
+ with self.assertRaises(ZeroDivisionError):
+ self.immediate_failure(callback=self.stop)
+
+ def test_callback_kw(self):
+ future = self.sync_future(callback=self.stop)
+ future2 = self.wait()
+ self.assertIs(future, future2)
+ self.assertEqual(future.result(), 42)
+
+ def test_callback_positional(self):
+ # When the callback is passed in positionally, future_wrap shouldn't
+ # add another callback in the kwargs.
+ future = self.sync_future(self.stop)
+ future2 = self.wait()
+ self.assertIs(future, future2)
+ self.assertEqual(future.result(), 42)
+
+ def test_no_callback(self):
+ future = self.sync_future()
+ self.assertEqual(future.result(), 42)
+
+ def test_none_callback_kw(self):
+ # explicitly pass None as callback
+ future = self.sync_future(callback=None)
+ self.assertEqual(future.result(), 42)
+
+ def test_none_callback_pos(self):
+ future = self.sync_future(None)
+ self.assertEqual(future.result(), 42)
+
+ def test_async_future(self):
+ future = self.async_future()
+ self.assertFalse(future.done())
+ self.io_loop.add_future(future, self.stop)
+ future2 = self.wait()
+ self.assertIs(future, future2)
+ self.assertEqual(future.result(), 42)
+
+ def test_delayed_failure(self):
+ future = self.delayed_failure()
+ self.io_loop.add_future(future, self.stop)
+ future2 = self.wait()
+ self.assertIs(future, future2)
+ with self.assertRaises(ZeroDivisionError):
+ future.result()
+
+ def test_kw_only_callback(self):
+ @return_future
+ def f(**kwargs):
+ kwargs['callback'](42)
+ future = f()
+ self.assertEqual(future.result(), 42)
+
+# The following series of classes demonstrate and test various styles
+# of use, with and without generators and futures.
class CapServer(TCPServer):
def handle_stream(self, stream, address):
logging.info("handle_stream")
class DecoratorCapClient(BaseCapClient):
- @future_wrap
+ @return_future
def capitalize(self, request_data, callback):
logging.info("capitalize")
self.request_data = request_data
class GeneratorCapClient(BaseCapClient):
- @future_wrap
+ @return_future
@gen.engine
def capitalize(self, request_data, callback):
logging.info('capitalize')