from tornado.concurrent import Future
from tornado import gen
from tornado.httpclient import HTTPError, HTTPRequest
+from tornado.locks import Event
from tornado.log import gen_log, app_log
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.template import DictLoader
self.write_message("subprotocol=%s" % self.selected_subprotocol)
+class OpenCoroutineHandler(TestWebSocketHandler):
+ def initialize(self, test, **kwargs):
+ super(OpenCoroutineHandler, self).initialize(**kwargs)
+ self.test = test
+ self.open_finished = False
+
+ @gen.coroutine
+ def open(self):
+ yield self.test.message_sent.wait()
+ yield gen.sleep(0.010)
+ self.open_finished = True
+
+ def on_message(self, message):
+ if not self.open_finished:
+ raise Exception('on_message called before open finished')
+ self.write_message('ok')
+
+
class WebSocketBaseTestCase(AsyncHTTPTestCase):
@gen.coroutine
def ws_connect(self, path, **kwargs):
dict(close_future=self.close_future)),
('/subprotocol', SubprotocolHandler,
dict(close_future=self.close_future)),
+ ('/open_coroutine', OpenCoroutineHandler,
+ dict(close_future=self.close_future, test=self)),
], template_loader=DictLoader({
'message.html': '<b>{{ message }}</b>',
}))
self.assertEqual(res, 'subprotocol=None')
yield self.close(ws)
+ @gen_test
+ def test_open_coroutine(self):
+ self.message_sent = Event()
+ ws = yield self.ws_connect('/open_coroutine')
+ yield ws.write_message('hello')
+ self.message_sent.set()
+ res = yield ws.read_message()
+ self.assertEqual(res, 'ok')
+ yield self.close(ws)
+
if sys.version_info >= (3, 5):
NativeCoroutineOnMessageHandler = exec_test(globals(), locals(), """
The arguments to `open` are extracted from the `tornado.web.URLSpec`
regular expression, just like the arguments to
`tornado.web.RequestHandler.get`.
+
+ `open` may be a coroutine. `on_message` will not be called until
+ `open` has returned.
+
+ .. versionchanged:: 5.1
+
+ ``open`` may be a coroutine.
"""
pass
return WebSocketProtocol13.compute_accept_value(
self.request.headers.get("Sec-Websocket-Key"))
+ @gen.coroutine
def _accept_connection(self):
subprotocol_header = self.request.headers.get("Sec-WebSocket-Protocol")
if subprotocol_header:
self.stream = self.handler.stream
self.start_pinging()
- self._run_callback(self.handler.open, *self.handler.open_args,
- **self.handler.open_kwargs)
- IOLoop.current().add_callback(self._receive_frame_loop)
+ open_result = self._run_callback(self.handler.open, *self.handler.open_args,
+ **self.handler.open_kwargs)
+ if open_result is not None:
+ yield open_result
+ yield self._receive_frame_loop()
def _parse_extensions_header(self, headers):
extensions = headers.get("Sec-WebSocket-Extensions", '')