from __future__ import absolute_import, division, print_function, with_statement
+import sys
import traceback
from tornado.concurrent import Future
from tornado.httpclient import HTTPError, HTTPRequest
from tornado.log import gen_log, app_log
from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
-from tornado.test.util import unittest
+from tornado.test.util import unittest, skipBefore35, exec_test
from tornado.web import Application, RequestHandler
try:
self.write_message(arg)
+class CoroutineOnMessageHandler(TestWebSocketHandler):
+ def initialize(self, close_future, compression_options=None):
+ super(CoroutineOnMessageHandler, self).initialize(close_future,
+ compression_options)
+ self.sleeping = 0
+
+ @gen.coroutine
+ def on_message(self, message):
+ if self.sleeping > 0:
+ self.write_message('another coroutine is already sleeping')
+ self.sleeping += 1
+ yield gen.sleep(0.01)
+ self.sleeping -= 1
+ self.write_message(message)
+
+
class WebSocketBaseTestCase(AsyncHTTPTestCase):
@gen.coroutine
def ws_connect(self, path, **kwargs):
dict(close_future=self.close_future)),
('/path_args/(.*)', PathArgsHandler,
dict(close_future=self.close_future)),
+ ('/coroutine', CoroutineOnMessageHandler,
+ dict(close_future=self.close_future)),
])
def test_http_request(self):
res = yield ws.read_message()
self.assertEqual(res, 'hello')
+ @gen_test
+ def test_coroutine(self):
+ ws = yield self.ws_connect('/coroutine')
+ # Send both messages immediately, coroutine must process one at a time.
+ yield ws.write_message('hello1')
+ yield ws.write_message('hello2')
+ res = yield ws.read_message()
+ self.assertEqual(res, 'hello1')
+ res = yield ws.read_message()
+ self.assertEqual(res, 'hello2')
+
@gen_test
def test_check_origin_valid_no_path(self):
port = self.get_http_port()
self.assertEqual(cm.exception.code, 403)
+if sys.version_info >= (3, 5):
+ NativeCoroutineOnMessageHandler = exec_test(globals(), locals(), """
+class NativeCoroutineOnMessageHandler(TestWebSocketHandler):
+ def initialize(self, close_future, compression_options=None):
+ super().initialize(close_future, compression_options)
+ self.sleeping = 0
+
+ async def on_message(self, message):
+ if self.sleeping > 0:
+ self.write_message('another coroutine is already sleeping')
+ self.sleeping += 1
+ await gen.sleep(0.01)
+ self.sleeping -= 1
+ self.write_message(message)""")['NativeCoroutineOnMessageHandler']
+
+
+class WebSocketNativeCoroutineTest(WebSocketBaseTestCase):
+ def get_app(self):
+ self.close_future = Future()
+ return Application([
+ ('/native', NativeCoroutineOnMessageHandler,
+ dict(close_future=self.close_future))])
+
+ @skipBefore35
+ @gen_test
+ def test_native_coroutine(self):
+ ws = yield self.ws_connect('/native')
+ # Send both messages immediately, coroutine must process one at a time.
+ yield ws.write_message('hello1')
+ yield ws.write_message('hello2')
+ res = yield ws.read_message()
+ self.assertEqual(res, 'hello1')
+ res = yield ws.read_message()
+ self.assertEqual(res, 'hello2')
+
+
class CompressionTestMixin(object):
MESSAGE = 'Hello world. Testing 123 123'
from tornado.concurrent import TracebackFuture
from tornado.escape import utf8, native_str, to_unicode
-from tornado import httpclient, httputil
+from tornado import gen, httpclient, httputil
from tornado.ioloop import IOLoop, PeriodicCallback
from tornado.iostream import StreamClosedError
from tornado.log import gen_log, app_log
"""Handle incoming messages on the WebSocket
This method must be overridden.
+
+ .. versionchanged:: 4.5
+
+ ``on_message`` can be a coroutine.
"""
raise NotImplementedError
def _run_callback(self, callback, *args, **kwargs):
"""Runs the given callback with exception handling.
- On error, aborts the websocket connection and returns False.
+ If the callback is a coroutine, returns its Future. On error, aborts the
+ websocket connection and returns None.
"""
try:
- callback(*args, **kwargs)
+ result = callback(*args, **kwargs)
except Exception:
app_log.error("Uncaught exception in %s",
getattr(self.request, 'path', None), exc_info=True)
self._abort()
+ else:
+ if result is not None:
+ self.stream.io_loop.add_future(gen.convert_yielded(result),
+ lambda f: f.result())
+
+ return result
def on_connection_close(self):
self._abort()
self._on_frame_data(_websocket_mask(self._frame_mask, data))
def _on_frame_data(self, data):
+ handled_future = None
+
self._wire_bytes_in += len(data)
if self._frame_opcode_is_control:
# control frames may be interleaved with a series of fragmented
self._fragmented_message_buffer = data
if self._final_frame:
- self._handle_message(opcode, data)
+ handled_future = self._handle_message(opcode, data)
if not self.client_terminated:
- self._receive_frame()
+ if handled_future:
+ # on_message is a coroutine, process more frames once it's done.
+ gen.convert_yielded(handled_future).add_done_callback(
+ lambda future: self._receive_frame())
+ else:
+ self._receive_frame()
def _handle_message(self, opcode, data):
+ """Execute on_message, returning its Future if it is a coroutine."""
if self.client_terminated:
return
except UnicodeDecodeError:
self._abort()
return
- self._run_callback(self.handler.on_message, decoded)
+ return self._run_callback(self.handler.on_message, decoded)
elif opcode == 0x2:
# Binary data
self._message_bytes_in += len(data)
- self._run_callback(self.handler.on_message, data)
+ return self._run_callback(self.handler.on_message, data)
elif opcode == 0x8:
# Close
self.client_terminated = True
elif opcode == 0xA:
# Pong
self.last_pong = IOLoop.current().time()
- self._run_callback(self.handler.on_pong, data)
+ return self._run_callback(self.handler.on_pong, data)
else:
self._abort()