]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Rename future_wrap to return_future.
authorBen Darnell <ben@bendarnell.com>
Sun, 27 Jan 2013 22:03:15 +0000 (17:03 -0500)
committerBen Darnell <ben@bendarnell.com>
Sun, 27 Jan 2013 22:03:15 +0000 (17:03 -0500)
Allow exceptions from the initial synchronous phase to pass through
to the caller.  Add docs and tests.

tornado/concurrent.py
tornado/test/concurrent_test.py

index e822d535392a1438d1f16ab35f3bcf90cdaf6369..e9057e84ab79fd88f55c54162af453975f9f274c 100644 (file)
 from __future__ import absolute_import, division, print_function, with_statement
 
 import functools
+import inspect
+import sys
 
 from tornado.stack_context import ExceptionStackContext
+from tornado.util import raise_exc_info
 
 try:
     from concurrent import futures
@@ -110,21 +113,71 @@ def run_on_executor(fn):
         return future
     return wrapper
 
-# TODO: this needs a better name
 
-
-def future_wrap(f):
+def return_future(f):
+    """Decorator to make a function that returns via callback return a `Future`.
+
+    The wrapped function should take a ``callback`` keyword argument
+    and invoke it with one argument when it has finished.  To signal failure,
+    the function can simply raise an exception (which will be
+    captured by the `stack_context` and passed along to the `Future`).
+
+    From the caller's perspective, the callback argument is optional.
+    If one is given, it will be invoked when the function is complete
+    with the `Future` as an argument.  If no callback is given, the caller
+    should use the `Future` to wait for the function to complete
+    (perhaps by yielding it in a `gen.engine` function, or passing it
+    to `IOLoop.add_future`).
+
+    Usage::
+        @return_future
+        def future_func(arg1, arg2, callback):
+            # Do stuff (possibly asynchronous)
+            callback(result)
+
+        @gen.engine
+        def caller(callback):
+            yield future_func(arg1, arg2)
+            callback()
+
+    Note that ``@return_future`` and ``@gen.engine`` can be applied to the
+    same function, provided ``@return_future`` appears first.
+    """
+    try:
+        callback_pos = inspect.getargspec(f).args.index('callback')
+    except ValueError:
+        # Callback is not accepted as a positional parameter
+        callback_pos = None
     @functools.wraps(f)
     def wrapper(*args, **kwargs):
         future = Future()
-        if kwargs.get('callback') is not None:
-            future.add_done_callback(kwargs.pop('callback'))
-        kwargs['callback'] = future.set_result
+        if callback_pos is not None and len(args) > callback_pos:
+            # The callback argument is being passed positionally
+            if args[callback_pos] is not None:
+                future.add_done_callback(args[callback_pos])
+            args = list(args)  # *args is normally a tuple
+            args[callback_pos] = future.set_result
+        else:
+            # The callback argument is either omitted or passed by keyword.
+            if kwargs.get('callback') is not None:
+                future.add_done_callback(kwargs.pop('callback'))
+            kwargs['callback'] = future.set_result
 
         def handle_error(typ, value, tb):
             future.set_exception(value)
             return True
+        exc_info = None
         with ExceptionStackContext(handle_error):
-            f(*args, **kwargs)
+            try:
+                result = f(*args, **kwargs)
+            except:
+                exc_info = sys.exc_info()
+            assert result is None, ("@return_future should not be used with "
+                                    "functions that return values")
+        if exc_info is not None:
+            # If the initial synchronous part of f() raised an exception,
+            # go ahead and raise it to the caller directly without waiting
+            # for them to inspect the Future.
+            raise_exc_info(exc_info)
         return future
     return wrapper
index 994150cb3d5f78c10c1299f56fc0a6d391ed03bc..c0666065746c406abdf6d8007c6536d11d8dacf3 100644 (file)
@@ -19,7 +19,7 @@ import logging
 import re
 import socket
 
-from tornado.concurrent import Future, future_wrap
+from tornado.concurrent import Future, return_future
 from tornado.escape import utf8, to_unicode
 from tornado import gen
 from tornado.iostream import IOStream
@@ -27,6 +27,79 @@ from tornado.tcpserver import TCPServer
 from tornado.testing import AsyncTestCase, LogTrapTestCase, get_unused_port
 
 
+class ReturnFutureTest(AsyncTestCase):
+    @return_future
+    def sync_future(self, callback):
+        callback(42)
+
+    @return_future
+    def async_future(self, callback):
+        self.io_loop.add_callback(callback, 42)
+
+    @return_future
+    def immediate_failure(self, callback):
+        1 / 0
+
+    @return_future
+    def delayed_failure(self, callback):
+        self.io_loop.add_callback(lambda: 1 / 0)
+
+    def test_immediate_failure(self):
+        with self.assertRaises(ZeroDivisionError):
+            self.immediate_failure(callback=self.stop)
+
+    def test_callback_kw(self):
+        future = self.sync_future(callback=self.stop)
+        future2 = self.wait()
+        self.assertIs(future, future2)
+        self.assertEqual(future.result(), 42)
+
+    def test_callback_positional(self):
+        # When the callback is passed in positionally, future_wrap shouldn't
+        # add another callback in the kwargs.
+        future = self.sync_future(self.stop)
+        future2 = self.wait()
+        self.assertIs(future, future2)
+        self.assertEqual(future.result(), 42)
+
+    def test_no_callback(self):
+        future = self.sync_future()
+        self.assertEqual(future.result(), 42)
+
+    def test_none_callback_kw(self):
+        # explicitly pass None as callback
+        future = self.sync_future(callback=None)
+        self.assertEqual(future.result(), 42)
+
+    def test_none_callback_pos(self):
+        future = self.sync_future(None)
+        self.assertEqual(future.result(), 42)
+
+    def test_async_future(self):
+        future = self.async_future()
+        self.assertFalse(future.done())
+        self.io_loop.add_future(future, self.stop)
+        future2 = self.wait()
+        self.assertIs(future, future2)
+        self.assertEqual(future.result(), 42)
+
+    def test_delayed_failure(self):
+        future = self.delayed_failure()
+        self.io_loop.add_future(future, self.stop)
+        future2 = self.wait()
+        self.assertIs(future, future2)
+        with self.assertRaises(ZeroDivisionError):
+            future.result()
+
+    def test_kw_only_callback(self):
+        @return_future
+        def f(**kwargs):
+            kwargs['callback'](42)
+        future = f()
+        self.assertEqual(future.result(), 42)
+
+# The following series of classes demonstrate and test various styles
+# of use, with and without generators and futures.
 class CapServer(TCPServer):
     def handle_stream(self, stream, address):
         logging.info("handle_stream")
@@ -88,7 +161,7 @@ class ManualCapClient(BaseCapClient):
 
 
 class DecoratorCapClient(BaseCapClient):
-    @future_wrap
+    @return_future
     def capitalize(self, request_data, callback):
         logging.info("capitalize")
         self.request_data = request_data
@@ -109,7 +182,7 @@ class DecoratorCapClient(BaseCapClient):
 
 
 class GeneratorCapClient(BaseCapClient):
-    @future_wrap
+    @return_future
     @gen.engine
     def capitalize(self, request_data, callback):
         logging.info('capitalize')