This makes it easier to use IOStreams directly from coroutines.
Closes #953.
import collections
import errno
+import functools
import numbers
import os
import socket
import sys
import re
+from tornado.concurrent import TracebackFuture
from tornado import ioloop
from tornado.log import gen_log, app_log
from tornado.netutil import ssl_wrap_socket, ssl_match_hostname, SSLCertificateError
from tornado import stack_context
-from tornado.util import bytes_type
+from tornado.util import bytes_type, ArgReplacer
try:
from tornado.platform.posix import _set_nonblocking
pass
+def _iostream_return_future(f):
+ """Similar to tornado.concurrent.return_future, but the Future will
+ also raise a StreamClosedError if the stream is closed before
+ it resolves.
+
+ Unlike return_future (and _auth_return_future), no Future will be
+ returned if a callback is given.
+ """
+ replacer = ArgReplacer(f, 'callback')
+
+ @functools.wraps(f)
+ def wrapper(*args, **kwargs):
+ if replacer.get_old_value(args, kwargs) is not None:
+ # If a callaback is present, just call in to the decorated
+ # function. This is a slight optimization (by not creating a
+ # Future that is unlikely to be used), but mainly avoids the
+ # complexity of running the callback in the expected way.
+ return f(*args, **kwargs)
+ future = TracebackFuture()
+ callback, args, kwargs = replacer.replace(
+ lambda value=None: future.set_result(value),
+ args, kwargs)
+ f(*args, **kwargs)
+ stream = args[0]
+ stream._pending_futures.add(future)
+ future.add_done_callback(
+ lambda fut: stream._pending_futures.discard(fut))
+ return future
+ return wrapper
+
+
class BaseIOStream(object):
"""A utility class to write to and read from a non-blocking file or socket.
self._state = None
self._pending_callbacks = 0
self._closed = False
+ self._pending_futures = set()
def fileno(self):
"""Returns the file descriptor for this stream."""
"""
return None
+ @_iostream_return_future
def read_until_regex(self, regex, callback):
"""Run ``callback`` when we read the given regex pattern.
self._read_regex = re.compile(regex)
self._try_inline_read()
+ @_iostream_return_future
def read_until(self, delimiter, callback):
"""Run ``callback`` when we read the given delimiter.
self._read_delimiter = delimiter
self._try_inline_read()
+ @_iostream_return_future
def read_bytes(self, num_bytes, callback, streaming_callback=None):
"""Run callback when we read the given number of bytes.
self._streaming_callback = stack_context.wrap(streaming_callback)
self._try_inline_read()
+ @_iostream_return_future
def read_until_close(self, callback, streaming_callback=None):
"""Reads all data from the socket until it is closed.
self._streaming_callback = stack_context.wrap(streaming_callback)
self._try_inline_read()
+ @_iostream_return_future
def write(self, data, callback=None):
"""Write the given data to this stream.
# If there are pending callbacks, don't run the close callback
# until they're done (see _maybe_add_error_handler)
if self.closed() and self._pending_callbacks == 0:
+ # Copy the _pending_futures set because each will remove itself
+ # from the set as it is closed.
+ for fut in list(self._pending_futures):
+ fut.set_exception(StreamClosedError())
if self._close_callback is not None:
cb = self._close_callback
self._close_callback = None
def write_to_fd(self, data):
return self.socket.send(data)
+ @_iostream_return_future
def connect(self, address, callback=None, server_hostname=None):
"""Connects the socket to a remote address without blocking.
# has completed.
self._ssl_connect_callback = stack_context.wrap(callback)
self._server_hostname = server_hostname
- super(SSLIOStream, self).connect(address, callback=None)
+ # Note: Since we don't pass our callback argument along to
+ # super.connect(), this will always return a Future.
+ # This is harmless, but a bit less efficient than it could be.
+ return super(SSLIOStream, self).connect(address, callback=None)
def _handle_connect(self):
# When the connection is complete, wrap the socket for SSL
from __future__ import absolute_import, division, print_function, with_statement
from tornado import netutil
from tornado.ioloop import IOLoop
-from tornado.iostream import IOStream, SSLIOStream, PipeIOStream
+from tornado.iostream import IOStream, SSLIOStream, PipeIOStream, StreamClosedError
+from tornado.httputil import HTTPHeaders
from tornado.log import gen_log, app_log
from tornado.netutil import ssl_wrap_socket
from tornado.stack_context import NullContext
-from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog
+from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog, gen_test
from tornado.test.util import unittest, skipIfNonUnix
from tornado.web import RequestHandler, Application
import errno
stream.close()
+ @gen_test
+ def test_future_interface(self):
+ """Basic test of IOStream's ability to return Futures."""
+ stream = self._make_client_iostream()
+ yield stream.connect(("localhost", self.get_http_port()))
+ yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
+ first_line = yield stream.read_until(b"\r\n")
+ self.assertEqual(first_line, b"HTTP/1.0 200 OK\r\n")
+ # callback=None is equivalent to no callback.
+ header_data = yield stream.read_until(b"\r\n\r\n", callback=None)
+ headers = HTTPHeaders.parse(header_data.decode('latin1'))
+ content_length = int(headers['Content-Length'])
+ body = yield stream.read_bytes(content_length)
+ self.assertEqual(body, b'Hello')
+ stream.close()
+
+ @gen_test
+ def test_future_close_while_reading(self):
+ stream = self._make_client_iostream()
+ yield stream.connect(("localhost", self.get_http_port()))
+ yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
+ with self.assertRaises(StreamClosedError):
+ yield stream.read_bytes(1024 * 1024)
+ stream.close()
+
+ @gen_test
+ def test_future_read_until_close(self):
+ # Ensure that the data comes through before the StreamClosedError.
+ stream = self._make_client_iostream()
+ yield stream.connect(("localhost", self.get_http_port()))
+ yield stream.write(b"GET / HTTP/1.0\r\nConnection: close\r\n\r\n")
+ yield stream.read_until(b"\r\n\r\n")
+ body = yield stream.read_until_close()
+ self.assertEqual(body, b"Hello")
+
+ # Nothing else to read; the error comes immediately without waiting
+ # for yield.
+ with self.assertRaises(StreamClosedError):
+ stream.read_bytes(1)
+
class TestIOStreamMixin(object):
def _make_server_iostream(self, connection, **kwargs):
server.close()
client.close()
+ def test_future_delayed_close_callback(self):
+ # Same as test_delayed_close_callback, but with the future interface.
+ server, client = self.make_iostream_pair()
+ # We can't call make_iostream_pair inside a gen_test function
+ # because the ioloop is not reentrant.
+ @gen_test
+ def f(self):
+ server.write(b"12")
+ chunks = []
+ chunks.append((yield client.read_bytes(1)))
+ server.close()
+ chunks.append((yield client.read_bytes(1)))
+ self.assertEqual(chunks, [b"1", b"2"])
+ try:
+ f(self)
+ finally:
+ server.close()
+ client.close()
+
def test_close_buffered_data(self):
# Similar to the previous test, but with data stored in the OS's
# socket buffers instead of the IOStream's read buffer. Out-of-band
self.replacer = ArgReplacer(function, 'callback')
def test_omitted(self):
- self.assertEqual(self.replacer.replace('new', (1, 2), dict()),
+ args = (1, 2)
+ kwargs = dict()
+ self.assertIs(self.replacer.get_old_value(args, kwargs), None)
+ self.assertEqual(self.replacer.replace('new', args, kwargs),
(None, (1, 2), dict(callback='new')))
def test_position(self):
- self.assertEqual(self.replacer.replace('new', (1, 2, 'old', 3), dict()),
+ args = (1, 2, 'old', 3)
+ kwargs = dict()
+ self.assertEqual(self.replacer.get_old_value(args, kwargs), 'old')
+ self.assertEqual(self.replacer.replace('new', args, kwargs),
('old', [1, 2, 'new', 3], dict()))
def test_keyword(self):
- self.assertEqual(self.replacer.replace('new', (1,),
- dict(y=2, callback='old', z=3)),
+ args = (1,)
+ kwargs = dict(y=2, callback='old', z=3)
+ self.assertEqual(self.replacer.get_old_value(args, kwargs), 'old')
+ self.assertEqual(self.replacer.replace('new', args, kwargs),
('old', (1,), dict(y=2, callback='new', z=3)))
# Not a positional parameter
self.arg_pos = None
+ def get_old_value(self, args, kwargs, default=None):
+ """Returns the old value of the named argument without replacing it.
+
+ Returns ``default`` if the argument is not present.
+ """
+ if self.arg_pos is not None and len(args) > self.arg_pos:
+ return args[self.arg_pos]
+ else:
+ return kwargs.get(self.name, default)
+
def replace(self, new_value, args, kwargs):
"""Replace the named argument in ``args, kwargs`` with ``new_value``.