]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Remove _iostream_return_future from write and connect as well.
authorBen Darnell <ben@bendarnell.com>
Mon, 17 Feb 2014 03:05:55 +0000 (22:05 -0500)
committerBen Darnell <ben@bendarnell.com>
Mon, 17 Feb 2014 03:05:55 +0000 (22:05 -0500)
tornado/iostream.py

index 11092adb7dff0c8c51af56a062d3627e10354e23..faf657a4c6e44a830b621f7897bff17b8d480212 100644 (file)
@@ -28,7 +28,6 @@ from __future__ import absolute_import, division, print_function, with_statement
 
 import collections
 import errno
-import functools
 import numbers
 import os
 import socket
@@ -41,7 +40,7 @@ from tornado import ioloop
 from tornado.log import gen_log, app_log
 from tornado.netutil import ssl_wrap_socket, ssl_match_hostname, SSLCertificateError
 from tornado import stack_context
-from tornado.util import bytes_type, ArgReplacer
+from tornado.util import bytes_type
 
 try:
     from tornado.platform.posix import _set_nonblocking
@@ -68,37 +67,6 @@ class StreamClosedError(IOError):
     pass
 
 
-def _iostream_return_future(f):
-    """Similar to tornado.concurrent.return_future, but the Future will
-    also raise a StreamClosedError if the stream is closed before
-    it resolves.
-
-    Unlike return_future (and _auth_return_future), no Future will be
-    returned if a callback is given.
-    """
-    replacer = ArgReplacer(f, 'callback')
-
-    @functools.wraps(f)
-    def wrapper(*args, **kwargs):
-        if replacer.get_old_value(args, kwargs) is not None:
-            # If a callaback is present, just call in to the decorated
-            # function.  This is a slight optimization (by not creating a
-            # Future that is unlikely to be used), but mainly avoids the
-            # complexity of running the callback in the expected way.
-            return f(*args, **kwargs)
-        future = TracebackFuture()
-        callback, args, kwargs = replacer.replace(
-            lambda value=None: future.set_result(value),
-            args, kwargs)
-        f(*args, **kwargs)
-        stream = args[0]
-        stream._pending_futures.add(future)
-        future.add_done_callback(
-            lambda fut: stream._pending_futures.discard(fut))
-        return future
-    return wrapper
-
-
 class BaseIOStream(object):
     """A utility class to write to and read from a non-blocking file or socket.
 
@@ -130,13 +98,14 @@ class BaseIOStream(object):
         self._read_future = None
         self._streaming_callback = None
         self._write_callback = None
+        self._write_future = None
         self._close_callback = None
         self._connect_callback = None
+        self._connect_future = None
         self._connecting = False
         self._state = None
         self._pending_callbacks = 0
         self._closed = False
-        self._pending_futures = set()
 
     def fileno(self):
         """Returns the file descriptor for this stream."""
@@ -238,7 +207,6 @@ class BaseIOStream(object):
         self._try_inline_read()
         return future
 
-    @_iostream_return_future
     def write(self, data, callback=None):
         """Write the given data to this stream.
 
@@ -258,12 +226,17 @@ class BaseIOStream(object):
             WRITE_BUFFER_CHUNK_SIZE = 128 * 1024
             for i in range(0, len(data), WRITE_BUFFER_CHUNK_SIZE):
                 self._write_buffer.append(data[i:i + WRITE_BUFFER_CHUNK_SIZE])
-        self._write_callback = stack_context.wrap(callback)
+        if callback is not None:
+            self._write_callback = stack_context.wrap(callback)
+            future = None
+        else:
+            future = self._write_future = TracebackFuture()
         if not self._connecting:
             self._handle_write()
             if self._write_buffer:
                 self._add_io_state(self.io_loop.WRITE)
             self._maybe_add_error_listener()
+        return future
 
     def set_close_callback(self, callback):
         """Call the given callback when the stream is closed."""
@@ -300,13 +273,18 @@ class BaseIOStream(object):
         # If there are pending callbacks, don't run the close callback
         # until they're done (see _maybe_add_error_handler)
         if self.closed() and self._pending_callbacks == 0:
-            # Copy the _pending_futures set because each will remove itself
-            # from the set as it is closed.
-            for fut in list(self._pending_futures):
-                fut.set_exception(StreamClosedError())
+            futures = []
             if self._read_future is not None:
-                self._read_future.set_exception(StreamClosedError())
+                futures.append(self._read_future)
                 self._read_future = None
+            if self._write_future is not None:
+                futures.append(self._write_future)
+                self._write_future = None
+            if self._connect_future is not None:
+                futures.append(self._connect_future)
+                self._connect_future = None
+            for future in futures:
+                future.set_exception(StreamClosedError())
             if self._close_callback is not None:
                 cb = self._close_callback
                 self._close_callback = None
@@ -464,14 +442,14 @@ class BaseIOStream(object):
     def _run_read_callback(self, data):
         self._streaming_callback = None
         if self._read_future is not None:
-            self._read_future.set_result(data)
+            future = self._read_future
             self._read_future = None
+            future.set_result(data)
         if self._read_callback is not None:
             callback = self._read_callback
             self._read_callback = None
             self._run_callback(callback, data)
 
-
     def _try_inline_read(self):
         """Attempt to complete the current read operation from buffered data.
 
@@ -621,10 +599,15 @@ class BaseIOStream(object):
                                         self.fileno(), e)
                     self.close(exc_info=True)
                     return
-        if not self._write_buffer and self._write_callback:
-            callback = self._write_callback
-            self._write_callback = None
-            self._run_callback(callback)
+        if not self._write_buffer:
+            if self._write_callback:
+                callback = self._write_callback
+                self._write_callback = None
+                self._run_callback(callback)
+            if self._write_future:
+                future = self._write_future
+                self._write_future = None
+                future.set_result(None)
 
     def _consume(self, loc):
         if loc == 0:
@@ -752,7 +735,6 @@ class IOStream(BaseIOStream):
     def write_to_fd(self, data):
         return self.socket.send(data)
 
-    @_iostream_return_future
     def connect(self, address, callback=None, server_hostname=None):
         """Connects the socket to a remote address without blocking.
 
@@ -790,8 +772,13 @@ class IOStream(BaseIOStream):
                                 self.socket.fileno(), e)
                 self.close(exc_info=True)
                 return
-        self._connect_callback = stack_context.wrap(callback)
+        if callback is not None:
+            self._connect_callback = stack_context.wrap(callback)
+            future = None
+        else:
+            future = self._connect_future = TracebackFuture()
         self._add_io_state(self.io_loop.WRITE)
+        return future
 
     def _handle_connect(self):
         err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
@@ -809,6 +796,10 @@ class IOStream(BaseIOStream):
             callback = self._connect_callback
             self._connect_callback = None
             self._run_callback(callback)
+        if self._connect_future is not None:
+            future = self._connect_future
+            self._connect_future = None
+            future.set_result(None)
         self._connecting = False
 
     def set_nodelay(self, value):