]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Optimized StackContext implementation
authorSerge S. Koval <serge.koval+github@gmail.com>
Sat, 13 Apr 2013 20:09:50 +0000 (23:09 +0300)
committerSerge S. Koval <serge.koval+github@gmail.com>
Sat, 13 Apr 2013 20:09:50 +0000 (23:09 +0300)
tornado/gen.py
tornado/stack_context.py
tornado/test/gen_test.py
tornado/test/stack_context_test.py

index 885b604d33102d32bd409192b23daa5b1250ce8d..64287c53155736da7121c6f005c163aabbc6ae9e 100644 (file)
@@ -136,7 +136,7 @@ def engine(func):
             if runner is not None:
                 return runner.handle_exception(typ, value, tb)
             return False
-        with ExceptionStackContext(handle_exception) as deactivate:
+        with ExceptionStackContext(handle_exception):
             try:
                 result = func(*args, **kwargs)
             except (Return, StopIteration) as e:
@@ -149,7 +149,6 @@ def engine(func):
                                 "@gen.engine functions cannot return values: "
                                 "%r" % (value,))
                         assert value is None
-                        deactivate()
                     runner = Runner(result, final_callback)
                     runner.run()
                     return
@@ -157,7 +156,6 @@ def engine(func):
                 raise ReturnValueIgnoredError(
                     "@gen.engine functions cannot return values: %r" %
                     (result,))
-            deactivate()
             # no yield, so we're done
     return wrapper
 
@@ -210,24 +208,21 @@ def coroutine(func):
                 typ, value, tb = sys.exc_info()
             future.set_exc_info((typ, value, tb))
             return True
-        with ExceptionStackContext(handle_exception) as deactivate:
+        with ExceptionStackContext(handle_exception):
             try:
                 result = func(*args, **kwargs)
             except (Return, StopIteration) as e:
                 result = getattr(e, 'value', None)
             except Exception:
-                deactivate()
                 future.set_exc_info(sys.exc_info())
                 return future
             else:
                 if isinstance(result, types.GeneratorType):
                     def final_callback(value):
-                        deactivate()
                         future.set_result(value)
                     runner = Runner(result, final_callback)
                     runner.run()
                     return future
-            deactivate()
             future.set_result(result)
         return future
     return wrapper
index 38ab7517b99ce9c4c6e0e483464ae0605bfb9574..7c2dea98e8398c482f01ba00379b2fba311ab8ee 100644 (file)
@@ -69,9 +69,6 @@ Here are a few rules of thumb for when it's necessary:
 
 from __future__ import absolute_import, division, print_function, with_statement
 
-import contextlib
-import functools
-import operator
 import sys
 import threading
 
@@ -84,7 +81,7 @@ class StackContextInconsistentError(Exception):
 
 class _State(threading.local):
     def __init__(self):
-        self.contexts = ()
+        self.contexts = (tuple(), None)
 _state = _State()
 
 
@@ -108,45 +105,51 @@ class StackContext(object):
     context that are currently pending).  This is an advanced feature
     and not necessary in most applications.
     """
-    def __init__(self, context_factory, _active_cell=None):
+    def __init__(self, context_factory):
         self.context_factory = context_factory
-        self.active_cell = _active_cell or [True]
+        self.contexts = []
+
+    # StackContext protocol
+    def enter(self):
+        context = self.context_factory()
+        self.contexts.append(context)
+        context.__enter__()
+
+    def exit(self, type, value, traceback):
+        context = self.contexts.pop()
+        context.__exit__(type, value, traceback)
 
     # Note that some of this code is duplicated in ExceptionStackContext
     # below.  ExceptionStackContext is more common and doesn't need
     # the full generality of this class.
     def __enter__(self):
         self.old_contexts = _state.contexts
-        # _state.contexts is a tuple of (class, arg, active_cell) tuples
-        self.new_contexts = (self.old_contexts +
-                             ((StackContext, self.context_factory,
-                               self.active_cell),))
+        self.new_contexts = (self.old_contexts[0] + (self,), self)
         _state.contexts = self.new_contexts
+
         try:
-            self.context = self.context_factory()
-            self.context.__enter__()
-        except Exception:
+            self.enter()
+        except:
             _state.contexts = self.old_contexts
             raise
-        return lambda: operator.setitem(self.active_cell, 0, False)
 
     def __exit__(self, type, value, traceback):
         try:
-            return self.context.__exit__(type, value, traceback)
+            self.exit(type, value, traceback)
         finally:
             final_contexts = _state.contexts
             _state.contexts = self.old_contexts
+
             # Generator coroutines and with-statements with non-local
             # effects interact badly.  Check here for signs of
             # the stack getting out of sync.
             # Note that this check comes after restoring _state.context
             # so that if it fails things are left in a (relatively)
             # consistent state.
-            if final_contexts is not self.new_contexts:
+            if final_contexts != self.new_contexts:
                 raise StackContextInconsistentError(
                     'stack_context inconsistency (may be caused by yield '
                     'within a "with StackContext" block)')
-            self.old_contexts = self.new_contexts = None
 
 
 class ExceptionStackContext(object):
@@ -162,17 +165,17 @@ class ExceptionStackContext(object):
     If the exception handler returns true, the exception will be
     consumed and will not be propagated to other exception handlers.
     """
-    def __init__(self, exception_handler, _active_cell=None):
+    def __init__(self, exception_handler):
         self.exception_handler = exception_handler
-        self.active_cell = _active_cell or [True]
+
+    def exit(self, type, value, traceback):
+        if type is not None:
+            return self.exception_handler(type, value, traceback)
 
     def __enter__(self):
         self.old_contexts = _state.contexts
-        self.new_contexts = (self.old_contexts +
-                             ((ExceptionStackContext, self.exception_handler,
-                               self.active_cell),))
+        self.new_contexts = (self.old_contexts[0], self)
         _state.contexts = self.new_contexts
-        return lambda: operator.setitem(self.active_cell, 0, False)
 
     def __exit__(self, type, value, traceback):
         try:
@@ -181,11 +184,11 @@ class ExceptionStackContext(object):
         finally:
             final_contexts = _state.contexts
             _state.contexts = self.old_contexts
-            if final_contexts is not self.new_contexts:
+
+            if final_contexts != self.new_contexts:
                 raise StackContextInconsistentError(
                     'stack_context inconsistency (may be caused by yield '
                     'within a "with StackContext" block)')
-            self.old_contexts = self.new_contexts = None
 
 
 class NullContext(object):
@@ -197,16 +200,12 @@ class NullContext(object):
     """
     def __enter__(self):
         self.old_contexts = _state.contexts
-        _state.contexts = ()
+        _state.contexts = (tuple(), None)
 
     def __exit__(self, type, value, traceback):
         _state.contexts = self.old_contexts
 
 
-class _StackContextWrapper(functools.partial):
-    pass
-
-
 def wrap(fn):
     """Returns a callable object that will restore the current `StackContext`
     when executed.
@@ -215,64 +214,85 @@ def wrap(fn):
     different execution context (either in a different thread or
     asynchronously in the same thread).
     """
-    if fn is None or fn.__class__ is _StackContextWrapper:
+    # Check if function is already wrapped
+    if fn is None or hasattr(fn, '_wrapped'):
         return fn
-    # functools.wraps doesn't appear to work on functools.partial objects
-    #@functools.wraps(fn)
 
+    # Capture current stack head
+    contexts = _state.contexts
+
+    #@functools.wraps
     def wrapped(*args, **kwargs):
-        callback, contexts, args = args[0], args[1], args[2:]
-
-        if _state.contexts:
-            new_contexts = [NullContext()]
-        else:
-            new_contexts = []
-        if contexts:
-            new_contexts.extend(cls(arg, active_cell)
-                                for (cls, arg, active_cell) in contexts
-                                if active_cell[0])
-        if len(new_contexts) > 1:
-            with _nested(*new_contexts):
-                callback(*args, **kwargs)
-        elif new_contexts:
-            with new_contexts[0]:
-                callback(*args, **kwargs)
-        else:
-            callback(*args, **kwargs)
-    return _StackContextWrapper(wrapped, fn, _state.contexts)
-
-
-@contextlib.contextmanager
-def _nested(*managers):
-    """Support multiple context managers in a single with-statement.
-
-    Copied from the python 2.6 standard library.  It's no longer present
-    in python 3 because the with statement natively supports multiple
-    context managers, but that doesn't help if the list of context
-    managers is not known until runtime.
-    """
-    exits = []
-    vars = []
-    exc = (None, None, None)
-    try:
-        for mgr in managers:
-            exit = mgr.__exit__
-            enter = mgr.__enter__
-            vars.append(enter())
-            exits.append(exit)
-        yield vars
-    except:
-        exc = sys.exc_info()
-    finally:
-        while exits:
-            exit = exits.pop()
-            try:
-                if exit(*exc):
-                    exc = (None, None, None)
-            except:
-                exc = sys.exc_info()
-        if exc != (None, None, None):
-            # Don't rely on sys.exc_info() still containing
-            # the right information. Another exception may
-            # have been raised and caught by an exit method
-            raise_exc_info(exc)
+        try:
+            # Force local state - switch to new stack chain
+            current_state = _state.contexts
+            _state.contexts = contexts
+
+            # Current exception
+            exc = (None, None, None)
+            top = None
+
+            # Apply stack contexts
+            last_ctx = 0
+            stack = contexts[0]
+
+            # Apply state
+            for n in stack:
+                try:
+                    n.enter()
+                    last_ctx += 1
+                except:
+                    # Exception happened. Record exception info and store top-most handler
+                    exc = sys.exc_info()
+                    top = n.old_contexts[1]
+
+            # Execute callback if no exception happened while restoring state
+            if top is None:
+                try:
+                    fn(*args, **kwargs)
+                except:
+                    exc = sys.exc_info()
+                    top = contexts[1]
+
+            # If there was exception, try to handle it by going through the exception chain
+            if top is not None:
+                exc = _handle_exception(top, exc)
+            else:
+                # Otherwise take shorter path and run stack contexts in reverse order
+                for n in xrange(last_ctx - 1, -1, -1):
+                    c = stack[n]
+
+                    try:
+                        c.exit(*exc)
+                    except:
+                        exc = sys.exc_info()
+                        top = c.old_contexts[1]
+                        break
+                else:
+                    top = None
+
+                # If if exception happened while unrolling, take longer exception handler path
+                if top is not None:
+                    exc = _handle_exception(top, exc)
+
+            # If exception was not handled, raise it
+            if exc != (None, None, None):
+                raise_exc_info(exc)
+        finally:
+            _state.contexts = current_state
+
+    wrapped._wrapped = True
+    return wrapped
+
+
+def _handle_exception(tail, exc):
+    while tail is not None:
+        try:
+            if tail.exit(*exc):
+                exc = (None, None, None)
+        except:
+            exc = sys.exc_info()
+
+        tail = tail.old_contexts[1]
+
+    return exc
index b3dc004187a685063576b1347a3b8dc3e74f4f6d..d77297dbf850bc44e700cfad65cb492d9fa28fd5 100644 (file)
@@ -838,3 +838,6 @@ class GenWebTest(AsyncHTTPTestCase):
     def test_yield_exception_handler(self):
         response = self.fetch('/yield_exception')
         self.assertEqual(response.body, b'ok')
+
+if __name__ == '__main__':
+    unittest.main()
index 5c4af0a150987a1efe26c46ba35c884cecb76840..50711ee1b39bd6532762ec8b7a9db2ff30d614a1 100644 (file)
@@ -95,38 +95,6 @@ class StackContextTest(AsyncTestCase):
             library_function(final_callback)
         self.wait()
 
-    def test_deactivate(self):
-        deactivate_callbacks = []
-
-        def f1():
-            with StackContext(functools.partial(self.context, 'c1')) as c1:
-                deactivate_callbacks.append(c1)
-                self.io_loop.add_callback(f2)
-
-        def f2():
-            with StackContext(functools.partial(self.context, 'c2')) as c2:
-                deactivate_callbacks.append(c2)
-                self.io_loop.add_callback(f3)
-
-        def f3():
-            with StackContext(functools.partial(self.context, 'c3')) as c3:
-                deactivate_callbacks.append(c3)
-                self.io_loop.add_callback(f4)
-
-        def f4():
-            self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
-            deactivate_callbacks[1]()
-            # deactivating a context doesn't remove it immediately,
-            # but it will be missing from the next iteration
-            self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
-            self.io_loop.add_callback(f5)
-
-        def f5():
-            self.assertEqual(self.active_contexts, ['c1', 'c3'])
-            self.stop()
-        self.io_loop.add_callback(f1)
-        self.wait()
-
     def test_isolation_nonempty(self):
         # f2 and f3 are a chain of operations started in context c1.
         # f2 is incidentally run under context c2, but that context should