]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
stack_context deactivation support
authorSerge S. Koval <serge.koval+github@gmail.com>
Mon, 20 May 2013 11:34:19 +0000 (14:34 +0300)
committerSerge S. Koval <serge.koval+github@gmail.com>
Mon, 20 May 2013 11:34:19 +0000 (14:34 +0300)
tornado/gen.py
tornado/stack_context.py
tornado/test/gen_test.py
tornado/test/stack_context_test.py

index 591ea713bf19ced7fb6ec4f869bf6d9730923cf8..62ee0a30fd67b1ece9c241eccce2b1a641e6f08e 100644 (file)
@@ -136,7 +136,7 @@ def engine(func):
             if runner is not None:
                 return runner.handle_exception(typ, value, tb)
             return False
-        with ExceptionStackContext(handle_exception):
+        with ExceptionStackContext(handle_exception) as deactivate:
             try:
                 result = func(*args, **kwargs)
             except (Return, StopIteration) as e:
@@ -149,6 +149,7 @@ def engine(func):
                                 "@gen.engine functions cannot return values: "
                                 "%r" % (value,))
                         assert value is None
+                        deactivate()
                     runner = Runner(result, final_callback)
                     runner.run()
                     return
@@ -156,6 +157,7 @@ def engine(func):
                 raise ReturnValueIgnoredError(
                     "@gen.engine functions cannot return values: %r" %
                     (result,))
+            deactivate()
             # no yield, so we're done
     return wrapper
 
@@ -208,21 +210,24 @@ def coroutine(func):
                 typ, value, tb = sys.exc_info()
             future.set_exc_info((typ, value, tb))
             return True
-        with ExceptionStackContext(handle_exception):
+        with ExceptionStackContext(handle_exception) as deactivate:
             try:
                 result = func(*args, **kwargs)
             except (Return, StopIteration) as e:
                 result = getattr(e, 'value', None)
             except Exception:
+                deactivate()
                 future.set_exc_info(sys.exc_info())
                 return future
             else:
                 if isinstance(result, types.GeneratorType):
                     def final_callback(value):
+                        deactivate()
                         future.set_result(value)
                     runner = Runner(result, final_callback)
                     runner.run()
                     return future
+            deactivate()
             future.set_result(result)
         return future
     return wrapper
index 642db1b35fd6e38430880f625f5cc5bd978cda1e..14478c1d9423cb83005dc733e4454540c4a7bc6e 100644 (file)
@@ -71,6 +71,9 @@ from __future__ import absolute_import, division, print_function, with_statement
 
 import sys
 import threading
+import operator
+
+from collections import deque
 
 from tornado.util import raise_exc_info
 
@@ -105,9 +108,13 @@ class StackContext(object):
     context that are currently pending).  This is an advanced feature
     and not necessary in most applications.
     """
-    def __init__(self, context_factory):
+    def __init__(self, context_factory, _active=True):
         self.context_factory = context_factory
         self.contexts = []
+        self.active = _active
+
+    def _deactivate(self):
+        self.active = False
 
     # StackContext protocol
     def enter(self):
@@ -133,6 +140,8 @@ class StackContext(object):
             _state.contexts = self.old_contexts
             raise
 
+        return self._deactivate
+
     def __exit__(self, type, value, traceback):
         try:
             self.exit(type, value, traceback)
@@ -168,8 +177,12 @@ class ExceptionStackContext(object):
     If the exception handler returns true, the exception will be
     consumed and will not be propagated to other exception handlers.
     """
-    def __init__(self, exception_handler):
+    def __init__(self, exception_handler, _active=True):
         self.exception_handler = exception_handler
+        self.active = _active
+
+    def _deactivate(self):
+        self.active = False
 
     def exit(self, type, value, traceback):
         if type is not None:
@@ -180,6 +193,8 @@ class ExceptionStackContext(object):
         self.new_contexts = (self.old_contexts[0], self)
         _state.contexts = self.new_contexts
 
+        return self._deactivate
+
     def __exit__(self, type, value, traceback):
         try:
             if type is not None:
@@ -212,6 +227,31 @@ class NullContext(object):
         _state.contexts = self.old_contexts
 
 
+def _remove_deactivated(contexts):
+    """Remove deactivated handlers from the chain"""
+    # Clean ctx handlers
+    stack_contexts = tuple([h for h in contexts[0] if h.active])
+
+    # Find new head
+    head = contexts[1]
+    while head is not None and not head.active:
+        head = head.old_contexts[1]
+
+    # Process chain
+    ctx = head
+    while ctx is not None:
+        parent = ctx.old_contexts[1]
+
+        while parent is not None and not parent.active:
+            parent = parent.old_contexts[1]
+
+            ctx.old_contexts = parent.old_contexts
+
+        ctx = parent
+
+    return (stack_contexts, head)
+
+
 def wrap(fn):
     """Returns a callable object that will restore the current `StackContext`
     when executed.
@@ -225,13 +265,18 @@ def wrap(fn):
         return fn
 
     # Capture current stack head
-    contexts = _state.contexts
+    # TODO: Any other better way to store contexts and update them in wrapped function?
+    cap_contexts = [_state.contexts]
 
-    #@functools.wraps
     def wrapped(*args, **kwargs):
         try:
-            # Force local state - switch to new stack chain
+            # Capture old state
             current_state = _state.contexts
+
+            # Remove deactivated items
+            cap_contexts[0] = contexts = _remove_deactivated(cap_contexts[0])
+
+            # Force new state
             _state.contexts = contexts
 
             # Current exception
index f51077efb85d1c0ebe082f8d747c05a81760fd56..cbfdff2d78e7a1c6eb78c0e29bb46d64bcd6a4a7 100644 (file)
@@ -346,6 +346,16 @@ class GenEngineTest(AsyncTestCase):
     def test_stack_context_leak(self):
         # regression test: repeated invocations of a gen-based
         # function should not result in accumulated stack_contexts
+        def _stack_depth():
+            head = stack_context._state.contexts[1]
+            length = 0
+
+            while head is not None:
+                length += 1
+                head = head.old_contexts[1]
+
+            return length
+
         @gen.engine
         def inner(callback):
             yield gen.Task(self.io_loop.add_callback)
@@ -355,10 +365,11 @@ class GenEngineTest(AsyncTestCase):
         def outer():
             for i in range(10):
                 yield gen.Task(inner)
-            stack_increase = len(stack_context._state.contexts) - initial_stack_depth
+
+            stack_increase = _stack_depth() - initial_stack_depth
             self.assertTrue(stack_increase <= 2)
             self.stop()
-        initial_stack_depth = len(stack_context._state.contexts)
+        initial_stack_depth = _stack_depth()
         self.run_gen(outer)
 
     def test_stack_context_leak_exception(self):
index d85ea50719866d700c37f29ceaa9eb050b3db4f5..976ef4000de4269945ef588e17d6bd5b9ba956a3 100644 (file)
@@ -3,7 +3,8 @@ from __future__ import absolute_import, division, print_function, with_statement
 
 from tornado import gen
 from tornado.log import app_log
-from tornado.stack_context import StackContext, wrap, NullContext, StackContextInconsistentError, ExceptionStackContext, run_with_stack_context
+from tornado.stack_context import (StackContext, wrap, NullContext, StackContextInconsistentError,
+                                   ExceptionStackContext, run_with_stack_context, _state)
 from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog, gen_test
 from tornado.test.util import unittest
 from tornado.web import asynchronous, Application, RequestHandler
@@ -95,6 +96,38 @@ class StackContextTest(AsyncTestCase):
             library_function(final_callback)
         self.wait()
 
+    def test_deactivate(self):
+        deactivate_callbacks = []
+
+        def f1():
+            with StackContext(functools.partial(self.context, 'c1')) as c1:
+                deactivate_callbacks.append(c1)
+                self.io_loop.add_callback(f2)
+
+        def f2():
+            with StackContext(functools.partial(self.context, 'c2')) as c2:
+                deactivate_callbacks.append(c2)
+                self.io_loop.add_callback(f3)
+
+        def f3():
+            with StackContext(functools.partial(self.context, 'c3')) as c3:
+                deactivate_callbacks.append(c3)
+                self.io_loop.add_callback(f4)
+
+        def f4():
+            self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
+            deactivate_callbacks[1]()
+            # deactivating a context doesn't remove it immediately,
+            # but it will be missing from the next iteration
+            self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
+            self.io_loop.add_callback(f5)
+
+        def f5():
+            self.assertEqual(self.active_contexts, ['c1', 'c3'])
+            self.stop()
+        self.io_loop.add_callback(f1)
+        self.wait()
+
     def test_isolation_nonempty(self):
         # f2 and f3 are a chain of operations started in context c1.
         # f2 is incidentally run under context c2, but that context should