]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Fix error handling with the combination of @asynchronous and @gen.coroutine.
authorBen Darnell <ben@bendarnell.com>
Sun, 24 Mar 2013 01:25:14 +0000 (21:25 -0400)
committerBen Darnell <ben@bendarnell.com>
Sun, 24 Mar 2013 01:25:14 +0000 (21:25 -0400)
Also make self.finish() optional for coroutines since we can finish
if the future resolves successfully.

tornado/test/gen_test.py
tornado/web.py

index 5da8cfcff54c1b7611b04e0e5d57a4f30f780258..a9db7bcb612e3ba01d909f2249262d4570fb2962 100644 (file)
@@ -647,6 +647,41 @@ class GenSequenceHandler(RequestHandler):
         self.finish("3")
 
 
+class GenCoroutineSequenceHandler(RequestHandler):
+    @asynchronous
+    @gen.coroutine
+    def get(self):
+        self.io_loop = self.request.connection.stream.io_loop
+        self.io_loop.add_callback((yield gen.Callback("k1")))
+        yield gen.Wait("k1")
+        self.write("1")
+        self.io_loop.add_callback((yield gen.Callback("k2")))
+        yield gen.Wait("k2")
+        self.write("2")
+        # reuse an old key
+        self.io_loop.add_callback((yield gen.Callback("k1")))
+        yield gen.Wait("k1")
+        self.finish("3")
+
+
+class GenCoroutineUnfinishedSequenceHandler(RequestHandler):
+    @asynchronous
+    @gen.coroutine
+    def get(self):
+        self.io_loop = self.request.connection.stream.io_loop
+        self.io_loop.add_callback((yield gen.Callback("k1")))
+        yield gen.Wait("k1")
+        self.write("1")
+        self.io_loop.add_callback((yield gen.Callback("k2")))
+        yield gen.Wait("k2")
+        self.write("2")
+        # reuse an old key
+        self.io_loop.add_callback((yield gen.Callback("k1")))
+        yield gen.Wait("k1")
+        # just write, don't finish
+        self.write("3")
+
+
 class GenTaskHandler(RequestHandler):
     @asynchronous
     @gen.engine
@@ -668,6 +703,16 @@ class GenExceptionHandler(RequestHandler):
         raise Exception("oops")
 
 
+class GenCoroutineExceptionHandler(RequestHandler):
+    @asynchronous
+    @gen.coroutine
+    def get(self):
+        # This test depends on the order of the two decorators.
+        io_loop = self.request.connection.stream.io_loop
+        yield gen.Task(io_loop.add_callback)
+        raise Exception("oops")
+
+
 class GenYieldExceptionHandler(RequestHandler):
     @asynchronous
     @gen.engine
@@ -688,8 +733,12 @@ class GenWebTest(AsyncHTTPTestCase):
     def get_app(self):
         return Application([
             ('/sequence', GenSequenceHandler),
+            ('/coroutine_sequence', GenCoroutineSequenceHandler),
+            ('/coroutine_unfinished_sequence',
+             GenCoroutineUnfinishedSequenceHandler),
             ('/task', GenTaskHandler),
             ('/exception', GenExceptionHandler),
+            ('/coroutine_exception', GenCoroutineExceptionHandler),
             ('/yield_exception', GenYieldExceptionHandler),
         ])
 
@@ -697,6 +746,14 @@ class GenWebTest(AsyncHTTPTestCase):
         response = self.fetch('/sequence')
         self.assertEqual(response.body, b"123")
 
+    def test_coroutine_sequence_handler(self):
+        response = self.fetch('/coroutine_sequence')
+        self.assertEqual(response.body, b"123")
+
+    def test_coroutine_unfinished_sequence_handler(self):
+        response = self.fetch('/coroutine_unfinished_sequence')
+        self.assertEqual(response.body, b"123")
+
     def test_task_handler(self):
         response = self.fetch('/task?url=%s' % url_escape(self.get_url('/sequence')))
         self.assertEqual(response.body, b"got response: 123")
@@ -707,6 +764,12 @@ class GenWebTest(AsyncHTTPTestCase):
             response = self.fetch('/exception')
         self.assertEqual(500, response.code)
 
+    def test_coroutine_exception_handler(self):
+        # Make sure we get an error and not a timeout
+        with ExpectLog(app_log, "Uncaught exception GET /coroutine_exception"):
+            response = self.fetch('/coroutine_exception')
+        self.assertEqual(500, response.code)
+
     def test_yield_exception_handler(self):
         response = self.fetch('/yield_exception')
         self.assertEqual(response.body, b'ok')
index fa07094d45263e4160802eb13ff0796734dca020..3293b539e3d565567b06307d2fed8be0f93898b1 100644 (file)
@@ -73,6 +73,7 @@ import traceback
 import types
 import uuid
 
+from tornado.concurrent import Future
 from tornado import escape
 from tornado import httputil
 from tornado import locale
@@ -1165,6 +1166,8 @@ def asynchronous(method):
               self.finish()
 
     """
+    # Delay the IOLoop import because it's not available on app engine.
+    from tornado.ioloop import IOLoop
     @functools.wraps(method)
     def wrapper(self, *args, **kwargs):
         if self.application._wsgi:
@@ -1172,7 +1175,20 @@ def asynchronous(method):
         self._auto_finish = False
         with stack_context.ExceptionStackContext(
                 self._stack_context_handle_exception):
-            return method(self, *args, **kwargs)
+            result = method(self, *args, **kwargs)
+            if isinstance(result, Future):
+                # If @asynchronous is used with @gen.coroutine, (but
+                # not @gen.engine), we can automatically finish the
+                # request when the future resolves.  Additionally,
+                # the Future will swallow any exceptions so we need
+                # to throw them back out to the stack context to finish
+                # the request.
+                def future_complete(f):
+                    f.result()
+                    if not self._finished:
+                        self.finish()
+                IOLoop.current().add_future(result, future_complete)
+            return result
     return wrapper