]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
websocket: Merge close detection into receive_frame_loop
authorBen Darnell <ben@bendarnell.com>
Sun, 21 Oct 2018 18:26:43 +0000 (14:26 -0400)
committerBen Darnell <ben@bendarnell.com>
Sat, 29 Dec 2018 03:17:57 +0000 (22:17 -0500)
This avoids races between message handling and the close callback.

tornado/websocket.py

index 33629bd4c2b35bd4b397fe1babe202504056d677..efede9b51ddcd6270d46761d4878232986341f26 100644 (file)
@@ -31,7 +31,7 @@ from tornado.concurrent import Future, future_set_result_unless_cancelled
 from tornado.escape import utf8, native_str, to_unicode
 from tornado import gen, httpclient, httputil
 from tornado.ioloop import IOLoop, PeriodicCallback
-from tornado.iostream import StreamClosedError
+from tornado.iostream import StreamClosedError, IOStream
 from tornado.log import gen_log, app_log
 from tornado import simple_httpclient
 from tornado.queues import Queue
@@ -55,7 +55,6 @@ from typing import (
 from types import TracebackType
 
 if TYPE_CHECKING:
-    from tornado.iostream import IOStream  # noqa: F401
     from typing_extensions import Protocol
 
     # The zlib compressor types aren't actually exposed anywhere
@@ -77,10 +76,6 @@ if TYPE_CHECKING:
         # The common base interface implemented by WebSocketHandler on
         # the server side and WebSocketClientConnection on the client
         # side.
-        @property
-        def stream(self) -> Optional[IOStream]:
-            pass
-
         @property
         def ping_interval(self) -> Optional[float]:
             pass
@@ -109,6 +104,9 @@ if TYPE_CHECKING:
         def close_reason(self, value: Optional[str]) -> None:
             pass
 
+        def on_connection_close(self) -> None:
+            pass
+
         def on_message(self, message: Union[str, bytes]) -> Optional["Awaitable[None]"]:
             pass
 
@@ -612,9 +610,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
             )
         return None
 
-    def _attach_stream(self) -> None:
-        self.stream = self.detach()
-        self.stream.set_close_callback(self.on_connection_close)
+    def _detach_stream(self) -> IOStream:
         # disable non-WS methods
         for method in [
             "write",
@@ -626,6 +622,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
             "finish",
         ]:
             setattr(self, method, _raise_not_supported_for_websockets)
+        return self.detach()
 
 
 def _raise_not_supported_for_websockets(*args: Any, **kwargs: Any) -> None:
@@ -638,7 +635,7 @@ class WebSocketProtocol(abc.ABC):
 
     def __init__(self, handler: "_WebSocketConnection") -> None:
         self.handler = handler
-        self.stream = handler.stream
+        self.stream = None  # type: Optional[IOStream]
         self.client_terminated = False
         self.server_terminated = False
 
@@ -947,9 +944,7 @@ class WebSocketProtocol13(WebSocketProtocol):
         handler.set_header("Sec-WebSocket-Accept", self._challenge_response(handler))
         handler.finish()
 
-        handler._attach_stream()
-        assert handler.stream is not None
-        self.stream = handler.stream
+        self.stream = handler._detach_stream()
 
         self.start_pinging()
         open_result = self._run_callback(
@@ -1116,6 +1111,7 @@ class WebSocketProtocol13(WebSocketProtocol):
                 yield self._receive_frame()
         except StreamClosedError:
             self._abort()
+        self.handler.on_connection_close()
 
     def _read_bytes(self, n: int) -> Awaitable[bytes]:
         self._wire_bytes_in += n
@@ -1456,18 +1452,18 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
             )
             return
 
+        if self._timeout is not None:
+            self.io_loop.remove_timeout(self._timeout)
+            self._timeout = None
+
         self.headers = headers
         self.protocol = self.get_websocket_protocol()
         self.protocol._process_server_headers(self.key, self.headers)
-        self.protocol.start_pinging()
-        IOLoop.current().add_callback(self.protocol._receive_frame_loop)
+        self.protocol.stream = self.connection.detach()
 
-        if self._timeout is not None:
-            self.io_loop.remove_timeout(self._timeout)
-            self._timeout = None
+        IOLoop.current().add_callback(self.protocol._receive_frame_loop)
+        self.protocol.start_pinging()
 
-        self.stream = self.connection.detach()
-        self.stream.set_close_callback(self.on_connection_close)
         # Once we've taken over the connection, clear the final callback
         # we set on the http request.  This deactivates the error handling
         # in simple_httpclient that would otherwise interfere with our