]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
websocket: Make WSH.get a coroutine
authorBen Darnell <ben@bendarnell.com>
Mon, 24 Dec 2018 16:57:33 +0000 (11:57 -0500)
committerBen Darnell <ben@bendarnell.com>
Sat, 29 Dec 2018 03:17:57 +0000 (22:17 -0500)
This is necessary to convert accept_connection to native coroutines -
the handshake no longer completes within a single IOLoop iteration
with this change due to coroutine scheduling.

This has the side effect of keeping the HTTP1Connection open for the
lifetime of the websocket connection. That's not great for memory, but
might help streamline close handling. Either way, it'll be refactored
in a future change.

tornado/websocket.py

index e72b2b456f2beddb63db9dc884c23c087dc6784a..8848ca8099162e170b8c45bb68bdcef35dde4588 100644 (file)
@@ -17,6 +17,7 @@ the protocol (known as "draft 76") and are not compatible with this module.
 """
 
 import abc
+import asyncio
 import base64
 import hashlib
 import os
@@ -228,7 +229,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
         self.stream = None  # type: Optional[IOStream]
         self._on_close_called = False
 
-    def get(self, *args: Any, **kwargs: Any) -> None:
+    async def get(self, *args: Any, **kwargs: Any) -> None:
         self.open_args = args
         self.open_kwargs = kwargs
 
@@ -275,11 +276,10 @@ class WebSocketHandler(tornado.web.RequestHandler):
 
         self.ws_connection = self.get_websocket_protocol()
         if self.ws_connection:
-            self.ws_connection.accept_connection(self)
+            await self.ws_connection.accept_connection(self)
         else:
             self.set_status(426, "Upgrade Required")
             self.set_header("Sec-WebSocket-Version", "7, 8, 13")
-            self.finish()
 
     stream = None
 
@@ -679,7 +679,7 @@ class WebSocketProtocol(abc.ABC):
         raise NotImplementedError()
 
     @abc.abstractmethod
-    def accept_connection(self, handler: WebSocketHandler) -> None:
+    async def accept_connection(self, handler: WebSocketHandler) -> None:
         raise NotImplementedError()
 
     @abc.abstractmethod
@@ -833,7 +833,7 @@ class WebSocketProtocol13(WebSocketProtocol):
         self._masked_frame = None
         self._frame_mask = None  # type: Optional[bytes]
         self._frame_length = None
-        self._fragmented_message_buffer = None
+        self._fragmented_message_buffer = None  # type: Optional[bytes]
         self._fragmented_message_opcode = None
         self._waiting = None  # type: object
         self._compression_options = params.compression_options
@@ -864,7 +864,7 @@ class WebSocketProtocol13(WebSocketProtocol):
     def selected_subprotocol(self, value: Optional[str]) -> None:
         self._selected_subprotocol = value
 
-    def accept_connection(self, handler: WebSocketHandler) -> None:
+    async def accept_connection(self, handler: WebSocketHandler) -> None:
         try:
             self._handle_websocket_headers(handler)
         except ValueError:
@@ -875,7 +875,10 @@ class WebSocketProtocol13(WebSocketProtocol):
             return
 
         try:
-            self._accept_connection(handler)
+            await self._accept_connection(handler)
+        except asyncio.CancelledError:
+            self._abort()
+            return
         except ValueError:
             gen_log.debug("Malformed WebSocket request received", exc_info=True)
             self._abort()
@@ -906,10 +909,7 @@ class WebSocketProtocol13(WebSocketProtocol):
             cast(str, handler.request.headers.get("Sec-Websocket-Key"))
         )
 
-    @gen.coroutine
-    def _accept_connection(
-        self, handler: WebSocketHandler
-    ) -> Generator[Any, Any, None]:
+    async def _accept_connection(self, handler: WebSocketHandler) -> None:
         subprotocol_header = handler.request.headers.get("Sec-WebSocket-Protocol")
         if subprotocol_header:
             subprotocols = [s.strip() for s in subprotocol_header.split(",")]
@@ -953,8 +953,8 @@ class WebSocketProtocol13(WebSocketProtocol):
             handler.open, *handler.open_args, **handler.open_kwargs
         )
         if open_result is not None:
-            yield open_result
-        yield self._receive_frame_loop()
+            await open_result
+        await self._receive_frame_loop()
 
     def _parse_extensions_header(
         self, headers: httputil.HTTPHeaders