]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add timeouts for idle keepalive connections in HTTPServer.
authorBen Darnell <ben@bendarnell.com>
Sun, 6 Apr 2014 13:17:59 +0000 (14:17 +0100)
committerBen Darnell <ben@bendarnell.com>
Sun, 6 Apr 2014 13:17:59 +0000 (14:17 +0100)
Add gen.with_timeout wrapper.

tornado/concurrent.py
tornado/gen.py
tornado/http1connection.py
tornado/httpserver.py
tornado/test/gen_test.py
tornado/test/httpserver_test.py

index ab7b65c3ad4817b8e42ba073621f3c2a8bf3468a..b73e1df67f5879a376c0991f3821a3e1f18b0bc2 100644 (file)
@@ -258,10 +258,13 @@ def return_future(f):
 def chain_future(a, b):
     """Chain two futures together so that when one completes, so does the other.
 
-    The result (success or failure) of ``a`` will be copied to ``b``.
+    The result (success or failure) of ``a`` will be copied to ``b``, unless
+    ``b`` has already been completed or cancelled by the time ``a`` finishes.
     """
     def copy(future):
         assert future is a
+        if b.done():
+            return
         if (isinstance(a, TracebackFuture) and isinstance(b, TracebackFuture)
                 and a.exc_info() is not None):
             b.set_exc_info(a.exc_info())
index 5631c5e9d87f12aa292c799f58896f13262ff9ed..28c031170068516c77c4cf23f5d00c99a885dda1 100644 (file)
@@ -87,7 +87,7 @@ import itertools
 import sys
 import types
 
-from tornado.concurrent import Future, TracebackFuture, is_future
+from tornado.concurrent import Future, TracebackFuture, is_future, chain_future
 from tornado.ioloop import IOLoop
 from tornado import stack_context
 
@@ -112,6 +112,10 @@ class ReturnValueIgnoredError(Exception):
     pass
 
 
+class TimeoutError(Exception):
+    """Exception raised by ``with_timeout``."""
+
+
 def engine(func):
     """Callback-oriented decorator for asynchronous generators.
 
@@ -454,6 +458,34 @@ def maybe_future(x):
         return fut
 
 
+def with_timeout(timeout, future, io_loop=None):
+    """Wraps a `.Future` in a timeout.
+
+    Raises `TimeoutError` if the input future does not complete before
+    ``timeout``, which may be specified in any form allowed by
+    `.IOLoop.add_timeout` (i.e. a `datetime.timedelta` or an absolute time
+    relative to `.IOLoop.time`)
+
+    Currently only supports Futures, not other `YieldPoint` classes.
+    """
+    # TODO: allow yield points in addition to futures?
+    # Tricky to do with stack_context semantics.
+    #
+    # It would be more efficient to cancel the input future on timeout instead
+    # of creating a new one, but we can't know if we are the only one waiting
+    # on the input future, so cancelling it might disrupt other callers.
+    result = Future()
+    chain_future(future, result)
+    if io_loop is None:
+        io_loop = IOLoop.current()
+    timeout_handle = io_loop.add_timeout(
+        timeout,
+        lambda: result.set_exception(TimeoutError("Timeout")))
+    io_loop.add_future(future,
+                        lambda future: io_loop.remove_timeout(timeout_handle))
+    return result
+
+
 _null_future = Future()
 _null_future.set_result(None)
 
index eb1f309069b62e8c7c8257885bcb40403ff6ea6d..e000ab8894e719cbe4c203416f444ad99d959871 100644 (file)
@@ -16,6 +16,7 @@
 
 from __future__ import absolute_import, division, print_function, with_statement
 
+import datetime
 import socket
 
 from tornado.concurrent import Future
@@ -36,7 +37,7 @@ class HTTP1Connection(object):
     """
     def __init__(self, stream, address, is_client,
                  no_keep_alive=False, protocol=None, chunk_size=None,
-                 max_header_size=None):
+                 max_header_size=None, header_timeout=None):
         self.is_client = is_client
         self.stream = stream
         self.address = address
@@ -60,6 +61,7 @@ class HTTP1Connection(object):
             self.protocol = "http"
         self._chunk_size = chunk_size or 65536
         self._max_header_size = max_header_size or 65536
+        self._header_timeout = header_timeout
         self._disconnect_on_finish = False
         self._clear_request_state()
         self.stream.set_close_callback(self._on_connection_close)
@@ -114,9 +116,20 @@ class HTTP1Connection(object):
         assert isinstance(delegate, httputil.HTTPMessageDelegate)
         self.message_delegate = delegate
         try:
-            header_data = yield self.stream.read_until_regex(
-                b"\r?\n\r?\n",
-                max_bytes=self._max_header_size)
+            header_future = self.stream.read_until_regex(
+                        b"\r?\n\r?\n",
+                        max_bytes=self._max_header_size)
+            if self._header_timeout is None:
+                header_data = yield header_future
+            else:
+                try:
+                    header_data = yield gen.with_timeout(
+                        datetime.timedelta(seconds=self._header_timeout),
+                        header_future,
+                        io_loop=self.stream.io_loop)
+                except gen.TimeoutError:
+                    self.close()
+                    raise gen.Return(False)
             self._reading = True
             self._finish_future = Future()
             start_line, headers = self._parse_headers(header_data)
index 6f9b9db2391c408af65a663a8d0374f7ccc9304c..488abb5a6626d5db90b722f7ab0002dfcc2ae3d7 100644 (file)
@@ -136,7 +136,8 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
     """
     def __init__(self, request_callback, no_keep_alive=False, io_loop=None,
                  xheaders=False, ssl_options=None, protocol=None, gzip=False,
-                 chunk_size=None, max_header_size=None, **kwargs):
+                 chunk_size=None, max_header_size=None,
+                 idle_connection_timeout=None, **kwargs):
         self.request_callback = request_callback
         self.no_keep_alive = no_keep_alive
         self.xheaders = xheaders
@@ -144,15 +145,18 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
         self.gzip = gzip
         self.chunk_size = chunk_size
         self.max_header_size = max_header_size
+        self.idle_connection_timeout = idle_connection_timeout or 3600
         TCPServer.__init__(self, io_loop=io_loop, ssl_options=ssl_options,
                            **kwargs)
 
     def handle_stream(self, stream, address):
-        conn = HTTP1Connection(stream, address=address, is_client=False,
-                               no_keep_alive=self.no_keep_alive,
-                               protocol=self.protocol,
-                               chunk_size=self.chunk_size,
-                               max_header_size=self.max_header_size)
+        conn = HTTP1Connection(
+            stream, address=address, is_client=False,
+            no_keep_alive=self.no_keep_alive,
+            protocol=self.protocol,
+            chunk_size=self.chunk_size,
+            max_header_size=self.max_header_size,
+            header_timeout=self.idle_connection_timeout)
         conn.start_serving(self, gzip=self.gzip)
 
     def start_request(self, connection):
index 7f4a09184870589ef061a92d21d80bb57a330e3b..3045e58bf3abf0dd44607e87d9134e3d9fa53144 100644 (file)
@@ -1,6 +1,7 @@
 from __future__ import absolute_import, division, print_function, with_statement
 
 import contextlib
+import datetime
 import functools
 import sys
 import textwrap
@@ -8,7 +9,7 @@ import time
 import platform
 import weakref
 
-from tornado.concurrent import return_future
+from tornado.concurrent import return_future, Future
 from tornado.escape import url_escape
 from tornado.httpclient import AsyncHTTPClient
 from tornado.ioloop import IOLoop
@@ -949,5 +950,40 @@ class GenWebTest(AsyncHTTPTestCase):
         response = self.fetch('/async_prepare_error')
         self.assertEqual(response.code, 403)
 
+
+class WithTimeoutTest(AsyncTestCase):
+    @gen_test
+    def test_timeout(self):
+        with self.assertRaises(gen.TimeoutError):
+            yield gen.with_timeout(datetime.timedelta(seconds=0.1),
+                                   Future())
+
+    @gen_test
+    def test_completes_before_timeout(self):
+        future = Future()
+        self.io_loop.add_timeout(datetime.timedelta(seconds=0.1),
+                                 lambda: future.set_result('asdf'))
+        result = yield gen.with_timeout(datetime.timedelta(seconds=3600),
+                                         future)
+        self.assertEqual(result, 'asdf')
+
+    @gen_test
+    def test_fails_before_timeout(self):
+        future = Future()
+        self.io_loop.add_timeout(
+            datetime.timedelta(seconds=0.1),
+            lambda: future.set_exception(ZeroDivisionError))
+        with self.assertRaises(ZeroDivisionError):
+            yield gen.with_timeout(datetime.timedelta(seconds=3600), future)
+
+    @gen_test
+    def test_already_resolved(self):
+        future = Future()
+        future.set_result('asdf')
+        result = yield gen.with_timeout(datetime.timedelta(seconds=3600),
+                                        future)
+        self.assertEqual(result, 'asdf')
+
+
 if __name__ == '__main__':
     unittest.main()
index 7a2d651da33cdc793a960d31a09f40feae8fe58b..d1cfbba00c4affb278fa442dea258e62969cf90c 100644 (file)
@@ -12,7 +12,7 @@ from tornado.log import gen_log
 from tornado.netutil import ssl_options_to_context
 from tornado.simple_httpclient import SimpleAsyncHTTPClient
 from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog
-from tornado.test.util import unittest
+from tornado.test.util import unittest, skipOnTravis
 from tornado.util import u, bytes_type
 from tornado.web import Application, RequestHandler, asynchronous
 from contextlib import closing
@@ -844,3 +844,50 @@ class MaxHeaderSizeTest(AsyncHTTPTestCase):
         with ExpectLog(gen_log, "Unsatisfiable read"):
             response = self.fetch("/", headers={'X-Filler': 'a' * 1000})
         self.assertEqual(response.code, 599)
+
+
+@skipOnTravis
+class IdleTimeoutTest(AsyncHTTPTestCase):
+    def get_app(self):
+        return Application([('/', HelloWorldRequestHandler)])
+
+    def get_httpserver_options(self):
+        return dict(idle_connection_timeout=0.1)
+
+    def setUp(self):
+        super(IdleTimeoutTest, self).setUp()
+        self.streams = []
+
+    def tearDown(self):
+        super(IdleTimeoutTest, self).tearDown()
+        for stream in self.streams:
+            stream.close()
+
+    def connect(self):
+        stream = IOStream(socket.socket())
+        stream.connect(('localhost', self.get_http_port()), self.stop)
+        self.wait()
+        self.streams.append(stream)
+        return stream
+
+    def test_unused_connection(self):
+        stream = self.connect()
+        stream.set_close_callback(self.stop)
+        self.wait()
+
+    def test_idle_after_use(self):
+        stream = self.connect()
+        stream.set_close_callback(lambda: self.stop("closed"))
+
+        # Use the connection twice to make sure keep-alives are working
+        for i in range(2):
+            stream.write(b"GET / HTTP/1.1\r\n\r\n")
+            stream.read_until(b"\r\n\r\n", self.stop)
+            self.wait()
+            stream.read_bytes(11, self.stop)
+            data = self.wait()
+            self.assertEqual(data, b"Hello world")
+
+        # Now let the timeout trigger and close the connection.
+        data = self.wait()
+        self.assertEqual(data, "closed")