From: Ben Darnell Date: Sat, 12 May 2018 19:13:15 +0000 (-0400) Subject: websocket: Allow open to be a coroutine X-Git-Tag: v5.1.0b1~18^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6d0026cf51d7d0390d71ec5069e28a9ca4356ec4;p=thirdparty%2Ftornado.git websocket: Allow open to be a coroutine Fixes #2358 --- diff --git a/docs/releases/v5.1.0.rst b/docs/releases/v5.1.0.rst index be2c244df..7dd73fead 100644 --- a/docs/releases/v5.1.0.rst +++ b/docs/releases/v5.1.0.rst @@ -141,6 +141,7 @@ Deprecation notice - The `.WebSocketHandler.select_subprotocol` method is now called with an empty list instead of a list containing an empty string if no subprotocols were requested by the client. +- `.WebSocketHandler.open` may now be a coroutine. - The ``data`` argument to `.WebSocketHandler.ping` is now optional. - Client-side websocket connections no longer buffer more than one message in memory at a time. diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py index ecb7123f9..a6439b9fb 100644 --- a/tornado/test/websocket_test.py +++ b/tornado/test/websocket_test.py @@ -7,6 +7,7 @@ import traceback from tornado.concurrent import Future from tornado import gen from tornado.httpclient import HTTPError, HTTPRequest +from tornado.locks import Event from tornado.log import gen_log, app_log from tornado.simple_httpclient import SimpleAsyncHTTPClient from tornado.template import DictLoader @@ -162,6 +163,24 @@ class SubprotocolHandler(TestWebSocketHandler): self.write_message("subprotocol=%s" % self.selected_subprotocol) +class OpenCoroutineHandler(TestWebSocketHandler): + def initialize(self, test, **kwargs): + super(OpenCoroutineHandler, self).initialize(**kwargs) + self.test = test + self.open_finished = False + + @gen.coroutine + def open(self): + yield self.test.message_sent.wait() + yield gen.sleep(0.010) + self.open_finished = True + + def on_message(self, message): + if not self.open_finished: + raise Exception('on_message called before open finished') + self.write_message('ok') + + class WebSocketBaseTestCase(AsyncHTTPTestCase): @gen.coroutine def ws_connect(self, path, **kwargs): @@ -204,6 +223,8 @@ class WebSocketTest(WebSocketBaseTestCase): dict(close_future=self.close_future)), ('/subprotocol', SubprotocolHandler, dict(close_future=self.close_future)), + ('/open_coroutine', OpenCoroutineHandler, + dict(close_future=self.close_future, test=self)), ], template_loader=DictLoader({ 'message.html': '{{ message }}', })) @@ -480,6 +501,16 @@ class WebSocketTest(WebSocketBaseTestCase): self.assertEqual(res, 'subprotocol=None') yield self.close(ws) + @gen_test + def test_open_coroutine(self): + self.message_sent = Event() + ws = yield self.ws_connect('/open_coroutine') + yield ws.write_message('hello') + self.message_sent.set() + res = yield ws.read_message() + self.assertEqual(res, 'ok') + yield self.close(ws) + if sys.version_info >= (3, 5): NativeCoroutineOnMessageHandler = exec_test(globals(), locals(), """ diff --git a/tornado/websocket.py b/tornado/websocket.py index e77e6623d..87119bdc7 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -318,6 +318,13 @@ class WebSocketHandler(tornado.web.RequestHandler): The arguments to `open` are extracted from the `tornado.web.URLSpec` regular expression, just like the arguments to `tornado.web.RequestHandler.get`. + + `open` may be a coroutine. `on_message` will not be called until + `open` has returned. + + .. versionchanged:: 5.1 + + ``open`` may be a coroutine. """ pass @@ -694,6 +701,7 @@ class WebSocketProtocol13(WebSocketProtocol): return WebSocketProtocol13.compute_accept_value( self.request.headers.get("Sec-Websocket-Key")) + @gen.coroutine def _accept_connection(self): subprotocol_header = self.request.headers.get("Sec-WebSocket-Protocol") if subprotocol_header: @@ -733,9 +741,11 @@ class WebSocketProtocol13(WebSocketProtocol): self.stream = self.handler.stream self.start_pinging() - self._run_callback(self.handler.open, *self.handler.open_args, - **self.handler.open_kwargs) - IOLoop.current().add_callback(self._receive_frame_loop) + open_result = self._run_callback(self.handler.open, *self.handler.open_args, + **self.handler.open_kwargs) + if open_result is not None: + yield open_result + yield self._receive_frame_loop() def _parse_extensions_header(self, headers): extensions = headers.get("Sec-WebSocket-Extensions", '')