]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
websocket: Convert legacy coroutines to native 2560/head
authorBen Darnell <ben@bendarnell.com>
Mon, 24 Dec 2018 18:03:18 +0000 (13:03 -0500)
committerBen Darnell <ben@bendarnell.com>
Sat, 29 Dec 2018 03:17:57 +0000 (22:17 -0500)
tornado/iostream.py
tornado/test/websocket_test.py
tornado/websocket.py

index c57d771170fc345130cf60c523e2a556b109e1b7..a4025d35fbe422dc08b03f56e1a618365b2f1da1 100644 (file)
@@ -23,6 +23,7 @@ Contents:
 * `PipeIOStream`: Pipe-based IOStream implementation.
 """
 
+import asyncio
 import collections
 import errno
 import io
@@ -629,7 +630,13 @@ class BaseIOStream(object):
         for future in futures:
             if not future.done():
                 future.set_exception(StreamClosedError(real_error=self.error))
-            future.exception()
+            # Reference the exception to silence warnings. Annoyingly,
+            # this raises if the future was cancelled, but just
+            # returns any other error.
+            try:
+                future.exception()
+            except asyncio.CancelledError:
+                pass
         if self._ssl_connect_future is not None:
             # _ssl_connect_future expects to see the real exception (typically
             # an ssl.SSLError), not just StreamClosedError.
@@ -778,6 +785,8 @@ class BaseIOStream(object):
             pos = self._read_to_buffer_loop()
         except UnsatisfiableReadError:
             raise
+        except asyncio.CancelledError:
+            raise
         except Exception as e:
             gen_log.warning("error on read: %s" % e)
             self.close(exc_info=e)
index d63a665af8105c20346d3b8e83e39b3531f39757..1a5a27281f00b41c613831b2ac22fb392b650796 100644 (file)
@@ -1,3 +1,4 @@
+import asyncio
 import functools
 import traceback
 import unittest
@@ -61,6 +62,8 @@ class EchoHandler(TestWebSocketHandler):
     def on_message(self, message):
         try:
             yield self.write_message(message, isinstance(message, bytes))
+        except asyncio.CancelledError:
+            pass
         except WebSocketClosedError:
             pass
 
index 8848ca8099162e170b8c45bb68bdcef35dde4588..c77b1999ba789ec4f375f1e8aedaa7cefc115713 100644 (file)
@@ -49,7 +49,6 @@ from typing import (
     List,
     Awaitable,
     Callable,
-    Generator,
     Tuple,
     Type,
 )
@@ -712,7 +711,7 @@ class WebSocketProtocol(abc.ABC):
         raise NotImplementedError()
 
     @abc.abstractmethod
-    def _receive_frame_loop(self) -> "Future[None]":
+    async def _receive_frame_loop(self) -> None:
         raise NotImplementedError()
 
 
@@ -1092,37 +1091,35 @@ class WebSocketProtocol13(WebSocketProtocol):
         except StreamClosedError:
             raise WebSocketClosedError()
 
-        @gen.coroutine
-        def wrapper() -> Generator[Any, Any, None]:
+        async def wrapper() -> None:
             try:
-                yield fut
+                await fut
             except StreamClosedError:
                 raise WebSocketClosedError()
 
-        return wrapper()
+        return asyncio.ensure_future(wrapper())
 
     def write_ping(self, data: bytes) -> None:
         """Send ping frame."""
         assert isinstance(data, bytes)
         self._write_frame(True, 0x9, data)
 
-    @gen.coroutine
-    def _receive_frame_loop(self) -> Generator[Any, Any, None]:
+    async def _receive_frame_loop(self) -> None:
         try:
             while not self.client_terminated:
-                yield self._receive_frame()
+                await self._receive_frame()
         except StreamClosedError:
             self._abort()
         self.handler.on_ws_connection_close(self.close_code, self.close_reason)
 
-    def _read_bytes(self, n: int) -> Awaitable[bytes]:
+    async def _read_bytes(self, n: int) -> bytes:
+        data = await self.stream.read_bytes(n)
         self._wire_bytes_in += n
-        return self.stream.read_bytes(n)
+        return data
 
-    @gen.coroutine
-    def _receive_frame(self) -> Generator[Any, Any, None]:
+    async def _receive_frame(self) -> None:
         # Read the frame header.
-        data = yield self._read_bytes(2)
+        data = await self._read_bytes(2)
         header, mask_payloadlen = struct.unpack("BB", data)
         is_final_frame = header & self.FIN
         reserved_bits = header & self.RSV_MASK
@@ -1149,10 +1146,10 @@ class WebSocketProtocol13(WebSocketProtocol):
         if payloadlen < 126:
             self._frame_length = payloadlen
         elif payloadlen == 126:
-            data = yield self._read_bytes(2)
+            data = await self._read_bytes(2)
             payloadlen = struct.unpack("!H", data)[0]
         elif payloadlen == 127:
-            data = yield self._read_bytes(8)
+            data = await self._read_bytes(8)
             payloadlen = struct.unpack("!Q", data)[0]
         new_len = payloadlen
         if self._fragmented_message_buffer is not None:
@@ -1164,8 +1161,8 @@ class WebSocketProtocol13(WebSocketProtocol):
 
         # Read the payload, unmasking if necessary.
         if is_masked:
-            self._frame_mask = yield self._read_bytes(4)
-        data = yield self._read_bytes(payloadlen)
+            self._frame_mask = await self._read_bytes(4)
+        data = await self._read_bytes(payloadlen)
         if is_masked:
             assert self._frame_mask is not None
             data = _websocket_mask(self._frame_mask, data)
@@ -1201,7 +1198,7 @@ class WebSocketProtocol13(WebSocketProtocol):
         if is_final_frame:
             handled_future = self._handle_message(opcode, data)
             if handled_future is not None:
-                yield handled_future
+                await handled_future
 
     def _handle_message(self, opcode: int, data: bytes) -> Optional["Future[None]"]:
         """Execute on_message, returning its Future if it is a coroutine."""