From: Ben Darnell Date: Sat, 29 Mar 2014 16:22:55 +0000 (+0000) Subject: Add a chunk_size limitation to HTTP1Connection. X-Git-Tag: v4.0.0b1~91^2~33 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=800d1dcec43f979fa3eb8c316ce5606431d23fff;p=thirdparty%2Ftornado.git Add a chunk_size limitation to HTTP1Connection. HTTPMessageDelegate will no longer receive more than chunk_size bytes in a single call. --- diff --git a/tornado/http1connection.py b/tornado/http1connection.py index e46c04278..571ed678d 100644 --- a/tornado/http1connection.py +++ b/tornado/http1connection.py @@ -35,7 +35,7 @@ class HTTP1Connection(object): until the HTTP conection is closed. """ def __init__(self, stream, address, is_client, - no_keep_alive=False, protocol=None): + no_keep_alive=False, protocol=None, chunk_size=None): self.is_client = is_client self.stream = stream self.address = address @@ -57,6 +57,7 @@ class HTTP1Connection(object): self.protocol = "https" else: self.protocol = "http" + self._chunk_size = chunk_size or 65536 self._disconnect_on_finish = False self._clear_request_state() self.stream.set_close_callback(self._on_connection_close) @@ -83,7 +84,8 @@ class HTTP1Connection(object): while True: request_delegate = delegate.start_request(self) if gzip: - request_delegate = _GzipMessageDelegate(request_delegate) + request_delegate = _GzipMessageDelegate(request_delegate, + self._chunk_size) try: ret = yield self._read_message(request_delegate) except iostream.StreamClosedError: @@ -94,7 +96,7 @@ class HTTP1Connection(object): def read_response(self, delegate, method, use_gzip=False): if use_gzip: - delegate = _GzipMessageDelegate(delegate) + delegate = _GzipMessageDelegate(delegate, self._chunk_size) return self._read_message(delegate, method=method) @gen.coroutine @@ -335,8 +337,11 @@ class HTTP1Connection(object): @gen.coroutine def _read_fixed_body(self, content_length): - body = yield self.stream.read_bytes(content_length) - self.message_delegate.data_received(body) + while content_length > 0: + body = yield self.stream.read_bytes( + min(self._chunk_size, content_length), partial=True) + content_length -= len(body) + yield gen.maybe_future(self.message_delegate.data_received(body)) @gen.coroutine def _read_chunked_body(self): @@ -346,10 +351,16 @@ class HTTP1Connection(object): chunk_len = int(chunk_len.strip(), 16) if chunk_len == 0: return + bytes_to_read = chunk_len + while bytes_to_read: + chunk = yield self.stream.read_bytes( + min(bytes_to_read, self._chunk_size), partial=True) + bytes_to_read -= len(chunk) + yield gen.maybe_future( + self.message_delegate.data_received(chunk)) # chunk ends with \r\n - chunk = yield self.stream.read_bytes(chunk_len + 2) - assert chunk[-2:] == b"\r\n" - self.message_delegate.data_received(chunk[:-2]) + crlf = yield self.stream.read_bytes(2) + assert crlf == b"\r\n" @gen.coroutine def _read_body_until_close(self): @@ -360,8 +371,9 @@ class HTTP1Connection(object): class _GzipMessageDelegate(httputil.HTTPMessageDelegate): """Wraps an `HTTPMessageDelegate` to decode ``Content-Encoding: gzip``. """ - def __init__(self, delegate): + def __init__(self, delegate, chunk_size): self._delegate = delegate + self._chunk_size = chunk_size self._decompressor = None def headers_received(self, start_line, headers): @@ -375,10 +387,19 @@ class _GzipMessageDelegate(httputil.HTTPMessageDelegate): del headers["Content-Encoding"] return self._delegate.headers_received(start_line, headers) + @gen.coroutine def data_received(self, chunk): if self._decompressor: - chunk = self._decompressor.decompress(chunk) - return self._delegate.data_received(chunk) + compressed_data = chunk + while compressed_data: + decompressed = self._decompressor.decompress( + compressed_data, self._chunk_size) + if decompressed: + yield gen.maybe_future( + self._delegate.data_received(decompressed)) + compressed_data = self._decompressor.unconsumed_tail + else: + yield gen.maybe_future(self._delegate.data_received(chunk)) def finish(self): if self._decompressor is not None: diff --git a/tornado/httpserver.py b/tornado/httpserver.py index d83f54c46..1515150b4 100644 --- a/tornado/httpserver.py +++ b/tornado/httpserver.py @@ -136,19 +136,21 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate): """ def __init__(self, request_callback, no_keep_alive=False, io_loop=None, xheaders=False, ssl_options=None, protocol=None, gzip=False, - **kwargs): + chunk_size=None, **kwargs): self.request_callback = request_callback self.no_keep_alive = no_keep_alive self.xheaders = xheaders self.protocol = protocol self.gzip = gzip + self.chunk_size = chunk_size TCPServer.__init__(self, io_loop=io_loop, ssl_options=ssl_options, **kwargs) def handle_stream(self, stream, address): conn = HTTP1Connection(stream, address=address, is_client=False, no_keep_alive=self.no_keep_alive, - protocol=self.protocol) + protocol=self.protocol, + chunk_size=self.chunk_size) conn.start_serving(self, gzip=self.gzip) def start_request(self, connection): diff --git a/tornado/test/httpserver_test.py b/tornado/test/httpserver_test.py index 343a8e4c0..9cadb042b 100644 --- a/tornado/test/httpserver_test.py +++ b/tornado/test/httpserver_test.py @@ -3,13 +3,14 @@ from __future__ import absolute_import, division, print_function, with_statement from tornado import netutil -from tornado.escape import json_decode, utf8, _unicode, recursive_unicode, native_str +from tornado.escape import json_decode, json_encode, utf8, _unicode, recursive_unicode, native_str from tornado.http1connection import HTTP1Connection from tornado.httpserver import HTTPServer -from tornado.httputil import HTTPHeaders, HTTPMessageDelegate +from tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine from tornado.iostream import IOStream from tornado.log import gen_log from tornado.netutil import ssl_options_to_context +from tornado.simple_httpclient import SimpleAsyncHTTPClient from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog from tornado.test.util import unittest from tornado.util import u, bytes_type @@ -737,3 +738,91 @@ class GzipUnsupportedTest(GzipBaseTest, AsyncHTTPTestCase): with ExpectLog(gen_log, "Unsupported Content-Encoding"): response = self.post_gzip('foo=bar') self.assertEquals(json_decode(response.body), {}) + + +class StreamingChunkSizeTest(AsyncHTTPTestCase): + # 50 characters long, and repetitive so it can be compressed. + BODY = b'01234567890123456789012345678901234567890123456789' + CHUNK_SIZE = 16 + + def get_http_client(self): + # body_producer doesn't work on curl_httpclient, so override the + # configured AsyncHTTPClient implementation. + return SimpleAsyncHTTPClient(io_loop=self.io_loop) + + def get_httpserver_options(self): + return dict(chunk_size=self.CHUNK_SIZE, gzip=True) + + class MessageDelegate(HTTPMessageDelegate): + def __init__(self, connection): + self.connection = connection + + def headers_received(self, start_line, headers): + self.chunk_lengths = [] + + def data_received(self, chunk): + self.chunk_lengths.append(len(chunk)) + + def finish(self): + response_body = utf8(json_encode(self.chunk_lengths)) + self.connection.write_headers( + ResponseStartLine('HTTP/1.1', 200, 'OK'), + HTTPHeaders({'Content-Length': str(len(response_body))})) + self.connection.write(response_body) + self.connection.finish() + + def get_app(self): + class App(HTTPServerConnectionDelegate): + def start_request(self, connection): + return StreamingChunkSizeTest.MessageDelegate(connection) + return App() + + def fetch_chunk_sizes(self, **kwargs): + response = self.fetch('/', method='POST', **kwargs) + response.rethrow() + chunks = json_decode(response.body) + self.assertEqual(len(self.BODY), sum(chunks)) + for chunk_size in chunks: + self.assertLessEqual(chunk_size, self.CHUNK_SIZE, + 'oversized chunk: ' + str(chunks)) + self.assertGreater(chunk_size, 0, + 'empty chunk: ' + str(chunks)) + return chunks + + def compress(self, body): + bytesio = BytesIO() + gzfile = gzip.GzipFile(mode='w', fileobj=bytesio) + gzfile.write(body) + gzfile.close() + compressed = bytesio.getvalue() + if len(compressed) >= len(body): + raise Exception("body did not shrink when compressed") + return compressed + + def test_regular_body(self): + chunks = self.fetch_chunk_sizes(body=self.BODY) + # Without compression we know exactly what to expect. + self.assertEqual([16, 16, 16, 2], chunks) + + def test_compressed_body(self): + self.fetch_chunk_sizes(body=self.compress(self.BODY), + headers={'Content-Encoding': 'gzip'}) + # Compression creates irregular boundaries so the assertions + # in fetch_chunk_sizes are as specific as we can get. + + def test_chunked_body(self): + def body_producer(write): + write(self.BODY[:20]) + write(self.BODY[20:]) + chunks = self.fetch_chunk_sizes(body_producer=body_producer) + # HTTP chunk boundaries translate to application-visible breaks + self.assertEqual([16, 4, 16, 14], chunks) + + def test_chunked_compressed(self): + compressed = self.compress(self.BODY) + self.assertGreater(len(compressed), 20) + def body_producer(write): + write(compressed[:20]) + write(compressed[20:]) + self.fetch_chunk_sizes(body_producer=body_producer, + headers={'Content-Encoding': 'gzip'}) diff --git a/tornado/util.py b/tornado/util.py index cc5322229..f4b2482fc 100644 --- a/tornado/util.py +++ b/tornado/util.py @@ -33,7 +33,7 @@ class ObjectDict(dict): class GzipDecompressor(object): """Streaming gzip decompressor. - The interface is like that of `zlib.decompressobj` (without the + The interface is like that of `zlib.decompressobj` (without some of the optional arguments, but it understands gzip headers and checksums. """ def __init__(self): @@ -42,14 +42,24 @@ class GzipDecompressor(object): # This works on cpython and pypy, but not jython. self.decompressobj = zlib.decompressobj(16 + zlib.MAX_WBITS) - def decompress(self, value): + def decompress(self, value, max_length=None): """Decompress a chunk, returning newly-available data. Some data may be buffered for later processing; `flush` must be called when there is no more input data to ensure that all data was processed. + + If ``max_length`` is given, some input data may be left over + in ``unconsumed_tail``; you must retrieve this value and pass + it back to a future call to `decompress` if it is not empty. + """ + return self.decompressobj.decompress(value, max_length) + + @property + def unconsumed_tail(self): + """Returns the unconsumed portion left over """ - return self.decompressobj.decompress(value) + return self.decompressobj.unconsumed_tail def flush(self): """Return any remaining buffered data not yet returned by decompress.