From: Ben Darnell Date: Sun, 2 Apr 2017 14:52:43 +0000 (-0400) Subject: websocket: Impose a size limit on incoming websocket messages (#1997) X-Git-Tag: v4.5.0~7 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=104a302b750131d463f7e3b8e0f71dd334e5f904;p=thirdparty%2Ftornado.git websocket: Impose a size limit on incoming websocket messages (#1997) --- diff --git a/docs/releases/v4.5.0.rst b/docs/releases/v4.5.0.rst index 627f714d9..1632d5251 100644 --- a/docs/releases/v4.5.0.rst +++ b/docs/releases/v4.5.0.rst @@ -4,6 +4,12 @@ What's new in Tornado 4.5 Xxx XX, 2017 ------------ +Backwards-compatibility warning +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- The `tornado.websocket` module now imposes a limit on the size of incoming + messages, which defaults to 10MiB. + New module ~~~~~~~~~~ @@ -146,6 +152,8 @@ General changes application settings can now be used to enable a periodic ping of the websocket connection, allowing dropped connections to be detected and closed. +- The new ``websocket_max_message_size`` setting defaults to 10MiB. + The connection will be closed if messages larger than this are received. - Headers set by `.RequestHandler.prepare` or `.RequestHandler.set_default_headers` are now sent as a part of the websocket handshake. diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py index 0875e4b91..7bdca8773 100644 --- a/tornado/test/websocket_test.py +++ b/tornado/test/websocket_test.py @@ -579,3 +579,31 @@ class ClientPeriodicPingTest(WebSocketBaseTestCase): self.assertEqual(response, "got ping") yield self.close(ws) # TODO: test that the connection gets closed if ping responses stop. + + +class MaxMessageSizeTest(WebSocketBaseTestCase): + def get_app(self): + self.close_future = Future() + return Application([ + ('/', EchoHandler, dict(close_future=self.close_future)), + ], websocket_max_message_size=1024) + + @gen_test + def test_large_message(self): + ws = yield self.ws_connect('/') + + # Write a message that is allowed. + msg = 'a' * 1024 + ws.write_message(msg) + resp = yield ws.read_message() + self.assertEqual(resp, msg) + + # Write a message that is too large. + ws.write_message(msg + 'b') + resp = yield ws.read_message() + # A message of None means the other side closed the connection. + self.assertIs(resp, None) + self.assertEqual(ws.close_code, 1009) + self.assertEqual(ws.close_reason, "message too big") + # TODO: Needs tests of messages split over multiple + # continuation frames. diff --git a/tornado/websocket.py b/tornado/websocket.py index 65243572f..0af9e8f8e 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -131,8 +131,12 @@ class WebSocketHandler(tornado.web.RequestHandler): value, a ping will be sent periodically, and the connection will be closed if a response is not received before the ``websocket_ping_timeout``. + Messages larger than the ``websocket_max_message_size`` application setting + (default 10MiB) will not be accepted. + .. versionchanged:: 4.5 - Added ``websocket_ping_interval`` and ``websocket_ping_timeout``. + Added ``websocket_ping_interval``, ``websocket_ping_timeout``, and + ``websocket_max_message_size``. """ def __init__(self, application, request, **kwargs): super(WebSocketHandler, self).__init__(application, request, **kwargs) @@ -213,6 +217,17 @@ class WebSocketHandler(tornado.web.RequestHandler): """ return self.settings.get('websocket_ping_timeout', None) + @property + def max_message_size(self): + """Maximum allowed message size. + + If the remote peer sends a message larger than this, the connection + will be closed. + + Default is 10MiB. + """ + return self.settings.get('websocket_max_message_size', None) + def write_message(self, message, binary=False): """Sends the given message to the client of this Web Socket. @@ -799,8 +814,7 @@ class WebSocketProtocol13(WebSocketProtocol): if self._masked_frame: self.stream.read_bytes(4, self._on_masking_key) else: - self.stream.read_bytes(self._frame_length, - self._on_frame_data) + self._read_frame_data(False) elif payloadlen == 126: self.stream.read_bytes(2, self._on_frame_length_16) elif payloadlen == 127: @@ -808,6 +822,17 @@ class WebSocketProtocol13(WebSocketProtocol): except StreamClosedError: self._abort() + def _read_frame_data(self, masked): + new_len = self._frame_length + if self._fragmented_message_buffer is not None: + new_len += len(self._fragmented_message_buffer) + if new_len > (self.handler.max_message_size or 10*1024*1024): + self.close(1009, "message too big") + return + self.stream.read_bytes( + self._frame_length, + self._on_masked_frame_data if masked else self._on_frame_data) + def _on_frame_length_16(self, data): self._wire_bytes_in += len(data) self._frame_length = struct.unpack("!H", data)[0] @@ -815,7 +840,7 @@ class WebSocketProtocol13(WebSocketProtocol): if self._masked_frame: self.stream.read_bytes(4, self._on_masking_key) else: - self.stream.read_bytes(self._frame_length, self._on_frame_data) + self._read_frame_data(False) except StreamClosedError: self._abort() @@ -826,7 +851,7 @@ class WebSocketProtocol13(WebSocketProtocol): if self._masked_frame: self.stream.read_bytes(4, self._on_masking_key) else: - self.stream.read_bytes(self._frame_length, self._on_frame_data) + self._read_frame_data(False) except StreamClosedError: self._abort() @@ -834,8 +859,7 @@ class WebSocketProtocol13(WebSocketProtocol): self._wire_bytes_in += len(data) self._frame_mask = data try: - self.stream.read_bytes(self._frame_length, - self._on_masked_frame_data) + self._read_frame_data(True) except StreamClosedError: self._abort() @@ -1007,7 +1031,8 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): `websocket_connect` function instead. """ def __init__(self, io_loop, request, on_message_callback=None, - compression_options=None, ping_interval=None, ping_timeout=None): + compression_options=None, ping_interval=None, ping_timeout=None, + max_message_size=None): self.compression_options = compression_options self.connect_future = TracebackFuture() self.protocol = None @@ -1018,6 +1043,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): self.close_code = self.close_reason = None self.ping_interval = ping_interval self.ping_timeout = ping_timeout + self.max_message_size = max_message_size scheme, sep, rest = request.url.partition(':') scheme = {'ws': 'http', 'wss': 'https'}[scheme] @@ -1145,7 +1171,8 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None, on_message_callback=None, compression_options=None, - ping_interval=None, ping_timeout=None): + ping_interval=None, ping_timeout=None, + max_message_size=None): """Client-side websocket support. Takes a url and returns a Future whose result is a @@ -1174,6 +1201,10 @@ def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None, .. versionchanged:: 4.1 Added ``compression_options`` and ``on_message_callback``. The ``io_loop`` argument is deprecated. + + .. versionchanged:: 4.5 + Added the ``ping_interval``, ``ping_timeout``, and ``max_message_size`` + arguments, which have the same meaning as in `WebSocketHandler`. """ if io_loop is None: io_loop = IOLoop.current() @@ -1191,7 +1222,8 @@ def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None, on_message_callback=on_message_callback, compression_options=compression_options, ping_interval=ping_interval, - ping_timeout=ping_timeout) + ping_timeout=ping_timeout, + max_message_size=max_message_size) if callback is not None: io_loop.add_future(conn.connect_future, callback) return conn.connect_future