]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Introduce StackContext, a way to automatically manage exception
authorBen Darnell <bdarnell@beaker.local>
Fri, 23 Jul 2010 19:35:08 +0000 (12:35 -0700)
committerBen Darnell <bdarnell@beaker.local>
Fri, 23 Jul 2010 19:35:08 +0000 (12:35 -0700)
handling and other stack-related state for asynchronous callbacks.
This means that it is no longer necessary to wrap everything
in RequestHandler.async_callback.

tornado/httpclient.py
tornado/ioloop.py
tornado/stack_context.py [new file with mode: 0644]
tornado/test/stack_context_test.py [new file with mode: 0755]
tornado/web.py

index 94bd98c5799138f2eb2c2b37905384a08b19cbeb..6abcc747e8d6d09fa5233624c19c480f6f93c6ef 100644 (file)
@@ -16,6 +16,8 @@
 
 """Blocking and non-blocking HTTP client implementations using pycurl."""
 
+from __future__ import with_statement
+
 import calendar
 import collections
 import cStringIO
@@ -27,6 +29,7 @@ import httputil
 import ioloop
 import logging
 import pycurl
+import stack_context
 import sys
 import time
 import weakref
@@ -150,7 +153,7 @@ class AsyncHTTPClient(object):
         """
         if not isinstance(request, HTTPRequest):
            request = HTTPRequest(url=request, **kwargs)
-        self._requests.append((request, callback))
+        self._requests.append((request, stack_context.wrap(callback)))
         self._process_queue()
         self._set_timeout(0)
 
@@ -202,16 +205,17 @@ class AsyncHTTPClient(object):
 
     def _handle_timeout(self):
         """Called by IOLoop when the requested timeout has passed."""
-        self._timeout = None
-        while True:
-            try:
-                ret, num_handles = self._multi.socket_action(
-                                        pycurl.SOCKET_TIMEOUT, 0)
-            except Exception, e:
-                ret = e[0]
-            if ret != pycurl.E_CALL_MULTI_PERFORM:
-                break
-        self._finish_pending_requests()
+        with stack_context.NullContext():
+            self._timeout = None
+            while True:
+                try:
+                    ret, num_handles = self._multi.socket_action(
+                                            pycurl.SOCKET_TIMEOUT, 0)
+                except Exception, e:
+                    ret = e[0]
+                if ret != pycurl.E_CALL_MULTI_PERFORM:
+                    break
+            self._finish_pending_requests()
 
         # In theory, we shouldn't have to do this because curl will
         # call _set_timeout whenever the timeout changes.  However,
@@ -245,30 +249,31 @@ class AsyncHTTPClient(object):
         self._process_queue()
 
     def _process_queue(self):
-        while True:
-            started = 0
-            while self._free_list and self._requests:
-                started += 1
-                curl = self._free_list.pop()
-                (request, callback) = self._requests.popleft()
-                curl.info = {
-                    "headers": httputil.HTTPHeaders(),
-                    "buffer": cStringIO.StringIO(),
-                    "request": request,
-                    "callback": callback,
-                    "start_time": time.time(),
-                }
-                # Disable IPv6 to mitigate the effects of this bug
-                # on curl versions <= 7.21.0
-                # http://sourceforge.net/tracker/?func=detail&aid=3017819&group_id=976&atid=100976
-                if pycurl.version_info()[2] <= 0x71500:  # 7.21.0
-                    curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4)
-                _curl_setup_request(curl, request, curl.info["buffer"],
-                                    curl.info["headers"])
-                self._multi.add_handle(curl)
-
-            if not started:
-                break
+        with stack_context.NullContext():
+            while True:
+                started = 0
+                while self._free_list and self._requests:
+                    started += 1
+                    curl = self._free_list.pop()
+                    (request, callback) = self._requests.popleft()
+                    curl.info = {
+                        "headers": httputil.HTTPHeaders(),
+                        "buffer": cStringIO.StringIO(),
+                        "request": request,
+                        "callback": callback,
+                        "start_time": time.time(),
+                    }
+                    # Disable IPv6 to mitigate the effects of this bug
+                    # on curl versions <= 7.21.0
+                    # http://sourceforge.net/tracker/?func=detail&aid=3017819&group_id=976&atid=100976
+                    if pycurl.version_info()[2] <= 0x71500:  # 7.21.0
+                        curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4)
+                    _curl_setup_request(curl, request, curl.info["buffer"],
+                                        curl.info["headers"])
+                    self._multi.add_handle(curl)
+
+                if not started:
+                    break
 
     def _finish(self, curl, curl_error=None, curl_message=None):
         info = curl.info
index c1345cb2f6b81766884886e61010945e10f038a6..97220314c9d218d0a5e4ff64566f9211b8b43a28 100644 (file)
@@ -21,6 +21,7 @@ import errno
 import os
 import logging
 import select
+import stack_context
 import time
 import traceback
 
@@ -145,7 +146,7 @@ class IOLoop(object):
 
     def add_handler(self, fd, handler, events):
         """Registers the given handler to receive the given events for fd."""
-        self._handlers[fd] = handler
+        self._handlers[fd] = stack_context.wrap(handler)
         self._impl.register(fd, events | self.ERROR)
 
     def update_handler(self, fd, events):
@@ -290,7 +291,7 @@ class IOLoop(object):
 
     def add_timeout(self, deadline, callback):
         """Calls the given callback at the time deadline from the I/O loop."""
-        timeout = _Timeout(deadline, callback)
+        timeout = _Timeout(deadline, stack_context.wrap(callback))
         bisect.insort(self._timeouts, timeout)
         return timeout
 
@@ -299,13 +300,9 @@ class IOLoop(object):
 
     def add_callback(self, callback):
         """Calls the given callback on the next I/O loop iteration."""
-        self._callbacks.add(callback)
+        self._callbacks.add(stack_context.wrap(callback))
         self._wake()
 
-    def remove_callback(self, callback):
-        """Removes the given callback from the next I/O loop iteration."""
-        self._callbacks.remove(callback)
-
     def _wake(self):
         try:
             self._waker_writer.write("x")
diff --git a/tornado/stack_context.py b/tornado/stack_context.py
new file mode 100644 (file)
index 0000000..43e511d
--- /dev/null
@@ -0,0 +1,141 @@
+#!/usr/bin/env python
+#
+# Copyright 2010 Facebook
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+'''StackContext allows applications to maintain threadlocal-like state
+that follows execution as it moves to other execution contexts.
+
+The motivating examples are to eliminate the need for explicit
+async_callback wrappers (as in tornado.web.RequestHandler), and to
+allow some additional context to be kept for logging.
+
+This is slightly magic, but it's an extension of the idea that an exception
+handler is a kind of stack-local state and when that stack is suspended
+and resumed in a new context that state needs to be preserved.  StackContext
+shifts the burden of restoring that state from each call site (e.g.
+wrapping each AsyncHTTPClient callback in async_callback) to the mechanisms
+that transfer control from one context to another (e.g. AsyncHTTPClient
+itself, IOLoop, thread pools, etc).
+
+Example usage:
+  @contextlib.contextmanager
+  def die_on_error():
+    try:
+      yield
+    except:
+      logging.error("exception in asynchronous operation", exc_info=True)
+      sys.exit(1)
+
+  with StackContext(die_on_error):
+    # Any exception thrown here *or in callback and its desendents*
+    # will cause the process to exit instead of spinning endlessly
+    # in the ioloop.
+    http_client.fetch(url, callback)
+  ioloop.start()
+'''
+
+from __future__ import with_statement
+
+import contextlib
+import functools
+import itertools
+import logging
+import threading
+
+class _State(threading.local):
+  def __init__(self):
+    self.contexts = ()
+_state = _State()
+
+@contextlib.contextmanager
+def StackContext(context_factory):
+  '''Establishes the given context as a StackContext that will be transferred.
+
+  Note that the parameter is a callable that returns a context
+  manager, not the context itself.  That is, where for a
+  non-transferable context manager you would say
+    with my_context():
+  StackContext takes the function itself rather than its result:
+    with StackContext(my_context):
+  '''
+  old_contexts = _state.contexts
+  try:
+    _state.contexts = old_contexts + (context_factory,)
+    with context_factory():
+      yield
+  finally:
+    _state.contexts = old_contexts
+
+@contextlib.contextmanager
+def NullContext():
+  '''Resets the StackContext.
+
+  Useful when creating a shared resource on demand (e.g. an AsyncHTTPClient)
+  where the stack that caused the creating is not relevant to future
+  operations.
+  '''
+  old_contexts = _state.contexts
+  try:
+    _state.contexts = ()
+    yield
+  finally:
+    _state.contexts = old_contexts
+
+def wrap(fn, *args, **kwargs):
+  '''Returns a callable object that will resore the current StackContext
+  when executed.
+
+  Use this whenever saving a callback to be executed later in a
+  different execution context (either in a different thread or
+  asynchronously in the same thread).
+
+  As a convenience, also binds parameters to the given function
+  like functools.partial.
+  '''
+  # functools.wraps doesn't appear to work on functools.partial objects
+  #@functools.wraps(fn)
+  def wrapped(callback, contexts, *args, **kwargs):
+    # _state.contexts and contexts may share a common prefix.
+    # For each element of contexts not in that prefix, create a new
+    # StackContext object.
+    # TODO(bdarnell): do we want to be strict about the order,
+    # or is what we really want just set(contexts) - set(_state.contexts)?
+    # I think we do want to be strict about using identity comparison,
+    # so a set may not be quite right.  Conversely, it's not very stack-like
+    # to have new contexts pop up in the middle, so would we want to
+    # ensure there are no existing contexts not in the stack being restored?
+    # That feels right, but given the difficulty of handling errors at this
+    # level I'm not going to check for it now.
+    pairs = itertools.izip(itertools.chain(_state.contexts,
+                                           itertools.repeat(None)),
+                           contexts)
+    new_contexts = []
+    for old, new in itertools.dropwhile(lambda x: x[0] is x[1], pairs):
+      new_contexts.append(StackContext(new))
+    if new_contexts:
+      with contextlib.nested(*new_contexts):
+        callback(*args, **kwargs)
+    else:
+      callback(*args, **kwargs)
+  if args or kwargs:
+    callback = functools.partial(fn, *args, **kwargs)
+  else:
+    callback = fn
+  contexts = _state.contexts
+  if contexts:
+    return functools.partial(wrapped, callback, contexts, *args, **kwargs)
+  else:
+    return callback
+
diff --git a/tornado/test/stack_context_test.py b/tornado/test/stack_context_test.py
new file mode 100755 (executable)
index 0000000..2f98c01
--- /dev/null
@@ -0,0 +1,57 @@
+#!/usr/bin/env python
+
+from tornado.httpclient import AsyncHTTPClient
+from tornado.httpserver import HTTPServer
+from tornado.ioloop import IOLoop
+from tornado.web import asynchronous, Application, RequestHandler
+import logging
+import unittest
+
+class TestRequestHandler(RequestHandler):
+  def __init__(self, app, request, io_loop):
+    super(TestRequestHandler, self).__init__(app, request)
+    self.io_loop = io_loop
+
+  @asynchronous
+  def get(self):
+    logging.info('in get()')
+    # call self.part2 without a self.async_callback wrapper.  Its
+    # exception should still get thrown
+    self.io_loop.add_callback(self.part2)
+
+  def part2(self):
+    logging.info('in part2()')
+    # Go through a third layer to make sure that contexts once restored
+    # are again passed on to future callbacks
+    self.io_loop.add_callback(self.part3)
+
+  def part3(self):
+    logging.info('in part3()')
+    raise Exception('test exception')
+
+  def get_error_html(self, status_code, **kwargs):
+    if 'exception' in kwargs and str(kwargs['exception']) == 'test exception':
+      return 'got expected exception'
+    else:
+      return 'unexpected failure'
+
+class StackContextTest(unittest.TestCase):
+  # Note that this test logs an error even when it passes.
+  # TODO(bdarnell): better logging setup for unittests
+  def test_stack_context(self):
+    self.io_loop = IOLoop()
+    app = Application([('/', TestRequestHandler, dict(io_loop=self.io_loop))])
+    server = HTTPServer(app, io_loop=self.io_loop)
+    server.listen(11000)
+    client = AsyncHTTPClient(io_loop=self.io_loop)
+    client.fetch('http://localhost:11000/', self.handle_response)
+    self.io_loop.start()
+    self.assertEquals(self.response.code, 500)
+    self.assertTrue('got expected exception' in self.response.body)
+
+  def handle_response(self, response):
+    self.response = response
+    self.io_loop.stop()
+
+if __name__ == '__main__':
+  unittest.main()
index f1e0871f9d42afcf98c654c0c2df5d404c298597..dc5a56e8446d507a4ea1004656ed8cabceb910a0 100644 (file)
@@ -43,9 +43,12 @@ See the Tornado walkthrough on GitHub for more details and a good
 getting started guide.
 """
 
+from __future__ import with_statement
+
 import base64
 import binascii
 import calendar
+import contextlib
 import Cookie
 import cStringIO
 import datetime
@@ -61,6 +64,7 @@ import logging
 import mimetypes
 import os.path
 import re
+import stack_context
 import stat
 import sys
 import template
@@ -754,10 +758,17 @@ class RequestHandler(object):
     def reverse_url(self, name, *args):
         return self.application.reverse_url(name, *args)
 
+    @contextlib.contextmanager
+    def _stack_context(self):
+        try:
+            yield
+        except Exception, e:
+            self._handle_request_exception(e)
+
     def _execute(self, transforms, *args, **kwargs):
         """Executes this request with the given output transforms."""
         self._transforms = transforms
-        try:
+        with stack_context.StackContext(self._stack_context):
             if self.request.method not in self.SUPPORTED_METHODS:
                 raise HTTPError(405)
             # If XSRF cookies are turned on, reject form submissions without
@@ -770,8 +781,6 @@ class RequestHandler(object):
                 getattr(self, self.request.method.lower())(*args, **kwargs)
                 if self._auto_finish and not self._finished:
                     self.finish()
-        except Exception, e:
-            self._handle_request_exception(e)
 
     def _generate_headers(self):
         lines = [self.request.version + " " + str(self._status_code) + " " +