From: Ben Darnell Date: Sun, 22 Apr 2018 03:31:49 +0000 (-0400) Subject: websocket: Refactor implementation to use coroutines X-Git-Tag: v5.1.0b1~23^2~4 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4635a34049f25ef3c2dd71d50164e67d0b68c7d8;p=thirdparty%2Ftornado.git websocket: Refactor implementation to use coroutines This avoids the deprecated IOStream interfaces and simplifies things a bit. --- diff --git a/tornado/websocket.py b/tornado/websocket.py index ced6d657f..0507a92c6 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -19,13 +19,11 @@ the protocol (known as "draft 76") and are not compatible with this module. from __future__ import absolute_import, division, print_function import base64 -import contextlib import hashlib import os import struct import tornado.escape import tornado.web -import warnings import zlib from tornado.concurrent import Future, future_set_result_unless_cancelled @@ -46,14 +44,6 @@ else: from urlparse import urlparse # py3 -@contextlib.contextmanager -def ignore_deprecation(): - """Context manager to ignore deprecation warnings.""" - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - yield - - class WebSocketError(Exception): pass @@ -723,7 +713,7 @@ class WebSocketProtocol13(WebSocketProtocol): self.start_pinging() self._run_callback(self.handler.open, *self.handler.open_args, **self.handler.open_kwargs) - self._receive_frame() + IOLoop.current().add_callback(self._receive_frame_loop) def _parse_extensions_header(self, headers): extensions = headers.get("Sec-WebSocket-Extensions", '') @@ -846,116 +836,84 @@ class WebSocketProtocol13(WebSocketProtocol): assert isinstance(data, bytes) self._write_frame(True, 0x9, data) - def _receive_frame(self): + @gen.coroutine + def _receive_frame_loop(self): try: - with ignore_deprecation(): - self.stream.read_bytes(2, self._on_frame_start) + while not self.client_terminated: + yield self._receive_frame() except StreamClosedError: self._abort() - def _on_frame_start(self, data): - self._wire_bytes_in += len(data) - header, payloadlen = struct.unpack("BB", data) - self._final_frame = header & self.FIN + def _read_bytes(self, n): + self._wire_bytes_in += n + return self.stream.read_bytes(n) + + @gen.coroutine + def _receive_frame(self): + # Read the frame header. + data = yield self._read_bytes(2) + header, mask_payloadlen = struct.unpack("BB", data) + is_final_frame = header & self.FIN reserved_bits = header & self.RSV_MASK - self._frame_opcode = header & self.OPCODE_MASK - self._frame_opcode_is_control = self._frame_opcode & 0x8 - if self._decompressor is not None and self._frame_opcode != 0: + opcode = header & self.OPCODE_MASK + opcode_is_control = opcode & 0x8 + if self._decompressor is not None and opcode != 0: + # Compression flag is present in the first frame's header, + # but we can't decompress until we have all the frames of + # the message. self._frame_compressed = bool(reserved_bits & self.RSV1) reserved_bits &= ~self.RSV1 if reserved_bits: # client is using as-yet-undefined extensions; abort self._abort() return - self._masked_frame = bool(payloadlen & 0x80) - payloadlen = payloadlen & 0x7f - if self._frame_opcode_is_control and payloadlen >= 126: + is_masked = bool(mask_payloadlen & 0x80) + payloadlen = mask_payloadlen & 0x7f + + # Parse and validate the length. + if opcode_is_control and payloadlen >= 126: # control frames must have payload < 126 self._abort() return - try: - with ignore_deprecation(): - if payloadlen < 126: - self._frame_length = payloadlen - if self._masked_frame: - self.stream.read_bytes(4, self._on_masking_key) - else: - self._read_frame_data(False) - elif payloadlen == 126: - self.stream.read_bytes(2, self._on_frame_length_16) - elif payloadlen == 127: - self.stream.read_bytes(8, self._on_frame_length_64) - except StreamClosedError: - self._abort() - - def _read_frame_data(self, masked): - new_len = self._frame_length + if payloadlen < 126: + self._frame_length = payloadlen + elif payloadlen == 126: + data = yield self._read_bytes(2) + payloadlen = struct.unpack("!H", data)[0] + elif payloadlen == 127: + data = yield self._read_bytes(8) + payloadlen = struct.unpack("!Q", data)[0] + new_len = payloadlen if self._fragmented_message_buffer is not None: new_len += len(self._fragmented_message_buffer) if new_len > (self.handler.max_message_size or 10 * 1024 * 1024): self.close(1009, "message too big") - return - with ignore_deprecation(): - self.stream.read_bytes( - self._frame_length, - self._on_masked_frame_data if masked else self._on_frame_data) - - def _on_frame_length_16(self, data): - self._wire_bytes_in += len(data) - self._frame_length = struct.unpack("!H", data)[0] - try: - if self._masked_frame: - with ignore_deprecation(): - self.stream.read_bytes(4, self._on_masking_key) - else: - self._read_frame_data(False) - except StreamClosedError: self._abort() + return - def _on_frame_length_64(self, data): - self._wire_bytes_in += len(data) - self._frame_length = struct.unpack("!Q", data)[0] - try: - if self._masked_frame: - with ignore_deprecation(): - self.stream.read_bytes(4, self._on_masking_key) - else: - self._read_frame_data(False) - except StreamClosedError: - self._abort() - - def _on_masking_key(self, data): - self._wire_bytes_in += len(data) - self._frame_mask = data - try: - self._read_frame_data(True) - except StreamClosedError: - self._abort() - - def _on_masked_frame_data(self, data): - # Don't touch _wire_bytes_in; we'll do it in _on_frame_data. - self._on_frame_data(_websocket_mask(self._frame_mask, data)) - - def _on_frame_data(self, data): - handled_future = None + # Read the payload, unmasking if necessary. + if is_masked: + self._frame_mask = yield self._read_bytes(4) + data = yield self._read_bytes(payloadlen) + if is_masked: + data = _websocket_mask(self._frame_mask, data) - self._wire_bytes_in += len(data) - if self._frame_opcode_is_control: + # Decide what to do with this frame. + if opcode_is_control: # control frames may be interleaved with a series of fragmented # data frames, so control frames must not interact with # self._fragmented_* - if not self._final_frame: + if not is_final_frame: # control frames must not be fragmented self._abort() return - opcode = self._frame_opcode - elif self._frame_opcode == 0: # continuation frame + elif opcode == 0: # continuation frame if self._fragmented_message_buffer is None: # nothing to continue self._abort() return self._fragmented_message_buffer += data - if self._final_frame: + if is_final_frame: opcode = self._fragmented_message_opcode data = self._fragmented_message_buffer self._fragmented_message_buffer = None @@ -964,22 +922,14 @@ class WebSocketProtocol13(WebSocketProtocol): # can't start new message until the old one is finished self._abort() return - if self._final_frame: - opcode = self._frame_opcode - else: - self._fragmented_message_opcode = self._frame_opcode + if not is_final_frame: + self._fragmented_message_opcode = opcode self._fragmented_message_buffer = data - if self._final_frame: + if is_final_frame: handled_future = self._handle_message(opcode, data) - - if not self.client_terminated: - if handled_future: - # on_message is a coroutine, process more frames once it's done. - handled_future.add_done_callback( - lambda future: self._receive_frame()) - else: - self._receive_frame() + if handled_future is not None: + yield handled_future def _handle_message(self, opcode, data): """Execute on_message, returning its Future if it is a coroutine.""" @@ -1182,7 +1132,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): self.protocol = self.get_websocket_protocol() self.protocol._process_server_headers(self.key, self.headers) self.protocol.start_pinging() - self.protocol._receive_frame() + IOLoop.current().add_callback(self.protocol._receive_frame_loop) if self._timeout is not None: self.io_loop.remove_timeout(self._timeout)