]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Allow exceptions thrown in the first (synchronous) phase of a gen.Task
authorBen Darnell <ben@bendarnell.com>
Mon, 2 Jan 2012 00:17:35 +0000 (16:17 -0800)
committerBen Darnell <ben@bendarnell.com>
Mon, 2 Jan 2012 00:17:35 +0000 (16:17 -0800)
to be caught by the generator.

tornado/gen.py
tornado/test/gen_test.py

index 6f6b49aa5b1880f457d8a31856c3424f130e1bca..59684f50fa9e314133d2ec0843a52a1174415e15 100644 (file)
@@ -255,6 +255,7 @@ class Runner(object):
         self.running = False
         self.finished = False
         self.exc_info = None
+        self.had_exception = False
 
     def register_callback(self, key):
         """Adds ``key`` to the list of callbacks."""
@@ -296,6 +297,7 @@ class Runner(object):
                         self.exc_info = sys.exc_info()
                 try:
                     if self.exc_info is not None:
+                        self.had_exception = True
                         exc_info = self.exc_info
                         self.exc_info = None
                         yielded = self.gen.throw(*exc_info)
@@ -303,7 +305,11 @@ class Runner(object):
                         yielded = self.gen.send(next)
                 except StopIteration:
                     self.finished = True
-                    if self.pending_callbacks:
+                    if self.pending_callbacks and not self.had_exception:
+                        # If we ran cleanly without waiting on all callbacks
+                        # raise an error (really more of a warning).  If we
+                        # had an exception then some callbacks may have been
+                        # orphaned, so skip the check in that case.
                         raise LeakedCallbackError(
                             "finished without waiting for callbacks %r" %
                             self.pending_callbacks)
@@ -315,7 +321,10 @@ class Runner(object):
                     yielded = Multi(yielded)
                 if isinstance(yielded, YieldPoint):
                     self.yield_point = yielded
-                    self.yield_point.start(self)
+                    try:
+                        self.yield_point.start(self)
+                    except Exception:
+                        self.exc_info = sys.exc_info()
                 else:
                     self.exc_info = (BadYieldError("yielded unknown object %r" % yielded),)
         finally:
index 4c6a4d591b00c2feb90800055dbda7dcd40ce241..fec486705cfce28485e03b3fa8f54167bc6bce2f 100644 (file)
@@ -57,6 +57,19 @@ class GenTest(AsyncTestCase):
             1/0
         self.assertRaises(ZeroDivisionError, self.run_gen, f)
 
+    def test_exception_in_task_phase1(self):
+        def fail_task(callback):
+            1/0
+
+        @gen.engine
+        def f():
+            try:
+                yield gen.Task(fail_task)
+                raise Exception("did not get expected exception")
+            except ZeroDivisionError:
+                self.stop()
+        self.run_gen(f)
+
     def test_with_arg(self):
         @gen.engine
         def f():