]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
websocket: Refactor implementation to use coroutines
authorBen Darnell <ben@bendarnell.com>
Sun, 22 Apr 2018 03:31:49 +0000 (23:31 -0400)
committerBen Darnell <ben@bendarnell.com>
Mon, 23 Apr 2018 13:08:26 +0000 (09:08 -0400)
This avoids the deprecated IOStream interfaces and simplifies things a
bit.

tornado/websocket.py

index ced6d657f6adbcbd923ea9822fe768df48b7b826..0507a92c67ea82ec5900d0378ca9a89b0abab8c8 100644 (file)
@@ -19,13 +19,11 @@ the protocol (known as "draft 76") and are not compatible with this module.
 from __future__ import absolute_import, division, print_function
 
 import base64
-import contextlib
 import hashlib
 import os
 import struct
 import tornado.escape
 import tornado.web
-import warnings
 import zlib
 
 from tornado.concurrent import Future, future_set_result_unless_cancelled
@@ -46,14 +44,6 @@ else:
     from urlparse import urlparse  # py3
 
 
-@contextlib.contextmanager
-def ignore_deprecation():
-    """Context manager to ignore deprecation warnings."""
-    with warnings.catch_warnings():
-        warnings.simplefilter('ignore', DeprecationWarning)
-        yield
-
-
 class WebSocketError(Exception):
     pass
 
@@ -723,7 +713,7 @@ class WebSocketProtocol13(WebSocketProtocol):
         self.start_pinging()
         self._run_callback(self.handler.open, *self.handler.open_args,
                            **self.handler.open_kwargs)
-        self._receive_frame()
+        IOLoop.current().add_callback(self._receive_frame_loop)
 
     def _parse_extensions_header(self, headers):
         extensions = headers.get("Sec-WebSocket-Extensions", '')
@@ -846,116 +836,84 @@ class WebSocketProtocol13(WebSocketProtocol):
         assert isinstance(data, bytes)
         self._write_frame(True, 0x9, data)
 
-    def _receive_frame(self):
+    @gen.coroutine
+    def _receive_frame_loop(self):
         try:
-            with ignore_deprecation():
-                self.stream.read_bytes(2, self._on_frame_start)
+            while not self.client_terminated:
+                yield self._receive_frame()
         except StreamClosedError:
             self._abort()
 
-    def _on_frame_start(self, data):
-        self._wire_bytes_in += len(data)
-        header, payloadlen = struct.unpack("BB", data)
-        self._final_frame = header & self.FIN
+    def _read_bytes(self, n):
+        self._wire_bytes_in += n
+        return self.stream.read_bytes(n)
+
+    @gen.coroutine
+    def _receive_frame(self):
+        # Read the frame header.
+        data = yield self._read_bytes(2)
+        header, mask_payloadlen = struct.unpack("BB", data)
+        is_final_frame = header & self.FIN
         reserved_bits = header & self.RSV_MASK
-        self._frame_opcode = header & self.OPCODE_MASK
-        self._frame_opcode_is_control = self._frame_opcode & 0x8
-        if self._decompressor is not None and self._frame_opcode != 0:
+        opcode = header & self.OPCODE_MASK
+        opcode_is_control = opcode & 0x8
+        if self._decompressor is not None and opcode != 0:
+            # Compression flag is present in the first frame's header,
+            # but we can't decompress until we have all the frames of
+            # the message.
             self._frame_compressed = bool(reserved_bits & self.RSV1)
             reserved_bits &= ~self.RSV1
         if reserved_bits:
             # client is using as-yet-undefined extensions; abort
             self._abort()
             return
-        self._masked_frame = bool(payloadlen & 0x80)
-        payloadlen = payloadlen & 0x7f
-        if self._frame_opcode_is_control and payloadlen >= 126:
+        is_masked = bool(mask_payloadlen & 0x80)
+        payloadlen = mask_payloadlen & 0x7f
+
+        # Parse and validate the length.
+        if opcode_is_control and payloadlen >= 126:
             # control frames must have payload < 126
             self._abort()
             return
-        try:
-            with ignore_deprecation():
-                if payloadlen < 126:
-                    self._frame_length = payloadlen
-                    if self._masked_frame:
-                        self.stream.read_bytes(4, self._on_masking_key)
-                    else:
-                        self._read_frame_data(False)
-                elif payloadlen == 126:
-                    self.stream.read_bytes(2, self._on_frame_length_16)
-                elif payloadlen == 127:
-                    self.stream.read_bytes(8, self._on_frame_length_64)
-        except StreamClosedError:
-            self._abort()
-
-    def _read_frame_data(self, masked):
-        new_len = self._frame_length
+        if payloadlen < 126:
+            self._frame_length = payloadlen
+        elif payloadlen == 126:
+            data = yield self._read_bytes(2)
+            payloadlen = struct.unpack("!H", data)[0]
+        elif payloadlen == 127:
+            data = yield self._read_bytes(8)
+            payloadlen = struct.unpack("!Q", data)[0]
+        new_len = payloadlen
         if self._fragmented_message_buffer is not None:
             new_len += len(self._fragmented_message_buffer)
         if new_len > (self.handler.max_message_size or 10 * 1024 * 1024):
             self.close(1009, "message too big")
-            return
-        with ignore_deprecation():
-            self.stream.read_bytes(
-                self._frame_length,
-                self._on_masked_frame_data if masked else self._on_frame_data)
-
-    def _on_frame_length_16(self, data):
-        self._wire_bytes_in += len(data)
-        self._frame_length = struct.unpack("!H", data)[0]
-        try:
-            if self._masked_frame:
-                with ignore_deprecation():
-                    self.stream.read_bytes(4, self._on_masking_key)
-            else:
-                self._read_frame_data(False)
-        except StreamClosedError:
             self._abort()
+            return
 
-    def _on_frame_length_64(self, data):
-        self._wire_bytes_in += len(data)
-        self._frame_length = struct.unpack("!Q", data)[0]
-        try:
-            if self._masked_frame:
-                with ignore_deprecation():
-                    self.stream.read_bytes(4, self._on_masking_key)
-            else:
-                self._read_frame_data(False)
-        except StreamClosedError:
-            self._abort()
-
-    def _on_masking_key(self, data):
-        self._wire_bytes_in += len(data)
-        self._frame_mask = data
-        try:
-            self._read_frame_data(True)
-        except StreamClosedError:
-            self._abort()
-
-    def _on_masked_frame_data(self, data):
-        # Don't touch _wire_bytes_in; we'll do it in _on_frame_data.
-        self._on_frame_data(_websocket_mask(self._frame_mask, data))
-
-    def _on_frame_data(self, data):
-        handled_future = None
+        # Read the payload, unmasking if necessary.
+        if is_masked:
+            self._frame_mask = yield self._read_bytes(4)
+        data = yield self._read_bytes(payloadlen)
+        if is_masked:
+            data = _websocket_mask(self._frame_mask, data)
 
-        self._wire_bytes_in += len(data)
-        if self._frame_opcode_is_control:
+        # Decide what to do with this frame.
+        if opcode_is_control:
             # control frames may be interleaved with a series of fragmented
             # data frames, so control frames must not interact with
             # self._fragmented_*
-            if not self._final_frame:
+            if not is_final_frame:
                 # control frames must not be fragmented
                 self._abort()
                 return
-            opcode = self._frame_opcode
-        elif self._frame_opcode == 0:  # continuation frame
+        elif opcode == 0:  # continuation frame
             if self._fragmented_message_buffer is None:
                 # nothing to continue
                 self._abort()
                 return
             self._fragmented_message_buffer += data
-            if self._final_frame:
+            if is_final_frame:
                 opcode = self._fragmented_message_opcode
                 data = self._fragmented_message_buffer
                 self._fragmented_message_buffer = None
@@ -964,22 +922,14 @@ class WebSocketProtocol13(WebSocketProtocol):
                 # can't start new message until the old one is finished
                 self._abort()
                 return
-            if self._final_frame:
-                opcode = self._frame_opcode
-            else:
-                self._fragmented_message_opcode = self._frame_opcode
+            if not is_final_frame:
+                self._fragmented_message_opcode = opcode
                 self._fragmented_message_buffer = data
 
-        if self._final_frame:
+        if is_final_frame:
             handled_future = self._handle_message(opcode, data)
-
-        if not self.client_terminated:
-            if handled_future:
-                # on_message is a coroutine, process more frames once it's done.
-                handled_future.add_done_callback(
-                    lambda future: self._receive_frame())
-            else:
-                self._receive_frame()
+            if handled_future is not None:
+                yield handled_future
 
     def _handle_message(self, opcode, data):
         """Execute on_message, returning its Future if it is a coroutine."""
@@ -1182,7 +1132,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
         self.protocol = self.get_websocket_protocol()
         self.protocol._process_server_headers(self.key, self.headers)
         self.protocol.start_pinging()
-        self.protocol._receive_frame()
+        IOLoop.current().add_callback(self.protocol._receive_frame_loop)
 
         if self._timeout is not None:
             self.io_loop.remove_timeout(self._timeout)