self.ws_connection = None
self.on_close()
+ def send_error(self, *args, **kwargs):
+ if self.stream is None:
+ super(WebSocketHandler, self).send_error(*args, **kwargs)
+ else:
+ # If we get an uncaught exception during the handshake,
+ # we have no choice but to abruptly close the connection.
+ # TODO: for uncaught exceptions after the handshake,
+ # we can close the connection more gracefully.
+ self.stream.close()
def _wrap_method(method):
def _disallow_for_websocket(self, *args, **kwargs):
else:
raise RuntimeError("Method not supported for Web Sockets")
return _disallow_for_websocket
-for method in ["write", "redirect", "set_header", "send_error", "set_cookie",
+for method in ["write", "redirect", "set_header", "set_cookie",
"set_status", "flush", "finish"]:
setattr(WebSocketHandler, method,
_wrap_method(getattr(WebSocketHandler, method)))
class _PerMessageDeflateCompressor(object):
- def __init__(self, persistent):
+ def __init__(self, persistent, max_wbits):
+ if max_wbits is None:
+ max_wbits = zlib.MAX_WBITS
+ # There is no symbolic constant for the minimum wbits value.
+ if not (8 <= max_wbits <= zlib.MAX_WBITS):
+ raise ValueError("Invalid max_wbits value %r; allowed range 8-%d",
+ max_wbits, zlib.MAX_WBITS)
+ self._max_wbits = max_wbits
if persistent:
self._compressor = self._create_compressor()
else:
self._compressor = None
def _create_compressor(self):
- return zlib.compressobj(-1, zlib.DEFLATED, -zlib.MAX_WBITS)
+ return zlib.compressobj(-1, zlib.DEFLATED, -self._max_wbits)
def compress(self, data):
compressor = self._compressor or self._create_compressor()
class _PerMessageDeflateDecompressor(object):
- def __init__(self, persistent):
+ def __init__(self, persistent, max_wbits):
+ if max_wbits is None:
+ max_wbits = zlib.MAX_WBITS
+ if not (8 <= max_wbits <= zlib.MAX_WBITS):
+ raise ValueError("Invalid max_wbits value %r; allowed range 8-%d",
+ max_wbits, zlib.MAX_WBITS)
+ self._max_wbits = max_wbits
if persistent:
self._decompressor = self._create_decompressor()
else:
self._decompressor = None
def _create_decompressor(self):
- return zlib.decompressobj(-zlib.MAX_WBITS)
+ return zlib.decompressobj(-self._max_wbits)
def decompress(self, data):
decompressor = self._decompressor or self._create_decompressor()
for ext in extensions:
if (ext[0] == 'permessage-deflate' and
self._compression_options is not None):
- # TODO: negotiate parameters. For now, only
- # allow the base extension.
- extension_header = (
- 'Sec-WebSocket-Extensions: permessage-deflate\r\n')
- self._create_compressors('server', {})
+ # TODO: negotiate parameters if compression_options
+ # specifies limits.
+ self._create_compressors('server', ext[1])
+ if ('client_max_window_bits' in ext[1] and
+ ext[1]['client_max_window_bits'] is None):
+ # Don't echo an offered client_max_window_bits
+ # parameter with no value.
+ del ext[1]['client_max_window_bits']
+ extension_header = ('Sec-WebSocket-Extensions: %s\r\n' %
+ httputil._encode_header(
+ 'permessage-deflate', ext[1]))
break
self.stream.write(tornado.escape.utf8(
else:
raise ValueError("unsupported extension %r", ext)
+ def _get_compressor_options(self, side, agreed_parameters):
+ """Converts a websocket agreed_parameters set to keyword arguments
+ for our compressor objects.
+ """
+ options = dict(
+ persistent=(side + '_no_context_takeover') not in agreed_parameters)
+ wbits_header = agreed_parameters.get(side + '_max_window_bits', None)
+ if wbits_header is None:
+ options['max_wbits'] = zlib.MAX_WBITS
+ else:
+ options['max_wbits'] = int(wbits_header)
+ return options
+
def _create_compressors(self, side, agreed_parameters):
- # TODO: support the max_wbits parameters.
+ # TODO: handle invalid parameters gracefully
+ allowed_keys = set(['server_no_context_takeover',
+ 'client_no_context_takeover',
+ 'server_max_window_bits',
+ 'client_max_window_bits'])
+ for key in agreed_parameters:
+ if key not in allowed_keys:
+ raise ValueError("unsupported compression parameter %r" % key)
other_side = 'client' if (side == 'server') else 'server'
self._compressor = _PerMessageDeflateCompressor(
- persistent=(side + '_no_context_takeover') not in agreed_parameters)
+ **self._get_compressor_options(side, agreed_parameters))
self._decompressor = _PerMessageDeflateDecompressor(
- persistent=((other_side + '_no_context_takeover')
- not in agreed_parameters))
+ **self._get_compressor_options(other_side, agreed_parameters))
def _write_frame(self, fin, opcode, data, flags=0):
if fin:
'Sec-WebSocket-Version': '13',
})
if self.compression_options is not None:
- # TODO: offer parameters for the deflate extension.
- request.headers['Sec-WebSocket-Extensions'] = 'permessage-deflate'
+ # Always offer to let the server set our max_wbits (and even though
+ # we don't offer it, we will accept a client_no_context_takeover
+ # from the server).
+ # TODO: set server parameters for deflate extension
+ # if requested in self.compression_options.
+ request.headers['Sec-WebSocket-Extensions'] = (
+ 'permessage-deflate; client_max_window_bits')
self.tcp_client = TCPClient(io_loop=io_loop)
super(WebSocketClientConnection, self).__init__(
self.protocol = None
def _on_close(self):
+ if not self.connect_future.done():
+ self.connect_future.set_exception(StreamClosedError())
self.on_message(None)
self.resolver.close()
super(WebSocketClientConnection, self)._on_close()