]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add an assertion for inconsistent StackContexts when used with generators.
authorBen Darnell <ben@bendarnell.com>
Sat, 16 Feb 2013 00:45:53 +0000 (19:45 -0500)
committerBen Darnell <ben@bendarnell.com>
Sat, 16 Feb 2013 00:45:53 +0000 (19:45 -0500)
tornado/stack_context.py
tornado/test/stack_context_test.py

index 832fe99f5280ec12705770682587ec621290036b..2a055d09d95fee096362d99cd24e0afdc69151ed 100644 (file)
@@ -77,6 +77,10 @@ import threading
 from tornado.util import raise_exc_info
 
 
+class StackContextInconsistentError(Exception):
+    pass
+
+
 class _State(threading.local):
     def __init__(self):
         self.contexts = ()
@@ -113,8 +117,10 @@ class StackContext(object):
     def __enter__(self):
         self.old_contexts = _state.contexts
         # _state.contexts is a tuple of (class, arg, active_cell) tuples
-        _state.contexts = (self.old_contexts +
-                           ((StackContext, self.context_factory, self.active_cell),))
+        self.new_contexts = (self.old_contexts +
+                             ((StackContext, self.context_factory,
+                               self.active_cell),))
+        _state.contexts = self.new_contexts
         try:
             self.context = self.context_factory()
             self.context.__enter__()
@@ -127,7 +133,19 @@ class StackContext(object):
         try:
             return self.context.__exit__(type, value, traceback)
         finally:
+            final_contexts = _state.contexts
             _state.contexts = self.old_contexts
+            # Generator coroutines and with-statements with non-local
+            # effects interact badly.  Check here for signs of
+            # the stack getting out of sync.
+            # Note that this check comes after restoring _state.context
+            # so that if it fails things are left in a (relatively)
+            # consistent state.
+            if final_contexts is not self.new_contexts:
+                raise StackContextInconsistentError(
+                    'stack_context inconsistency (may be caused by yield '
+                    'within a "with StackContext" block)')
+            self.old_contexts = self.new_contexts = None
 
 
 class ExceptionStackContext(object):
@@ -149,9 +167,10 @@ class ExceptionStackContext(object):
 
     def __enter__(self):
         self.old_contexts = _state.contexts
-        _state.contexts = (self.old_contexts +
-                           ((ExceptionStackContext, self.exception_handler,
-                             self.active_cell),))
+        self.new_contexts = (self.old_contexts +
+                             ((ExceptionStackContext, self.exception_handler,
+                               self.active_cell),))
+        _state.contexts = self.new_contexts
         return lambda: operator.setitem(self.active_cell, 0, False)
 
     def __exit__(self, type, value, traceback):
@@ -159,8 +178,13 @@ class ExceptionStackContext(object):
             if type is not None:
                 return self.exception_handler(type, value, traceback)
         finally:
+            final_contexts = _state.contexts
             _state.contexts = self.old_contexts
-            self.old_contexts = None
+            if final_contexts is not self.new_contexts:
+                raise StackContextInconsistentError(
+                    'stack_context inconsistency (may be caused by yield '
+                    'within a "with StackContext" block)')
+            self.old_contexts = self.new_contexts = None
 
 
 class NullContext(object):
index 7265a0590c78a01868f03df4efca71fa2cda15b5..5c4af0a150987a1efe26c46ba35c884cecb76840 100644 (file)
@@ -1,9 +1,10 @@
 #!/usr/bin/env python
 from __future__ import absolute_import, division, print_function, with_statement
 
+from tornado import gen
 from tornado.log import app_log
-from tornado.stack_context import StackContext, wrap, NullContext
-from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog
+from tornado.stack_context import StackContext, wrap, NullContext, StackContextInconsistentError, ExceptionStackContext
+from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog, gen_test
 from tornado.test.util import unittest
 from tornado.web import asynchronous, Application, RequestHandler
 import contextlib
@@ -168,6 +169,46 @@ class StackContextTest(AsyncTestCase):
         self.io_loop.add_callback(f1)
         self.wait()
 
+    def test_yield_in_with(self):
+        @gen.engine
+        def f():
+            with StackContext(functools.partial(self.context, 'c1')):
+                # This yield is a problem: the generator will be suspended
+                # and the StackContext's __exit__ is not called yet, so
+                # the context will be left on _state.contexts for anything
+                # that runs before the yield resolves.
+                yield gen.Task(self.io_loop.add_callback)
+
+        with self.assertRaises(StackContextInconsistentError):
+            f()
+            self.wait()
+
+    @gen_test
+    def test_yield_outside_with(self):
+        # This pattern avoids the problem in the previous test.
+        cb = yield gen.Callback('k1')
+        with StackContext(functools.partial(self.context, 'c1')):
+            self.io_loop.add_callback(cb)
+        yield gen.Wait('k1')
+
+    def test_yield_in_with_exception_stack_context(self):
+        # As above, but with ExceptionStackContext instead of StackContext.
+        @gen.engine
+        def f():
+            with ExceptionStackContext(lambda t, v, tb: False):
+                yield gen.Task(self.io_loop.add_callback)
+
+        with self.assertRaises(StackContextInconsistentError):
+            f()
+            self.wait()
+
+    @gen_test
+    def test_yield_outside_with_exception_stack_context(self):
+        cb = yield gen.Callback('k1')
+        with ExceptionStackContext(lambda t, v, tb: False):
+            self.io_loop.add_callback(cb)
+        yield gen.Wait('k1')
+
 
 if __name__ == '__main__':
     unittest.main()