]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Support the max_wbits websocket deflate parameters.
authorBen Darnell <ben@bendarnell.com>
Sun, 27 Jul 2014 22:29:06 +0000 (18:29 -0400)
committerBen Darnell <ben@bendarnell.com>
Sun, 27 Jul 2014 22:29:06 +0000 (18:29 -0400)
Slightly improve error handling in the websocket handshake.

tornado/httputil.py
tornado/websocket.py

index efe9f653f794f2936d2eba9a05f408f7b1760a44..c472998faac056ad82458ba9aed2112ce69c7901 100644 (file)
@@ -843,6 +843,26 @@ def _parse_header(line):
     return key, pdict
 
 
+def _encode_header(key, pdict):
+    """Inverse of _parse_header.
+
+    >>> _encode_header('permessage-deflate',
+    ...     {'client_max_window_bits': 15, 'client_no_context_takeover': None})
+    'permessage-deflate; client_max_window_bits=15; client_no_context_takeover'
+    """
+    if not pdict:
+        return key
+    out = [key]
+    # Sort the parameters just to make it easy to test.
+    for k, v in sorted(pdict.items()):
+        if v is None:
+            out.append(k)
+        else:
+            # TODO: quote if necessary.
+            out.append('%s=%s' % (k, v))
+    return '; '.join(out)
+
+
 def doctests():
     import doctest
     return doctest.DocTestSuite()
index 014eafe286f1fe0d8b285e17ba7127f89b0ea549..60144322243ff8a8ec2cf749ee944f8803459f59 100644 (file)
@@ -341,6 +341,15 @@ class WebSocketHandler(tornado.web.RequestHandler):
             self.ws_connection = None
             self.on_close()
 
+    def send_error(self, *args, **kwargs):
+        if self.stream is None:
+            super(WebSocketHandler, self).send_error(*args, **kwargs)
+        else:
+            # If we get an uncaught exception during the handshake,
+            # we have no choice but to abruptly close the connection.
+            # TODO: for uncaught exceptions after the handshake,
+            # we can close the connection more gracefully.
+            self.stream.close()
 
 def _wrap_method(method):
     def _disallow_for_websocket(self, *args, **kwargs):
@@ -349,7 +358,7 @@ def _wrap_method(method):
         else:
             raise RuntimeError("Method not supported for Web Sockets")
     return _disallow_for_websocket
-for method in ["write", "redirect", "set_header", "send_error", "set_cookie",
+for method in ["write", "redirect", "set_header", "set_cookie",
                "set_status", "flush", "finish"]:
     setattr(WebSocketHandler, method,
             _wrap_method(getattr(WebSocketHandler, method)))
@@ -389,14 +398,21 @@ class WebSocketProtocol(object):
 
 
 class _PerMessageDeflateCompressor(object):
-    def __init__(self, persistent):
+    def __init__(self, persistent, max_wbits):
+        if max_wbits is None:
+            max_wbits = zlib.MAX_WBITS
+        # There is no symbolic constant for the minimum wbits value.
+        if not (8 <= max_wbits <= zlib.MAX_WBITS):
+            raise ValueError("Invalid max_wbits value %r; allowed range 8-%d",
+                             max_wbits, zlib.MAX_WBITS)
+        self._max_wbits = max_wbits
         if persistent:
             self._compressor = self._create_compressor()
         else:
             self._compressor = None
 
     def _create_compressor(self):
-        return zlib.compressobj(-1, zlib.DEFLATED, -zlib.MAX_WBITS)
+        return zlib.compressobj(-1, zlib.DEFLATED, -self._max_wbits)
 
     def compress(self, data):
         compressor = self._compressor or self._create_compressor()
@@ -407,14 +423,20 @@ class _PerMessageDeflateCompressor(object):
 
 
 class _PerMessageDeflateDecompressor(object):
-    def __init__(self, persistent):
+    def __init__(self, persistent, max_wbits):
+        if max_wbits is None:
+            max_wbits = zlib.MAX_WBITS
+        if not (8 <= max_wbits <= zlib.MAX_WBITS):
+            raise ValueError("Invalid max_wbits value %r; allowed range 8-%d",
+                             max_wbits, zlib.MAX_WBITS)
+        self._max_wbits = max_wbits
         if persistent:
             self._decompressor = self._create_decompressor()
         else:
             self._decompressor = None
 
     def _create_decompressor(self):
-        return zlib.decompressobj(-zlib.MAX_WBITS)
+        return zlib.decompressobj(-self._max_wbits)
 
     def decompress(self, data):
         decompressor = self._decompressor or self._create_decompressor()
@@ -509,11 +531,17 @@ class WebSocketProtocol13(WebSocketProtocol):
         for ext in extensions:
             if (ext[0] == 'permessage-deflate' and
                 self._compression_options is not None):
-                # TODO: negotiate parameters.  For now, only
-                # allow the base extension.
-                extension_header = (
-                    'Sec-WebSocket-Extensions: permessage-deflate\r\n')
-                self._create_compressors('server', {})
+                # TODO: negotiate parameters if compression_options
+                # specifies limits.
+                self._create_compressors('server', ext[1])
+                if ('client_max_window_bits' in ext[1] and
+                    ext[1]['client_max_window_bits'] is None):
+                    # Don't echo an offered client_max_window_bits
+                    # parameter with no value.
+                    del ext[1]['client_max_window_bits']
+                extension_header = ('Sec-WebSocket-Extensions: %s\r\n' %
+                                    httputil._encode_header(
+                                        'permessage-deflate', ext[1]))
                 break
 
         self.stream.write(tornado.escape.utf8(
@@ -554,14 +582,33 @@ class WebSocketProtocol13(WebSocketProtocol):
             else:
                 raise ValueError("unsupported extension %r", ext)
 
+    def _get_compressor_options(self, side, agreed_parameters):
+        """Converts a websocket agreed_parameters set to keyword arguments
+        for our compressor objects.
+        """
+        options = dict(
+            persistent=(side + '_no_context_takeover') not in agreed_parameters)
+        wbits_header = agreed_parameters.get(side + '_max_window_bits', None)
+        if wbits_header is None:
+            options['max_wbits'] = zlib.MAX_WBITS
+        else:
+            options['max_wbits'] = int(wbits_header)
+        return options
+
     def _create_compressors(self, side, agreed_parameters):
-        # TODO: support the max_wbits parameters.
+        # TODO: handle invalid parameters gracefully
+        allowed_keys = set(['server_no_context_takeover',
+                            'client_no_context_takeover',
+                            'server_max_window_bits',
+                            'client_max_window_bits'])
+        for key in agreed_parameters:
+            if key not in allowed_keys:
+                raise ValueError("unsupported compression parameter %r" % key)
         other_side = 'client' if (side == 'server') else 'server'
         self._compressor = _PerMessageDeflateCompressor(
-            persistent=(side + '_no_context_takeover') not in agreed_parameters)
+            **self._get_compressor_options(side, agreed_parameters))
         self._decompressor = _PerMessageDeflateDecompressor(
-            persistent=((other_side + '_no_context_takeover')
-                        not in agreed_parameters))
+            **self._get_compressor_options(other_side, agreed_parameters))
 
     def _write_frame(self, fin, opcode, data, flags=0):
         if fin:
@@ -808,8 +855,13 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
             'Sec-WebSocket-Version': '13',
         })
         if self.compression_options is not None:
-            # TODO: offer parameters for the deflate extension.
-            request.headers['Sec-WebSocket-Extensions'] = 'permessage-deflate'
+            # Always offer to let the server set our max_wbits (and even though
+            # we don't offer it, we will accept a client_no_context_takeover
+            # from the server).
+            # TODO: set server parameters for deflate extension
+            # if requested in self.compression_options.
+            request.headers['Sec-WebSocket-Extensions'] = (
+                'permessage-deflate; client_max_window_bits')
 
         self.tcp_client = TCPClient(io_loop=io_loop)
         super(WebSocketClientConnection, self).__init__(
@@ -833,6 +885,8 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
             self.protocol = None
 
     def _on_close(self):
+        if not self.connect_future.done():
+            self.connect_future.set_exception(StreamClosedError())
         self.on_message(None)
         self.resolver.close()
         super(WebSocketClientConnection, self)._on_close()