From: A. Jesse Jiryu Davis Date: Mon, 5 Dec 2016 11:37:19 +0000 (-0500) Subject: tcpserver: handle_stream can be a native coroutine X-Git-Tag: v4.5.0~49^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e09a6f82f8c7b42352a55a384f7f0f90dbe436d0;p=thirdparty%2Ftornado.git tcpserver: handle_stream can be a native coroutine --- diff --git a/tornado/tcpserver.py b/tornado/tcpserver.py index 54837f7a6..583b127dd 100644 --- a/tornado/tcpserver.py +++ b/tornado/tcpserver.py @@ -21,6 +21,7 @@ import errno import os import socket +from tornado import gen from tornado.log import app_log from tornado.ioloop import IOLoop from tornado.iostream import IOStream, SSLIOStream @@ -285,8 +286,10 @@ class TCPServer(object): stream = IOStream(connection, io_loop=self.io_loop, max_buffer_size=self.max_buffer_size, read_chunk_size=self.read_chunk_size) + future = self.handle_stream(stream, address) if future is not None: - self.io_loop.add_future(future, lambda f: f.result()) + self.io_loop.add_future(gen.convert_yielded(future), + lambda f: f.result()) except Exception: app_log.error("Error in connection callback", exc_info=True) diff --git a/tornado/test/tcpserver_test.py b/tornado/test/tcpserver_test.py index c01c04ddf..ba43c76c6 100644 --- a/tornado/test/tcpserver_test.py +++ b/tornado/test/tcpserver_test.py @@ -6,6 +6,7 @@ from tornado.iostream import IOStream from tornado.log import app_log from tornado.stack_context import NullContext from tornado.tcpserver import TCPServer +from tornado.test.util import skipBefore35, exec_test from tornado.testing import AsyncTestCase, ExpectLog, bind_unused_port, gen_test @@ -37,3 +38,25 @@ class TCPServerTest(AsyncTestCase): server.stop() if client is not None: client.close() + + @skipBefore35 + @gen_test + def test_handle_stream_native_coroutine(self): + # handle_stream may be a native coroutine. + + namespace = exec_test(globals(), locals(), """ + class TestServer(TCPServer): + async def handle_stream(self, stream, address): + stream.write(b'data') + stream.close() + """) + + sock, port = bind_unused_port() + server = namespace['TestServer']() + server.add_socket(sock) + client = IOStream(socket.socket()) + yield client.connect(('localhost', port)) + result = yield client.read_until_close() + self.assertEqual(result, b'data') + server.stop() + client.close()