From: A. Jesse Jiryu Davis Date: Mon, 9 Jan 2017 02:43:42 +0000 (-0500) Subject: websocket: WebSocketHandler.on_message allows coroutines X-Git-Tag: v4.5.0~29^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=refs%2Fpull%2F1909%2Fhead;p=thirdparty%2Ftornado.git websocket: WebSocketHandler.on_message allows coroutines Fixes #1650 --- diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py index 91f6692a9..48390e6c3 100644 --- a/tornado/test/websocket_test.py +++ b/tornado/test/websocket_test.py @@ -1,5 +1,6 @@ from __future__ import absolute_import, division, print_function, with_statement +import sys import traceback from tornado.concurrent import Future @@ -7,7 +8,7 @@ from tornado import gen from tornado.httpclient import HTTPError, HTTPRequest from tornado.log import gen_log, app_log from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog -from tornado.test.util import unittest +from tornado.test.util import unittest, skipBefore35, exec_test from tornado.web import Application, RequestHandler try: @@ -92,6 +93,22 @@ class PathArgsHandler(TestWebSocketHandler): self.write_message(arg) +class CoroutineOnMessageHandler(TestWebSocketHandler): + def initialize(self, close_future, compression_options=None): + super(CoroutineOnMessageHandler, self).initialize(close_future, + compression_options) + self.sleeping = 0 + + @gen.coroutine + def on_message(self, message): + if self.sleeping > 0: + self.write_message('another coroutine is already sleeping') + self.sleeping += 1 + yield gen.sleep(0.01) + self.sleeping -= 1 + self.write_message(message) + + class WebSocketBaseTestCase(AsyncHTTPTestCase): @gen.coroutine def ws_connect(self, path, **kwargs): @@ -126,6 +143,8 @@ class WebSocketTest(WebSocketBaseTestCase): dict(close_future=self.close_future)), ('/path_args/(.*)', PathArgsHandler, dict(close_future=self.close_future)), + ('/coroutine', CoroutineOnMessageHandler, + dict(close_future=self.close_future)), ]) def test_http_request(self): @@ -259,6 +278,17 @@ class WebSocketTest(WebSocketBaseTestCase): res = yield ws.read_message() self.assertEqual(res, 'hello') + @gen_test + def test_coroutine(self): + ws = yield self.ws_connect('/coroutine') + # Send both messages immediately, coroutine must process one at a time. + yield ws.write_message('hello1') + yield ws.write_message('hello2') + res = yield ws.read_message() + self.assertEqual(res, 'hello1') + res = yield ws.read_message() + self.assertEqual(res, 'hello2') + @gen_test def test_check_origin_valid_no_path(self): port = self.get_http_port() @@ -330,6 +360,42 @@ class WebSocketTest(WebSocketBaseTestCase): self.assertEqual(cm.exception.code, 403) +if sys.version_info >= (3, 5): + NativeCoroutineOnMessageHandler = exec_test(globals(), locals(), """ +class NativeCoroutineOnMessageHandler(TestWebSocketHandler): + def initialize(self, close_future, compression_options=None): + super().initialize(close_future, compression_options) + self.sleeping = 0 + + async def on_message(self, message): + if self.sleeping > 0: + self.write_message('another coroutine is already sleeping') + self.sleeping += 1 + await gen.sleep(0.01) + self.sleeping -= 1 + self.write_message(message)""")['NativeCoroutineOnMessageHandler'] + + +class WebSocketNativeCoroutineTest(WebSocketBaseTestCase): + def get_app(self): + self.close_future = Future() + return Application([ + ('/native', NativeCoroutineOnMessageHandler, + dict(close_future=self.close_future))]) + + @skipBefore35 + @gen_test + def test_native_coroutine(self): + ws = yield self.ws_connect('/native') + # Send both messages immediately, coroutine must process one at a time. + yield ws.write_message('hello1') + yield ws.write_message('hello2') + res = yield ws.read_message() + self.assertEqual(res, 'hello1') + res = yield ws.read_message() + self.assertEqual(res, 'hello2') + + class CompressionTestMixin(object): MESSAGE = 'Hello world. Testing 123 123' diff --git a/tornado/websocket.py b/tornado/websocket.py index 74358c90c..754fca5cd 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -30,7 +30,7 @@ import zlib from tornado.concurrent import TracebackFuture from tornado.escape import utf8, native_str, to_unicode -from tornado import httpclient, httputil +from tornado import gen, httpclient, httputil from tornado.ioloop import IOLoop, PeriodicCallback from tornado.iostream import StreamClosedError from tornado.log import gen_log, app_log @@ -269,6 +269,10 @@ class WebSocketHandler(tornado.web.RequestHandler): """Handle incoming messages on the WebSocket This method must be overridden. + + .. versionchanged:: 4.5 + + ``on_message`` can be a coroutine. """ raise NotImplementedError @@ -442,14 +446,21 @@ class WebSocketProtocol(object): def _run_callback(self, callback, *args, **kwargs): """Runs the given callback with exception handling. - On error, aborts the websocket connection and returns False. + If the callback is a coroutine, returns its Future. On error, aborts the + websocket connection and returns None. """ try: - callback(*args, **kwargs) + result = callback(*args, **kwargs) except Exception: app_log.error("Uncaught exception in %s", getattr(self.request, 'path', None), exc_info=True) self._abort() + else: + if result is not None: + self.stream.io_loop.add_future(gen.convert_yielded(result), + lambda f: f.result()) + + return result def on_connection_close(self): self._abort() @@ -810,6 +821,8 @@ class WebSocketProtocol13(WebSocketProtocol): self._on_frame_data(_websocket_mask(self._frame_mask, data)) def _on_frame_data(self, data): + handled_future = None + self._wire_bytes_in += len(data) if self._frame_opcode_is_control: # control frames may be interleaved with a series of fragmented @@ -842,12 +855,18 @@ class WebSocketProtocol13(WebSocketProtocol): self._fragmented_message_buffer = data if self._final_frame: - self._handle_message(opcode, data) + handled_future = self._handle_message(opcode, data) if not self.client_terminated: - self._receive_frame() + if handled_future: + # on_message is a coroutine, process more frames once it's done. + gen.convert_yielded(handled_future).add_done_callback( + lambda future: self._receive_frame()) + else: + self._receive_frame() def _handle_message(self, opcode, data): + """Execute on_message, returning its Future if it is a coroutine.""" if self.client_terminated: return @@ -862,11 +881,11 @@ class WebSocketProtocol13(WebSocketProtocol): except UnicodeDecodeError: self._abort() return - self._run_callback(self.handler.on_message, decoded) + return self._run_callback(self.handler.on_message, decoded) elif opcode == 0x2: # Binary data self._message_bytes_in += len(data) - self._run_callback(self.handler.on_message, data) + return self._run_callback(self.handler.on_message, data) elif opcode == 0x8: # Close self.client_terminated = True @@ -883,7 +902,7 @@ class WebSocketProtocol13(WebSocketProtocol): elif opcode == 0xA: # Pong self.last_pong = IOLoop.current().time() - self._run_callback(self.handler.on_pong, data) + return self._run_callback(self.handler.on_pong, data) else: self._abort()