From: Ben Darnell Date: Sat, 19 May 2018 15:11:21 +0000 (-0400) Subject: websocket: Limit post-decompression size of received messages X-Git-Tag: v5.1.0b1~14^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=refs%2Fpull%2F2391%2Fhead;p=thirdparty%2Ftornado.git websocket: Limit post-decompression size of received messages Protects against memory exhaustion DoS attacks. --- diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py index a6439b9fb..ea7b1e4ce 100644 --- a/tornado/test/websocket_test.py +++ b/tornado/test/websocket_test.py @@ -553,10 +553,22 @@ class CompressionTestMixin(object): def get_app(self): self.close_future = Future() + + class LimitedHandler(TestWebSocketHandler): + @property + def max_message_size(self): + return 1024 + + def on_message(self, message): + self.write_message(str(len(message))) + return Application([ ('/echo', EchoHandler, dict( close_future=self.close_future, compression_options=self.get_server_compression_options())), + ('/limited', LimitedHandler, dict( + close_future=self.close_future, + compression_options=self.get_server_compression_options())), ]) def get_server_compression_options(self): @@ -582,6 +594,22 @@ class CompressionTestMixin(object): ws.protocol._wire_bytes_out) yield self.close(ws) + @gen_test + def test_size_limit(self): + ws = yield self.ws_connect( + '/limited', + compression_options=self.get_client_compression_options()) + # Small messages pass through. + ws.write_message('a' * 128) + response = yield ws.read_message() + self.assertEqual(response, '128') + # This message is too big after decompression, but it compresses + # down to a size that will pass the initial checks. + ws.write_message('a' * 2048) + response = yield ws.read_message() + self.assertIsNone(response) + yield self.close(ws) + class UncompressedTestMixin(CompressionTestMixin): """Specialization of CompressionTestMixin when we expect no compression.""" diff --git a/tornado/websocket.py b/tornado/websocket.py index 6e5b10300..0b994fc12 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -44,6 +44,8 @@ if PY3: else: from urlparse import urlparse # py3 +_default_max_message_size = 10 * 1024 * 1024 + class WebSocketError(Exception): pass @@ -57,6 +59,10 @@ class WebSocketClosedError(WebSocketError): pass +class _DecompressTooLargeError(Exception): + pass + + class WebSocketHandler(tornado.web.RequestHandler): """Subclass this class to create a basic WebSocket handler. @@ -225,7 +231,7 @@ class WebSocketHandler(tornado.web.RequestHandler): Default is 10MiB. """ - return self.settings.get('websocket_max_message_size', None) + return self.settings.get('websocket_max_message_size', _default_max_message_size) def write_message(self, message, binary=False): """Sends the given message to the client of this Web Socket. @@ -596,7 +602,8 @@ class _PerMessageDeflateCompressor(object): class _PerMessageDeflateDecompressor(object): - def __init__(self, persistent, max_wbits, compression_options=None): + def __init__(self, persistent, max_wbits, max_message_size, compression_options=None): + self._max_message_size = max_message_size if max_wbits is None: max_wbits = zlib.MAX_WBITS if not (8 <= max_wbits <= zlib.MAX_WBITS): @@ -613,7 +620,10 @@ class _PerMessageDeflateDecompressor(object): def decompress(self, data): decompressor = self._decompressor or self._create_decompressor() - return decompressor.decompress(data + b'\x00\x00\xff\xff') + result = decompressor.decompress(data + b'\x00\x00\xff\xff', self._max_message_size) + if decompressor.unconsumed_tail: + raise _DecompressTooLargeError() + return result class WebSocketProtocol13(WebSocketProtocol): @@ -801,6 +811,7 @@ class WebSocketProtocol13(WebSocketProtocol): self._compressor = _PerMessageDeflateCompressor( **self._get_compressor_options(side, agreed_parameters, compression_options)) self._decompressor = _PerMessageDeflateDecompressor( + max_message_size=self.handler.max_message_size, **self._get_compressor_options(other_side, agreed_parameters, compression_options)) def _write_frame(self, fin, opcode, data, flags=0): @@ -920,7 +931,7 @@ class WebSocketProtocol13(WebSocketProtocol): new_len = payloadlen if self._fragmented_message_buffer is not None: new_len += len(self._fragmented_message_buffer) - if new_len > (self.handler.max_message_size or 10 * 1024 * 1024): + if new_len > self.handler.max_message_size: self.close(1009, "message too big") self._abort() return @@ -971,7 +982,12 @@ class WebSocketProtocol13(WebSocketProtocol): return if self._frame_compressed: - data = self._decompressor.decompress(data) + try: + data = self._decompressor.decompress(data) + except _DecompressTooLargeError: + self.close(1009, "message too big after decompression") + self._abort() + return if opcode == 0x1: # UTF-8 data @@ -1260,7 +1276,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): def websocket_connect(url, callback=None, connect_timeout=None, on_message_callback=None, compression_options=None, ping_interval=None, ping_timeout=None, - max_message_size=None, subprotocols=None): + max_message_size=_default_max_message_size, subprotocols=None): """Client-side websocket support. Takes a url and returns a Future whose result is a