def get_app(self):
self.close_future = Future()
+
+ class LimitedHandler(TestWebSocketHandler):
+ @property
+ def max_message_size(self):
+ return 1024
+
+ def on_message(self, message):
+ self.write_message(str(len(message)))
+
return Application([
('/echo', EchoHandler, dict(
close_future=self.close_future,
compression_options=self.get_server_compression_options())),
+ ('/limited', LimitedHandler, dict(
+ close_future=self.close_future,
+ compression_options=self.get_server_compression_options())),
])
def get_server_compression_options(self):
ws.protocol._wire_bytes_out)
yield self.close(ws)
+ @gen_test
+ def test_size_limit(self):
+ ws = yield self.ws_connect(
+ '/limited',
+ compression_options=self.get_client_compression_options())
+ # Small messages pass through.
+ ws.write_message('a' * 128)
+ response = yield ws.read_message()
+ self.assertEqual(response, '128')
+ # This message is too big after decompression, but it compresses
+ # down to a size that will pass the initial checks.
+ ws.write_message('a' * 2048)
+ response = yield ws.read_message()
+ self.assertIsNone(response)
+ yield self.close(ws)
+
class UncompressedTestMixin(CompressionTestMixin):
"""Specialization of CompressionTestMixin when we expect no compression."""
else:
from urlparse import urlparse # py3
+_default_max_message_size = 10 * 1024 * 1024
+
class WebSocketError(Exception):
pass
pass
+class _DecompressTooLargeError(Exception):
+ pass
+
+
class WebSocketHandler(tornado.web.RequestHandler):
"""Subclass this class to create a basic WebSocket handler.
Default is 10MiB.
"""
- return self.settings.get('websocket_max_message_size', None)
+ return self.settings.get('websocket_max_message_size', _default_max_message_size)
def write_message(self, message, binary=False):
"""Sends the given message to the client of this Web Socket.
class _PerMessageDeflateDecompressor(object):
- def __init__(self, persistent, max_wbits, compression_options=None):
+ def __init__(self, persistent, max_wbits, max_message_size, compression_options=None):
+ self._max_message_size = max_message_size
if max_wbits is None:
max_wbits = zlib.MAX_WBITS
if not (8 <= max_wbits <= zlib.MAX_WBITS):
def decompress(self, data):
decompressor = self._decompressor or self._create_decompressor()
- return decompressor.decompress(data + b'\x00\x00\xff\xff')
+ result = decompressor.decompress(data + b'\x00\x00\xff\xff', self._max_message_size)
+ if decompressor.unconsumed_tail:
+ raise _DecompressTooLargeError()
+ return result
class WebSocketProtocol13(WebSocketProtocol):
self._compressor = _PerMessageDeflateCompressor(
**self._get_compressor_options(side, agreed_parameters, compression_options))
self._decompressor = _PerMessageDeflateDecompressor(
+ max_message_size=self.handler.max_message_size,
**self._get_compressor_options(other_side, agreed_parameters, compression_options))
def _write_frame(self, fin, opcode, data, flags=0):
new_len = payloadlen
if self._fragmented_message_buffer is not None:
new_len += len(self._fragmented_message_buffer)
- if new_len > (self.handler.max_message_size or 10 * 1024 * 1024):
+ if new_len > self.handler.max_message_size:
self.close(1009, "message too big")
self._abort()
return
return
if self._frame_compressed:
- data = self._decompressor.decompress(data)
+ try:
+ data = self._decompressor.decompress(data)
+ except _DecompressTooLargeError:
+ self.close(1009, "message too big after decompression")
+ self._abort()
+ return
if opcode == 0x1:
# UTF-8 data
def websocket_connect(url, callback=None, connect_timeout=None,
on_message_callback=None, compression_options=None,
ping_interval=None, ping_timeout=None,
- max_message_size=None, subprotocols=None):
+ max_message_size=_default_max_message_size, subprotocols=None):
"""Client-side websocket support.
Takes a url and returns a Future whose result is a