From: Ben Darnell Date: Sun, 27 Jul 2014 22:29:06 +0000 (-0400) Subject: Support the max_wbits websocket deflate parameters. X-Git-Tag: v4.1.0b1~119 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=fec81f65503a48edf92dc6e2740f77b4082b9033;p=thirdparty%2Ftornado.git Support the max_wbits websocket deflate parameters. Slightly improve error handling in the websocket handshake. --- diff --git a/tornado/httputil.py b/tornado/httputil.py index efe9f653f..c472998fa 100644 --- a/tornado/httputil.py +++ b/tornado/httputil.py @@ -843,6 +843,26 @@ def _parse_header(line): return key, pdict +def _encode_header(key, pdict): + """Inverse of _parse_header. + + >>> _encode_header('permessage-deflate', + ... {'client_max_window_bits': 15, 'client_no_context_takeover': None}) + 'permessage-deflate; client_max_window_bits=15; client_no_context_takeover' + """ + if not pdict: + return key + out = [key] + # Sort the parameters just to make it easy to test. + for k, v in sorted(pdict.items()): + if v is None: + out.append(k) + else: + # TODO: quote if necessary. + out.append('%s=%s' % (k, v)) + return '; '.join(out) + + def doctests(): import doctest return doctest.DocTestSuite() diff --git a/tornado/websocket.py b/tornado/websocket.py index 014eafe28..601443222 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -341,6 +341,15 @@ class WebSocketHandler(tornado.web.RequestHandler): self.ws_connection = None self.on_close() + def send_error(self, *args, **kwargs): + if self.stream is None: + super(WebSocketHandler, self).send_error(*args, **kwargs) + else: + # If we get an uncaught exception during the handshake, + # we have no choice but to abruptly close the connection. + # TODO: for uncaught exceptions after the handshake, + # we can close the connection more gracefully. + self.stream.close() def _wrap_method(method): def _disallow_for_websocket(self, *args, **kwargs): @@ -349,7 +358,7 @@ def _wrap_method(method): else: raise RuntimeError("Method not supported for Web Sockets") return _disallow_for_websocket -for method in ["write", "redirect", "set_header", "send_error", "set_cookie", +for method in ["write", "redirect", "set_header", "set_cookie", "set_status", "flush", "finish"]: setattr(WebSocketHandler, method, _wrap_method(getattr(WebSocketHandler, method))) @@ -389,14 +398,21 @@ class WebSocketProtocol(object): class _PerMessageDeflateCompressor(object): - def __init__(self, persistent): + def __init__(self, persistent, max_wbits): + if max_wbits is None: + max_wbits = zlib.MAX_WBITS + # There is no symbolic constant for the minimum wbits value. + if not (8 <= max_wbits <= zlib.MAX_WBITS): + raise ValueError("Invalid max_wbits value %r; allowed range 8-%d", + max_wbits, zlib.MAX_WBITS) + self._max_wbits = max_wbits if persistent: self._compressor = self._create_compressor() else: self._compressor = None def _create_compressor(self): - return zlib.compressobj(-1, zlib.DEFLATED, -zlib.MAX_WBITS) + return zlib.compressobj(-1, zlib.DEFLATED, -self._max_wbits) def compress(self, data): compressor = self._compressor or self._create_compressor() @@ -407,14 +423,20 @@ class _PerMessageDeflateCompressor(object): class _PerMessageDeflateDecompressor(object): - def __init__(self, persistent): + def __init__(self, persistent, max_wbits): + if max_wbits is None: + max_wbits = zlib.MAX_WBITS + if not (8 <= max_wbits <= zlib.MAX_WBITS): + raise ValueError("Invalid max_wbits value %r; allowed range 8-%d", + max_wbits, zlib.MAX_WBITS) + self._max_wbits = max_wbits if persistent: self._decompressor = self._create_decompressor() else: self._decompressor = None def _create_decompressor(self): - return zlib.decompressobj(-zlib.MAX_WBITS) + return zlib.decompressobj(-self._max_wbits) def decompress(self, data): decompressor = self._decompressor or self._create_decompressor() @@ -509,11 +531,17 @@ class WebSocketProtocol13(WebSocketProtocol): for ext in extensions: if (ext[0] == 'permessage-deflate' and self._compression_options is not None): - # TODO: negotiate parameters. For now, only - # allow the base extension. - extension_header = ( - 'Sec-WebSocket-Extensions: permessage-deflate\r\n') - self._create_compressors('server', {}) + # TODO: negotiate parameters if compression_options + # specifies limits. + self._create_compressors('server', ext[1]) + if ('client_max_window_bits' in ext[1] and + ext[1]['client_max_window_bits'] is None): + # Don't echo an offered client_max_window_bits + # parameter with no value. + del ext[1]['client_max_window_bits'] + extension_header = ('Sec-WebSocket-Extensions: %s\r\n' % + httputil._encode_header( + 'permessage-deflate', ext[1])) break self.stream.write(tornado.escape.utf8( @@ -554,14 +582,33 @@ class WebSocketProtocol13(WebSocketProtocol): else: raise ValueError("unsupported extension %r", ext) + def _get_compressor_options(self, side, agreed_parameters): + """Converts a websocket agreed_parameters set to keyword arguments + for our compressor objects. + """ + options = dict( + persistent=(side + '_no_context_takeover') not in agreed_parameters) + wbits_header = agreed_parameters.get(side + '_max_window_bits', None) + if wbits_header is None: + options['max_wbits'] = zlib.MAX_WBITS + else: + options['max_wbits'] = int(wbits_header) + return options + def _create_compressors(self, side, agreed_parameters): - # TODO: support the max_wbits parameters. + # TODO: handle invalid parameters gracefully + allowed_keys = set(['server_no_context_takeover', + 'client_no_context_takeover', + 'server_max_window_bits', + 'client_max_window_bits']) + for key in agreed_parameters: + if key not in allowed_keys: + raise ValueError("unsupported compression parameter %r" % key) other_side = 'client' if (side == 'server') else 'server' self._compressor = _PerMessageDeflateCompressor( - persistent=(side + '_no_context_takeover') not in agreed_parameters) + **self._get_compressor_options(side, agreed_parameters)) self._decompressor = _PerMessageDeflateDecompressor( - persistent=((other_side + '_no_context_takeover') - not in agreed_parameters)) + **self._get_compressor_options(other_side, agreed_parameters)) def _write_frame(self, fin, opcode, data, flags=0): if fin: @@ -808,8 +855,13 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): 'Sec-WebSocket-Version': '13', }) if self.compression_options is not None: - # TODO: offer parameters for the deflate extension. - request.headers['Sec-WebSocket-Extensions'] = 'permessage-deflate' + # Always offer to let the server set our max_wbits (and even though + # we don't offer it, we will accept a client_no_context_takeover + # from the server). + # TODO: set server parameters for deflate extension + # if requested in self.compression_options. + request.headers['Sec-WebSocket-Extensions'] = ( + 'permessage-deflate; client_max_window_bits') self.tcp_client = TCPClient(io_loop=io_loop) super(WebSocketClientConnection, self).__init__( @@ -833,6 +885,8 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): self.protocol = None def _on_close(self): + if not self.connect_future.done(): + self.connect_future.set_exception(StreamClosedError()) self.on_message(None) self.resolver.close() super(WebSocketClientConnection, self)._on_close()