From: Ben Darnell Date: Sun, 27 Jul 2014 05:26:11 +0000 (-0400) Subject: Implement permessage-deflate websocket extension. X-Git-Tag: v4.1.0b1~120 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=576c1c43bc73006f80f661a9238b718ef4b51901;p=thirdparty%2Ftornado.git Implement permessage-deflate websocket extension. Parameters to the extension are not fully supported (the client side supports client_no_context_takeover which is mandatory to implement, but the server rejects any parameters offered by the client, and neither side supports setting wbits). Closes #308. Closes #668. --- diff --git a/demos/websocket/chatdemo.py b/demos/websocket/chatdemo.py index c1067e9e9..ad5b2c2e1 100755 --- a/demos/websocket/chatdemo.py +++ b/demos/websocket/chatdemo.py @@ -56,6 +56,10 @@ class ChatSocketHandler(tornado.websocket.WebSocketHandler): cache = [] cache_size = 200 + def get_compression_options(self): + # Non-None enables compression with default options. + return {} + def open(self): ChatSocketHandler.waiters.add(self) diff --git a/maint/test/websocket/client.py b/maint/test/websocket/client.py index 91bcd2845..9df1a82a4 100644 --- a/maint/test/websocket/client.py +++ b/maint/test/websocket/client.py @@ -22,7 +22,7 @@ def run_tests(): for i in range(1, num_tests + 1): logging.info('running test case %d', i) url = options.url + '/runCase?case=%d&agent=%s' % (i, options.name) - test_ws = yield websocket_connect(url, None) + test_ws = yield websocket_connect(url, None, compression_options={}) while True: message = yield test_ws.read_message() if message is None: diff --git a/maint/test/websocket/fuzzingclient.json b/maint/test/websocket/fuzzingclient.json index 759963f44..9e07e830d 100644 --- a/maint/test/websocket/fuzzingclient.json +++ b/maint/test/websocket/fuzzingclient.json @@ -14,6 +14,6 @@ ], "cases": ["*"], - "exclude-cases": ["9.*"], + "exclude-cases": ["9.*", "12.*.1","12.2.*", "12.3.*", "12.4.*", "12.5.*", "13.*.1"], "exclude-agent-cases": {} } diff --git a/maint/test/websocket/fuzzingserver.json b/maint/test/websocket/fuzzingserver.json index 8fc4ab60f..28d541c29 100644 --- a/maint/test/websocket/fuzzingserver.json +++ b/maint/test/websocket/fuzzingserver.json @@ -7,6 +7,6 @@ "webport": 8080, "cases": ["*"], - "exclude-cases": ["9.*"], + "exclude-cases": ["9.*", "12.*.1","12.2.*", "12.3.*", "12.4.*", "12.5.*", "13.*.1"], "exclude-agent-cases": {} } diff --git a/maint/test/websocket/server.py b/maint/test/websocket/server.py index b44056cd6..305bd7468 100644 --- a/maint/test/websocket/server.py +++ b/maint/test/websocket/server.py @@ -12,6 +12,9 @@ class EchoHandler(WebSocketHandler): def on_message(self, message): self.write_message(message, binary=isinstance(message, bytes_type)) + def get_compression_options(self): + return {} + if __name__ == '__main__': parse_command_line() app = Application([ diff --git a/tornado/httputil.py b/tornado/httputil.py index 1c5753822..efe9f653f 100644 --- a/tornado/httputil.py +++ b/tornado/httputil.py @@ -803,6 +803,8 @@ def parse_response_start_line(line): # _parseparam and _parse_header are copied and modified from python2.7's cgi.py # The original 2.7 version of this code did not correctly support some # combinations of semicolons and double quotes. +# It has also been modified to support valueless parameters as seen in +# websocket extension negotiations. def _parseparam(s): @@ -836,6 +838,8 @@ def _parse_header(line): value = value[1:-1] value = value.replace('\\\\', '\\').replace('\\"', '"') pdict[name] = value + else: + pdict[p] = None return key, pdict diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py index a1f85cf5c..f8cd163bd 100644 --- a/tornado/test/websocket_test.py +++ b/tornado/test/websocket_test.py @@ -3,6 +3,7 @@ from __future__ import absolute_import, division, print_function, with_statement import traceback from tornado.concurrent import Future +from tornado import gen from tornado.httpclient import HTTPError, HTTPRequest from tornado.log import gen_log, app_log from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog @@ -34,8 +35,12 @@ class TestWebSocketHandler(WebSocketHandler): This allows for deterministic cleanup of the associated socket. """ - def initialize(self, close_future): + def initialize(self, close_future, compression_options=None): self.close_future = close_future + self.compression_options = compression_options + + def get_compression_options(self): + return self.compression_options def on_close(self): self.close_future.set_result((self.close_code, self.close_reason)) @@ -73,7 +78,25 @@ class CloseReasonHandler(TestWebSocketHandler): self.close(1001, "goodbye") -class WebSocketTest(AsyncHTTPTestCase): +class WebSocketBaseTestCase(AsyncHTTPTestCase): + @gen.coroutine + def ws_connect(self, path, compression_options=None): + ws = yield websocket_connect( + 'ws://localhost:%d%s' % (self.get_http_port(), path), + compression_options=compression_options) + raise gen.Return(ws) + + @gen.coroutine + def close(self, ws): + """Close a websocket connection and wait for the server side. + + If we don't wait here, there are sometimes leak warnings in the + tests. + """ + ws.close() + yield self.close_future + +class WebSocketTest(WebSocketBaseTestCase): def get_app(self): self.close_future = Future() return Application([ @@ -93,14 +116,11 @@ class WebSocketTest(AsyncHTTPTestCase): @gen_test def test_websocket_gen(self): - ws = yield websocket_connect( - 'ws://localhost:%d/echo' % self.get_http_port(), - io_loop=self.io_loop) + ws = yield self.ws_connect('/echo') ws.write_message('hello') response = yield ws.read_message() self.assertEqual(response, 'hello') - ws.close() - yield self.close_future + yield self.close(ws) def test_websocket_callbacks(self): websocket_connect( @@ -117,49 +137,39 @@ class WebSocketTest(AsyncHTTPTestCase): @gen_test def test_binary_message(self): - ws = yield websocket_connect( - 'ws://localhost:%d/echo' % self.get_http_port()) + ws = yield self.ws_connect('/echo') ws.write_message(b'hello \xe9', binary=True) response = yield ws.read_message() self.assertEqual(response, b'hello \xe9') - ws.close() - yield self.close_future + yield self.close(ws) @gen_test def test_unicode_message(self): - ws = yield websocket_connect( - 'ws://localhost:%d/echo' % self.get_http_port()) + ws = yield self.ws_connect('/echo') ws.write_message(u('hello \u00e9')) response = yield ws.read_message() self.assertEqual(response, u('hello \u00e9')) - ws.close() - yield self.close_future + yield self.close(ws) @gen_test def test_error_in_on_message(self): - ws = yield websocket_connect( - 'ws://localhost:%d/error_in_on_message' % self.get_http_port()) + ws = yield self.ws_connect('/error_in_on_message') ws.write_message('hello') with ExpectLog(app_log, "Uncaught exception"): response = yield ws.read_message() self.assertIs(response, None) - ws.close() - yield self.close_future + yield self.close(ws) @gen_test def test_websocket_http_fail(self): with self.assertRaises(HTTPError) as cm: - yield websocket_connect( - 'ws://localhost:%d/notfound' % self.get_http_port(), - io_loop=self.io_loop) + yield self.ws_connect('/notfound') self.assertEqual(cm.exception.code, 404) @gen_test def test_websocket_http_success(self): with self.assertRaises(WebSocketError): - yield websocket_connect( - 'ws://localhost:%d/non_ws' % self.get_http_port(), - io_loop=self.io_loop) + yield self.ws_connect('/non_ws') @gen_test def test_websocket_network_fail(self): @@ -178,6 +188,7 @@ class WebSocketTest(AsyncHTTPTestCase): 'ws://localhost:%d/echo' % self.get_http_port()) ws.write_message('hello') ws.write_message('world') + # Close the underlying stream. ws.stream.close() yield self.close_future @@ -189,13 +200,11 @@ class WebSocketTest(AsyncHTTPTestCase): headers={'X-Test': 'hello'})) response = yield ws.read_message() self.assertEqual(response, 'hello') - ws.close() - yield self.close_future + yield self.close(ws) @gen_test def test_server_close_reason(self): - ws = yield websocket_connect( - 'ws://localhost:%d/close_reason' % self.get_http_port()) + ws = yield self.ws_connect('/close_reason') msg = yield ws.read_message() # A message of None means the other side closed the connection. self.assertIs(msg, None) @@ -204,8 +213,7 @@ class WebSocketTest(AsyncHTTPTestCase): @gen_test def test_client_close_reason(self): - ws = yield websocket_connect( - 'ws://localhost:%d/echo' % self.get_http_port()) + ws = yield self.ws_connect('/echo') ws.close(1001, 'goodbye') code, reason = yield self.close_future self.assertEqual(code, 1001) @@ -223,8 +231,7 @@ class WebSocketTest(AsyncHTTPTestCase): ws.write_message('hello') response = yield ws.read_message() self.assertEqual(response, 'hello') - ws.close() - yield self.close_future + yield self.close(ws) @gen_test def test_check_origin_valid_with_path(self): @@ -238,8 +245,7 @@ class WebSocketTest(AsyncHTTPTestCase): ws.write_message('hello') response = yield ws.read_message() self.assertEqual(response, 'hello') - ws.close() - yield self.close_future + yield self.close(ws) @gen_test def test_check_origin_invalid_partial_url(self): @@ -284,6 +290,78 @@ class WebSocketTest(AsyncHTTPTestCase): self.assertEqual(cm.exception.code, 403) +class CompressionTestMixin(object): + MESSAGE = 'Hello world. Testing 123 123' + + def get_app(self): + self.close_future = Future() + return Application([ + ('/echo', EchoHandler, dict( + close_future=self.close_future, + compression_options=self.get_server_compression_options())), + ]) + + def get_server_compression_options(self): + return None + + def get_client_compression_options(self): + return None + + @gen_test + def test_message_sizes(self): + ws = yield self.ws_connect( + '/echo', + compression_options=self.get_client_compression_options()) + # Send the same message three times so we can measure the + # effect of the context_takeover options. + for i in range(3): + ws.write_message(self.MESSAGE) + response = yield ws.read_message() + self.assertEqual(response, self.MESSAGE) + self.assertEqual(ws.protocol._message_bytes_out, len(self.MESSAGE) * 3) + self.assertEqual(ws.protocol._message_bytes_in, len(self.MESSAGE) * 3) + self.verify_wire_bytes(ws.protocol._wire_bytes_in, + ws.protocol._wire_bytes_out) + yield self.close(ws) + + +class UncompressedTestMixin(CompressionTestMixin): + """Specialization of CompressionTestMixin when we expect no compression.""" + def verify_wire_bytes(self, bytes_in, bytes_out): + # Bytes out includes the 4-byte mask key per message. + self.assertEqual(bytes_out, 3 * (len(self.MESSAGE) + 6)) + self.assertEqual(bytes_in, 3 * (len(self.MESSAGE) + 2)) + + +class NoCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase): + pass + + +# If only one side tries to compress, the extension is not negotiated. +class ServerOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase): + def get_server_compression_options(self): + return {} + + +class ClientOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase): + def get_client_compression_options(self): + return {} + + +class DefaultCompressionTest(CompressionTestMixin, WebSocketBaseTestCase): + def get_server_compression_options(self): + return {} + + def get_client_compression_options(self): + return {} + + def verify_wire_bytes(self, bytes_in, bytes_out): + self.assertLess(bytes_out, 3 * (len(self.MESSAGE) + 6)) + self.assertLess(bytes_in, 3 * (len(self.MESSAGE) + 2)) + # Bytes out includes the 4 bytes mask key per message. + self.assertEqual(bytes_out, bytes_in + 12) + + class MaskFunctionMixin(object): # Subclasses should define self.mask(mask, data) def test_mask(self): diff --git a/tornado/websocket.py b/tornado/websocket.py index a77e02c49..014eafe28 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -26,6 +26,7 @@ import os import struct import tornado.escape import tornado.web +import zlib from tornado.concurrent import TracebackFuture from tornado.escape import utf8, native_str, to_unicode @@ -171,7 +172,8 @@ class WebSocketHandler(tornado.web.RequestHandler): self.stream.set_close_callback(self.on_connection_close) if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"): - self.ws_connection = WebSocketProtocol13(self) + self.ws_connection = WebSocketProtocol13( + self, compression_options=self.get_compression_options()) self.ws_connection.accept_connection() else: self.stream.write(tornado.escape.utf8( @@ -213,6 +215,9 @@ class WebSocketHandler(tornado.web.RequestHandler): """ return None + def get_compression_options(self): + return None + def open(self): """Invoked when a new WebSocket is opened. @@ -383,13 +388,55 @@ class WebSocketProtocol(object): self.close() # let the subclass cleanup +class _PerMessageDeflateCompressor(object): + def __init__(self, persistent): + if persistent: + self._compressor = self._create_compressor() + else: + self._compressor = None + + def _create_compressor(self): + return zlib.compressobj(-1, zlib.DEFLATED, -zlib.MAX_WBITS) + + def compress(self, data): + compressor = self._compressor or self._create_compressor() + data = (compressor.compress(data) + + compressor.flush(zlib.Z_SYNC_FLUSH)) + assert data.endswith(b'\x00\x00\xff\xff') + return data[:-4] + + +class _PerMessageDeflateDecompressor(object): + def __init__(self, persistent): + if persistent: + self._decompressor = self._create_decompressor() + else: + self._decompressor = None + + def _create_decompressor(self): + return zlib.decompressobj(-zlib.MAX_WBITS) + + def decompress(self, data): + decompressor = self._decompressor or self._create_decompressor() + return decompressor.decompress(data + b'\x00\x00\xff\xff') + + class WebSocketProtocol13(WebSocketProtocol): """Implementation of the WebSocket protocol from RFC 6455. This class supports versions 7 and 8 of the protocol in addition to the final version 13. """ - def __init__(self, handler, mask_outgoing=False): + # Bit masks for the first byte of a frame. + FIN = 0x80 + RSV1 = 0x40 + RSV2 = 0x20 + RSV3 = 0x10 + RSV_MASK = RSV1 | RSV2 | RSV3 + OPCODE_MASK = 0x0f + + def __init__(self, handler, mask_outgoing=False, + compression_options=None): WebSocketProtocol.__init__(self, handler) self.mask_outgoing = mask_outgoing self._final_frame = False @@ -400,6 +447,19 @@ class WebSocketProtocol13(WebSocketProtocol): self._fragmented_message_buffer = None self._fragmented_message_opcode = None self._waiting = None + self._compression_options = compression_options + self._decompressor = None + self._compressor = None + self._frame_compressed = None + # The total uncompressed size of all messages received or sent. + # Unicode messages are encoded to utf8. + # Only for testing; subject to change. + self._message_bytes_in = 0 + self._message_bytes_out = 0 + # The total size of all packets received or sent. Includes + # the effect of compression, frame overhead, and control frames. + self._wire_bytes_in = 0 + self._wire_bytes_out = 0 def accept_connection(self): try: @@ -444,24 +504,71 @@ class WebSocketProtocol13(WebSocketProtocol): assert selected in subprotocols subprotocol_header = "Sec-WebSocket-Protocol: %s\r\n" % selected + extension_header = '' + extensions = self._parse_extensions_header(self.request.headers) + for ext in extensions: + if (ext[0] == 'permessage-deflate' and + self._compression_options is not None): + # TODO: negotiate parameters. For now, only + # allow the base extension. + extension_header = ( + 'Sec-WebSocket-Extensions: permessage-deflate\r\n') + self._create_compressors('server', {}) + break + self.stream.write(tornado.escape.utf8( "HTTP/1.1 101 Switching Protocols\r\n" "Upgrade: websocket\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Accept: %s\r\n" - "%s" - "\r\n" % (self._challenge_response(), subprotocol_header))) + "%s%s" + "\r\n" % (self._challenge_response(), + subprotocol_header, extension_header))) self._run_callback(self.handler.open, *self.handler.open_args, **self.handler.open_kwargs) self._receive_frame() - def _write_frame(self, fin, opcode, data): + def _parse_extensions_header(self, headers): + extensions = headers.get("Sec-WebSocket-Extensions", '') + if extensions: + return [httputil._parse_header(e.strip()) + for e in extensions.split(',')] + return [] + + def _process_server_headers(self, key, headers): + """Process the headers sent by the server to this client connection. + + 'key' is the websocket handshake challenge/response key. + """ + assert headers['Upgrade'].lower() == 'websocket' + assert headers['Connection'].lower() == 'upgrade' + accept = self.compute_accept_value(key) + assert headers['Sec-Websocket-Accept'] == accept + + extensions = self._parse_extensions_header(headers) + for ext in extensions: + if (ext[0] == 'permessage-deflate' and + self._compression_options is not None): + self._create_compressors('client', ext[1]) + else: + raise ValueError("unsupported extension %r", ext) + + def _create_compressors(self, side, agreed_parameters): + # TODO: support the max_wbits parameters. + other_side = 'client' if (side == 'server') else 'server' + self._compressor = _PerMessageDeflateCompressor( + persistent=(side + '_no_context_takeover') not in agreed_parameters) + self._decompressor = _PerMessageDeflateDecompressor( + persistent=((other_side + '_no_context_takeover') + not in agreed_parameters)) + + def _write_frame(self, fin, opcode, data, flags=0): if fin: - finbit = 0x80 + finbit = self.FIN else: finbit = 0 - frame = struct.pack("B", finbit | opcode) + frame = struct.pack("B", finbit | opcode | flags) l = len(data) if self.mask_outgoing: mask_bit = 0x80 @@ -477,6 +584,7 @@ class WebSocketProtocol13(WebSocketProtocol): mask = os.urandom(4) data = mask + _websocket_mask(mask, data) frame += data + self._wire_bytes_out += len(frame) self.stream.write(frame) def write_message(self, message, binary=False): @@ -487,8 +595,13 @@ class WebSocketProtocol13(WebSocketProtocol): opcode = 0x1 message = tornado.escape.utf8(message) assert isinstance(message, bytes_type) + self._message_bytes_out += len(message) + flags = 0 + if self._compressor: + message = self._compressor.compress(message) + flags |= self.RSV1 try: - self._write_frame(True, opcode, message) + self._write_frame(True, opcode, message, flags=flags) except StreamClosedError: self._abort() @@ -504,11 +617,15 @@ class WebSocketProtocol13(WebSocketProtocol): self._abort() def _on_frame_start(self, data): + self._wire_bytes_in += len(data) header, payloadlen = struct.unpack("BB", data) - self._final_frame = header & 0x80 - reserved_bits = header & 0x70 - self._frame_opcode = header & 0xf + self._final_frame = header & self.FIN + reserved_bits = header & self.RSV_MASK + self._frame_opcode = header & self.OPCODE_MASK self._frame_opcode_is_control = self._frame_opcode & 0x8 + if self._decompressor is not None: + self._frame_compressed = bool(reserved_bits & self.RSV1) + reserved_bits &= ~self.RSV1 if reserved_bits: # client is using as-yet-undefined extensions; abort self._abort() @@ -534,6 +651,7 @@ class WebSocketProtocol13(WebSocketProtocol): self._abort() def _on_frame_length_16(self, data): + self._wire_bytes_in += len(data) self._frame_length = struct.unpack("!H", data)[0] try: if self._masked_frame: @@ -544,6 +662,7 @@ class WebSocketProtocol13(WebSocketProtocol): self._abort() def _on_frame_length_64(self, data): + self._wire_bytes_in += len(data) self._frame_length = struct.unpack("!Q", data)[0] try: if self._masked_frame: @@ -554,6 +673,7 @@ class WebSocketProtocol13(WebSocketProtocol): self._abort() def _on_masking_key(self, data): + self._wire_bytes_in += len(data) self._frame_mask = data try: self.stream.read_bytes(self._frame_length, self._on_masked_frame_data) @@ -561,9 +681,11 @@ class WebSocketProtocol13(WebSocketProtocol): self._abort() def _on_masked_frame_data(self, data): + # Don't touch _wire_bytes_in; we'll do it in _on_frame_data. self._on_frame_data(_websocket_mask(self._frame_mask, data)) def _on_frame_data(self, data): + self._wire_bytes_in += len(data) if self._frame_opcode_is_control: # control frames may be interleaved with a series of fragmented # data frames, so control frames must not interact with @@ -604,8 +726,12 @@ class WebSocketProtocol13(WebSocketProtocol): if self.client_terminated: return + if self._frame_compressed: + data = self._decompressor.decompress(data) + if opcode == 0x1: # UTF-8 data + self._message_bytes_in += len(data) try: decoded = data.decode("utf-8") except UnicodeDecodeError: @@ -614,6 +740,7 @@ class WebSocketProtocol13(WebSocketProtocol): self._run_callback(self.handler.on_message, decoded) elif opcode == 0x2: # Binary data + self._message_bytes_in += len(data) self._run_callback(self.handler.on_message, data) elif opcode == 0x8: # Close @@ -664,7 +791,8 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): This class should not be instantiated directly; use the `websocket_connect` function instead. """ - def __init__(self, io_loop, request): + def __init__(self, io_loop, request, compression_options=None): + self.compression_options = compression_options self.connect_future = TracebackFuture() self.read_future = None self.read_queue = collections.deque() @@ -679,6 +807,9 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): 'Sec-WebSocket-Key': self.key, 'Sec-WebSocket-Version': '13', }) + if self.compression_options is not None: + # TODO: offer parameters for the deflate extension. + request.headers['Sec-WebSocket-Extensions'] = 'permessage-deflate' self.tcp_client = TCPClient(io_loop=io_loop) super(WebSocketClientConnection, self).__init__( @@ -720,12 +851,10 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): start_line, headers) self.headers = headers - assert self.headers['Upgrade'].lower() == 'websocket' - assert self.headers['Connection'].lower() == 'upgrade' - accept = WebSocketProtocol13.compute_accept_value(self.key) - assert self.headers['Sec-Websocket-Accept'] == accept - - self.protocol = WebSocketProtocol13(self, mask_outgoing=True) + self.protocol = WebSocketProtocol13( + self, mask_outgoing=True, + compression_options=self.compression_options) + self.protocol._process_server_headers(self.key, self.headers) self.protocol._receive_frame() if self._timeout is not None: @@ -770,7 +899,8 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): pass -def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None): +def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None, + compression_options=None): """Client-side websocket support. Takes a url and returns a Future whose result is a @@ -791,7 +921,7 @@ def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None): request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout) request = httpclient._RequestProxy( request, httpclient.HTTPRequest._DEFAULTS) - conn = WebSocketClientConnection(io_loop, request) + conn = WebSocketClientConnection(io_loop, request, compression_options) if callback is not None: io_loop.add_future(conn.connect_future, callback) return conn.connect_future