]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
tcpserver: handle_stream can be a native coroutine
authorA. Jesse Jiryu Davis <jesse@mongodb.com>
Mon, 5 Dec 2016 11:37:19 +0000 (06:37 -0500)
committerA. Jesse Jiryu Davis <jesse@mongodb.com>
Mon, 5 Dec 2016 11:40:02 +0000 (06:40 -0500)
tornado/tcpserver.py
tornado/test/tcpserver_test.py

index 54837f7a65352a664b0c34fc7d655581d7c22445..583b127dddf42865b9a9b83db3757a94692cc7d0 100644 (file)
@@ -21,6 +21,7 @@ import errno
 import os
 import socket
 
+from tornado import gen
 from tornado.log import app_log
 from tornado.ioloop import IOLoop
 from tornado.iostream import IOStream, SSLIOStream
@@ -285,8 +286,10 @@ class TCPServer(object):
                 stream = IOStream(connection, io_loop=self.io_loop,
                                   max_buffer_size=self.max_buffer_size,
                                   read_chunk_size=self.read_chunk_size)
+
             future = self.handle_stream(stream, address)
             if future is not None:
-                self.io_loop.add_future(future, lambda f: f.result())
+                self.io_loop.add_future(gen.convert_yielded(future),
+                                        lambda f: f.result())
         except Exception:
             app_log.error("Error in connection callback", exc_info=True)
index c01c04ddfb2baf903a76fbf5dfa182c3c2d21172..ba43c76c6bc5697b9f95bdaa145bce68c86045c5 100644 (file)
@@ -6,6 +6,7 @@ from tornado.iostream import IOStream
 from tornado.log import app_log
 from tornado.stack_context import NullContext
 from tornado.tcpserver import TCPServer
+from tornado.test.util import skipBefore35, exec_test
 from tornado.testing import AsyncTestCase, ExpectLog, bind_unused_port, gen_test
 
 
@@ -37,3 +38,25 @@ class TCPServerTest(AsyncTestCase):
                 server.stop()
             if client is not None:
                 client.close()
+
+    @skipBefore35
+    @gen_test
+    def test_handle_stream_native_coroutine(self):
+        # handle_stream may be a native coroutine.
+
+        namespace = exec_test(globals(), locals(), """
+        class TestServer(TCPServer):
+            async def handle_stream(self, stream, address):
+                stream.write(b'data')
+                stream.close()
+        """)
+
+        sock, port = bind_unused_port()
+        server = namespace['TestServer']()
+        server.add_socket(sock)
+        client = IOStream(socket.socket())
+        yield client.connect(('localhost', port))
+        result = yield client.read_until_close()
+        self.assertEqual(result, b'data')
+        server.stop()
+        client.close()