from __future__ import absolute_import, division, print_function
import base64
-import contextlib
import hashlib
import os
import struct
import tornado.escape
import tornado.web
-import warnings
import zlib
from tornado.concurrent import Future, future_set_result_unless_cancelled
from urlparse import urlparse # py3
-@contextlib.contextmanager
-def ignore_deprecation():
- """Context manager to ignore deprecation warnings."""
- with warnings.catch_warnings():
- warnings.simplefilter('ignore', DeprecationWarning)
- yield
-
-
class WebSocketError(Exception):
pass
self.start_pinging()
self._run_callback(self.handler.open, *self.handler.open_args,
**self.handler.open_kwargs)
- self._receive_frame()
+ IOLoop.current().add_callback(self._receive_frame_loop)
def _parse_extensions_header(self, headers):
extensions = headers.get("Sec-WebSocket-Extensions", '')
assert isinstance(data, bytes)
self._write_frame(True, 0x9, data)
- def _receive_frame(self):
+ @gen.coroutine
+ def _receive_frame_loop(self):
try:
- with ignore_deprecation():
- self.stream.read_bytes(2, self._on_frame_start)
+ while not self.client_terminated:
+ yield self._receive_frame()
except StreamClosedError:
self._abort()
- def _on_frame_start(self, data):
- self._wire_bytes_in += len(data)
- header, payloadlen = struct.unpack("BB", data)
- self._final_frame = header & self.FIN
+ def _read_bytes(self, n):
+ self._wire_bytes_in += n
+ return self.stream.read_bytes(n)
+
+ @gen.coroutine
+ def _receive_frame(self):
+ # Read the frame header.
+ data = yield self._read_bytes(2)
+ header, mask_payloadlen = struct.unpack("BB", data)
+ is_final_frame = header & self.FIN
reserved_bits = header & self.RSV_MASK
- self._frame_opcode = header & self.OPCODE_MASK
- self._frame_opcode_is_control = self._frame_opcode & 0x8
- if self._decompressor is not None and self._frame_opcode != 0:
+ opcode = header & self.OPCODE_MASK
+ opcode_is_control = opcode & 0x8
+ if self._decompressor is not None and opcode != 0:
+ # Compression flag is present in the first frame's header,
+ # but we can't decompress until we have all the frames of
+ # the message.
self._frame_compressed = bool(reserved_bits & self.RSV1)
reserved_bits &= ~self.RSV1
if reserved_bits:
# client is using as-yet-undefined extensions; abort
self._abort()
return
- self._masked_frame = bool(payloadlen & 0x80)
- payloadlen = payloadlen & 0x7f
- if self._frame_opcode_is_control and payloadlen >= 126:
+ is_masked = bool(mask_payloadlen & 0x80)
+ payloadlen = mask_payloadlen & 0x7f
+
+ # Parse and validate the length.
+ if opcode_is_control and payloadlen >= 126:
# control frames must have payload < 126
self._abort()
return
- try:
- with ignore_deprecation():
- if payloadlen < 126:
- self._frame_length = payloadlen
- if self._masked_frame:
- self.stream.read_bytes(4, self._on_masking_key)
- else:
- self._read_frame_data(False)
- elif payloadlen == 126:
- self.stream.read_bytes(2, self._on_frame_length_16)
- elif payloadlen == 127:
- self.stream.read_bytes(8, self._on_frame_length_64)
- except StreamClosedError:
- self._abort()
-
- def _read_frame_data(self, masked):
- new_len = self._frame_length
+ if payloadlen < 126:
+ self._frame_length = payloadlen
+ elif payloadlen == 126:
+ data = yield self._read_bytes(2)
+ payloadlen = struct.unpack("!H", data)[0]
+ elif payloadlen == 127:
+ data = yield self._read_bytes(8)
+ payloadlen = struct.unpack("!Q", data)[0]
+ new_len = payloadlen
if self._fragmented_message_buffer is not None:
new_len += len(self._fragmented_message_buffer)
if new_len > (self.handler.max_message_size or 10 * 1024 * 1024):
self.close(1009, "message too big")
- return
- with ignore_deprecation():
- self.stream.read_bytes(
- self._frame_length,
- self._on_masked_frame_data if masked else self._on_frame_data)
-
- def _on_frame_length_16(self, data):
- self._wire_bytes_in += len(data)
- self._frame_length = struct.unpack("!H", data)[0]
- try:
- if self._masked_frame:
- with ignore_deprecation():
- self.stream.read_bytes(4, self._on_masking_key)
- else:
- self._read_frame_data(False)
- except StreamClosedError:
self._abort()
+ return
- def _on_frame_length_64(self, data):
- self._wire_bytes_in += len(data)
- self._frame_length = struct.unpack("!Q", data)[0]
- try:
- if self._masked_frame:
- with ignore_deprecation():
- self.stream.read_bytes(4, self._on_masking_key)
- else:
- self._read_frame_data(False)
- except StreamClosedError:
- self._abort()
-
- def _on_masking_key(self, data):
- self._wire_bytes_in += len(data)
- self._frame_mask = data
- try:
- self._read_frame_data(True)
- except StreamClosedError:
- self._abort()
-
- def _on_masked_frame_data(self, data):
- # Don't touch _wire_bytes_in; we'll do it in _on_frame_data.
- self._on_frame_data(_websocket_mask(self._frame_mask, data))
-
- def _on_frame_data(self, data):
- handled_future = None
+ # Read the payload, unmasking if necessary.
+ if is_masked:
+ self._frame_mask = yield self._read_bytes(4)
+ data = yield self._read_bytes(payloadlen)
+ if is_masked:
+ data = _websocket_mask(self._frame_mask, data)
- self._wire_bytes_in += len(data)
- if self._frame_opcode_is_control:
+ # Decide what to do with this frame.
+ if opcode_is_control:
# control frames may be interleaved with a series of fragmented
# data frames, so control frames must not interact with
# self._fragmented_*
- if not self._final_frame:
+ if not is_final_frame:
# control frames must not be fragmented
self._abort()
return
- opcode = self._frame_opcode
- elif self._frame_opcode == 0: # continuation frame
+ elif opcode == 0: # continuation frame
if self._fragmented_message_buffer is None:
# nothing to continue
self._abort()
return
self._fragmented_message_buffer += data
- if self._final_frame:
+ if is_final_frame:
opcode = self._fragmented_message_opcode
data = self._fragmented_message_buffer
self._fragmented_message_buffer = None
# can't start new message until the old one is finished
self._abort()
return
- if self._final_frame:
- opcode = self._frame_opcode
- else:
- self._fragmented_message_opcode = self._frame_opcode
+ if not is_final_frame:
+ self._fragmented_message_opcode = opcode
self._fragmented_message_buffer = data
- if self._final_frame:
+ if is_final_frame:
handled_future = self._handle_message(opcode, data)
-
- if not self.client_terminated:
- if handled_future:
- # on_message is a coroutine, process more frames once it's done.
- handled_future.add_done_callback(
- lambda future: self._receive_frame())
- else:
- self._receive_frame()
+ if handled_future is not None:
+ yield handled_future
def _handle_message(self, opcode, data):
"""Execute on_message, returning its Future if it is a coroutine."""
self.protocol = self.get_websocket_protocol()
self.protocol._process_server_headers(self.key, self.headers)
self.protocol.start_pinging()
- self.protocol._receive_frame()
+ IOLoop.current().add_callback(self.protocol._receive_frame_loop)
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)