From 5b5635feec8a325ecbdf0abde26116e00a8d662a Mon Sep 17 00:00:00 2001 From: Ben Darnell Date: Mon, 24 Dec 2018 13:03:18 -0500 Subject: [PATCH] websocket: Convert legacy coroutines to native --- tornado/iostream.py | 11 ++++++++++- tornado/test/websocket_test.py | 3 +++ tornado/websocket.py | 35 ++++++++++++++++------------------ 3 files changed, 29 insertions(+), 20 deletions(-) diff --git a/tornado/iostream.py b/tornado/iostream.py index c57d77117..a4025d35f 100644 --- a/tornado/iostream.py +++ b/tornado/iostream.py @@ -23,6 +23,7 @@ Contents: * `PipeIOStream`: Pipe-based IOStream implementation. """ +import asyncio import collections import errno import io @@ -629,7 +630,13 @@ class BaseIOStream(object): for future in futures: if not future.done(): future.set_exception(StreamClosedError(real_error=self.error)) - future.exception() + # Reference the exception to silence warnings. Annoyingly, + # this raises if the future was cancelled, but just + # returns any other error. + try: + future.exception() + except asyncio.CancelledError: + pass if self._ssl_connect_future is not None: # _ssl_connect_future expects to see the real exception (typically # an ssl.SSLError), not just StreamClosedError. @@ -778,6 +785,8 @@ class BaseIOStream(object): pos = self._read_to_buffer_loop() except UnsatisfiableReadError: raise + except asyncio.CancelledError: + raise except Exception as e: gen_log.warning("error on read: %s" % e) self.close(exc_info=e) diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py index d63a665af..1a5a27281 100644 --- a/tornado/test/websocket_test.py +++ b/tornado/test/websocket_test.py @@ -1,3 +1,4 @@ +import asyncio import functools import traceback import unittest @@ -61,6 +62,8 @@ class EchoHandler(TestWebSocketHandler): def on_message(self, message): try: yield self.write_message(message, isinstance(message, bytes)) + except asyncio.CancelledError: + pass except WebSocketClosedError: pass diff --git a/tornado/websocket.py b/tornado/websocket.py index 8848ca809..c77b1999b 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -49,7 +49,6 @@ from typing import ( List, Awaitable, Callable, - Generator, Tuple, Type, ) @@ -712,7 +711,7 @@ class WebSocketProtocol(abc.ABC): raise NotImplementedError() @abc.abstractmethod - def _receive_frame_loop(self) -> "Future[None]": + async def _receive_frame_loop(self) -> None: raise NotImplementedError() @@ -1092,37 +1091,35 @@ class WebSocketProtocol13(WebSocketProtocol): except StreamClosedError: raise WebSocketClosedError() - @gen.coroutine - def wrapper() -> Generator[Any, Any, None]: + async def wrapper() -> None: try: - yield fut + await fut except StreamClosedError: raise WebSocketClosedError() - return wrapper() + return asyncio.ensure_future(wrapper()) def write_ping(self, data: bytes) -> None: """Send ping frame.""" assert isinstance(data, bytes) self._write_frame(True, 0x9, data) - @gen.coroutine - def _receive_frame_loop(self) -> Generator[Any, Any, None]: + async def _receive_frame_loop(self) -> None: try: while not self.client_terminated: - yield self._receive_frame() + await self._receive_frame() except StreamClosedError: self._abort() self.handler.on_ws_connection_close(self.close_code, self.close_reason) - def _read_bytes(self, n: int) -> Awaitable[bytes]: + async def _read_bytes(self, n: int) -> bytes: + data = await self.stream.read_bytes(n) self._wire_bytes_in += n - return self.stream.read_bytes(n) + return data - @gen.coroutine - def _receive_frame(self) -> Generator[Any, Any, None]: + async def _receive_frame(self) -> None: # Read the frame header. - data = yield self._read_bytes(2) + data = await self._read_bytes(2) header, mask_payloadlen = struct.unpack("BB", data) is_final_frame = header & self.FIN reserved_bits = header & self.RSV_MASK @@ -1149,10 +1146,10 @@ class WebSocketProtocol13(WebSocketProtocol): if payloadlen < 126: self._frame_length = payloadlen elif payloadlen == 126: - data = yield self._read_bytes(2) + data = await self._read_bytes(2) payloadlen = struct.unpack("!H", data)[0] elif payloadlen == 127: - data = yield self._read_bytes(8) + data = await self._read_bytes(8) payloadlen = struct.unpack("!Q", data)[0] new_len = payloadlen if self._fragmented_message_buffer is not None: @@ -1164,8 +1161,8 @@ class WebSocketProtocol13(WebSocketProtocol): # Read the payload, unmasking if necessary. if is_masked: - self._frame_mask = yield self._read_bytes(4) - data = yield self._read_bytes(payloadlen) + self._frame_mask = await self._read_bytes(4) + data = await self._read_bytes(payloadlen) if is_masked: assert self._frame_mask is not None data = _websocket_mask(self._frame_mask, data) @@ -1201,7 +1198,7 @@ class WebSocketProtocol13(WebSocketProtocol): if is_final_frame: handled_future = self._handle_message(opcode, data) if handled_future is not None: - yield handled_future + await handled_future def _handle_message(self, opcode: int, data: bytes) -> Optional["Future[None]"]: """Execute on_message, returning its Future if it is a coroutine.""" -- 2.47.2