]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Improve exception handling for gen module.
authorBen Darnell <ben@bendarnell.com>
Mon, 5 Sep 2011 20:23:57 +0000 (13:23 -0700)
committerBen Darnell <ben@bendarnell.com>
Mon, 5 Sep 2011 20:23:57 +0000 (13:23 -0700)
tornado/gen.py
tornado/test/gen_test.py

index 52a0b8b23501499fb6a58ac1b943f14fed56619e..6571e5ed81d660ee7cc230f08a0a67da95d9c7ff 100644 (file)
@@ -49,6 +49,7 @@ operations have started.
 """
 
 import functools
+import sys
 import types
 
 class KeyReuseError(Exception): pass
@@ -215,8 +216,9 @@ class Runner(object):
         self.yield_point = _NullYieldPoint()
         self.pending_callbacks = set()
         self.results = {}
-        self.waiting = None
         self.running = False
+        self.finished = False
+        self.exc_info = None
 
     def register_callback(self, key):
         """Adds ``key`` to the list of callbacks."""
@@ -244,26 +246,40 @@ class Runner(object):
         """Starts or resumes the generator, running until it reaches a
         yield point that is not ready.
         """
-        if self.running:
+        if self.running or self.finished:
             return
         try:
             self.running = True
             while True:
-                if not self.yield_point.is_ready():
-                    return
-                next = self.yield_point.get_result()
+                if self.exc_info is None:
+                    try:
+                        if not self.yield_point.is_ready():
+                            return
+                        next = self.yield_point.get_result()
+                    except Exception:
+                        self.exc_info = sys.exc_info()
                 try:
-                    yielded = self.gen.send(next)
+                    if self.exc_info is not None:
+                        exc_info = self.exc_info
+                        self.exc_info = None
+                        yielded = self.gen.throw(*exc_info)
+                    else:
+                        yielded = self.gen.send(next)
                 except StopIteration:
+                    self.finished = True
                     if self.pending_callbacks:
                         raise LeakedCallbackError(
                             "finished without waiting for callbacks %r" %
                             self.pending_callbacks)
                     return
-                if not isinstance(yielded, YieldPoint):
-                    raise BadYieldError("yielded unknown object %r" % yielded)
-                self.yield_point = yielded
-                self.yield_point.start(self)
+                except Exception:
+                    self.finished = True
+                    raise
+                if isinstance(yielded, YieldPoint):
+                    self.yield_point = yielded
+                    self.yield_point.start(self)
+                else:
+                    self.exc_info = (BadYieldError("yielded unknown object %r" % yielded),)
         finally:
             self.running = False
 
index 7e6641b146a8f359bb23805a8a01855f57f6d30a..49cf978a20ecd01c962f0b905521528cf9876fca 100644 (file)
@@ -125,6 +125,41 @@ class GenTest(AsyncTestCase):
             self.stop()
         self.run_gen(f)
 
+    def test_exception_in_yield(self):
+        @gen.engine
+        def f():
+            try:
+                yield gen.Wait("k1")
+                raise "did not get expected exception"
+            except gen.UnknownKeyError:
+                pass
+            self.stop()
+        self.run_gen(f)
+
+    def test_resume_after_exception_in_yield(self):
+        @gen.engine
+        def f():
+            try:
+                yield gen.Wait("k1")
+                raise "did not get expected exception"
+            except gen.UnknownKeyError:
+                pass
+            (yield gen.Callback("k2"))("v2")
+            self.assertEqual((yield gen.Wait("k2")), "v2")
+            self.stop()
+        self.run_gen(f)
+
+    def test_orphaned_callback(self):
+        @gen.engine
+        def f():
+            self.orphaned_callback = yield gen.Callback(1)
+        try:
+            self.run_gen(f)
+            raise "did not get expected exception"
+        except gen.LeakedCallbackError:
+            pass
+        self.orphaned_callback()
+
 
 class GenSequenceHandler(RequestHandler):
     @asynchronous