]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add @gen.coroutine and gen.Return for future-based (and py33-style) coroutines.
authorBen Darnell <ben@bendarnell.com>
Sat, 2 Mar 2013 18:29:48 +0000 (13:29 -0500)
committerBen Darnell <ben@bendarnell.com>
Sat, 2 Mar 2013 18:29:48 +0000 (13:29 -0500)
tornado/gen.py
tornado/test/gen_test.py
tornado/testing.py

index 8d9d8935cbe99e974c03bfcfc60b80f868d8abfa..6192fe498344f343da0e0fc9d2c0346a1d73e89f 100644 (file)
@@ -91,6 +91,10 @@ class BadYieldError(Exception):
     pass
 
 
+class ReturnValueIgnoredError(Exception):
+    pass
+
+
 def engine(func):
     """Decorator for asynchronous generators.
 
@@ -117,17 +121,93 @@ def engine(func):
                 return runner.handle_exception(typ, value, tb)
             return False
         with ExceptionStackContext(handle_exception) as deactivate:
-            gen = func(*args, **kwargs)
-            if isinstance(gen, types.GeneratorType):
-                runner = Runner(gen, deactivate)
-                runner.run()
-                return
-            assert gen is None, gen
+            try:
+                result = func(*args, **kwargs)
+            except (Return, StopIteration) as e:
+                result = getattr(e, 'value', None)
+            else:
+                if isinstance(result, types.GeneratorType):
+                    def final_callback(value):
+                        if value is not None:
+                            raise ReturnValueIgnoredError(
+                                "@gen.engine functions cannot return values: "
+                                "%r" % result)
+                        assert value is None
+                        deactivate()
+                    runner = Runner(result, final_callback)
+                    runner.run()
+                    return
+            if result is not None:
+                raise ReturnValueIgnoredError(
+                    "@gen.engine functions cannot return values: %r" % result)
             deactivate()
             # no yield, so we're done
     return wrapper
 
 
+def coroutine(func):
+    """Future-oriented decorator for asynchronous generators.
+
+    Similar to ``@gen.engine``, but the decorated function does not receive
+    a ``callback`` parameter.  Instead, it may "return" by raising the
+    special exception `gen.Return(value)`.  In Python 3.3+, it is also
+    possible for the function to simply use the ``return`` statement.
+    (prior to Python 3.3 generators were not allowed to also return values.
+
+    Functions with this decorator return a `Future`.  Additionally,
+    they may be called with a ``callback`` keyword argument, which will
+    be invoked with the future when it resolves.
+
+    From the caller's perspective, ``@gen.coroutine`` is similar to
+    the combination of ``@return_future`` and ``@gen.engine``.
+    """
+    @functools.wraps(func)
+    def wrapper(*args, **kwargs):
+        runner = None
+        future = Future()
+
+        if 'callback' in kwargs:
+            IOLoop.current().add_future(future, kwargs.pop('callback'))
+
+        def handle_exception(typ, value, tb):
+            try:
+                if runner is not None and runner.handle_exception(typ, value, tb):
+                    return True
+            except Exception as e:
+                # can't just say "Exception as value" - exceptions are cleared
+                # from local namespace after except clause finishes.
+                value = e
+            future.set_exception(value)
+            return True
+        with ExceptionStackContext(handle_exception) as deactivate:
+            try:
+                result = func(*args, **kwargs)
+            except (Return, StopIteration) as e:
+                result = getattr(e, 'value', None)
+            except Exception as e:
+                deactivate()
+                future.set_exception(e)
+                return future
+            else:
+                if isinstance(result, types.GeneratorType):
+                    def final_callback(value):
+                        deactivate()
+                        future.set_result(value)
+                    runner = Runner(result, final_callback)
+                    runner.run()
+                    return future
+            deactivate()
+            future.set_result(result)
+        return future
+    return wrapper
+
+
+class Return(Exception):
+    def __init__(self, value=None):
+        super(Return, self).__init__()
+        self.value = value
+
+
 class YieldPoint(object):
     """Base class for objects that may be yielded from the generator."""
     def start(self, runner):
@@ -374,7 +454,7 @@ class Runner(object):
                         yielded = self.gen.throw(*exc_info)
                     else:
                         yielded = self.gen.send(next)
-                except StopIteration:
+                except (StopIteration, Return) as e:
                     self.finished = True
                     if self.pending_callbacks and not self.had_exception:
                         # If we ran cleanly without waiting on all callbacks
@@ -384,7 +464,7 @@ class Runner(object):
                         raise LeakedCallbackError(
                             "finished without waiting for callbacks %r" %
                             self.pending_callbacks)
-                    self.final_callback()
+                    self.final_callback(getattr(e, 'value', None))
                     self.final_callback = None
                     return
                 except Exception:
index 7826bb6a8205f6bc58a70afc19579e361ee07bd7..687fd6758641b97f441c1afeba1bb64e968fe348 100644 (file)
@@ -1,6 +1,8 @@
 from __future__ import absolute_import, division, print_function, with_statement
 
 import functools
+import sys
+import textwrap
 import time
 
 from tornado.concurrent import return_future
@@ -8,16 +10,19 @@ from tornado.escape import url_escape
 from tornado.httpclient import AsyncHTTPClient
 from tornado.log import app_log
 from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog, gen_test
+from tornado.test.util import unittest
 from tornado.web import Application, RequestHandler, asynchronous
 
 from tornado import gen
 
 
+skipBefore33 = unittest.skipIf(sys.version_info < (3, 3), 'PEP 380 not available')
 
-class GenTest(AsyncTestCase):
+
+class GenEngineTest(AsyncTestCase):
     def run_gen(self, f):
         f()
-        self.wait()
+        return self.wait()
 
     def delay_callback(self, iterations, callback, arg):
         """Runs callback(arg) after a number of IOLoop iterations."""
@@ -320,6 +325,274 @@ class GenTest(AsyncTestCase):
         initial_stack_depth = len(stack_context._state.contexts)
         self.run_gen(outer)
 
+    def test_raise_after_stop(self):
+        # This pattern will be used in the following tests so make sure
+        # the exception propagates as expected.
+        @gen.engine
+        def f():
+            self.stop()
+            1 / 0
+
+        with self.assertRaises(ZeroDivisionError):
+            self.run_gen(f)
+
+    def test_sync_raise_return(self):
+        # gen.Return is allowed in @gen.engine, but it may not be used
+        # to return a value.
+        @gen.engine
+        def f():
+            self.stop(42)
+            raise gen.Return()
+
+        result = self.run_gen(f)
+        self.assertEqual(result, 42)
+
+    def test_async_raise_return(self):
+        @gen.engine
+        def f():
+            yield gen.Task(self.io_loop.add_callback)
+            self.stop(42)
+            raise gen.Return()
+
+        result = self.run_gen(f)
+        self.assertEqual(result, 42)
+
+    def test_sync_raise_return_value(self):
+        @gen.engine
+        def f():
+            raise gen.Return(42)
+
+        with self.assertRaises(gen.ReturnValueIgnoredError):
+            self.run_gen(f)
+
+    def test_async_raise_return_value(self):
+        @gen.engine
+        def f():
+            yield gen.Task(self.io_loop.add_callback)
+            raise gen.Return(42)
+
+        with self.assertRaises(gen.ReturnValueIgnoredError):
+            self.run_gen(f)
+
+    def test_return_value(self):
+        # It is an error to apply @gen.engine to a function that returns
+        # a value.
+        @gen.engine
+        def f():
+            return 42
+
+        with self.assertRaises(gen.ReturnValueIgnoredError):
+            self.run_gen(f)
+
+
+class GenCoroutineTest(AsyncTestCase):
+    def setUp(self):
+        # Stray StopIteration exceptions can lead to tests exiting prematurely,
+        # so we need explicit checks here to make sure the tests run all
+        # the way through.
+        self.finished = False
+        super(GenCoroutineTest, self).setUp()
+
+    def tearDown(self):
+        super(GenCoroutineTest, self).tearDown()
+        assert self.finished
+
+    @gen_test
+    def test_sync_gen_return(self):
+        @gen.coroutine
+        def f():
+            raise gen.Return(42)
+        result = yield f()
+        self.assertEqual(result, 42)
+        self.finished = True
+
+    @gen_test
+    def test_async_gen_return(self):
+        @gen.coroutine
+        def f():
+            yield gen.Task(self.io_loop.add_callback)
+            raise gen.Return(42)
+        result = yield f()
+        self.assertEqual(result, 42)
+        self.finished = True
+
+    @gen_test
+    def test_sync_return(self):
+        @gen.coroutine
+        def f():
+            return 42
+        result = yield f()
+        self.assertEqual(result, 42)
+        self.finished = True
+
+    @skipBefore33
+    @gen_test
+    def test_async_return(self):
+        # It is a compile-time error to return a value in a generator
+        # before Python 3.3, so we must test this with exec.
+        # Flatten the real global and local namespace into our fake globals:
+        # it's all global from the perspective of f().
+        global_namespace = dict(globals(), **locals())
+        local_namespace = {}
+        exec(textwrap.dedent("""
+        @gen.coroutine
+        def f():
+            yield gen.Task(self.io_loop.add_callback)
+            return 42
+        """), global_namespace, local_namespace)
+        result = yield local_namespace['f']()
+        self.assertEqual(result, 42)
+        self.finished = True
+
+    @skipBefore33
+    @gen_test
+    def test_async_early_return(self):
+        # A yield statement exists but is not executed, which means
+        # this function "returns" via an exception.  This exception
+        # doesn't happen before the exception handling is set up.
+        global_namespace = dict(globals(), **locals())
+        local_namespace = {}
+        exec(textwrap.dedent("""
+        @gen.coroutine
+        def f():
+            if True:
+                return 42
+            yield gen.Task(self.io_loop.add_callback)
+        """), global_namespace, local_namespace)
+        result = yield local_namespace['f']()
+        self.assertEqual(result, 42)
+        self.finished = True
+
+    @gen_test
+    def test_sync_return_no_value(self):
+        @gen.coroutine
+        def f():
+            return
+        result = yield f()
+        self.assertEqual(result, None)
+        self.finished = True
+
+    @gen_test
+    def test_async_return_no_value(self):
+        # Without a return value we don't need python 3.3.
+        @gen.coroutine
+        def f():
+            yield gen.Task(self.io_loop.add_callback)
+            return
+        result = yield f()
+        self.assertEqual(result, None)
+        self.finished = True
+
+    @gen_test
+    def test_sync_raise(self):
+        @gen.coroutine
+        def f():
+            1 / 0
+        # The exception is raised when the future is yielded
+        # (or equivalently when its result method is called),
+        # not when the function itself is called).
+        future = f()
+        with self.assertRaises(ZeroDivisionError):
+            yield future
+        self.finished = True
+
+    @gen_test
+    def test_async_raise(self):
+        @gen.coroutine
+        def f():
+            yield gen.Task(self.io_loop.add_callback)
+            1 / 0
+        future = f()
+        with self.assertRaises(ZeroDivisionError):
+            yield future
+        self.finished = True
+
+    @gen_test
+    def test_pass_callback(self):
+        @gen.coroutine
+        def f():
+            raise gen.Return(42)
+        # The callback version passes a future to the callback without
+        # resolving it so exception information is available to the caller.
+        future = yield gen.Task(f)
+        self.assertEqual(future.result(), 42)
+        self.finished = True
+
+    @gen_test
+    def test_replace_yieldpoint_exception(self):
+        # Test exception handling: a coroutine can catch one exception
+        # raised by a yield point and raise a different one.
+        @gen.coroutine
+        def f1():
+            1 / 0
+
+        @gen.coroutine
+        def f2():
+            try:
+                yield f1()
+            except ZeroDivisionError:
+                raise KeyError()
+
+        future = f2()
+        with self.assertRaises(KeyError):
+            yield future
+        self.finished = True
+
+    @gen_test
+    def test_swallow_yieldpoint_exception(self):
+        # Test exception handling: a coroutine can catch an exception
+        # raised by a yield point and not raise a different one.
+        @gen.coroutine
+        def f1():
+            1 / 0
+
+        @gen.coroutine
+        def f2():
+            try:
+                yield f1()
+            except ZeroDivisionError:
+                raise gen.Return(42)
+
+        result = yield f2()
+        self.assertEqual(result, 42)
+        self.finished = True
+
+    @gen_test
+    def test_replace_context_exception(self):
+        # Test exception handling: exceptions thrown into the stack context
+        # can be caught and replaced.
+        @gen.coroutine
+        def f2():
+            self.io_loop.add_callback(lambda: 1/ 0)
+            try:
+                yield gen.Task(self.io_loop.add_timeout,
+                               self.io_loop.time() + 10)
+            except ZeroDivisionError:
+                raise KeyError()
+
+        future = f2()
+        with self.assertRaises(KeyError):
+            yield future
+        self.finished = True
+
+    @gen_test
+    def test_swallow_context_exception(self):
+        # Test exception handling: exceptions thrown into the stack context
+        # can be caught and ignored.
+        @gen.coroutine
+        def f2():
+            self.io_loop.add_callback(lambda: 1/ 0)
+            try:
+                yield gen.Task(self.io_loop.add_timeout,
+                               self.io_loop.time() + 10)
+            except ZeroDivisionError:
+                raise gen.Return(42)
+
+        result = yield f2()
+        self.assertEqual(result, 42)
+        self.finished = True
+
+
 
 class GenSequenceHandler(RequestHandler):
     @asynchronous
index 8e635831c8ba727a33ad36aba0811f98f1916880..51716657f03c129676d8f9142b4f8987f116c7f8 100644 (file)
@@ -388,7 +388,7 @@ def gen_test(f):
         if result is None:
             return
         assert isinstance(result, types.GeneratorType)
-        runner = gen.Runner(result, self.stop)
+        runner = gen.Runner(result, lambda value: self.stop())
         runner.run()
         self.wait()
     return wrapper