# Author: Jacob Kristhammar, 2010
import array
+import base64
+import collections
import functools
import hashlib
+import logging
+import os
+import re
+import socket
import struct
import time
-import base64
import tornado.escape
import tornado.web
+from tornado.concurrent import Future, return_future
+from tornado.escape import utf8, to_unicode, native_str
+from tornado.httputil import HTTPHeaders
+from tornado.ioloop import IOLoop
+from tornado.iostream import IOStream, SSLIOStream
from tornado.log import gen_log, app_log
+from tornado import simple_httpclient
from tornado.util import bytes_type
try:
except NameError:
xrange = range # py3
+try:
+ import urlparse # py2
+except ImportError:
+ import urllib.parse as urlparse # py3
+
class WebSocketHandler(tornado.web.RequestHandler):
"""Subclass this class to create a basic WebSocket handler.
This class supports versions 7 and 8 of the protocol in addition to the
final version 13.
"""
- def __init__(self, handler):
+ def __init__(self, handler, mask_outgoing=False):
WebSocketProtocol.__init__(self, handler)
+ self.mask_outgoing = mask_outgoing
self._final_frame = False
self._frame_opcode = None
+ self._masked_frame = None
self._frame_mask = None
self._frame_length = None
self._fragmented_message_buffer = None
self._handle_websocket_headers()
self._accept_connection()
except ValueError:
- gen_log.debug("Malformed WebSocket request received")
+ gen_log.debug("Malformed WebSocket request received", exc_info=True)
self._abort()
return
if not all(map(lambda f: self.request.headers.get(f), fields)):
raise ValueError("Missing/Invalid WebSocket headers")
- def _challenge_response(self):
+ @staticmethod
+ def compute_accept_value(key):
+ """Computes the value for the Sec-WebSocket-Accept header,
+ given the value for Sec-WebSocket-Key.
+ """
sha1 = hashlib.sha1()
- sha1.update(tornado.escape.utf8(
- self.request.headers.get("Sec-Websocket-Key")))
+ sha1.update(utf8(key))
sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11") # Magic value
- return tornado.escape.native_str(base64.b64encode(sha1.digest()))
+ return native_str(base64.b64encode(sha1.digest()))
+
+ def _challenge_response(self):
+ return WebSocketProtocol13.compute_accept_value(
+ self.request.headers.get("Sec-Websocket-Key"))
+
def _accept_connection(self):
subprotocol_header = ''
finbit = 0
frame = struct.pack("B", finbit | opcode)
l = len(data)
+ if self.mask_outgoing:
+ mask_bit = 0x80
+ else:
+ mask_bit = 0
if l < 126:
- frame += struct.pack("B", l)
+ frame += struct.pack("B", l | mask_bit)
elif l <= 0xFFFF:
- frame += struct.pack("!BH", 126, l)
+ frame += struct.pack("!BH", 126 | mask_bit, l)
else:
- frame += struct.pack("!BQ", 127, l)
+ frame += struct.pack("!BQ", 127 | mask_bit, l)
+ if self.mask_outgoing:
+ mask = os.urandom(4)
+ data = mask + self._apply_mask(mask, data)
frame += data
self.stream.write(frame)
# client is using as-yet-undefined extensions; abort
self._abort()
return
- if not (payloadlen & 0x80):
- # Unmasked frame -> abort connection
- self._abort()
- return
+ self._masked_frame = bool(payloadlen & 0x80)
payloadlen = payloadlen & 0x7f
if self._frame_opcode_is_control and payloadlen >= 126:
# control frames must have payload < 126
return
if payloadlen < 126:
self._frame_length = payloadlen
- self.stream.read_bytes(4, self._on_masking_key)
+ if self._masked_frame:
+ self.stream.read_bytes(4, self._on_masking_key)
+ else:
+ self.stream.read_bytes(self._frame_length, self._on_frame_data)
elif payloadlen == 126:
self.stream.read_bytes(2, self._on_frame_length_16)
elif payloadlen == 127:
def _on_frame_length_16(self, data):
self._frame_length = struct.unpack("!H", data)[0]
- self.stream.read_bytes(4, self._on_masking_key)
+ if self._masked_frame:
+ self.stream.read_bytes(4, self._on_masking_key)
+ else:
+ self.stream.read_bytes(self._frame_length, self._on_frame_data)
def _on_frame_length_64(self, data):
self._frame_length = struct.unpack("!Q", data)[0]
- self.stream.read_bytes(4, self._on_masking_key)
+ if self._masked_frame:
+ self.stream.read_bytes(4, self._on_masking_key)
+ else:
+ self.stream.read_bytes(self._frame_length, self._on_frame_data)
def _on_masking_key(self, data):
- self._frame_mask = array.array("B", data)
- self.stream.read_bytes(self._frame_length, self._on_frame_data)
+ self._frame_mask = data
+ self.stream.read_bytes(self._frame_length, self._on_masked_frame_data)
- def _on_frame_data(self, data):
+ def _apply_mask(self, mask, data):
+ mask = array.array("B", mask)
unmasked = array.array("B", data)
for i in xrange(len(data)):
- unmasked[i] = unmasked[i] ^ self._frame_mask[i % 4]
+ unmasked[i] = unmasked[i] ^ mask[i % 4]
+ if hasattr(unmasked, 'tobytes'):
+ # tostring was deprecated in py32. It hasn't been removed,
+ # but since we turn on deprecation warnings in our tests
+ # we need to use the right one.
+ return unmasked.tobytes()
+ else:
+ return unmasked.tostring()
+
+ def _on_masked_frame_data(self, data):
+ self._on_frame_data(self._apply_mask(self._frame_mask, data))
+ def _on_frame_data(self, data):
if self._frame_opcode_is_control:
# control frames may be interleaved with a series of fragmented
# data frames, so control frames must not interact with
# nothing to continue
self._abort()
return
- self._fragmented_message_buffer += unmasked
+ self._fragmented_message_buffer += data
if self._final_frame:
opcode = self._fragmented_message_opcode
- unmasked = self._fragmented_message_buffer
+ data = self._fragmented_message_buffer
self._fragmented_message_buffer = None
else: # start of new data message
if self._fragmented_message_buffer is not None:
opcode = self._frame_opcode
else:
self._fragmented_message_opcode = self._frame_opcode
- self._fragmented_message_buffer = unmasked
+ self._fragmented_message_buffer = data
if self._final_frame:
- self._handle_message(opcode, unmasked.tostring())
+ self._handle_message(opcode, data)
if not self.client_terminated:
self._receive_frame()
# otherwise just close the connection.
self._waiting = self.stream.io_loop.add_timeout(
self.stream.io_loop.time() + 5, self._abort)
+
+
+class _WebSocketClientConnection(simple_httpclient._HTTPConnection):
+ def __init__(self, io_loop, client, request):
+ self.connect_future = Future()
+ self.read_future = None
+ self.read_queue = collections.deque()
+ self.key = base64.b64encode(os.urandom(16))
+
+ scheme, sep, rest = request.url.partition(':')
+ scheme = {'ws': 'http', 'wss': 'https'}[scheme]
+ request.url = scheme + sep + rest
+ request.headers.update({
+ 'Upgrade': 'websocket',
+ 'Connection': 'Upgrade',
+ 'Sec-WebSocket-Key': self.key,
+ 'Sec-WebSocket-Version': '13',
+ })
+
+ super(_WebSocketClientConnection, self).__init__(
+ io_loop, client, request, lambda: None, lambda response: None,
+ 104857600)
+
+ def _on_close(self):
+ self.on_message(None)
+
+
+ def _handle_1xx(self, code):
+ assert code == 101
+ assert self.headers['Upgrade'].lower() == 'websocket'
+ assert self.headers['Connection'].lower() == 'upgrade'
+ accept = WebSocketProtocol13.compute_accept_value(self.key)
+ assert self.headers['Sec-Websocket-Accept'] == accept
+
+ self.protocol = WebSocketProtocol13(self, mask_outgoing=True)
+ self.protocol._receive_frame()
+
+ if self._timeout is not None:
+ self.io_loop.remove_timeout(self._timeout)
+ self._timeout = None
+
+ self.connect_future.set_result(self)
+
+ def write_message(self, message, binary=False):
+ self.protocol.write_message(message, binary)
+
+ def read_message(self, callback=None):
+ assert self.read_future is None
+ future = Future()
+ if self.read_queue:
+ future.set_result(self.read_queue.popleft())
+ else:
+ self.read_future = future
+ if callback is not None:
+ self.io_loop.add_future(future, callback)
+ return future
+
+ def on_message(self, message):
+ if self.read_future is not None:
+ self.read_future.set_result(message)
+ self.read_future = None
+ else:
+ self.read_queue.append(message)
+
+ def on_pong(self, data):
+ pass
+
+
+def WebSocketConnect(url, io_loop=None, callback=None):
+ if io_loop is None:
+ io_loop = IOLoop.instance()
+ request = simple_httpclient.HTTPRequest(url)
+ request = simple_httpclient._RequestProxy(
+ request, simple_httpclient.HTTPRequest._DEFAULTS)
+ from tornado.util import ObjectDict
+ from tornado.netutil import Resolver
+ # TODO: refactor _HTTPConnection's client parameter
+ client = ObjectDict(resolver=Resolver(io_loop), hostname_mapping=None)
+ conn = _WebSocketClientConnection(io_loop, client, request)
+ if callback is not None:
+ io_loop.add_future(conn.connect_future, callback)
+ return conn.connect_future