From 66a647d2438341e33353e87c42c49f57df931d9f Mon Sep 17 00:00:00 2001 From: Ben Darnell Date: Sun, 19 Jan 2014 22:59:37 -0500 Subject: [PATCH] Return a Future from IOStream methods. This makes it easier to use IOStreams directly from coroutines. Closes #953. --- tornado/iostream.py | 51 ++++++++++++++++++++++++++-- tornado/test/iostream_test.py | 64 +++++++++++++++++++++++++++++++++-- tornado/test/util_test.py | 16 ++++++--- tornado/util.py | 10 ++++++ 4 files changed, 133 insertions(+), 8 deletions(-) diff --git a/tornado/iostream.py b/tornado/iostream.py index 1554ddf69..6baea4a13 100644 --- a/tornado/iostream.py +++ b/tornado/iostream.py @@ -28,6 +28,7 @@ from __future__ import absolute_import, division, print_function, with_statement import collections import errno +import functools import numbers import os import socket @@ -35,11 +36,12 @@ import ssl import sys import re +from tornado.concurrent import TracebackFuture from tornado import ioloop from tornado.log import gen_log, app_log from tornado.netutil import ssl_wrap_socket, ssl_match_hostname, SSLCertificateError from tornado import stack_context -from tornado.util import bytes_type +from tornado.util import bytes_type, ArgReplacer try: from tornado.platform.posix import _set_nonblocking @@ -66,6 +68,37 @@ class StreamClosedError(IOError): pass +def _iostream_return_future(f): + """Similar to tornado.concurrent.return_future, but the Future will + also raise a StreamClosedError if the stream is closed before + it resolves. + + Unlike return_future (and _auth_return_future), no Future will be + returned if a callback is given. + """ + replacer = ArgReplacer(f, 'callback') + + @functools.wraps(f) + def wrapper(*args, **kwargs): + if replacer.get_old_value(args, kwargs) is not None: + # If a callaback is present, just call in to the decorated + # function. This is a slight optimization (by not creating a + # Future that is unlikely to be used), but mainly avoids the + # complexity of running the callback in the expected way. + return f(*args, **kwargs) + future = TracebackFuture() + callback, args, kwargs = replacer.replace( + lambda value=None: future.set_result(value), + args, kwargs) + f(*args, **kwargs) + stream = args[0] + stream._pending_futures.add(future) + future.add_done_callback( + lambda fut: stream._pending_futures.discard(fut)) + return future + return wrapper + + class BaseIOStream(object): """A utility class to write to and read from a non-blocking file or socket. @@ -102,6 +135,7 @@ class BaseIOStream(object): self._state = None self._pending_callbacks = 0 self._closed = False + self._pending_futures = set() def fileno(self): """Returns the file descriptor for this stream.""" @@ -142,6 +176,7 @@ class BaseIOStream(object): """ return None + @_iostream_return_future def read_until_regex(self, regex, callback): """Run ``callback`` when we read the given regex pattern. @@ -152,6 +187,7 @@ class BaseIOStream(object): self._read_regex = re.compile(regex) self._try_inline_read() + @_iostream_return_future def read_until(self, delimiter, callback): """Run ``callback`` when we read the given delimiter. @@ -162,6 +198,7 @@ class BaseIOStream(object): self._read_delimiter = delimiter self._try_inline_read() + @_iostream_return_future def read_bytes(self, num_bytes, callback, streaming_callback=None): """Run callback when we read the given number of bytes. @@ -176,6 +213,7 @@ class BaseIOStream(object): self._streaming_callback = stack_context.wrap(streaming_callback) self._try_inline_read() + @_iostream_return_future def read_until_close(self, callback, streaming_callback=None): """Reads all data from the socket until it is closed. @@ -202,6 +240,7 @@ class BaseIOStream(object): self._streaming_callback = stack_context.wrap(streaming_callback) self._try_inline_read() + @_iostream_return_future def write(self, data, callback=None): """Write the given data to this stream. @@ -266,6 +305,10 @@ class BaseIOStream(object): # If there are pending callbacks, don't run the close callback # until they're done (see _maybe_add_error_handler) if self.closed() and self._pending_callbacks == 0: + # Copy the _pending_futures set because each will remove itself + # from the set as it is closed. + for fut in list(self._pending_futures): + fut.set_exception(StreamClosedError()) if self._close_callback is not None: cb = self._close_callback self._close_callback = None @@ -704,6 +747,7 @@ class IOStream(BaseIOStream): def write_to_fd(self, data): return self.socket.send(data) + @_iostream_return_future def connect(self, address, callback=None, server_hostname=None): """Connects the socket to a remote address without blocking. @@ -904,7 +948,10 @@ class SSLIOStream(IOStream): # has completed. self._ssl_connect_callback = stack_context.wrap(callback) self._server_hostname = server_hostname - super(SSLIOStream, self).connect(address, callback=None) + # Note: Since we don't pass our callback argument along to + # super.connect(), this will always return a Future. + # This is harmless, but a bit less efficient than it could be. + return super(SSLIOStream, self).connect(address, callback=None) def _handle_connect(self): # When the connection is complete, wrap the socket for SSL diff --git a/tornado/test/iostream_test.py b/tornado/test/iostream_test.py index 0675c4f7a..35a4bd857 100644 --- a/tornado/test/iostream_test.py +++ b/tornado/test/iostream_test.py @@ -1,11 +1,12 @@ from __future__ import absolute_import, division, print_function, with_statement from tornado import netutil from tornado.ioloop import IOLoop -from tornado.iostream import IOStream, SSLIOStream, PipeIOStream +from tornado.iostream import IOStream, SSLIOStream, PipeIOStream, StreamClosedError +from tornado.httputil import HTTPHeaders from tornado.log import gen_log, app_log from tornado.netutil import ssl_wrap_socket from tornado.stack_context import NullContext -from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog +from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog, gen_test from tornado.test.util import unittest, skipIfNonUnix from tornado.web import RequestHandler, Application import errno @@ -106,6 +107,46 @@ class TestIOStreamWebMixin(object): stream.close() + @gen_test + def test_future_interface(self): + """Basic test of IOStream's ability to return Futures.""" + stream = self._make_client_iostream() + yield stream.connect(("localhost", self.get_http_port())) + yield stream.write(b"GET / HTTP/1.0\r\n\r\n") + first_line = yield stream.read_until(b"\r\n") + self.assertEqual(first_line, b"HTTP/1.0 200 OK\r\n") + # callback=None is equivalent to no callback. + header_data = yield stream.read_until(b"\r\n\r\n", callback=None) + headers = HTTPHeaders.parse(header_data.decode('latin1')) + content_length = int(headers['Content-Length']) + body = yield stream.read_bytes(content_length) + self.assertEqual(body, b'Hello') + stream.close() + + @gen_test + def test_future_close_while_reading(self): + stream = self._make_client_iostream() + yield stream.connect(("localhost", self.get_http_port())) + yield stream.write(b"GET / HTTP/1.0\r\n\r\n") + with self.assertRaises(StreamClosedError): + yield stream.read_bytes(1024 * 1024) + stream.close() + + @gen_test + def test_future_read_until_close(self): + # Ensure that the data comes through before the StreamClosedError. + stream = self._make_client_iostream() + yield stream.connect(("localhost", self.get_http_port())) + yield stream.write(b"GET / HTTP/1.0\r\nConnection: close\r\n\r\n") + yield stream.read_until(b"\r\n\r\n") + body = yield stream.read_until_close() + self.assertEqual(body, b"Hello") + + # Nothing else to read; the error comes immediately without waiting + # for yield. + with self.assertRaises(StreamClosedError): + stream.read_bytes(1) + class TestIOStreamMixin(object): def _make_server_iostream(self, connection, **kwargs): @@ -298,6 +339,25 @@ class TestIOStreamMixin(object): server.close() client.close() + def test_future_delayed_close_callback(self): + # Same as test_delayed_close_callback, but with the future interface. + server, client = self.make_iostream_pair() + # We can't call make_iostream_pair inside a gen_test function + # because the ioloop is not reentrant. + @gen_test + def f(self): + server.write(b"12") + chunks = [] + chunks.append((yield client.read_bytes(1))) + server.close() + chunks.append((yield client.read_bytes(1))) + self.assertEqual(chunks, [b"1", b"2"]) + try: + f(self) + finally: + server.close() + client.close() + def test_close_buffered_data(self): # Similar to the previous test, but with data stored in the OS's # socket buffers instead of the IOStream's read buffer. Out-of-band diff --git a/tornado/test/util_test.py b/tornado/test/util_test.py index 5df54f5e5..41ccbb9a5 100644 --- a/tornado/test/util_test.py +++ b/tornado/test/util_test.py @@ -151,14 +151,22 @@ class ArgReplacerTest(unittest.TestCase): self.replacer = ArgReplacer(function, 'callback') def test_omitted(self): - self.assertEqual(self.replacer.replace('new', (1, 2), dict()), + args = (1, 2) + kwargs = dict() + self.assertIs(self.replacer.get_old_value(args, kwargs), None) + self.assertEqual(self.replacer.replace('new', args, kwargs), (None, (1, 2), dict(callback='new'))) def test_position(self): - self.assertEqual(self.replacer.replace('new', (1, 2, 'old', 3), dict()), + args = (1, 2, 'old', 3) + kwargs = dict() + self.assertEqual(self.replacer.get_old_value(args, kwargs), 'old') + self.assertEqual(self.replacer.replace('new', args, kwargs), ('old', [1, 2, 'new', 3], dict())) def test_keyword(self): - self.assertEqual(self.replacer.replace('new', (1,), - dict(y=2, callback='old', z=3)), + args = (1,) + kwargs = dict(y=2, callback='old', z=3) + self.assertEqual(self.replacer.get_old_value(args, kwargs), 'old') + self.assertEqual(self.replacer.replace('new', args, kwargs), ('old', (1,), dict(y=2, callback='new', z=3))) diff --git a/tornado/util.py b/tornado/util.py index a2fba779c..cc5322229 100644 --- a/tornado/util.py +++ b/tornado/util.py @@ -243,6 +243,16 @@ class ArgReplacer(object): # Not a positional parameter self.arg_pos = None + def get_old_value(self, args, kwargs, default=None): + """Returns the old value of the named argument without replacing it. + + Returns ``default`` if the argument is not present. + """ + if self.arg_pos is not None and len(args) > self.arg_pos: + return args[self.arg_pos] + else: + return kwargs.get(self.name, default) + def replace(self, new_value, args, kwargs): """Replace the named argument in ``args, kwargs`` with ``new_value``. -- 2.47.2