def wrapped(*args, **kwargs):
callback, contexts, args = args[0], args[1], args[2:]
- if contexts is _state.contexts or not contexts:
+ if contexts is _state.contexts:
callback(*args, **kwargs)
return
if not _state.contexts:
callback(*args, **kwargs)
else:
callback(*args, **kwargs)
- if _state.contexts:
- return _StackContextWrapper(wrapped, fn, _state.contexts)
- else:
- return _StackContextWrapper(fn)
+ return _StackContextWrapper(wrapped, fn, _state.contexts)
@contextlib.contextmanager
from __future__ import absolute_import, division, with_statement
from tornado.log import app_log
-from tornado.stack_context import StackContext, wrap
+from tornado.stack_context import StackContext, wrap, NullContext
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog
from tornado.test.util import unittest
from tornado.util import b
self.io_loop.add_callback(f1)
self.wait()
+ def test_isolation_nonempty(self):
+ # f2 and f3 are a chain of operations started in context c1.
+ # f2 is incidentally run under context c2, but that context should
+ # not be passed along to f3.
+ def f1():
+ with StackContext(functools.partial(self.context, 'c1')):
+ wrapped = wrap(f2)
+ with StackContext(functools.partial(self.context, 'c2')):
+ wrapped()
+
+ def f2():
+ self.assertIn('c1', self.active_contexts)
+ self.io_loop.add_callback(f3)
+
+ def f3():
+ self.assertIn('c1', self.active_contexts)
+ self.assertNotIn('c2', self.active_contexts)
+ self.stop()
+
+ self.io_loop.add_callback(f1)
+ self.wait()
+
+ def test_isolation_empty(self):
+ # Similar to test_isolation_nonempty, but here the f2/f3 chain
+ # is started without any context. Behavior should be equivalent
+ # to the nonempty case (although historically it was not)
+ def f1():
+ with NullContext():
+ wrapped = wrap(f2)
+ with StackContext(functools.partial(self.context, 'c2')):
+ wrapped()
+
+ def f2():
+ self.io_loop.add_callback(f3)
+
+ def f3():
+ self.assertNotIn('c2', self.active_contexts)
+ self.stop()
+
+ self.io_loop.add_callback(f1)
+ self.wait()
+
if __name__ == '__main__':
unittest.main()