]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
websocket: Impose a size limit on incoming websocket messages (#1997)
authorBen Darnell <ben@bendarnell.com>
Sun, 2 Apr 2017 14:52:43 +0000 (10:52 -0400)
committerGitHub <noreply@github.com>
Sun, 2 Apr 2017 14:52:43 +0000 (10:52 -0400)
docs/releases/v4.5.0.rst
tornado/test/websocket_test.py
tornado/websocket.py

index 627f714d90306985c695c094faab6df4ab5a5a56..1632d5251f946f4d6cb9142770bfe80c9eb4645e 100644 (file)
@@ -4,6 +4,12 @@ What's new in Tornado 4.5
 Xxx XX, 2017
 ------------
 
+Backwards-compatibility warning
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+- The `tornado.websocket` module now imposes a limit on the size of incoming
+  messages, which defaults to 10MiB.
+
 New module
 ~~~~~~~~~~
 
@@ -146,6 +152,8 @@ General changes
   application settings can now be used to enable a periodic ping of
   the websocket connection, allowing dropped connections to be
   detected and closed.
+- The new ``websocket_max_message_size`` setting defaults to 10MiB.
+  The connection will be closed if messages larger than this are received.
 - Headers set by `.RequestHandler.prepare` or
   `.RequestHandler.set_default_headers` are now sent as a part of the
   websocket handshake.
index 0875e4b91e27fc4a0bc0968023216043905c5c62..7bdca8773cd92b8bfc272c6cc99a9d277cb3d5d4 100644 (file)
@@ -579,3 +579,31 @@ class ClientPeriodicPingTest(WebSocketBaseTestCase):
             self.assertEqual(response, "got ping")
         yield self.close(ws)
         # TODO: test that the connection gets closed if ping responses stop.
+
+
+class MaxMessageSizeTest(WebSocketBaseTestCase):
+    def get_app(self):
+        self.close_future = Future()
+        return Application([
+            ('/', EchoHandler, dict(close_future=self.close_future)),
+        ], websocket_max_message_size=1024)
+
+    @gen_test
+    def test_large_message(self):
+        ws = yield self.ws_connect('/')
+
+        # Write a message that is allowed.
+        msg = 'a' * 1024
+        ws.write_message(msg)
+        resp = yield ws.read_message()
+        self.assertEqual(resp, msg)
+
+        # Write a message that is too large.
+        ws.write_message(msg + 'b')
+        resp = yield ws.read_message()
+        # A message of None means the other side closed the connection.
+        self.assertIs(resp, None)
+        self.assertEqual(ws.close_code, 1009)
+        self.assertEqual(ws.close_reason, "message too big")
+        # TODO: Needs tests of messages split over multiple
+        # continuation frames.
index 65243572fc7abc26ddac84d1166b995a3de95854..0af9e8f8e4ed137ed15f7b7d23f6412ed5e03f05 100644 (file)
@@ -131,8 +131,12 @@ class WebSocketHandler(tornado.web.RequestHandler):
     value, a ping will be sent periodically, and the connection will be
     closed if a response is not received before the ``websocket_ping_timeout``.
 
+    Messages larger than the ``websocket_max_message_size`` application setting
+    (default 10MiB) will not be accepted.
+
     .. versionchanged:: 4.5
-       Added ``websocket_ping_interval`` and ``websocket_ping_timeout``.
+       Added ``websocket_ping_interval``, ``websocket_ping_timeout``, and
+       ``websocket_max_message_size``.
     """
     def __init__(self, application, request, **kwargs):
         super(WebSocketHandler, self).__init__(application, request, **kwargs)
@@ -213,6 +217,17 @@ class WebSocketHandler(tornado.web.RequestHandler):
         """
         return self.settings.get('websocket_ping_timeout', None)
 
+    @property
+    def max_message_size(self):
+        """Maximum allowed message size.
+
+        If the remote peer sends a message larger than this, the connection
+        will be closed.
+
+        Default is 10MiB.
+        """
+        return self.settings.get('websocket_max_message_size', None)
+
     def write_message(self, message, binary=False):
         """Sends the given message to the client of this Web Socket.
 
@@ -799,8 +814,7 @@ class WebSocketProtocol13(WebSocketProtocol):
                 if self._masked_frame:
                     self.stream.read_bytes(4, self._on_masking_key)
                 else:
-                    self.stream.read_bytes(self._frame_length,
-                                           self._on_frame_data)
+                    self._read_frame_data(False)
             elif payloadlen == 126:
                 self.stream.read_bytes(2, self._on_frame_length_16)
             elif payloadlen == 127:
@@ -808,6 +822,17 @@ class WebSocketProtocol13(WebSocketProtocol):
         except StreamClosedError:
             self._abort()
 
+    def _read_frame_data(self, masked):
+        new_len = self._frame_length
+        if self._fragmented_message_buffer is not None:
+            new_len += len(self._fragmented_message_buffer)
+        if new_len > (self.handler.max_message_size or 10*1024*1024):
+            self.close(1009, "message too big")
+            return
+        self.stream.read_bytes(
+            self._frame_length,
+            self._on_masked_frame_data if masked else self._on_frame_data)
+
     def _on_frame_length_16(self, data):
         self._wire_bytes_in += len(data)
         self._frame_length = struct.unpack("!H", data)[0]
@@ -815,7 +840,7 @@ class WebSocketProtocol13(WebSocketProtocol):
             if self._masked_frame:
                 self.stream.read_bytes(4, self._on_masking_key)
             else:
-                self.stream.read_bytes(self._frame_length, self._on_frame_data)
+                self._read_frame_data(False)
         except StreamClosedError:
             self._abort()
 
@@ -826,7 +851,7 @@ class WebSocketProtocol13(WebSocketProtocol):
             if self._masked_frame:
                 self.stream.read_bytes(4, self._on_masking_key)
             else:
-                self.stream.read_bytes(self._frame_length, self._on_frame_data)
+                self._read_frame_data(False)
         except StreamClosedError:
             self._abort()
 
@@ -834,8 +859,7 @@ class WebSocketProtocol13(WebSocketProtocol):
         self._wire_bytes_in += len(data)
         self._frame_mask = data
         try:
-            self.stream.read_bytes(self._frame_length,
-                                   self._on_masked_frame_data)
+            self._read_frame_data(True)
         except StreamClosedError:
             self._abort()
 
@@ -1007,7 +1031,8 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
     `websocket_connect` function instead.
     """
     def __init__(self, io_loop, request, on_message_callback=None,
-                 compression_options=None, ping_interval=None, ping_timeout=None):
+                 compression_options=None, ping_interval=None, ping_timeout=None,
+                 max_message_size=None):
         self.compression_options = compression_options
         self.connect_future = TracebackFuture()
         self.protocol = None
@@ -1018,6 +1043,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
         self.close_code = self.close_reason = None
         self.ping_interval = ping_interval
         self.ping_timeout = ping_timeout
+        self.max_message_size = max_message_size
 
         scheme, sep, rest = request.url.partition(':')
         scheme = {'ws': 'http', 'wss': 'https'}[scheme]
@@ -1145,7 +1171,8 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
 
 def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None,
                       on_message_callback=None, compression_options=None,
-                      ping_interval=None, ping_timeout=None):
+                      ping_interval=None, ping_timeout=None,
+                      max_message_size=None):
     """Client-side websocket support.
 
     Takes a url and returns a Future whose result is a
@@ -1174,6 +1201,10 @@ def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None,
     .. versionchanged:: 4.1
        Added ``compression_options`` and ``on_message_callback``.
        The ``io_loop`` argument is deprecated.
+
+    .. versionchanged:: 4.5
+       Added the ``ping_interval``, ``ping_timeout``, and ``max_message_size``
+       arguments, which have the same meaning as in `WebSocketHandler`.
     """
     if io_loop is None:
         io_loop = IOLoop.current()
@@ -1191,7 +1222,8 @@ def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None,
                                      on_message_callback=on_message_callback,
                                      compression_options=compression_options,
                                      ping_interval=ping_interval,
-                                     ping_timeout=ping_timeout)
+                                     ping_timeout=ping_timeout,
+                                     max_message_size=max_message_size)
     if callback is not None:
         io_loop.add_future(conn.connect_future, callback)
     return conn.connect_future