import traceback
from tornado.concurrent import Future
+from tornado import gen
from tornado.httpclient import HTTPError, HTTPRequest
from tornado.log import gen_log, app_log
from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
This allows for deterministic cleanup of the associated socket.
"""
- def initialize(self, close_future):
+ def initialize(self, close_future, compression_options=None):
self.close_future = close_future
+ self.compression_options = compression_options
+
+ def get_compression_options(self):
+ return self.compression_options
def on_close(self):
self.close_future.set_result((self.close_code, self.close_reason))
self.close(1001, "goodbye")
-class WebSocketTest(AsyncHTTPTestCase):
+class WebSocketBaseTestCase(AsyncHTTPTestCase):
+ @gen.coroutine
+ def ws_connect(self, path, compression_options=None):
+ ws = yield websocket_connect(
+ 'ws://localhost:%d%s' % (self.get_http_port(), path),
+ compression_options=compression_options)
+ raise gen.Return(ws)
+
+ @gen.coroutine
+ def close(self, ws):
+ """Close a websocket connection and wait for the server side.
+
+ If we don't wait here, there are sometimes leak warnings in the
+ tests.
+ """
+ ws.close()
+ yield self.close_future
+
+class WebSocketTest(WebSocketBaseTestCase):
def get_app(self):
self.close_future = Future()
return Application([
@gen_test
def test_websocket_gen(self):
- ws = yield websocket_connect(
- 'ws://localhost:%d/echo' % self.get_http_port(),
- io_loop=self.io_loop)
+ ws = yield self.ws_connect('/echo')
ws.write_message('hello')
response = yield ws.read_message()
self.assertEqual(response, 'hello')
- ws.close()
- yield self.close_future
+ yield self.close(ws)
def test_websocket_callbacks(self):
websocket_connect(
@gen_test
def test_binary_message(self):
- ws = yield websocket_connect(
- 'ws://localhost:%d/echo' % self.get_http_port())
+ ws = yield self.ws_connect('/echo')
ws.write_message(b'hello \xe9', binary=True)
response = yield ws.read_message()
self.assertEqual(response, b'hello \xe9')
- ws.close()
- yield self.close_future
+ yield self.close(ws)
@gen_test
def test_unicode_message(self):
- ws = yield websocket_connect(
- 'ws://localhost:%d/echo' % self.get_http_port())
+ ws = yield self.ws_connect('/echo')
ws.write_message(u('hello \u00e9'))
response = yield ws.read_message()
self.assertEqual(response, u('hello \u00e9'))
- ws.close()
- yield self.close_future
+ yield self.close(ws)
@gen_test
def test_error_in_on_message(self):
- ws = yield websocket_connect(
- 'ws://localhost:%d/error_in_on_message' % self.get_http_port())
+ ws = yield self.ws_connect('/error_in_on_message')
ws.write_message('hello')
with ExpectLog(app_log, "Uncaught exception"):
response = yield ws.read_message()
self.assertIs(response, None)
- ws.close()
- yield self.close_future
+ yield self.close(ws)
@gen_test
def test_websocket_http_fail(self):
with self.assertRaises(HTTPError) as cm:
- yield websocket_connect(
- 'ws://localhost:%d/notfound' % self.get_http_port(),
- io_loop=self.io_loop)
+ yield self.ws_connect('/notfound')
self.assertEqual(cm.exception.code, 404)
@gen_test
def test_websocket_http_success(self):
with self.assertRaises(WebSocketError):
- yield websocket_connect(
- 'ws://localhost:%d/non_ws' % self.get_http_port(),
- io_loop=self.io_loop)
+ yield self.ws_connect('/non_ws')
@gen_test
def test_websocket_network_fail(self):
'ws://localhost:%d/echo' % self.get_http_port())
ws.write_message('hello')
ws.write_message('world')
+ # Close the underlying stream.
ws.stream.close()
yield self.close_future
headers={'X-Test': 'hello'}))
response = yield ws.read_message()
self.assertEqual(response, 'hello')
- ws.close()
- yield self.close_future
+ yield self.close(ws)
@gen_test
def test_server_close_reason(self):
- ws = yield websocket_connect(
- 'ws://localhost:%d/close_reason' % self.get_http_port())
+ ws = yield self.ws_connect('/close_reason')
msg = yield ws.read_message()
# A message of None means the other side closed the connection.
self.assertIs(msg, None)
@gen_test
def test_client_close_reason(self):
- ws = yield websocket_connect(
- 'ws://localhost:%d/echo' % self.get_http_port())
+ ws = yield self.ws_connect('/echo')
ws.close(1001, 'goodbye')
code, reason = yield self.close_future
self.assertEqual(code, 1001)
ws.write_message('hello')
response = yield ws.read_message()
self.assertEqual(response, 'hello')
- ws.close()
- yield self.close_future
+ yield self.close(ws)
@gen_test
def test_check_origin_valid_with_path(self):
ws.write_message('hello')
response = yield ws.read_message()
self.assertEqual(response, 'hello')
- ws.close()
- yield self.close_future
+ yield self.close(ws)
@gen_test
def test_check_origin_invalid_partial_url(self):
self.assertEqual(cm.exception.code, 403)
+class CompressionTestMixin(object):
+ MESSAGE = 'Hello world. Testing 123 123'
+
+ def get_app(self):
+ self.close_future = Future()
+ return Application([
+ ('/echo', EchoHandler, dict(
+ close_future=self.close_future,
+ compression_options=self.get_server_compression_options())),
+ ])
+
+ def get_server_compression_options(self):
+ return None
+
+ def get_client_compression_options(self):
+ return None
+
+ @gen_test
+ def test_message_sizes(self):
+ ws = yield self.ws_connect(
+ '/echo',
+ compression_options=self.get_client_compression_options())
+ # Send the same message three times so we can measure the
+ # effect of the context_takeover options.
+ for i in range(3):
+ ws.write_message(self.MESSAGE)
+ response = yield ws.read_message()
+ self.assertEqual(response, self.MESSAGE)
+ self.assertEqual(ws.protocol._message_bytes_out, len(self.MESSAGE) * 3)
+ self.assertEqual(ws.protocol._message_bytes_in, len(self.MESSAGE) * 3)
+ self.verify_wire_bytes(ws.protocol._wire_bytes_in,
+ ws.protocol._wire_bytes_out)
+ yield self.close(ws)
+
+
+class UncompressedTestMixin(CompressionTestMixin):
+ """Specialization of CompressionTestMixin when we expect no compression."""
+ def verify_wire_bytes(self, bytes_in, bytes_out):
+ # Bytes out includes the 4-byte mask key per message.
+ self.assertEqual(bytes_out, 3 * (len(self.MESSAGE) + 6))
+ self.assertEqual(bytes_in, 3 * (len(self.MESSAGE) + 2))
+
+
+class NoCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
+ pass
+
+
+# If only one side tries to compress, the extension is not negotiated.
+class ServerOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
+ def get_server_compression_options(self):
+ return {}
+
+
+class ClientOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
+ def get_client_compression_options(self):
+ return {}
+
+
+class DefaultCompressionTest(CompressionTestMixin, WebSocketBaseTestCase):
+ def get_server_compression_options(self):
+ return {}
+
+ def get_client_compression_options(self):
+ return {}
+
+ def verify_wire_bytes(self, bytes_in, bytes_out):
+ self.assertLess(bytes_out, 3 * (len(self.MESSAGE) + 6))
+ self.assertLess(bytes_in, 3 * (len(self.MESSAGE) + 2))
+ # Bytes out includes the 4 bytes mask key per message.
+ self.assertEqual(bytes_out, bytes_in + 12)
+
+
class MaskFunctionMixin(object):
# Subclasses should define self.mask(mask, data)
def test_mask(self):
import struct
import tornado.escape
import tornado.web
+import zlib
from tornado.concurrent import TracebackFuture
from tornado.escape import utf8, native_str, to_unicode
self.stream.set_close_callback(self.on_connection_close)
if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
- self.ws_connection = WebSocketProtocol13(self)
+ self.ws_connection = WebSocketProtocol13(
+ self, compression_options=self.get_compression_options())
self.ws_connection.accept_connection()
else:
self.stream.write(tornado.escape.utf8(
"""
return None
+ def get_compression_options(self):
+ return None
+
def open(self):
"""Invoked when a new WebSocket is opened.
self.close() # let the subclass cleanup
+class _PerMessageDeflateCompressor(object):
+ def __init__(self, persistent):
+ if persistent:
+ self._compressor = self._create_compressor()
+ else:
+ self._compressor = None
+
+ def _create_compressor(self):
+ return zlib.compressobj(-1, zlib.DEFLATED, -zlib.MAX_WBITS)
+
+ def compress(self, data):
+ compressor = self._compressor or self._create_compressor()
+ data = (compressor.compress(data) +
+ compressor.flush(zlib.Z_SYNC_FLUSH))
+ assert data.endswith(b'\x00\x00\xff\xff')
+ return data[:-4]
+
+
+class _PerMessageDeflateDecompressor(object):
+ def __init__(self, persistent):
+ if persistent:
+ self._decompressor = self._create_decompressor()
+ else:
+ self._decompressor = None
+
+ def _create_decompressor(self):
+ return zlib.decompressobj(-zlib.MAX_WBITS)
+
+ def decompress(self, data):
+ decompressor = self._decompressor or self._create_decompressor()
+ return decompressor.decompress(data + b'\x00\x00\xff\xff')
+
+
class WebSocketProtocol13(WebSocketProtocol):
"""Implementation of the WebSocket protocol from RFC 6455.
This class supports versions 7 and 8 of the protocol in addition to the
final version 13.
"""
- def __init__(self, handler, mask_outgoing=False):
+ # Bit masks for the first byte of a frame.
+ FIN = 0x80
+ RSV1 = 0x40
+ RSV2 = 0x20
+ RSV3 = 0x10
+ RSV_MASK = RSV1 | RSV2 | RSV3
+ OPCODE_MASK = 0x0f
+
+ def __init__(self, handler, mask_outgoing=False,
+ compression_options=None):
WebSocketProtocol.__init__(self, handler)
self.mask_outgoing = mask_outgoing
self._final_frame = False
self._fragmented_message_buffer = None
self._fragmented_message_opcode = None
self._waiting = None
+ self._compression_options = compression_options
+ self._decompressor = None
+ self._compressor = None
+ self._frame_compressed = None
+ # The total uncompressed size of all messages received or sent.
+ # Unicode messages are encoded to utf8.
+ # Only for testing; subject to change.
+ self._message_bytes_in = 0
+ self._message_bytes_out = 0
+ # The total size of all packets received or sent. Includes
+ # the effect of compression, frame overhead, and control frames.
+ self._wire_bytes_in = 0
+ self._wire_bytes_out = 0
def accept_connection(self):
try:
assert selected in subprotocols
subprotocol_header = "Sec-WebSocket-Protocol: %s\r\n" % selected
+ extension_header = ''
+ extensions = self._parse_extensions_header(self.request.headers)
+ for ext in extensions:
+ if (ext[0] == 'permessage-deflate' and
+ self._compression_options is not None):
+ # TODO: negotiate parameters. For now, only
+ # allow the base extension.
+ extension_header = (
+ 'Sec-WebSocket-Extensions: permessage-deflate\r\n')
+ self._create_compressors('server', {})
+ break
+
self.stream.write(tornado.escape.utf8(
"HTTP/1.1 101 Switching Protocols\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Accept: %s\r\n"
- "%s"
- "\r\n" % (self._challenge_response(), subprotocol_header)))
+ "%s%s"
+ "\r\n" % (self._challenge_response(),
+ subprotocol_header, extension_header)))
self._run_callback(self.handler.open, *self.handler.open_args,
**self.handler.open_kwargs)
self._receive_frame()
- def _write_frame(self, fin, opcode, data):
+ def _parse_extensions_header(self, headers):
+ extensions = headers.get("Sec-WebSocket-Extensions", '')
+ if extensions:
+ return [httputil._parse_header(e.strip())
+ for e in extensions.split(',')]
+ return []
+
+ def _process_server_headers(self, key, headers):
+ """Process the headers sent by the server to this client connection.
+
+ 'key' is the websocket handshake challenge/response key.
+ """
+ assert headers['Upgrade'].lower() == 'websocket'
+ assert headers['Connection'].lower() == 'upgrade'
+ accept = self.compute_accept_value(key)
+ assert headers['Sec-Websocket-Accept'] == accept
+
+ extensions = self._parse_extensions_header(headers)
+ for ext in extensions:
+ if (ext[0] == 'permessage-deflate' and
+ self._compression_options is not None):
+ self._create_compressors('client', ext[1])
+ else:
+ raise ValueError("unsupported extension %r", ext)
+
+ def _create_compressors(self, side, agreed_parameters):
+ # TODO: support the max_wbits parameters.
+ other_side = 'client' if (side == 'server') else 'server'
+ self._compressor = _PerMessageDeflateCompressor(
+ persistent=(side + '_no_context_takeover') not in agreed_parameters)
+ self._decompressor = _PerMessageDeflateDecompressor(
+ persistent=((other_side + '_no_context_takeover')
+ not in agreed_parameters))
+
+ def _write_frame(self, fin, opcode, data, flags=0):
if fin:
- finbit = 0x80
+ finbit = self.FIN
else:
finbit = 0
- frame = struct.pack("B", finbit | opcode)
+ frame = struct.pack("B", finbit | opcode | flags)
l = len(data)
if self.mask_outgoing:
mask_bit = 0x80
mask = os.urandom(4)
data = mask + _websocket_mask(mask, data)
frame += data
+ self._wire_bytes_out += len(frame)
self.stream.write(frame)
def write_message(self, message, binary=False):
opcode = 0x1
message = tornado.escape.utf8(message)
assert isinstance(message, bytes_type)
+ self._message_bytes_out += len(message)
+ flags = 0
+ if self._compressor:
+ message = self._compressor.compress(message)
+ flags |= self.RSV1
try:
- self._write_frame(True, opcode, message)
+ self._write_frame(True, opcode, message, flags=flags)
except StreamClosedError:
self._abort()
self._abort()
def _on_frame_start(self, data):
+ self._wire_bytes_in += len(data)
header, payloadlen = struct.unpack("BB", data)
- self._final_frame = header & 0x80
- reserved_bits = header & 0x70
- self._frame_opcode = header & 0xf
+ self._final_frame = header & self.FIN
+ reserved_bits = header & self.RSV_MASK
+ self._frame_opcode = header & self.OPCODE_MASK
self._frame_opcode_is_control = self._frame_opcode & 0x8
+ if self._decompressor is not None:
+ self._frame_compressed = bool(reserved_bits & self.RSV1)
+ reserved_bits &= ~self.RSV1
if reserved_bits:
# client is using as-yet-undefined extensions; abort
self._abort()
self._abort()
def _on_frame_length_16(self, data):
+ self._wire_bytes_in += len(data)
self._frame_length = struct.unpack("!H", data)[0]
try:
if self._masked_frame:
self._abort()
def _on_frame_length_64(self, data):
+ self._wire_bytes_in += len(data)
self._frame_length = struct.unpack("!Q", data)[0]
try:
if self._masked_frame:
self._abort()
def _on_masking_key(self, data):
+ self._wire_bytes_in += len(data)
self._frame_mask = data
try:
self.stream.read_bytes(self._frame_length, self._on_masked_frame_data)
self._abort()
def _on_masked_frame_data(self, data):
+ # Don't touch _wire_bytes_in; we'll do it in _on_frame_data.
self._on_frame_data(_websocket_mask(self._frame_mask, data))
def _on_frame_data(self, data):
+ self._wire_bytes_in += len(data)
if self._frame_opcode_is_control:
# control frames may be interleaved with a series of fragmented
# data frames, so control frames must not interact with
if self.client_terminated:
return
+ if self._frame_compressed:
+ data = self._decompressor.decompress(data)
+
if opcode == 0x1:
# UTF-8 data
+ self._message_bytes_in += len(data)
try:
decoded = data.decode("utf-8")
except UnicodeDecodeError:
self._run_callback(self.handler.on_message, decoded)
elif opcode == 0x2:
# Binary data
+ self._message_bytes_in += len(data)
self._run_callback(self.handler.on_message, data)
elif opcode == 0x8:
# Close
This class should not be instantiated directly; use the
`websocket_connect` function instead.
"""
- def __init__(self, io_loop, request):
+ def __init__(self, io_loop, request, compression_options=None):
+ self.compression_options = compression_options
self.connect_future = TracebackFuture()
self.read_future = None
self.read_queue = collections.deque()
'Sec-WebSocket-Key': self.key,
'Sec-WebSocket-Version': '13',
})
+ if self.compression_options is not None:
+ # TODO: offer parameters for the deflate extension.
+ request.headers['Sec-WebSocket-Extensions'] = 'permessage-deflate'
self.tcp_client = TCPClient(io_loop=io_loop)
super(WebSocketClientConnection, self).__init__(
start_line, headers)
self.headers = headers
- assert self.headers['Upgrade'].lower() == 'websocket'
- assert self.headers['Connection'].lower() == 'upgrade'
- accept = WebSocketProtocol13.compute_accept_value(self.key)
- assert self.headers['Sec-Websocket-Accept'] == accept
-
- self.protocol = WebSocketProtocol13(self, mask_outgoing=True)
+ self.protocol = WebSocketProtocol13(
+ self, mask_outgoing=True,
+ compression_options=self.compression_options)
+ self.protocol._process_server_headers(self.key, self.headers)
self.protocol._receive_frame()
if self._timeout is not None:
pass
-def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None):
+def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None,
+ compression_options=None):
"""Client-side websocket support.
Takes a url and returns a Future whose result is a
request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout)
request = httpclient._RequestProxy(
request, httpclient.HTTPRequest._DEFAULTS)
- conn = WebSocketClientConnection(io_loop, request)
+ conn = WebSocketClientConnection(io_loop, request, compression_options)
if callback is not None:
io_loop.add_future(conn.connect_future, callback)
return conn.connect_future