]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
websocket: WebSocketHandler.on_message allows coroutines 1909/head
authorA. Jesse Jiryu Davis <jesse@mongodb.com>
Mon, 9 Jan 2017 02:43:42 +0000 (21:43 -0500)
committerA. Jesse Jiryu Davis <jesse@mongodb.com>
Sun, 26 Feb 2017 16:22:23 +0000 (11:22 -0500)
Fixes #1650

tornado/test/websocket_test.py
tornado/websocket.py

index 91f6692a9b304d60fe769a50aa0206e2f67c562a..48390e6c3bab3e0a7887fb78e4561886a5183801 100644 (file)
@@ -1,5 +1,6 @@
 from __future__ import absolute_import, division, print_function, with_statement
 
+import sys
 import traceback
 
 from tornado.concurrent import Future
@@ -7,7 +8,7 @@ from tornado import gen
 from tornado.httpclient import HTTPError, HTTPRequest
 from tornado.log import gen_log, app_log
 from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
-from tornado.test.util import unittest
+from tornado.test.util import unittest, skipBefore35, exec_test
 from tornado.web import Application, RequestHandler
 
 try:
@@ -92,6 +93,22 @@ class PathArgsHandler(TestWebSocketHandler):
         self.write_message(arg)
 
 
+class CoroutineOnMessageHandler(TestWebSocketHandler):
+    def initialize(self, close_future, compression_options=None):
+        super(CoroutineOnMessageHandler, self).initialize(close_future,
+                                                          compression_options)
+        self.sleeping = 0
+
+    @gen.coroutine
+    def on_message(self, message):
+        if self.sleeping > 0:
+            self.write_message('another coroutine is already sleeping')
+        self.sleeping += 1
+        yield gen.sleep(0.01)
+        self.sleeping -= 1
+        self.write_message(message)
+
+
 class WebSocketBaseTestCase(AsyncHTTPTestCase):
     @gen.coroutine
     def ws_connect(self, path, **kwargs):
@@ -126,6 +143,8 @@ class WebSocketTest(WebSocketBaseTestCase):
              dict(close_future=self.close_future)),
             ('/path_args/(.*)', PathArgsHandler,
              dict(close_future=self.close_future)),
+            ('/coroutine', CoroutineOnMessageHandler,
+             dict(close_future=self.close_future)),
         ])
 
     def test_http_request(self):
@@ -259,6 +278,17 @@ class WebSocketTest(WebSocketBaseTestCase):
         res = yield ws.read_message()
         self.assertEqual(res, 'hello')
 
+    @gen_test
+    def test_coroutine(self):
+        ws = yield self.ws_connect('/coroutine')
+        # Send both messages immediately, coroutine must process one at a time.
+        yield ws.write_message('hello1')
+        yield ws.write_message('hello2')
+        res = yield ws.read_message()
+        self.assertEqual(res, 'hello1')
+        res = yield ws.read_message()
+        self.assertEqual(res, 'hello2')
+
     @gen_test
     def test_check_origin_valid_no_path(self):
         port = self.get_http_port()
@@ -330,6 +360,42 @@ class WebSocketTest(WebSocketBaseTestCase):
         self.assertEqual(cm.exception.code, 403)
 
 
+if sys.version_info >= (3, 5):
+    NativeCoroutineOnMessageHandler = exec_test(globals(), locals(), """
+class NativeCoroutineOnMessageHandler(TestWebSocketHandler):
+    def initialize(self, close_future, compression_options=None):
+        super().initialize(close_future, compression_options)
+        self.sleeping = 0
+
+    async def on_message(self, message):
+        if self.sleeping > 0:
+            self.write_message('another coroutine is already sleeping')
+        self.sleeping += 1
+        await gen.sleep(0.01)
+        self.sleeping -= 1
+        self.write_message(message)""")['NativeCoroutineOnMessageHandler']
+
+
+class WebSocketNativeCoroutineTest(WebSocketBaseTestCase):
+    def get_app(self):
+        self.close_future = Future()
+        return Application([
+            ('/native', NativeCoroutineOnMessageHandler,
+             dict(close_future=self.close_future))])
+
+    @skipBefore35
+    @gen_test
+    def test_native_coroutine(self):
+        ws = yield self.ws_connect('/native')
+        # Send both messages immediately, coroutine must process one at a time.
+        yield ws.write_message('hello1')
+        yield ws.write_message('hello2')
+        res = yield ws.read_message()
+        self.assertEqual(res, 'hello1')
+        res = yield ws.read_message()
+        self.assertEqual(res, 'hello2')
+
+
 class CompressionTestMixin(object):
     MESSAGE = 'Hello world. Testing 123 123'
 
index 74358c90c164f27924309092eec2af60b77e9daf..754fca5cdc1540141e6578b51fb0d7afe056c7e5 100644 (file)
@@ -30,7 +30,7 @@ import zlib
 
 from tornado.concurrent import TracebackFuture
 from tornado.escape import utf8, native_str, to_unicode
-from tornado import httpclient, httputil
+from tornado import gen, httpclient, httputil
 from tornado.ioloop import IOLoop, PeriodicCallback
 from tornado.iostream import StreamClosedError
 from tornado.log import gen_log, app_log
@@ -269,6 +269,10 @@ class WebSocketHandler(tornado.web.RequestHandler):
         """Handle incoming messages on the WebSocket
 
         This method must be overridden.
+
+        .. versionchanged:: 4.5
+
+           ``on_message`` can be a coroutine.
         """
         raise NotImplementedError
 
@@ -442,14 +446,21 @@ class WebSocketProtocol(object):
     def _run_callback(self, callback, *args, **kwargs):
         """Runs the given callback with exception handling.
 
-        On error, aborts the websocket connection and returns False.
+        If the callback is a coroutine, returns its Future. On error, aborts the
+        websocket connection and returns None.
         """
         try:
-            callback(*args, **kwargs)
+            result = callback(*args, **kwargs)
         except Exception:
             app_log.error("Uncaught exception in %s",
                           getattr(self.request, 'path', None), exc_info=True)
             self._abort()
+        else:
+            if result is not None:
+                self.stream.io_loop.add_future(gen.convert_yielded(result),
+                                               lambda f: f.result())
+
+            return result
 
     def on_connection_close(self):
         self._abort()
@@ -810,6 +821,8 @@ class WebSocketProtocol13(WebSocketProtocol):
         self._on_frame_data(_websocket_mask(self._frame_mask, data))
 
     def _on_frame_data(self, data):
+        handled_future = None
+
         self._wire_bytes_in += len(data)
         if self._frame_opcode_is_control:
             # control frames may be interleaved with a series of fragmented
@@ -842,12 +855,18 @@ class WebSocketProtocol13(WebSocketProtocol):
                 self._fragmented_message_buffer = data
 
         if self._final_frame:
-            self._handle_message(opcode, data)
+            handled_future = self._handle_message(opcode, data)
 
         if not self.client_terminated:
-            self._receive_frame()
+            if handled_future:
+                # on_message is a coroutine, process more frames once it's done.
+                gen.convert_yielded(handled_future).add_done_callback(
+                    lambda future: self._receive_frame())
+            else:
+                self._receive_frame()
 
     def _handle_message(self, opcode, data):
+        """Execute on_message, returning its Future if it is a coroutine."""
         if self.client_terminated:
             return
 
@@ -862,11 +881,11 @@ class WebSocketProtocol13(WebSocketProtocol):
             except UnicodeDecodeError:
                 self._abort()
                 return
-            self._run_callback(self.handler.on_message, decoded)
+            return self._run_callback(self.handler.on_message, decoded)
         elif opcode == 0x2:
             # Binary data
             self._message_bytes_in += len(data)
-            self._run_callback(self.handler.on_message, data)
+            return self._run_callback(self.handler.on_message, data)
         elif opcode == 0x8:
             # Close
             self.client_terminated = True
@@ -883,7 +902,7 @@ class WebSocketProtocol13(WebSocketProtocol):
         elif opcode == 0xA:
             # Pong
             self.last_pong = IOLoop.current().time()
-            self._run_callback(self.handler.on_pong, data)
+            return self._run_callback(self.handler.on_pong, data)
         else:
             self._abort()