]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Implement permessage-deflate websocket extension.
authorBen Darnell <ben@bendarnell.com>
Sun, 27 Jul 2014 05:26:11 +0000 (01:26 -0400)
committerBen Darnell <ben@bendarnell.com>
Sun, 27 Jul 2014 05:32:09 +0000 (01:32 -0400)
Parameters to the extension are not fully supported (the client side
supports client_no_context_takeover which is mandatory to implement,
but the server rejects any parameters offered by the client, and neither
side supports setting wbits).

Closes #308.
Closes #668.

demos/websocket/chatdemo.py
maint/test/websocket/client.py
maint/test/websocket/fuzzingclient.json
maint/test/websocket/fuzzingserver.json
maint/test/websocket/server.py
tornado/httputil.py
tornado/test/websocket_test.py
tornado/websocket.py

index c1067e9e97572f6d7a2b84bfb9a1e704db8477d1..ad5b2c2e1a288296afb39efeafad6c01e5bf4f87 100755 (executable)
@@ -56,6 +56,10 @@ class ChatSocketHandler(tornado.websocket.WebSocketHandler):
     cache = []
     cache_size = 200
 
+    def get_compression_options(self):
+        # Non-None enables compression with default options.
+        return {}
+
     def open(self):
         ChatSocketHandler.waiters.add(self)
 
index 91bcd28459db3ea8709fc01771e95be5c464011b..9df1a82a449e516d5bf7d6b2c6841bcc78ce295a 100644 (file)
@@ -22,7 +22,7 @@ def run_tests():
     for i in range(1, num_tests + 1):
         logging.info('running test case %d', i)
         url = options.url + '/runCase?case=%d&agent=%s' % (i, options.name)
-        test_ws = yield websocket_connect(url, None)
+        test_ws = yield websocket_connect(url, None, compression_options={})
         while True:
             message = yield test_ws.read_message()
             if message is None:
index 759963f441a5478b88caa4ec5ddccaf57d2d89c7..9e07e830dcb60ae8270934003fb8840d24be2d42 100644 (file)
@@ -14,6 +14,6 @@
        ],
 
    "cases": ["*"],
-   "exclude-cases": ["9.*"],
+   "exclude-cases": ["9.*", "12.*.1","12.2.*", "12.3.*", "12.4.*", "12.5.*", "13.*.1"],
    "exclude-agent-cases": {}
 }
index 8fc4ab60f7fec4cf921e9ebebecd89f01dd1f79f..28d541c2924bbd427317a0ada89f0e42b11b5455 100644 (file)
@@ -7,6 +7,6 @@
    "webport": 8080,
 
    "cases": ["*"],
-   "exclude-cases": ["9.*"],
+   "exclude-cases": ["9.*", "12.*.1","12.2.*", "12.3.*", "12.4.*", "12.5.*", "13.*.1"],
    "exclude-agent-cases": {}
 }
index b44056cd63151e2463ab3c4a4fd2535517fe43da..305bd7468038ae50ba93524a645edc4916715cde 100644 (file)
@@ -12,6 +12,9 @@ class EchoHandler(WebSocketHandler):
     def on_message(self, message):
         self.write_message(message, binary=isinstance(message, bytes_type))
 
+    def get_compression_options(self):
+        return {}
+
 if __name__ == '__main__':
     parse_command_line()
     app = Application([
index 1c5753822666ec40b9a5f1c3bd788b132c4f4b05..efe9f653f794f2936d2eba9a05f408f7b1760a44 100644 (file)
@@ -803,6 +803,8 @@ def parse_response_start_line(line):
 # _parseparam and _parse_header are copied and modified from python2.7's cgi.py
 # The original 2.7 version of this code did not correctly support some
 # combinations of semicolons and double quotes.
+# It has also been modified to support valueless parameters as seen in
+# websocket extension negotiations.
 
 
 def _parseparam(s):
@@ -836,6 +838,8 @@ def _parse_header(line):
                 value = value[1:-1]
                 value = value.replace('\\\\', '\\').replace('\\"', '"')
             pdict[name] = value
+        else:
+            pdict[p] = None
     return key, pdict
 
 
index a1f85cf5c17d30c7f0c7bddb46604d012ad48972..f8cd163bd33b661726491963fc4e832840a8b922 100644 (file)
@@ -3,6 +3,7 @@ from __future__ import absolute_import, division, print_function, with_statement
 import traceback
 
 from tornado.concurrent import Future
+from tornado import gen
 from tornado.httpclient import HTTPError, HTTPRequest
 from tornado.log import gen_log, app_log
 from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
@@ -34,8 +35,12 @@ class TestWebSocketHandler(WebSocketHandler):
 
     This allows for deterministic cleanup of the associated socket.
     """
-    def initialize(self, close_future):
+    def initialize(self, close_future, compression_options=None):
         self.close_future = close_future
+        self.compression_options = compression_options
+
+    def get_compression_options(self):
+        return self.compression_options
 
     def on_close(self):
         self.close_future.set_result((self.close_code, self.close_reason))
@@ -73,7 +78,25 @@ class CloseReasonHandler(TestWebSocketHandler):
         self.close(1001, "goodbye")
 
 
-class WebSocketTest(AsyncHTTPTestCase):
+class WebSocketBaseTestCase(AsyncHTTPTestCase):
+    @gen.coroutine
+    def ws_connect(self, path, compression_options=None):
+        ws = yield websocket_connect(
+            'ws://localhost:%d%s' % (self.get_http_port(), path),
+            compression_options=compression_options)
+        raise gen.Return(ws)
+
+    @gen.coroutine
+    def close(self, ws):
+        """Close a websocket connection and wait for the server side.
+
+        If we don't wait here, there are sometimes leak warnings in the
+        tests.
+        """
+        ws.close()
+        yield self.close_future
+
+class WebSocketTest(WebSocketBaseTestCase):
     def get_app(self):
         self.close_future = Future()
         return Application([
@@ -93,14 +116,11 @@ class WebSocketTest(AsyncHTTPTestCase):
 
     @gen_test
     def test_websocket_gen(self):
-        ws = yield websocket_connect(
-            'ws://localhost:%d/echo' % self.get_http_port(),
-            io_loop=self.io_loop)
+        ws = yield self.ws_connect('/echo')
         ws.write_message('hello')
         response = yield ws.read_message()
         self.assertEqual(response, 'hello')
-        ws.close()
-        yield self.close_future
+        yield self.close(ws)
 
     def test_websocket_callbacks(self):
         websocket_connect(
@@ -117,49 +137,39 @@ class WebSocketTest(AsyncHTTPTestCase):
 
     @gen_test
     def test_binary_message(self):
-        ws = yield websocket_connect(
-            'ws://localhost:%d/echo' % self.get_http_port())
+        ws = yield self.ws_connect('/echo')
         ws.write_message(b'hello \xe9', binary=True)
         response = yield ws.read_message()
         self.assertEqual(response, b'hello \xe9')
-        ws.close()
-        yield self.close_future
+        yield self.close(ws)
 
     @gen_test
     def test_unicode_message(self):
-        ws = yield websocket_connect(
-            'ws://localhost:%d/echo' % self.get_http_port())
+        ws = yield self.ws_connect('/echo')
         ws.write_message(u('hello \u00e9'))
         response = yield ws.read_message()
         self.assertEqual(response, u('hello \u00e9'))
-        ws.close()
-        yield self.close_future
+        yield self.close(ws)
 
     @gen_test
     def test_error_in_on_message(self):
-        ws = yield websocket_connect(
-            'ws://localhost:%d/error_in_on_message' % self.get_http_port())
+        ws = yield self.ws_connect('/error_in_on_message')
         ws.write_message('hello')
         with ExpectLog(app_log, "Uncaught exception"):
             response = yield ws.read_message()
         self.assertIs(response, None)
-        ws.close()
-        yield self.close_future
+        yield self.close(ws)
 
     @gen_test
     def test_websocket_http_fail(self):
         with self.assertRaises(HTTPError) as cm:
-            yield websocket_connect(
-                'ws://localhost:%d/notfound' % self.get_http_port(),
-                io_loop=self.io_loop)
+            yield self.ws_connect('/notfound')
         self.assertEqual(cm.exception.code, 404)
 
     @gen_test
     def test_websocket_http_success(self):
         with self.assertRaises(WebSocketError):
-            yield websocket_connect(
-                'ws://localhost:%d/non_ws' % self.get_http_port(),
-                io_loop=self.io_loop)
+            yield self.ws_connect('/non_ws')
 
     @gen_test
     def test_websocket_network_fail(self):
@@ -178,6 +188,7 @@ class WebSocketTest(AsyncHTTPTestCase):
             'ws://localhost:%d/echo' % self.get_http_port())
         ws.write_message('hello')
         ws.write_message('world')
+        # Close the underlying stream.
         ws.stream.close()
         yield self.close_future
 
@@ -189,13 +200,11 @@ class WebSocketTest(AsyncHTTPTestCase):
                         headers={'X-Test': 'hello'}))
         response = yield ws.read_message()
         self.assertEqual(response, 'hello')
-        ws.close()
-        yield self.close_future
+        yield self.close(ws)
 
     @gen_test
     def test_server_close_reason(self):
-        ws = yield websocket_connect(
-            'ws://localhost:%d/close_reason' % self.get_http_port())
+        ws = yield self.ws_connect('/close_reason')
         msg = yield ws.read_message()
         # A message of None means the other side closed the connection.
         self.assertIs(msg, None)
@@ -204,8 +213,7 @@ class WebSocketTest(AsyncHTTPTestCase):
 
     @gen_test
     def test_client_close_reason(self):
-        ws = yield websocket_connect(
-            'ws://localhost:%d/echo' % self.get_http_port())
+        ws = yield self.ws_connect('/echo')
         ws.close(1001, 'goodbye')
         code, reason = yield self.close_future
         self.assertEqual(code, 1001)
@@ -223,8 +231,7 @@ class WebSocketTest(AsyncHTTPTestCase):
         ws.write_message('hello')
         response = yield ws.read_message()
         self.assertEqual(response, 'hello')
-        ws.close()
-        yield self.close_future
+        yield self.close(ws)
 
     @gen_test
     def test_check_origin_valid_with_path(self):
@@ -238,8 +245,7 @@ class WebSocketTest(AsyncHTTPTestCase):
         ws.write_message('hello')
         response = yield ws.read_message()
         self.assertEqual(response, 'hello')
-        ws.close()
-        yield self.close_future
+        yield self.close(ws)
 
     @gen_test
     def test_check_origin_invalid_partial_url(self):
@@ -284,6 +290,78 @@ class WebSocketTest(AsyncHTTPTestCase):
         self.assertEqual(cm.exception.code, 403)
 
 
+class CompressionTestMixin(object):
+    MESSAGE = 'Hello world. Testing 123 123'
+
+    def get_app(self):
+        self.close_future = Future()
+        return Application([
+            ('/echo', EchoHandler, dict(
+                close_future=self.close_future,
+                compression_options=self.get_server_compression_options())),
+        ])
+
+    def get_server_compression_options(self):
+        return None
+
+    def get_client_compression_options(self):
+        return None
+
+    @gen_test
+    def test_message_sizes(self):
+        ws = yield self.ws_connect(
+            '/echo',
+            compression_options=self.get_client_compression_options())
+        # Send the same message three times so we can measure the
+        # effect of the context_takeover options.
+        for i in range(3):
+            ws.write_message(self.MESSAGE)
+            response = yield ws.read_message()
+            self.assertEqual(response, self.MESSAGE)
+        self.assertEqual(ws.protocol._message_bytes_out, len(self.MESSAGE) * 3)
+        self.assertEqual(ws.protocol._message_bytes_in, len(self.MESSAGE) * 3)
+        self.verify_wire_bytes(ws.protocol._wire_bytes_in,
+                               ws.protocol._wire_bytes_out)
+        yield self.close(ws)
+
+
+class UncompressedTestMixin(CompressionTestMixin):
+    """Specialization of CompressionTestMixin when we expect no compression."""
+    def verify_wire_bytes(self, bytes_in, bytes_out):
+        # Bytes out includes the 4-byte mask key per message.
+        self.assertEqual(bytes_out, 3 * (len(self.MESSAGE) + 6))
+        self.assertEqual(bytes_in, 3 * (len(self.MESSAGE) + 2))
+
+
+class NoCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
+    pass
+
+
+# If only one side tries to compress, the extension is not negotiated.
+class ServerOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
+    def get_server_compression_options(self):
+        return {}
+
+
+class ClientOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
+    def get_client_compression_options(self):
+        return {}
+
+
+class DefaultCompressionTest(CompressionTestMixin, WebSocketBaseTestCase):
+    def get_server_compression_options(self):
+        return {}
+
+    def get_client_compression_options(self):
+        return {}
+
+    def verify_wire_bytes(self, bytes_in, bytes_out):
+        self.assertLess(bytes_out, 3 * (len(self.MESSAGE) + 6))
+        self.assertLess(bytes_in, 3 * (len(self.MESSAGE) + 2))
+        # Bytes out includes the 4 bytes mask key per message.
+        self.assertEqual(bytes_out, bytes_in + 12)
+
+
 class MaskFunctionMixin(object):
     # Subclasses should define self.mask(mask, data)
     def test_mask(self):
index a77e02c49357754c9ab8328b23ba605f3adacc69..014eafe286f1fe0d8b285e17ba7127f89b0ea549 100644 (file)
@@ -26,6 +26,7 @@ import os
 import struct
 import tornado.escape
 import tornado.web
+import zlib
 
 from tornado.concurrent import TracebackFuture
 from tornado.escape import utf8, native_str, to_unicode
@@ -171,7 +172,8 @@ class WebSocketHandler(tornado.web.RequestHandler):
         self.stream.set_close_callback(self.on_connection_close)
 
         if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
-            self.ws_connection = WebSocketProtocol13(self)
+            self.ws_connection = WebSocketProtocol13(
+                self, compression_options=self.get_compression_options())
             self.ws_connection.accept_connection()
         else:
             self.stream.write(tornado.escape.utf8(
@@ -213,6 +215,9 @@ class WebSocketHandler(tornado.web.RequestHandler):
         """
         return None
 
+    def get_compression_options(self):
+        return None
+
     def open(self):
         """Invoked when a new WebSocket is opened.
 
@@ -383,13 +388,55 @@ class WebSocketProtocol(object):
         self.close()  # let the subclass cleanup
 
 
+class _PerMessageDeflateCompressor(object):
+    def __init__(self, persistent):
+        if persistent:
+            self._compressor = self._create_compressor()
+        else:
+            self._compressor = None
+
+    def _create_compressor(self):
+        return zlib.compressobj(-1, zlib.DEFLATED, -zlib.MAX_WBITS)
+
+    def compress(self, data):
+        compressor = self._compressor or self._create_compressor()
+        data = (compressor.compress(data) +
+                compressor.flush(zlib.Z_SYNC_FLUSH))
+        assert data.endswith(b'\x00\x00\xff\xff')
+        return data[:-4]
+
+
+class _PerMessageDeflateDecompressor(object):
+    def __init__(self, persistent):
+        if persistent:
+            self._decompressor = self._create_decompressor()
+        else:
+            self._decompressor = None
+
+    def _create_decompressor(self):
+        return zlib.decompressobj(-zlib.MAX_WBITS)
+
+    def decompress(self, data):
+        decompressor = self._decompressor or self._create_decompressor()
+        return decompressor.decompress(data + b'\x00\x00\xff\xff')
+
+
 class WebSocketProtocol13(WebSocketProtocol):
     """Implementation of the WebSocket protocol from RFC 6455.
 
     This class supports versions 7 and 8 of the protocol in addition to the
     final version 13.
     """
-    def __init__(self, handler, mask_outgoing=False):
+    # Bit masks for the first byte of a frame.
+    FIN = 0x80
+    RSV1 = 0x40
+    RSV2 = 0x20
+    RSV3 = 0x10
+    RSV_MASK = RSV1 | RSV2 | RSV3
+    OPCODE_MASK = 0x0f
+
+    def __init__(self, handler, mask_outgoing=False,
+                 compression_options=None):
         WebSocketProtocol.__init__(self, handler)
         self.mask_outgoing = mask_outgoing
         self._final_frame = False
@@ -400,6 +447,19 @@ class WebSocketProtocol13(WebSocketProtocol):
         self._fragmented_message_buffer = None
         self._fragmented_message_opcode = None
         self._waiting = None
+        self._compression_options = compression_options
+        self._decompressor = None
+        self._compressor = None
+        self._frame_compressed = None
+        # The total uncompressed size of all messages received or sent.
+        # Unicode messages are encoded to utf8.
+        # Only for testing; subject to change.
+        self._message_bytes_in = 0
+        self._message_bytes_out = 0
+        # The total size of all packets received or sent.  Includes
+        # the effect of compression, frame overhead, and control frames.
+        self._wire_bytes_in = 0
+        self._wire_bytes_out = 0
 
     def accept_connection(self):
         try:
@@ -444,24 +504,71 @@ class WebSocketProtocol13(WebSocketProtocol):
                 assert selected in subprotocols
                 subprotocol_header = "Sec-WebSocket-Protocol: %s\r\n" % selected
 
+        extension_header = ''
+        extensions = self._parse_extensions_header(self.request.headers)
+        for ext in extensions:
+            if (ext[0] == 'permessage-deflate' and
+                self._compression_options is not None):
+                # TODO: negotiate parameters.  For now, only
+                # allow the base extension.
+                extension_header = (
+                    'Sec-WebSocket-Extensions: permessage-deflate\r\n')
+                self._create_compressors('server', {})
+                break
+
         self.stream.write(tornado.escape.utf8(
             "HTTP/1.1 101 Switching Protocols\r\n"
             "Upgrade: websocket\r\n"
             "Connection: Upgrade\r\n"
             "Sec-WebSocket-Accept: %s\r\n"
-            "%s"
-            "\r\n" % (self._challenge_response(), subprotocol_header)))
+            "%s%s"
+            "\r\n" % (self._challenge_response(),
+                      subprotocol_header, extension_header)))
 
         self._run_callback(self.handler.open, *self.handler.open_args,
                            **self.handler.open_kwargs)
         self._receive_frame()
 
-    def _write_frame(self, fin, opcode, data):
+    def _parse_extensions_header(self, headers):
+        extensions = headers.get("Sec-WebSocket-Extensions", '')
+        if extensions:
+            return [httputil._parse_header(e.strip())
+                    for e in extensions.split(',')]
+        return []
+
+    def _process_server_headers(self, key, headers):
+        """Process the headers sent by the server to this client connection.
+
+        'key' is the websocket handshake challenge/response key.
+        """
+        assert headers['Upgrade'].lower() == 'websocket'
+        assert headers['Connection'].lower() == 'upgrade'
+        accept = self.compute_accept_value(key)
+        assert headers['Sec-Websocket-Accept'] == accept
+
+        extensions = self._parse_extensions_header(headers)
+        for ext in extensions:
+            if (ext[0] == 'permessage-deflate' and
+                self._compression_options is not None):
+                self._create_compressors('client', ext[1])
+            else:
+                raise ValueError("unsupported extension %r", ext)
+
+    def _create_compressors(self, side, agreed_parameters):
+        # TODO: support the max_wbits parameters.
+        other_side = 'client' if (side == 'server') else 'server'
+        self._compressor = _PerMessageDeflateCompressor(
+            persistent=(side + '_no_context_takeover') not in agreed_parameters)
+        self._decompressor = _PerMessageDeflateDecompressor(
+            persistent=((other_side + '_no_context_takeover')
+                        not in agreed_parameters))
+
+    def _write_frame(self, fin, opcode, data, flags=0):
         if fin:
-            finbit = 0x80
+            finbit = self.FIN
         else:
             finbit = 0
-        frame = struct.pack("B", finbit | opcode)
+        frame = struct.pack("B", finbit | opcode | flags)
         l = len(data)
         if self.mask_outgoing:
             mask_bit = 0x80
@@ -477,6 +584,7 @@ class WebSocketProtocol13(WebSocketProtocol):
             mask = os.urandom(4)
             data = mask + _websocket_mask(mask, data)
         frame += data
+        self._wire_bytes_out += len(frame)
         self.stream.write(frame)
 
     def write_message(self, message, binary=False):
@@ -487,8 +595,13 @@ class WebSocketProtocol13(WebSocketProtocol):
             opcode = 0x1
         message = tornado.escape.utf8(message)
         assert isinstance(message, bytes_type)
+        self._message_bytes_out += len(message)
+        flags = 0
+        if self._compressor:
+            message = self._compressor.compress(message)
+            flags |= self.RSV1
         try:
-            self._write_frame(True, opcode, message)
+            self._write_frame(True, opcode, message, flags=flags)
         except StreamClosedError:
             self._abort()
 
@@ -504,11 +617,15 @@ class WebSocketProtocol13(WebSocketProtocol):
             self._abort()
 
     def _on_frame_start(self, data):
+        self._wire_bytes_in += len(data)
         header, payloadlen = struct.unpack("BB", data)
-        self._final_frame = header & 0x80
-        reserved_bits = header & 0x70
-        self._frame_opcode = header & 0xf
+        self._final_frame = header & self.FIN
+        reserved_bits = header & self.RSV_MASK
+        self._frame_opcode = header & self.OPCODE_MASK
         self._frame_opcode_is_control = self._frame_opcode & 0x8
+        if self._decompressor is not None:
+            self._frame_compressed = bool(reserved_bits & self.RSV1)
+            reserved_bits &= ~self.RSV1
         if reserved_bits:
             # client is using as-yet-undefined extensions; abort
             self._abort()
@@ -534,6 +651,7 @@ class WebSocketProtocol13(WebSocketProtocol):
             self._abort()
 
     def _on_frame_length_16(self, data):
+        self._wire_bytes_in += len(data)
         self._frame_length = struct.unpack("!H", data)[0]
         try:
             if self._masked_frame:
@@ -544,6 +662,7 @@ class WebSocketProtocol13(WebSocketProtocol):
             self._abort()
 
     def _on_frame_length_64(self, data):
+        self._wire_bytes_in += len(data)
         self._frame_length = struct.unpack("!Q", data)[0]
         try:
             if self._masked_frame:
@@ -554,6 +673,7 @@ class WebSocketProtocol13(WebSocketProtocol):
             self._abort()
 
     def _on_masking_key(self, data):
+        self._wire_bytes_in += len(data)
         self._frame_mask = data
         try:
             self.stream.read_bytes(self._frame_length, self._on_masked_frame_data)
@@ -561,9 +681,11 @@ class WebSocketProtocol13(WebSocketProtocol):
             self._abort()
 
     def _on_masked_frame_data(self, data):
+        # Don't touch _wire_bytes_in; we'll do it in _on_frame_data.
         self._on_frame_data(_websocket_mask(self._frame_mask, data))
 
     def _on_frame_data(self, data):
+        self._wire_bytes_in += len(data)
         if self._frame_opcode_is_control:
             # control frames may be interleaved with a series of fragmented
             # data frames, so control frames must not interact with
@@ -604,8 +726,12 @@ class WebSocketProtocol13(WebSocketProtocol):
         if self.client_terminated:
             return
 
+        if self._frame_compressed:
+            data = self._decompressor.decompress(data)
+
         if opcode == 0x1:
             # UTF-8 data
+            self._message_bytes_in += len(data)
             try:
                 decoded = data.decode("utf-8")
             except UnicodeDecodeError:
@@ -614,6 +740,7 @@ class WebSocketProtocol13(WebSocketProtocol):
             self._run_callback(self.handler.on_message, decoded)
         elif opcode == 0x2:
             # Binary data
+            self._message_bytes_in += len(data)
             self._run_callback(self.handler.on_message, data)
         elif opcode == 0x8:
             # Close
@@ -664,7 +791,8 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
     This class should not be instantiated directly; use the
     `websocket_connect` function instead.
     """
-    def __init__(self, io_loop, request):
+    def __init__(self, io_loop, request, compression_options=None):
+        self.compression_options = compression_options
         self.connect_future = TracebackFuture()
         self.read_future = None
         self.read_queue = collections.deque()
@@ -679,6 +807,9 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
             'Sec-WebSocket-Key': self.key,
             'Sec-WebSocket-Version': '13',
         })
+        if self.compression_options is not None:
+            # TODO: offer parameters for the deflate extension.
+            request.headers['Sec-WebSocket-Extensions'] = 'permessage-deflate'
 
         self.tcp_client = TCPClient(io_loop=io_loop)
         super(WebSocketClientConnection, self).__init__(
@@ -720,12 +851,10 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
                 start_line, headers)
 
         self.headers = headers
-        assert self.headers['Upgrade'].lower() == 'websocket'
-        assert self.headers['Connection'].lower() == 'upgrade'
-        accept = WebSocketProtocol13.compute_accept_value(self.key)
-        assert self.headers['Sec-Websocket-Accept'] == accept
-
-        self.protocol = WebSocketProtocol13(self, mask_outgoing=True)
+        self.protocol = WebSocketProtocol13(
+            self, mask_outgoing=True,
+            compression_options=self.compression_options)
+        self.protocol._process_server_headers(self.key, self.headers)
         self.protocol._receive_frame()
 
         if self._timeout is not None:
@@ -770,7 +899,8 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
         pass
 
 
-def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None):
+def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None,
+                      compression_options=None):
     """Client-side websocket support.
 
     Takes a url and returns a Future whose result is a
@@ -791,7 +921,7 @@ def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None):
         request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout)
     request = httpclient._RequestProxy(
         request, httpclient.HTTPRequest._DEFAULTS)
-    conn = WebSocketClientConnection(io_loop, request)
+    conn = WebSocketClientConnection(io_loop, request, compression_options)
     if callback is not None:
         io_loop.add_future(conn.connect_future, callback)
     return conn.connect_future