From: Ben Darnell Date: Sun, 21 Oct 2018 18:26:43 +0000 (-0400) Subject: websocket: Merge close detection into receive_frame_loop X-Git-Tag: v6.0.0b1~12^2~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a9b9a1485731649e3cac597d37003017eadbda9c;p=thirdparty%2Ftornado.git websocket: Merge close detection into receive_frame_loop This avoids races between message handling and the close callback. --- diff --git a/tornado/websocket.py b/tornado/websocket.py index 33629bd4c..efede9b51 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -31,7 +31,7 @@ from tornado.concurrent import Future, future_set_result_unless_cancelled from tornado.escape import utf8, native_str, to_unicode from tornado import gen, httpclient, httputil from tornado.ioloop import IOLoop, PeriodicCallback -from tornado.iostream import StreamClosedError +from tornado.iostream import StreamClosedError, IOStream from tornado.log import gen_log, app_log from tornado import simple_httpclient from tornado.queues import Queue @@ -55,7 +55,6 @@ from typing import ( from types import TracebackType if TYPE_CHECKING: - from tornado.iostream import IOStream # noqa: F401 from typing_extensions import Protocol # The zlib compressor types aren't actually exposed anywhere @@ -77,10 +76,6 @@ if TYPE_CHECKING: # The common base interface implemented by WebSocketHandler on # the server side and WebSocketClientConnection on the client # side. - @property - def stream(self) -> Optional[IOStream]: - pass - @property def ping_interval(self) -> Optional[float]: pass @@ -109,6 +104,9 @@ if TYPE_CHECKING: def close_reason(self, value: Optional[str]) -> None: pass + def on_connection_close(self) -> None: + pass + def on_message(self, message: Union[str, bytes]) -> Optional["Awaitable[None]"]: pass @@ -612,9 +610,7 @@ class WebSocketHandler(tornado.web.RequestHandler): ) return None - def _attach_stream(self) -> None: - self.stream = self.detach() - self.stream.set_close_callback(self.on_connection_close) + def _detach_stream(self) -> IOStream: # disable non-WS methods for method in [ "write", @@ -626,6 +622,7 @@ class WebSocketHandler(tornado.web.RequestHandler): "finish", ]: setattr(self, method, _raise_not_supported_for_websockets) + return self.detach() def _raise_not_supported_for_websockets(*args: Any, **kwargs: Any) -> None: @@ -638,7 +635,7 @@ class WebSocketProtocol(abc.ABC): def __init__(self, handler: "_WebSocketConnection") -> None: self.handler = handler - self.stream = handler.stream + self.stream = None # type: Optional[IOStream] self.client_terminated = False self.server_terminated = False @@ -947,9 +944,7 @@ class WebSocketProtocol13(WebSocketProtocol): handler.set_header("Sec-WebSocket-Accept", self._challenge_response(handler)) handler.finish() - handler._attach_stream() - assert handler.stream is not None - self.stream = handler.stream + self.stream = handler._detach_stream() self.start_pinging() open_result = self._run_callback( @@ -1116,6 +1111,7 @@ class WebSocketProtocol13(WebSocketProtocol): yield self._receive_frame() except StreamClosedError: self._abort() + self.handler.on_connection_close() def _read_bytes(self, n: int) -> Awaitable[bytes]: self._wire_bytes_in += n @@ -1456,18 +1452,18 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): ) return + if self._timeout is not None: + self.io_loop.remove_timeout(self._timeout) + self._timeout = None + self.headers = headers self.protocol = self.get_websocket_protocol() self.protocol._process_server_headers(self.key, self.headers) - self.protocol.start_pinging() - IOLoop.current().add_callback(self.protocol._receive_frame_loop) + self.protocol.stream = self.connection.detach() - if self._timeout is not None: - self.io_loop.remove_timeout(self._timeout) - self._timeout = None + IOLoop.current().add_callback(self.protocol._receive_frame_loop) + self.protocol.start_pinging() - self.stream = self.connection.detach() - self.stream.set_close_callback(self.on_connection_close) # Once we've taken over the connection, clear the final callback # we set on the http request. This deactivates the error handling # in simple_httpclient that would otherwise interfere with our