]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add max_body_size and body_timeout limits to http1connection.
authorBen Darnell <ben@bendarnell.com>
Sun, 20 Apr 2014 05:01:20 +0000 (01:01 -0400)
committerBen Darnell <ben@bendarnell.com>
Sun, 20 Apr 2014 05:01:20 +0000 (01:01 -0400)
These limits can be set on a per-request basis during headers_received.

tornado/http1connection.py
tornado/httpserver.py
tornado/test/httpserver_test.py

index 787003c86cb4ec6cdfb068f180a411842bb9b8ee..44480f604f781cfdb4f45b73735da02b6ebb26da 100644 (file)
@@ -37,7 +37,8 @@ class HTTP1Connection(object):
     """
     def __init__(self, stream, address, is_client,
                  no_keep_alive=False, protocol=None, chunk_size=None,
-                 max_header_size=None, header_timeout=None):
+                 max_header_size=None, header_timeout=None,
+                 max_body_size=None, body_timeout=None):
         self.is_client = is_client
         self.stream = stream
         self.address = address
@@ -62,6 +63,9 @@ class HTTP1Connection(object):
         self._chunk_size = chunk_size or 65536
         self._max_header_size = max_header_size or 65536
         self._header_timeout = header_timeout
+        self._default_max_body_size = (max_body_size or
+                                       self.stream.max_buffer_size)
+        self._default_body_timeout = body_timeout
         self._disconnect_on_finish = False
         self._clear_request_state()
         self.stream.set_close_callback(self._on_connection_close)
@@ -115,6 +119,8 @@ class HTTP1Connection(object):
     def _read_message(self, delegate, method=None):
         assert isinstance(delegate, httputil.HTTPMessageDelegate)
         self.message_delegate = delegate
+        self._max_body_size = self._default_max_body_size
+        self._body_timeout = self._default_body_timeout
         try:
             header_future = self.stream.read_until_regex(
                         b"\r?\n\r?\n",
@@ -169,7 +175,18 @@ class HTTP1Connection(object):
             if not skip_body:
                 body_future = self._read_body(headers)
                 if body_future is not None:
-                    yield body_future
+                    if self._body_timeout is None:
+                        yield body_future
+                    else:
+                        try:
+                            yield gen.with_timeout(
+                                self.stream.io_loop.time() + self._body_timeout,
+                                body_future, self.stream.io_loop)
+                        except gen.TimeoutError:
+                            gen_log.info("Timeout reading body from %r",
+                                         self.address)
+                            self.stream.close()
+                            raise gen.Return(False)
             self._reading = False
             self.message_delegate.finish()
             yield self._finish_future
@@ -226,6 +243,12 @@ class HTTP1Connection(object):
         self.stream = None
         return stream
 
+    def set_body_timeout(self, timeout):
+        self._body_timeout = timeout
+
+    def set_max_body_size(self, max_body_size):
+        self._max_body_size = max_body_size
+
     def write_headers(self, start_line, headers, chunk=None, callback=None,
                       has_body=True):
         if self.is_client:
@@ -378,7 +401,7 @@ class HTTP1Connection(object):
         content_length = headers.get("Content-Length")
         if content_length:
             content_length = int(content_length)
-            if content_length > self.stream.max_buffer_size:
+            if content_length > self._max_body_size:
                 raise httputil.HTTPInputException("Content-Length too long")
             return self._read_fixed_body(content_length)
         if headers.get("Transfer-Encoding") == "chunked":
@@ -398,11 +421,15 @@ class HTTP1Connection(object):
     @gen.coroutine
     def _read_chunked_body(self):
         # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
+        total_size = 0
         while True:
             chunk_len = yield self.stream.read_until(b"\r\n", max_bytes=64)
             chunk_len = int(chunk_len.strip(), 16)
             if chunk_len == 0:
                 return
+            total_size += chunk_len
+            if total_size > self._max_body_size:
+                raise httputil.HTTPInputException("chunked body too large")
             bytes_to_read = chunk_len
             while bytes_to_read:
                 chunk = yield self.stream.read_bytes(
index 488abb5a6626d5db90b722f7ab0002dfcc2ae3d7..b8339d961eb9d9b94ff02466d7c26b68cf2ac7a1 100644 (file)
@@ -137,7 +137,8 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
     def __init__(self, request_callback, no_keep_alive=False, io_loop=None,
                  xheaders=False, ssl_options=None, protocol=None, gzip=False,
                  chunk_size=None, max_header_size=None,
-                 idle_connection_timeout=None, **kwargs):
+                 idle_connection_timeout=None, body_timeout=None,
+                 max_body_size=None, **kwargs):
         self.request_callback = request_callback
         self.no_keep_alive = no_keep_alive
         self.xheaders = xheaders
@@ -146,6 +147,8 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
         self.chunk_size = chunk_size
         self.max_header_size = max_header_size
         self.idle_connection_timeout = idle_connection_timeout or 3600
+        self.body_timeout = body_timeout
+        self.max_body_size = max_body_size
         TCPServer.__init__(self, io_loop=io_loop, ssl_options=ssl_options,
                            **kwargs)
 
@@ -156,7 +159,9 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
             protocol=self.protocol,
             chunk_size=self.chunk_size,
             max_header_size=self.max_header_size,
-            header_timeout=self.idle_connection_timeout)
+            header_timeout=self.idle_connection_timeout,
+            max_body_size=self.max_body_size,
+            body_timeout=self.body_timeout)
         conn.start_serving(self, gzip=self.gzip)
 
     def start_request(self, connection):
index d1cfbba00c4affb278fa442dea258e62969cf90c..0c46fb358b664fd7627504be56707574749325e6 100644 (file)
@@ -4,6 +4,7 @@
 from __future__ import absolute_import, division, print_function, with_statement
 from tornado import netutil
 from tornado.escape import json_decode, json_encode, utf8, _unicode, recursive_unicode, native_str
+from tornado import gen
 from tornado.http1connection import HTTP1Connection
 from tornado.httpserver import HTTPServer
 from tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine
@@ -11,10 +12,10 @@ from tornado.iostream import IOStream
 from tornado.log import gen_log
 from tornado.netutil import ssl_options_to_context
 from tornado.simple_httpclient import SimpleAsyncHTTPClient
-from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog
+from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog, gen_test
 from tornado.test.util import unittest, skipOnTravis
 from tornado.util import u, bytes_type
-from tornado.web import Application, RequestHandler, asynchronous
+from tornado.web import Application, RequestHandler, asynchronous, stream_request_body
 from contextlib import closing
 import datetime
 import gzip
@@ -891,3 +892,114 @@ class IdleTimeoutTest(AsyncHTTPTestCase):
         # Now let the timeout trigger and close the connection.
         data = self.wait()
         self.assertEqual(data, "closed")
+
+
+class BodyLimitsTest(AsyncHTTPTestCase):
+    def get_app(self):
+        class BufferedHandler(RequestHandler):
+            def put(self):
+                self.write(str(len(self.request.body)))
+
+        @stream_request_body
+        class StreamingHandler(RequestHandler):
+            def initialize(self):
+                self.bytes_read = 0
+
+            def prepare(self):
+                if 'expected_size' in self.request.arguments:
+                    self.request.connection.set_max_body_size(
+                        int(self.get_argument('expected_size')))
+                if 'body_timeout' in self.request.arguments:
+                    self.request.connection.set_body_timeout(
+                        float(self.get_argument('body_timeout')))
+
+            def data_received(self, data):
+                self.bytes_read += len(data)
+
+            def put(self):
+                self.write(str(self.bytes_read))
+
+        return Application([('/buffered', BufferedHandler),
+                            ('/streaming', StreamingHandler)])
+
+    def get_httpserver_options(self):
+        return dict(body_timeout=3600, max_body_size=4096)
+
+    def get_http_client(self):
+        # body_producer doesn't work on curl_httpclient, so override the
+        # configured AsyncHTTPClient implementation.
+        return SimpleAsyncHTTPClient(io_loop=self.io_loop)
+
+    def test_small_body(self):
+        response = self.fetch('/buffered', method='PUT', body=b'a'*4096)
+        self.assertEqual(response.body, b'4096')
+        response = self.fetch('/streaming', method='PUT', body=b'a'*4096)
+        self.assertEqual(response.body, b'4096')
+
+    def test_large_body_buffered(self):
+        with ExpectLog(gen_log, '.*Content-Length too long'):
+            response = self.fetch('/buffered', method='PUT', body=b'a'*10240)
+        self.assertEqual(response.code, 599)
+
+    def test_large_body_buffered_chunked(self):
+        with ExpectLog(gen_log, '.*chunked body too large'):
+            response = self.fetch('/buffered', method='PUT',
+                                  body_producer=lambda write: write(b'a'*10240))
+        self.assertEqual(response.code, 599)
+
+    def test_large_body_streaming(self):
+        with ExpectLog(gen_log, '.*Content-Length too long'):
+            response = self.fetch('/streaming', method='PUT', body=b'a'*10240)
+        self.assertEqual(response.code, 599)
+
+    def test_large_body_streaming_chunked(self):
+        with ExpectLog(gen_log, '.*chunked body too large'):
+            response = self.fetch('/streaming', method='PUT',
+                                  body_producer=lambda write: write(b'a'*10240))
+        self.assertEqual(response.code, 599)
+
+    def test_large_body_streaming_override(self):
+        response = self.fetch('/streaming?expected_size=10240', method='PUT',
+                              body=b'a'*10240)
+        self.assertEqual(response.body, b'10240')
+
+    def test_large_body_streaming_chunked_override(self):
+        response = self.fetch('/streaming?expected_size=10240', method='PUT',
+                              body_producer=lambda write: write(b'a'*10240))
+        self.assertEqual(response.body, b'10240')
+
+    @gen_test
+    def test_timeout(self):
+        stream = IOStream(socket.socket())
+        try:
+            yield stream.connect(('127.0.0.1', self.get_http_port()))
+            # Use a raw stream because AsyncHTTPClient won't let us read a
+            # response without finishing a body.
+            stream.write(b'PUT /streaming?body_timeout=0.1 HTTP/1.0\r\n'
+                         b'Content-Length: 42\r\n\r\n')
+            with ExpectLog(gen_log, 'Timeout reading body'):
+                response = yield stream.read_until_close()
+            self.assertEqual(response, b'')
+        finally:
+            stream.close()
+
+    @gen_test
+    def test_body_size_override_reset(self):
+        # The max_body_size override is reset between requests.
+        stream = IOStream(socket.socket())
+        try:
+            yield stream.connect(('127.0.0.1', self.get_http_port()))
+            # Use a raw stream so we can make sure it's all on one connection.
+            stream.write(b'PUT /streaming?expected_size=10240 HTTP/1.1\r\n'
+                         b'Content-Length: 10240\r\n\r\n')
+            stream.write(b'a'*10240)
+            response = yield gen.Task(read_stream_body, stream)
+            self.assertEqual(response, b'10240')
+            # Without the ?expected_size parameter, we get the old default value
+            stream.write(b'PUT /streaming HTTP/1.1\r\n'
+                         b'Content-Length: 10240\r\n\r\n')
+            with ExpectLog(gen_log, '.*Content-Length too long'):
+                data = yield stream.read_until_close()
+            self.assertEqual(data, b'')
+        finally:
+            stream.close()