]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Use a StackContext to allow exceptions thrown from asynchronous functions
authorBen Darnell <ben@bendarnell.com>
Mon, 2 Jan 2012 03:20:05 +0000 (19:20 -0800)
committerBen Darnell <ben@bendarnell.com>
Mon, 2 Jan 2012 03:20:05 +0000 (19:20 -0800)
called by a generator to be caught normally.

Closes #405.
Closes #411.

tornado/gen.py
tornado/test/gen_test.py

index 59684f50fa9e314133d2ec0843a52a1174415e15..74c3bf9d18e3b2556f001c24daf7c25b6033c147 100644 (file)
@@ -62,12 +62,15 @@ it was called with one argument, the result is that argument.  If it was
 called with more than one argument or any keyword arguments, the result
 is an `Arguments` object, which is a named tuple ``(args, kwargs)``.
 """
+from __future__ import with_statement
 
 import functools
 import operator
 import sys
 import types
 
+from tornado.stack_context import ExceptionStackContext
+
 class KeyReuseError(Exception): pass
 class UnknownKeyError(Exception): pass
 class LeakedCallbackError(Exception): pass
@@ -86,12 +89,23 @@ def engine(func):
     """
     @functools.wraps(func)
     def wrapper(*args, **kwargs):
-        gen = func(*args, **kwargs)
-        if isinstance(gen, types.GeneratorType):
-            Runner(gen).run()
-            return
-        assert gen is None, gen
-        # no yield, so we're done
+        runner = None
+        def handle_exception(typ, value, tb):
+            # if the function throws an exception before its first "yield"
+            # (or is not a generator at all), the Runner won't exist yet.
+            # However, in that case we haven't reached anything asynchronous
+            # yet, so we can just let the exception propagate.
+            if runner is not None:
+                return runner.handle_exception(typ, value, tb)
+            return False
+        with ExceptionStackContext(handle_exception):
+            gen = func(*args, **kwargs)
+            if isinstance(gen, types.GeneratorType):
+                runner = Runner(gen)
+                runner.run()
+                return
+            assert gen is None, gen
+            # no yield, so we're done
     return wrapper
 
 class YieldPoint(object):
@@ -341,6 +355,14 @@ class Runner(object):
             self.set_result(key, result)
         return inner
 
+    def handle_exception(self, typ, value, tb):
+        if not self.running and not self.finished:
+            self.exc_info = (typ, value, tb)
+            self.run()
+            return True
+        else:
+            return False
+
 # in python 2.6+ this could be a collections.namedtuple
 class Arguments(tuple):
     """The result of a yield expression whose callback had more than one
index fec486705cfce28485e03b3fa8f54167bc6bce2f..15c30ab6d71397e99a8c110729f0307280dfd493 100644 (file)
@@ -70,6 +70,19 @@ class GenTest(AsyncTestCase):
                 self.stop()
         self.run_gen(f)
 
+    def test_exception_in_task_phase2(self):
+        def fail_task(callback):
+            self.io_loop.add_callback(lambda: 1/0)
+
+        @gen.engine
+        def f():
+            try:
+                yield gen.Task(fail_task)
+                raise Exception("did not get expected exception")
+            except ZeroDivisionError:
+                self.stop()
+        self.run_gen(f)
+
     def test_with_arg(self):
         @gen.engine
         def f():