From: Ben Darnell Date: Mon, 24 Dec 2018 16:57:33 +0000 (-0500) Subject: websocket: Make WSH.get a coroutine X-Git-Tag: v6.0.0b1~12^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e69becbf8c994be5bfa6584ef72df594b27934b5;p=thirdparty%2Ftornado.git websocket: Make WSH.get a coroutine This is necessary to convert accept_connection to native coroutines - the handshake no longer completes within a single IOLoop iteration with this change due to coroutine scheduling. This has the side effect of keeping the HTTP1Connection open for the lifetime of the websocket connection. That's not great for memory, but might help streamline close handling. Either way, it'll be refactored in a future change. --- diff --git a/tornado/websocket.py b/tornado/websocket.py index e72b2b456..8848ca809 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -17,6 +17,7 @@ the protocol (known as "draft 76") and are not compatible with this module. """ import abc +import asyncio import base64 import hashlib import os @@ -228,7 +229,7 @@ class WebSocketHandler(tornado.web.RequestHandler): self.stream = None # type: Optional[IOStream] self._on_close_called = False - def get(self, *args: Any, **kwargs: Any) -> None: + async def get(self, *args: Any, **kwargs: Any) -> None: self.open_args = args self.open_kwargs = kwargs @@ -275,11 +276,10 @@ class WebSocketHandler(tornado.web.RequestHandler): self.ws_connection = self.get_websocket_protocol() if self.ws_connection: - self.ws_connection.accept_connection(self) + await self.ws_connection.accept_connection(self) else: self.set_status(426, "Upgrade Required") self.set_header("Sec-WebSocket-Version", "7, 8, 13") - self.finish() stream = None @@ -679,7 +679,7 @@ class WebSocketProtocol(abc.ABC): raise NotImplementedError() @abc.abstractmethod - def accept_connection(self, handler: WebSocketHandler) -> None: + async def accept_connection(self, handler: WebSocketHandler) -> None: raise NotImplementedError() @abc.abstractmethod @@ -833,7 +833,7 @@ class WebSocketProtocol13(WebSocketProtocol): self._masked_frame = None self._frame_mask = None # type: Optional[bytes] self._frame_length = None - self._fragmented_message_buffer = None + self._fragmented_message_buffer = None # type: Optional[bytes] self._fragmented_message_opcode = None self._waiting = None # type: object self._compression_options = params.compression_options @@ -864,7 +864,7 @@ class WebSocketProtocol13(WebSocketProtocol): def selected_subprotocol(self, value: Optional[str]) -> None: self._selected_subprotocol = value - def accept_connection(self, handler: WebSocketHandler) -> None: + async def accept_connection(self, handler: WebSocketHandler) -> None: try: self._handle_websocket_headers(handler) except ValueError: @@ -875,7 +875,10 @@ class WebSocketProtocol13(WebSocketProtocol): return try: - self._accept_connection(handler) + await self._accept_connection(handler) + except asyncio.CancelledError: + self._abort() + return except ValueError: gen_log.debug("Malformed WebSocket request received", exc_info=True) self._abort() @@ -906,10 +909,7 @@ class WebSocketProtocol13(WebSocketProtocol): cast(str, handler.request.headers.get("Sec-Websocket-Key")) ) - @gen.coroutine - def _accept_connection( - self, handler: WebSocketHandler - ) -> Generator[Any, Any, None]: + async def _accept_connection(self, handler: WebSocketHandler) -> None: subprotocol_header = handler.request.headers.get("Sec-WebSocket-Protocol") if subprotocol_header: subprotocols = [s.strip() for s in subprotocol_header.split(",")] @@ -953,8 +953,8 @@ class WebSocketProtocol13(WebSocketProtocol): handler.open, *handler.open_args, **handler.open_kwargs ) if open_result is not None: - yield open_result - yield self._receive_frame_loop() + await open_result + await self._receive_frame_loop() def _parse_extensions_header( self, headers: httputil.HTTPHeaders