]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
websocket: Limit post-decompression size of received messages 2391/head
authorBen Darnell <ben@bendarnell.com>
Sat, 19 May 2018 15:11:21 +0000 (11:11 -0400)
committerBen Darnell <ben@bendarnell.com>
Sat, 19 May 2018 15:34:39 +0000 (11:34 -0400)
Protects against memory exhaustion DoS attacks.

tornado/test/websocket_test.py
tornado/websocket.py

index a6439b9fb7ffcfdc0ee11c962e8c72db7b78e1e3..ea7b1e4ceff2b86134d70f605316f60c08006fcb 100644 (file)
@@ -553,10 +553,22 @@ class CompressionTestMixin(object):
 
     def get_app(self):
         self.close_future = Future()
+
+        class LimitedHandler(TestWebSocketHandler):
+            @property
+            def max_message_size(self):
+                return 1024
+
+            def on_message(self, message):
+                self.write_message(str(len(message)))
+
         return Application([
             ('/echo', EchoHandler, dict(
                 close_future=self.close_future,
                 compression_options=self.get_server_compression_options())),
+            ('/limited', LimitedHandler, dict(
+                close_future=self.close_future,
+                compression_options=self.get_server_compression_options())),
         ])
 
     def get_server_compression_options(self):
@@ -582,6 +594,22 @@ class CompressionTestMixin(object):
                                ws.protocol._wire_bytes_out)
         yield self.close(ws)
 
+    @gen_test
+    def test_size_limit(self):
+        ws = yield self.ws_connect(
+            '/limited',
+            compression_options=self.get_client_compression_options())
+        # Small messages pass through.
+        ws.write_message('a' * 128)
+        response = yield ws.read_message()
+        self.assertEqual(response, '128')
+        # This message is too big after decompression, but it compresses
+        # down to a size that will pass the initial checks.
+        ws.write_message('a' * 2048)
+        response = yield ws.read_message()
+        self.assertIsNone(response)
+        yield self.close(ws)
+
 
 class UncompressedTestMixin(CompressionTestMixin):
     """Specialization of CompressionTestMixin when we expect no compression."""
index 6e5b103009224f2898d4c103e6483ef95f643ae0..0b994fc123c4a3ee88a23a95a50300873b1e2992 100644 (file)
@@ -44,6 +44,8 @@ if PY3:
 else:
     from urlparse import urlparse  # py3
 
+_default_max_message_size = 10 * 1024 * 1024
+
 
 class WebSocketError(Exception):
     pass
@@ -57,6 +59,10 @@ class WebSocketClosedError(WebSocketError):
     pass
 
 
+class _DecompressTooLargeError(Exception):
+    pass
+
+
 class WebSocketHandler(tornado.web.RequestHandler):
     """Subclass this class to create a basic WebSocket handler.
 
@@ -225,7 +231,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
 
         Default is 10MiB.
         """
-        return self.settings.get('websocket_max_message_size', None)
+        return self.settings.get('websocket_max_message_size', _default_max_message_size)
 
     def write_message(self, message, binary=False):
         """Sends the given message to the client of this Web Socket.
@@ -596,7 +602,8 @@ class _PerMessageDeflateCompressor(object):
 
 
 class _PerMessageDeflateDecompressor(object):
-    def __init__(self, persistent, max_wbits, compression_options=None):
+    def __init__(self, persistent, max_wbits, max_message_size, compression_options=None):
+        self._max_message_size = max_message_size
         if max_wbits is None:
             max_wbits = zlib.MAX_WBITS
         if not (8 <= max_wbits <= zlib.MAX_WBITS):
@@ -613,7 +620,10 @@ class _PerMessageDeflateDecompressor(object):
 
     def decompress(self, data):
         decompressor = self._decompressor or self._create_decompressor()
-        return decompressor.decompress(data + b'\x00\x00\xff\xff')
+        result = decompressor.decompress(data + b'\x00\x00\xff\xff', self._max_message_size)
+        if decompressor.unconsumed_tail:
+            raise _DecompressTooLargeError()
+        return result
 
 
 class WebSocketProtocol13(WebSocketProtocol):
@@ -801,6 +811,7 @@ class WebSocketProtocol13(WebSocketProtocol):
         self._compressor = _PerMessageDeflateCompressor(
             **self._get_compressor_options(side, agreed_parameters, compression_options))
         self._decompressor = _PerMessageDeflateDecompressor(
+            max_message_size=self.handler.max_message_size,
             **self._get_compressor_options(other_side, agreed_parameters, compression_options))
 
     def _write_frame(self, fin, opcode, data, flags=0):
@@ -920,7 +931,7 @@ class WebSocketProtocol13(WebSocketProtocol):
         new_len = payloadlen
         if self._fragmented_message_buffer is not None:
             new_len += len(self._fragmented_message_buffer)
-        if new_len > (self.handler.max_message_size or 10 * 1024 * 1024):
+        if new_len > self.handler.max_message_size:
             self.close(1009, "message too big")
             self._abort()
             return
@@ -971,7 +982,12 @@ class WebSocketProtocol13(WebSocketProtocol):
             return
 
         if self._frame_compressed:
-            data = self._decompressor.decompress(data)
+            try:
+                data = self._decompressor.decompress(data)
+            except _DecompressTooLargeError:
+                self.close(1009, "message too big after decompression")
+                self._abort()
+                return
 
         if opcode == 0x1:
             # UTF-8 data
@@ -1260,7 +1276,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
 def websocket_connect(url, callback=None, connect_timeout=None,
                       on_message_callback=None, compression_options=None,
                       ping_interval=None, ping_timeout=None,
-                      max_message_size=None, subprotocols=None):
+                      max_message_size=_default_max_message_size, subprotocols=None):
     """Client-side websocket support.
 
     Takes a url and returns a Future whose result is a