]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Capture the stack context of gen callbacks as they are generated.
authorBen Darnell <ben@bendarnell.com>
Fri, 8 Mar 2013 21:18:22 +0000 (16:18 -0500)
committerBen Darnell <ben@bendarnell.com>
Fri, 8 Mar 2013 21:18:22 +0000 (16:18 -0500)
This guards against functions that add their own stack context
without wrapping their callbacks.

tornado/gen.py
tornado/test/gen_test.py

index ec5648ae1d032518fbb388966b89632506f5e9c1..b819f40750628943472d258a1af1de2f1f11840c 100644 (file)
@@ -72,7 +72,7 @@ import types
 
 from tornado.concurrent import Future, TracebackFuture
 from tornado.ioloop import IOLoop
-from tornado.stack_context import ExceptionStackContext
+from tornado.stack_context import ExceptionStackContext, wrap
 
 
 class KeyReuseError(Exception):
@@ -496,7 +496,7 @@ class Runner(object):
             else:
                 result = None
             self.set_result(key, result)
-        return inner
+        return wrap(inner)
 
     def handle_exception(self, typ, value, tb):
         if not self.running and not self.finished:
index bfe92d7395a13d6c8a0fcda0da86ab196be52a42..7c0fecb12c8d8bce7eaae1b3b0e70bbb45230ffd 100644 (file)
@@ -1,5 +1,6 @@
 from __future__ import absolute_import, division, print_function, with_statement
 
+import contextlib
 import functools
 import sys
 import textwrap
@@ -9,6 +10,7 @@ from tornado.concurrent import return_future
 from tornado.escape import url_escape
 from tornado.httpclient import AsyncHTTPClient
 from tornado.log import app_log
+from tornado import stack_context
 from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog, gen_test
 from tornado.test.util import unittest
 from tornado.web import Application, RequestHandler, asynchronous
@@ -20,6 +22,20 @@ skipBefore33 = unittest.skipIf(sys.version_info < (3, 3), 'PEP 380 not available
 
 
 class GenEngineTest(AsyncTestCase):
+    def setUp(self):
+        super(GenEngineTest, self).setUp()
+        self.named_contexts = []
+
+    def named_context(self, name):
+        @contextlib.contextmanager
+        def context():
+            self.named_contexts.append(name)
+            try:
+                yield
+            finally:
+                self.assertEqual(self.named_contexts.pop(), name)
+        return context
+
     def run_gen(self, f):
         f()
         return self.wait()
@@ -286,8 +302,6 @@ class GenEngineTest(AsyncTestCase):
     def test_stack_context_leak(self):
         # regression test: repeated invocations of a gen-based
         # function should not result in accumulated stack_contexts
-        from tornado import stack_context
-
         @gen.engine
         def inner(callback):
             yield gen.Task(self.io_loop.add_callback)
@@ -305,8 +319,6 @@ class GenEngineTest(AsyncTestCase):
 
     def test_stack_context_leak_exception(self):
         # same as previous, but with a function that exits with an exception
-        from tornado import stack_context
-
         @gen.engine
         def inner(callback):
             yield gen.Task(self.io_loop.add_callback)
@@ -325,6 +337,32 @@ class GenEngineTest(AsyncTestCase):
         initial_stack_depth = len(stack_context._state.contexts)
         self.run_gen(outer)
 
+    def function_with_stack_context(self, callback):
+        # Technically this function should stack_context.wrap its callback
+        # upon entry.  However, it is very common for this step to be
+        # omitted.
+        def step2():
+            self.assertEqual(self.named_contexts, ['a'])
+            self.io_loop.add_callback(callback)
+
+        with stack_context.StackContext(self.named_context('a')):
+            self.io_loop.add_callback(step2)
+
+    @gen_test
+    def test_wait_transfer_stack_context(self):
+        # Wait should not pick up contexts from where callback was invoked,
+        # even if that function improperly fails to wrap its callback.
+        cb = yield gen.Callback('k1')
+        self.function_with_stack_context(cb)
+        self.assertEqual(self.named_contexts, [])
+        yield gen.Wait('k1')
+        self.assertEqual(self.named_contexts, [])
+
+    @gen_test
+    def test_task_transfer_stack_context(self):
+        yield gen.Task(self.function_with_stack_context)
+        self.assertEqual(self.named_contexts, [])
+
     def test_raise_after_stop(self):
         # This pattern will be used in the following tests so make sure
         # the exception propagates as expected.