]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Return a Future from IOStream methods.
authorBen Darnell <ben@bendarnell.com>
Mon, 20 Jan 2014 03:59:37 +0000 (22:59 -0500)
committerBen Darnell <ben@bendarnell.com>
Mon, 20 Jan 2014 21:27:56 +0000 (16:27 -0500)
This makes it easier to use IOStreams directly from coroutines.

Closes #953.

tornado/iostream.py
tornado/test/iostream_test.py
tornado/test/util_test.py
tornado/util.py

index 1554ddf6950145faf4a0f777e00771331955f50c..6baea4a1364b06131afae529ddd54eda075bcf9f 100644 (file)
@@ -28,6 +28,7 @@ from __future__ import absolute_import, division, print_function, with_statement
 
 import collections
 import errno
+import functools
 import numbers
 import os
 import socket
@@ -35,11 +36,12 @@ import ssl
 import sys
 import re
 
+from tornado.concurrent import TracebackFuture
 from tornado import ioloop
 from tornado.log import gen_log, app_log
 from tornado.netutil import ssl_wrap_socket, ssl_match_hostname, SSLCertificateError
 from tornado import stack_context
-from tornado.util import bytes_type
+from tornado.util import bytes_type, ArgReplacer
 
 try:
     from tornado.platform.posix import _set_nonblocking
@@ -66,6 +68,37 @@ class StreamClosedError(IOError):
     pass
 
 
+def _iostream_return_future(f):
+    """Similar to tornado.concurrent.return_future, but the Future will
+    also raise a StreamClosedError if the stream is closed before
+    it resolves.
+
+    Unlike return_future (and _auth_return_future), no Future will be
+    returned if a callback is given.
+    """
+    replacer = ArgReplacer(f, 'callback')
+
+    @functools.wraps(f)
+    def wrapper(*args, **kwargs):
+        if replacer.get_old_value(args, kwargs) is not None:
+            # If a callaback is present, just call in to the decorated
+            # function.  This is a slight optimization (by not creating a
+            # Future that is unlikely to be used), but mainly avoids the
+            # complexity of running the callback in the expected way.
+            return f(*args, **kwargs)
+        future = TracebackFuture()
+        callback, args, kwargs = replacer.replace(
+            lambda value=None: future.set_result(value),
+            args, kwargs)
+        f(*args, **kwargs)
+        stream = args[0]
+        stream._pending_futures.add(future)
+        future.add_done_callback(
+            lambda fut: stream._pending_futures.discard(fut))
+        return future
+    return wrapper
+
+
 class BaseIOStream(object):
     """A utility class to write to and read from a non-blocking file or socket.
 
@@ -102,6 +135,7 @@ class BaseIOStream(object):
         self._state = None
         self._pending_callbacks = 0
         self._closed = False
+        self._pending_futures = set()
 
     def fileno(self):
         """Returns the file descriptor for this stream."""
@@ -142,6 +176,7 @@ class BaseIOStream(object):
         """
         return None
 
+    @_iostream_return_future
     def read_until_regex(self, regex, callback):
         """Run ``callback`` when we read the given regex pattern.
 
@@ -152,6 +187,7 @@ class BaseIOStream(object):
         self._read_regex = re.compile(regex)
         self._try_inline_read()
 
+    @_iostream_return_future
     def read_until(self, delimiter, callback):
         """Run ``callback`` when we read the given delimiter.
 
@@ -162,6 +198,7 @@ class BaseIOStream(object):
         self._read_delimiter = delimiter
         self._try_inline_read()
 
+    @_iostream_return_future
     def read_bytes(self, num_bytes, callback, streaming_callback=None):
         """Run callback when we read the given number of bytes.
 
@@ -176,6 +213,7 @@ class BaseIOStream(object):
         self._streaming_callback = stack_context.wrap(streaming_callback)
         self._try_inline_read()
 
+    @_iostream_return_future
     def read_until_close(self, callback, streaming_callback=None):
         """Reads all data from the socket until it is closed.
 
@@ -202,6 +240,7 @@ class BaseIOStream(object):
         self._streaming_callback = stack_context.wrap(streaming_callback)
         self._try_inline_read()
 
+    @_iostream_return_future
     def write(self, data, callback=None):
         """Write the given data to this stream.
 
@@ -266,6 +305,10 @@ class BaseIOStream(object):
         # If there are pending callbacks, don't run the close callback
         # until they're done (see _maybe_add_error_handler)
         if self.closed() and self._pending_callbacks == 0:
+            # Copy the _pending_futures set because each will remove itself
+            # from the set as it is closed.
+            for fut in list(self._pending_futures):
+                fut.set_exception(StreamClosedError())
             if self._close_callback is not None:
                 cb = self._close_callback
                 self._close_callback = None
@@ -704,6 +747,7 @@ class IOStream(BaseIOStream):
     def write_to_fd(self, data):
         return self.socket.send(data)
 
+    @_iostream_return_future
     def connect(self, address, callback=None, server_hostname=None):
         """Connects the socket to a remote address without blocking.
 
@@ -904,7 +948,10 @@ class SSLIOStream(IOStream):
         # has completed.
         self._ssl_connect_callback = stack_context.wrap(callback)
         self._server_hostname = server_hostname
-        super(SSLIOStream, self).connect(address, callback=None)
+        # Note: Since we don't pass our callback argument along to
+        # super.connect(), this will always return a Future.
+        # This is harmless, but a bit less efficient than it could be.
+        return super(SSLIOStream, self).connect(address, callback=None)
 
     def _handle_connect(self):
         # When the connection is complete, wrap the socket for SSL
index 0675c4f7a39c7c14d5cb0319de1b4727bb8fef45..35a4bd857dcce35813e6d2a281f772cf37adccc4 100644 (file)
@@ -1,11 +1,12 @@
 from __future__ import absolute_import, division, print_function, with_statement
 from tornado import netutil
 from tornado.ioloop import IOLoop
-from tornado.iostream import IOStream, SSLIOStream, PipeIOStream
+from tornado.iostream import IOStream, SSLIOStream, PipeIOStream, StreamClosedError
+from tornado.httputil import HTTPHeaders
 from tornado.log import gen_log, app_log
 from tornado.netutil import ssl_wrap_socket
 from tornado.stack_context import NullContext
-from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog
+from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog, gen_test
 from tornado.test.util import unittest, skipIfNonUnix
 from tornado.web import RequestHandler, Application
 import errno
@@ -106,6 +107,46 @@ class TestIOStreamWebMixin(object):
 
         stream.close()
 
+    @gen_test
+    def test_future_interface(self):
+        """Basic test of IOStream's ability to return Futures."""
+        stream = self._make_client_iostream()
+        yield stream.connect(("localhost", self.get_http_port()))
+        yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
+        first_line = yield stream.read_until(b"\r\n")
+        self.assertEqual(first_line, b"HTTP/1.0 200 OK\r\n")
+        # callback=None is equivalent to no callback.
+        header_data = yield stream.read_until(b"\r\n\r\n", callback=None)
+        headers = HTTPHeaders.parse(header_data.decode('latin1'))
+        content_length = int(headers['Content-Length'])
+        body = yield stream.read_bytes(content_length)
+        self.assertEqual(body, b'Hello')
+        stream.close()
+
+    @gen_test
+    def test_future_close_while_reading(self):
+        stream = self._make_client_iostream()
+        yield stream.connect(("localhost", self.get_http_port()))
+        yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
+        with self.assertRaises(StreamClosedError):
+            yield stream.read_bytes(1024 * 1024)
+        stream.close()
+
+    @gen_test
+    def test_future_read_until_close(self):
+        # Ensure that the data comes through before the StreamClosedError.
+        stream = self._make_client_iostream()
+        yield stream.connect(("localhost", self.get_http_port()))
+        yield stream.write(b"GET / HTTP/1.0\r\nConnection: close\r\n\r\n")
+        yield stream.read_until(b"\r\n\r\n")
+        body = yield stream.read_until_close()
+        self.assertEqual(body, b"Hello")
+
+        # Nothing else to read; the error comes immediately without waiting
+        # for yield.
+        with self.assertRaises(StreamClosedError):
+            stream.read_bytes(1)
+
 
 class TestIOStreamMixin(object):
     def _make_server_iostream(self, connection, **kwargs):
@@ -298,6 +339,25 @@ class TestIOStreamMixin(object):
             server.close()
             client.close()
 
+    def test_future_delayed_close_callback(self):
+        # Same as test_delayed_close_callback, but with the future interface.
+        server, client = self.make_iostream_pair()
+        # We can't call make_iostream_pair inside a gen_test function
+        # because the ioloop is not reentrant.
+        @gen_test
+        def f(self):
+            server.write(b"12")
+            chunks = []
+            chunks.append((yield client.read_bytes(1)))
+            server.close()
+            chunks.append((yield client.read_bytes(1)))
+            self.assertEqual(chunks, [b"1", b"2"])
+        try:
+            f(self)
+        finally:
+            server.close()
+            client.close()
+
     def test_close_buffered_data(self):
         # Similar to the previous test, but with data stored in the OS's
         # socket buffers instead of the IOStream's read buffer.  Out-of-band
index 5df54f5e53c0bf3a4e47bf257a99a0a12cf8b941..41ccbb9a584018294988bcdd09c26b583f0a2ea3 100644 (file)
@@ -151,14 +151,22 @@ class ArgReplacerTest(unittest.TestCase):
         self.replacer = ArgReplacer(function, 'callback')
 
     def test_omitted(self):
-        self.assertEqual(self.replacer.replace('new', (1, 2), dict()),
+        args = (1, 2)
+        kwargs = dict()
+        self.assertIs(self.replacer.get_old_value(args, kwargs), None)
+        self.assertEqual(self.replacer.replace('new', args, kwargs),
                          (None, (1, 2), dict(callback='new')))
 
     def test_position(self):
-        self.assertEqual(self.replacer.replace('new', (1, 2, 'old', 3), dict()),
+        args = (1, 2, 'old', 3)
+        kwargs = dict()
+        self.assertEqual(self.replacer.get_old_value(args, kwargs), 'old')
+        self.assertEqual(self.replacer.replace('new', args, kwargs),
                          ('old', [1, 2, 'new', 3], dict()))
 
     def test_keyword(self):
-        self.assertEqual(self.replacer.replace('new', (1,),
-                                               dict(y=2, callback='old', z=3)),
+        args = (1,)
+        kwargs = dict(y=2, callback='old', z=3)
+        self.assertEqual(self.replacer.get_old_value(args, kwargs), 'old')
+        self.assertEqual(self.replacer.replace('new', args, kwargs),
                          ('old', (1,), dict(y=2, callback='new', z=3)))
index a2fba779ca4880683255d711a2432084d3506e4c..cc53222296e782d63981e5dc93f05b52bee23749 100644 (file)
@@ -243,6 +243,16 @@ class ArgReplacer(object):
             # Not a positional parameter
             self.arg_pos = None
 
+    def get_old_value(self, args, kwargs, default=None):
+        """Returns the old value of the named argument without replacing it.
+
+        Returns ``default`` if the argument is not present.
+        """
+        if self.arg_pos is not None and len(args) > self.arg_pos:
+            return args[self.arg_pos]
+        else:
+            return kwargs.get(self.name, default)
+
     def replace(self, new_value, args, kwargs):
         """Replace the named argument in ``args, kwargs`` with ``new_value``.