]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
websocket: Allow open to be a coroutine 2385/head
authorBen Darnell <ben@bendarnell.com>
Sat, 12 May 2018 19:13:15 +0000 (15:13 -0400)
committerBen Darnell <ben@bendarnell.com>
Sat, 12 May 2018 19:13:15 +0000 (15:13 -0400)
Fixes #2358

docs/releases/v5.1.0.rst
tornado/test/websocket_test.py
tornado/websocket.py

index be2c244dfe54a0a289986e69ab07be0bff958fe6..7dd73feade8fafd6b70179456c9ad2cea666eeda 100644 (file)
@@ -141,6 +141,7 @@ Deprecation notice
 - The `.WebSocketHandler.select_subprotocol` method is now called with
   an empty list instead of a list containing an empty string if no
   subprotocols were requested by the client.
+- `.WebSocketHandler.open` may now be a coroutine.
 - The ``data`` argument to `.WebSocketHandler.ping` is now optional.
 - Client-side websocket connections no longer buffer more than one
   message in memory at a time.
index ecb7123f977ac5e481b0b32f12a733f22c4e5f31..a6439b9fb7ffcfdc0ee11c962e8c72db7b78e1e3 100644 (file)
@@ -7,6 +7,7 @@ import traceback
 from tornado.concurrent import Future
 from tornado import gen
 from tornado.httpclient import HTTPError, HTTPRequest
+from tornado.locks import Event
 from tornado.log import gen_log, app_log
 from tornado.simple_httpclient import SimpleAsyncHTTPClient
 from tornado.template import DictLoader
@@ -162,6 +163,24 @@ class SubprotocolHandler(TestWebSocketHandler):
         self.write_message("subprotocol=%s" % self.selected_subprotocol)
 
 
+class OpenCoroutineHandler(TestWebSocketHandler):
+    def initialize(self, test, **kwargs):
+        super(OpenCoroutineHandler, self).initialize(**kwargs)
+        self.test = test
+        self.open_finished = False
+
+    @gen.coroutine
+    def open(self):
+        yield self.test.message_sent.wait()
+        yield gen.sleep(0.010)
+        self.open_finished = True
+
+    def on_message(self, message):
+        if not self.open_finished:
+            raise Exception('on_message called before open finished')
+        self.write_message('ok')
+
+
 class WebSocketBaseTestCase(AsyncHTTPTestCase):
     @gen.coroutine
     def ws_connect(self, path, **kwargs):
@@ -204,6 +223,8 @@ class WebSocketTest(WebSocketBaseTestCase):
              dict(close_future=self.close_future)),
             ('/subprotocol', SubprotocolHandler,
              dict(close_future=self.close_future)),
+            ('/open_coroutine', OpenCoroutineHandler,
+             dict(close_future=self.close_future, test=self)),
         ], template_loader=DictLoader({
             'message.html': '<b>{{ message }}</b>',
         }))
@@ -480,6 +501,16 @@ class WebSocketTest(WebSocketBaseTestCase):
         self.assertEqual(res, 'subprotocol=None')
         yield self.close(ws)
 
+    @gen_test
+    def test_open_coroutine(self):
+        self.message_sent = Event()
+        ws = yield self.ws_connect('/open_coroutine')
+        yield ws.write_message('hello')
+        self.message_sent.set()
+        res = yield ws.read_message()
+        self.assertEqual(res, 'ok')
+        yield self.close(ws)
+
 
 if sys.version_info >= (3, 5):
     NativeCoroutineOnMessageHandler = exec_test(globals(), locals(), """
index e77e6623dc36f4f2c937a51e2559ba00234e3bc9..87119bdc732b5da1736de7bec8c4e2f65108270c 100644 (file)
@@ -318,6 +318,13 @@ class WebSocketHandler(tornado.web.RequestHandler):
         The arguments to `open` are extracted from the `tornado.web.URLSpec`
         regular expression, just like the arguments to
         `tornado.web.RequestHandler.get`.
+
+        `open` may be a coroutine. `on_message` will not be called until
+        `open` has returned.
+
+        .. versionchanged:: 5.1
+
+           ``open`` may be a coroutine.
         """
         pass
 
@@ -694,6 +701,7 @@ class WebSocketProtocol13(WebSocketProtocol):
         return WebSocketProtocol13.compute_accept_value(
             self.request.headers.get("Sec-Websocket-Key"))
 
+    @gen.coroutine
     def _accept_connection(self):
         subprotocol_header = self.request.headers.get("Sec-WebSocket-Protocol")
         if subprotocol_header:
@@ -733,9 +741,11 @@ class WebSocketProtocol13(WebSocketProtocol):
         self.stream = self.handler.stream
 
         self.start_pinging()
-        self._run_callback(self.handler.open, *self.handler.open_args,
-                           **self.handler.open_kwargs)
-        IOLoop.current().add_callback(self._receive_frame_loop)
+        open_result = self._run_callback(self.handler.open, *self.handler.open_args,
+                                         **self.handler.open_kwargs)
+        if open_result is not None:
+            yield open_result
+        yield self._receive_frame_loop()
 
     def _parse_extensions_header(self, headers):
         extensions = headers.get("Sec-WebSocket-Extensions", '')