]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add a chunk_size limitation to HTTP1Connection.
authorBen Darnell <ben@bendarnell.com>
Sat, 29 Mar 2014 16:22:55 +0000 (16:22 +0000)
committerBen Darnell <ben@bendarnell.com>
Sat, 29 Mar 2014 16:32:21 +0000 (16:32 +0000)
HTTPMessageDelegate will no longer receive more than chunk_size bytes
in a single call.

tornado/http1connection.py
tornado/httpserver.py
tornado/test/httpserver_test.py
tornado/util.py

index e46c04278601e51a7921bd67473695b892bf1705..571ed678d8bfe867f232d7e4e528e7ee5f8b7a15 100644 (file)
@@ -35,7 +35,7 @@ class HTTP1Connection(object):
     until the HTTP conection is closed.
     """
     def __init__(self, stream, address, is_client,
-                 no_keep_alive=False, protocol=None):
+                 no_keep_alive=False, protocol=None, chunk_size=None):
         self.is_client = is_client
         self.stream = stream
         self.address = address
@@ -57,6 +57,7 @@ class HTTP1Connection(object):
             self.protocol = "https"
         else:
             self.protocol = "http"
+        self._chunk_size = chunk_size or 65536
         self._disconnect_on_finish = False
         self._clear_request_state()
         self.stream.set_close_callback(self._on_connection_close)
@@ -83,7 +84,8 @@ class HTTP1Connection(object):
         while True:
             request_delegate = delegate.start_request(self)
             if gzip:
-                request_delegate = _GzipMessageDelegate(request_delegate)
+                request_delegate = _GzipMessageDelegate(request_delegate,
+                                                        self._chunk_size)
             try:
                 ret = yield self._read_message(request_delegate)
             except iostream.StreamClosedError:
@@ -94,7 +96,7 @@ class HTTP1Connection(object):
 
     def read_response(self, delegate, method, use_gzip=False):
         if use_gzip:
-            delegate = _GzipMessageDelegate(delegate)
+            delegate = _GzipMessageDelegate(delegate, self._chunk_size)
         return self._read_message(delegate, method=method)
 
     @gen.coroutine
@@ -335,8 +337,11 @@ class HTTP1Connection(object):
 
     @gen.coroutine
     def _read_fixed_body(self, content_length):
-        body = yield self.stream.read_bytes(content_length)
-        self.message_delegate.data_received(body)
+        while content_length > 0:
+            body = yield self.stream.read_bytes(
+                min(self._chunk_size, content_length), partial=True)
+            content_length -= len(body)
+            yield gen.maybe_future(self.message_delegate.data_received(body))
 
     @gen.coroutine
     def _read_chunked_body(self):
@@ -346,10 +351,16 @@ class HTTP1Connection(object):
             chunk_len = int(chunk_len.strip(), 16)
             if chunk_len == 0:
                 return
+            bytes_to_read = chunk_len
+            while bytes_to_read:
+                chunk = yield self.stream.read_bytes(
+                    min(bytes_to_read, self._chunk_size), partial=True)
+                bytes_to_read -= len(chunk)
+                yield gen.maybe_future(
+                    self.message_delegate.data_received(chunk))
             # chunk ends with \r\n
-            chunk = yield self.stream.read_bytes(chunk_len + 2)
-            assert chunk[-2:] == b"\r\n"
-            self.message_delegate.data_received(chunk[:-2])
+            crlf = yield self.stream.read_bytes(2)
+            assert crlf == b"\r\n"
 
     @gen.coroutine
     def _read_body_until_close(self):
@@ -360,8 +371,9 @@ class HTTP1Connection(object):
 class _GzipMessageDelegate(httputil.HTTPMessageDelegate):
     """Wraps an `HTTPMessageDelegate` to decode ``Content-Encoding: gzip``.
     """
-    def __init__(self, delegate):
+    def __init__(self, delegate, chunk_size):
         self._delegate = delegate
+        self._chunk_size = chunk_size
         self._decompressor = None
 
     def headers_received(self, start_line, headers):
@@ -375,10 +387,19 @@ class _GzipMessageDelegate(httputil.HTTPMessageDelegate):
             del headers["Content-Encoding"]
         return self._delegate.headers_received(start_line, headers)
 
+    @gen.coroutine
     def data_received(self, chunk):
         if self._decompressor:
-            chunk = self._decompressor.decompress(chunk)
-        return self._delegate.data_received(chunk)
+            compressed_data = chunk
+            while compressed_data:
+                decompressed = self._decompressor.decompress(
+                    compressed_data, self._chunk_size)
+                if decompressed:
+                    yield gen.maybe_future(
+                        self._delegate.data_received(decompressed))
+                compressed_data = self._decompressor.unconsumed_tail
+        else:
+            yield gen.maybe_future(self._delegate.data_received(chunk))
 
     def finish(self):
         if self._decompressor is not None:
index d83f54c468134a3b0a146f54435f42bc5ffe4661..1515150b4df3d461f2e338b9a0cc835331676cdc 100644 (file)
@@ -136,19 +136,21 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
     """
     def __init__(self, request_callback, no_keep_alive=False, io_loop=None,
                  xheaders=False, ssl_options=None, protocol=None, gzip=False,
-                 **kwargs):
+                 chunk_size=None, **kwargs):
         self.request_callback = request_callback
         self.no_keep_alive = no_keep_alive
         self.xheaders = xheaders
         self.protocol = protocol
         self.gzip = gzip
+        self.chunk_size = chunk_size
         TCPServer.__init__(self, io_loop=io_loop, ssl_options=ssl_options,
                            **kwargs)
 
     def handle_stream(self, stream, address):
         conn = HTTP1Connection(stream, address=address, is_client=False,
                                no_keep_alive=self.no_keep_alive,
-                               protocol=self.protocol)
+                               protocol=self.protocol,
+                               chunk_size=self.chunk_size)
         conn.start_serving(self, gzip=self.gzip)
 
     def start_request(self, connection):
index 343a8e4c072e55cadb27cfd17cda5fad88b4ed0a..9cadb042bba0b8feccc49bf0630e51a2d1ef63a9 100644 (file)
@@ -3,13 +3,14 @@
 
 from __future__ import absolute_import, division, print_function, with_statement
 from tornado import netutil
-from tornado.escape import json_decode, utf8, _unicode, recursive_unicode, native_str
+from tornado.escape import json_decode, json_encode, utf8, _unicode, recursive_unicode, native_str
 from tornado.http1connection import HTTP1Connection
 from tornado.httpserver import HTTPServer
-from tornado.httputil import HTTPHeaders, HTTPMessageDelegate
+from tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine
 from tornado.iostream import IOStream
 from tornado.log import gen_log
 from tornado.netutil import ssl_options_to_context
+from tornado.simple_httpclient import SimpleAsyncHTTPClient
 from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog
 from tornado.test.util import unittest
 from tornado.util import u, bytes_type
@@ -737,3 +738,91 @@ class GzipUnsupportedTest(GzipBaseTest, AsyncHTTPTestCase):
         with ExpectLog(gen_log, "Unsupported Content-Encoding"):
             response = self.post_gzip('foo=bar')
         self.assertEquals(json_decode(response.body), {})
+
+
+class StreamingChunkSizeTest(AsyncHTTPTestCase):
+    # 50 characters long, and repetitive so it can be compressed.
+    BODY = b'01234567890123456789012345678901234567890123456789'
+    CHUNK_SIZE = 16
+
+    def get_http_client(self):
+        # body_producer doesn't work on curl_httpclient, so override the
+        # configured AsyncHTTPClient implementation.
+        return SimpleAsyncHTTPClient(io_loop=self.io_loop)
+
+    def get_httpserver_options(self):
+        return dict(chunk_size=self.CHUNK_SIZE, gzip=True)
+
+    class MessageDelegate(HTTPMessageDelegate):
+        def __init__(self, connection):
+            self.connection = connection
+
+        def headers_received(self, start_line, headers):
+            self.chunk_lengths = []
+
+        def data_received(self, chunk):
+            self.chunk_lengths.append(len(chunk))
+
+        def finish(self):
+            response_body = utf8(json_encode(self.chunk_lengths))
+            self.connection.write_headers(
+                ResponseStartLine('HTTP/1.1', 200, 'OK'),
+                HTTPHeaders({'Content-Length': str(len(response_body))}))
+            self.connection.write(response_body)
+            self.connection.finish()
+
+    def get_app(self):
+        class App(HTTPServerConnectionDelegate):
+            def start_request(self, connection):
+                return StreamingChunkSizeTest.MessageDelegate(connection)
+        return App()
+
+    def fetch_chunk_sizes(self, **kwargs):
+        response = self.fetch('/', method='POST', **kwargs)
+        response.rethrow()
+        chunks = json_decode(response.body)
+        self.assertEqual(len(self.BODY), sum(chunks))
+        for chunk_size in chunks:
+            self.assertLessEqual(chunk_size, self.CHUNK_SIZE,
+                                 'oversized chunk: ' + str(chunks))
+            self.assertGreater(chunk_size, 0,
+                               'empty chunk: ' + str(chunks))
+        return chunks
+
+    def compress(self, body):
+        bytesio = BytesIO()
+        gzfile = gzip.GzipFile(mode='w', fileobj=bytesio)
+        gzfile.write(body)
+        gzfile.close()
+        compressed = bytesio.getvalue()
+        if len(compressed) >= len(body):
+            raise Exception("body did not shrink when compressed")
+        return compressed
+
+    def test_regular_body(self):
+        chunks = self.fetch_chunk_sizes(body=self.BODY)
+        # Without compression we know exactly what to expect.
+        self.assertEqual([16, 16, 16, 2], chunks)
+
+    def test_compressed_body(self):
+         self.fetch_chunk_sizes(body=self.compress(self.BODY),
+                                headers={'Content-Encoding': 'gzip'})
+         # Compression creates irregular boundaries so the assertions
+         # in fetch_chunk_sizes are as specific as we can get.
+
+    def test_chunked_body(self):
+        def body_producer(write):
+            write(self.BODY[:20])
+            write(self.BODY[20:])
+        chunks = self.fetch_chunk_sizes(body_producer=body_producer)
+        # HTTP chunk boundaries translate to application-visible breaks
+        self.assertEqual([16, 4, 16, 14], chunks)
+
+    def test_chunked_compressed(self):
+        compressed = self.compress(self.BODY)
+        self.assertGreater(len(compressed), 20)
+        def body_producer(write):
+            write(compressed[:20])
+            write(compressed[20:])
+        self.fetch_chunk_sizes(body_producer=body_producer,
+                               headers={'Content-Encoding': 'gzip'})
index cc53222296e782d63981e5dc93f05b52bee23749..f4b2482fc54cbf0bcd5642367d7b4574f96df465 100644 (file)
@@ -33,7 +33,7 @@ class ObjectDict(dict):
 class GzipDecompressor(object):
     """Streaming gzip decompressor.
 
-    The interface is like that of `zlib.decompressobj` (without the
+    The interface is like that of `zlib.decompressobj` (without some of the
     optional arguments, but it understands gzip headers and checksums.
     """
     def __init__(self):
@@ -42,14 +42,24 @@ class GzipDecompressor(object):
         # This works on cpython and pypy, but not jython.
         self.decompressobj = zlib.decompressobj(16 + zlib.MAX_WBITS)
 
-    def decompress(self, value):
+    def decompress(self, value, max_length=None):
         """Decompress a chunk, returning newly-available data.
 
         Some data may be buffered for later processing; `flush` must
         be called when there is no more input data to ensure that
         all data was processed.
+
+        If ``max_length`` is given, some input data may be left over
+        in ``unconsumed_tail``; you must retrieve this value and pass
+        it back to a future call to `decompress` if it is not empty.
+        """
+        return self.decompressobj.decompress(value, max_length)
+
+    @property
+    def unconsumed_tail(self):
+        """Returns the unconsumed portion left over
         """
-        return self.decompressobj.decompress(value)
+        return self.decompressobj.unconsumed_tail
 
     def flush(self):
         """Return any remaining buffered data not yet returned by decompress.