]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Prevent leak of StackContexts in repeated gen.engine functions.
authorBen Darnell <ben@bendarnell.com>
Mon, 21 May 2012 05:08:59 +0000 (22:08 -0700)
committerBen Darnell <ben@bendarnell.com>
Mon, 21 May 2012 05:08:59 +0000 (22:08 -0700)
Internally, StackContexts now return a deactivation callback,
which can be used to prevent that StackContext from propagating
further.  This is used in gen.engine because the decorator doesn't know
which arguments are callbacks that need to be wrapped outside of its
ExceptionStackContext.  This is deliberately undocumented for now.

Closes #507.

tornado/gen.py
tornado/stack_context.py
tornado/test/gen_test.py
tornado/test/stack_context_test.py

index 752c3f24913ebb4d7ab48b313e03ec1db6dd9ca8..506697d7b9a426c81bda2cb6456fef40947760ca 100644 (file)
@@ -113,13 +113,14 @@ def engine(func):
             if runner is not None:
                 return runner.handle_exception(typ, value, tb)
             return False
-        with ExceptionStackContext(handle_exception):
+        with ExceptionStackContext(handle_exception) as deactivate:
             gen = func(*args, **kwargs)
             if isinstance(gen, types.GeneratorType):
-                runner = Runner(gen)
+                runner = Runner(gen, deactivate)
                 runner.run()
                 return
             assert gen is None, gen
+            deactivate()
             # no yield, so we're done
     return wrapper
 
@@ -285,8 +286,9 @@ class Runner(object):
 
     Maintains information about pending callbacks and their results.
     """
-    def __init__(self, gen):
+    def __init__(self, gen, deactivate_stack_context):
         self.gen = gen
+        self.deactivate_stack_context = deactivate_stack_context
         self.yield_point = _NullYieldPoint()
         self.pending_callbacks = set()
         self.results = {}
@@ -351,6 +353,7 @@ class Runner(object):
                         raise LeakedCallbackError(
                             "finished without waiting for callbacks %r" %
                             self.pending_callbacks)
+                    self.deactivate_stack_context()
                     return
                 except Exception:
                     self.finished = True
index df1869992a84b637ba54455dd243961bc4b29cf4..3e0bea85736a41b164ccf1b46b6449403a95d0a5 100644 (file)
@@ -71,6 +71,7 @@ from __future__ import absolute_import, division, with_statement
 import contextlib
 import functools
 import itertools
+import operator
 import sys
 import threading
 
@@ -95,23 +96,25 @@ class StackContext(object):
 
       with StackContext(my_context):
     '''
-    def __init__(self, context_factory):
+    def __init__(self, context_factory, _active_cell=None):
         self.context_factory = context_factory
+        self.active_cell = _active_cell or [True]
 
     # Note that some of this code is duplicated in ExceptionStackContext
     # below.  ExceptionStackContext is more common and doesn't need
     # the full generality of this class.
     def __enter__(self):
         self.old_contexts = _state.contexts
-        # _state.contexts is a tuple of (class, arg) pairs
+        # _state.contexts is a tuple of (class, arg, active_cell) tuples
         _state.contexts = (self.old_contexts +
-                           ((StackContext, self.context_factory),))
+                           ((StackContext, self.context_factory, self.active_cell),))
         try:
             self.context = self.context_factory()
             self.context.__enter__()
         except Exception:
             _state.contexts = self.old_contexts
             raise
+        return lambda: operator.setitem(self.active_cell, 0, False)
 
     def __exit__(self, type, value, traceback):
         try:
@@ -133,13 +136,16 @@ class ExceptionStackContext(object):
     If the exception handler returns true, the exception will be
     consumed and will not be propagated to other exception handlers.
     '''
-    def __init__(self, exception_handler):
+    def __init__(self, exception_handler, _active_cell=None):
         self.exception_handler = exception_handler
+        self.active_cell = _active_cell or [True]
 
     def __enter__(self):
         self.old_contexts = _state.contexts
         _state.contexts = (self.old_contexts +
-                           ((ExceptionStackContext, self.exception_handler),))
+                           ((ExceptionStackContext, self.exception_handler,
+                             self.active_cell),))
+        return lambda: operator.setitem(self.active_cell, 0, False)
 
     def __exit__(self, type, value, traceback):
         try:
@@ -186,7 +192,9 @@ def wrap(fn):
             callback(*args, **kwargs)
             return
         if not _state.contexts:
-            new_contexts = [cls(arg) for (cls, arg) in contexts]
+            new_contexts = [cls(arg, active_cell)
+                            for (cls, arg, active_cell) in contexts
+                            if active_cell[0]]
         # If we're moving down the stack, _state.contexts is a prefix
         # of contexts.  For each element of contexts not in that prefix,
         # create a new StackContext object.
@@ -198,10 +206,13 @@ def wrap(fn):
                 for a, b in itertools.izip(_state.contexts, contexts))):
             # contexts have been removed or changed, so start over
             new_contexts = ([NullContext()] +
-                            [cls(arg) for (cls, arg) in contexts])
+                            [cls(arg, active_cell)
+                             for (cls, arg, active_cell) in contexts
+                             if active_cell[0]])
         else:
-            new_contexts = [cls(arg)
-                            for (cls, arg) in contexts[len(_state.contexts):]]
+            new_contexts = [cls(arg, active_cell)
+                            for (cls, arg, active_cell) in contexts[len(_state.contexts):]
+                            if active_cell[0]]
         if len(new_contexts) > 1:
             with _nested(*new_contexts):
                 callback(*args, **kwargs)
index 86d7d0d608f30b009ac0ec8b347d2ee0f7fe0954..198190cbb2df6f2e05c47d5303996f5dbabd2b39 100644 (file)
@@ -249,6 +249,24 @@ class GenTest(AsyncTestCase):
             self.stop()
         self.run_gen(f)
 
+    def test_stack_context_leak(self):
+        # regression test: repeated invocations of a gen-based
+        # function should not result in accumulated stack_contexts
+        from tornado import stack_context
+        @gen.engine
+        def inner(callback):
+            yield gen.Task(self.io_loop.add_callback)
+            callback()
+        @gen.engine
+        def outer():
+            for i in xrange(10):
+                yield gen.Task(inner)
+            stack_increase = len(stack_context._state.contexts) - initial_stack_depth
+            self.assertTrue(stack_increase <= 2)
+            self.stop()
+        initial_stack_depth = len(stack_context._state.contexts)
+        self.run_gen(outer)
+
 
 class GenSequenceHandler(RequestHandler):
     @asynchronous
index e682f6a538501add080ed9ac58001916b6131086..f35728759069e7f1c7bad77291b959d72a2a8801 100644 (file)
@@ -93,5 +93,32 @@ class StackContextTest(AsyncTestCase, LogTrapTestCase):
             library_function(final_callback)
         self.wait()
 
+    def test_deactivate(self):
+        deactivate_callbacks = []
+        def f1():
+            with StackContext(functools.partial(self.context, 'c1')) as c1:
+                deactivate_callbacks.append(c1)
+                self.io_loop.add_callback(f2)
+        def f2():
+            with StackContext(functools.partial(self.context, 'c2')) as c2:
+                deactivate_callbacks.append(c2)
+                self.io_loop.add_callback(f3)
+        def f3():
+            with StackContext(functools.partial(self.context, 'c3')) as c3:
+                deactivate_callbacks.append(c3)
+                self.io_loop.add_callback(f4)
+        def f4():
+            self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
+            deactivate_callbacks[1]()
+            # deactivating a context doesn't remove it immediately,
+            # but it will be missing from the next iteration
+            self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
+            self.io_loop.add_callback(f5)
+        def f5():
+            self.assertEqual(self.active_contexts, ['c1', 'c3'])
+            self.stop()
+        self.io_loop.add_callback(f1)
+        self.wait()
+
 if __name__ == '__main__':
     unittest.main()