]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Fix IOStream.write() to never orphan Future 1987/head
authorAntoine Pitrou <antoine@python.org>
Mon, 14 Nov 2016 18:38:25 +0000 (19:38 +0100)
committerBen Darnell <ben@bendarnell.com>
Sun, 26 Mar 2017 16:25:19 +0000 (12:25 -0400)
The current behaviour is error-prone and makes it difficult to use the
Future-returning variant of IOStream.write().  This change makes sure
the returned Future is triggered as soon as the corresponding write
is issued.

tornado/iostream.py
tornado/test/iostream_test.py

index 0746e1d51e9063d7849ee0dd0fc4bf0d1889f635..0f9d62856a341487d885c550d7106414c62e5996 100644 (file)
@@ -167,6 +167,8 @@ class BaseIOStream(object):
         self._write_buffer_pos = 0
         self._write_buffer_size = 0
         self._write_buffer_frozen = False
+        self._total_write_index = 0
+        self._total_write_done_index = 0
         self._pending_writes_while_frozen = []
         self._read_delimiter = None
         self._read_regex = None
@@ -178,7 +180,7 @@ class BaseIOStream(object):
         self._read_future = None
         self._streaming_callback = None
         self._write_callback = None
-        self._write_future = None
+        self._write_futures = collections.deque()
         self._close_callback = None
         self._connect_callback = None
         self._connect_future = None
@@ -388,12 +390,14 @@ class BaseIOStream(object):
             else:
                 self._write_buffer += data
                 self._write_buffer_size += len(data)
+            self._total_write_index += len(data)
         if callback is not None:
             self._write_callback = stack_context.wrap(callback)
             future = None
         else:
-            future = self._write_future = TracebackFuture()
+            future = TracebackFuture()
             future.add_done_callback(lambda f: f.exception())
+            self._write_futures.append((self._total_write_index, future))
         if not self._connecting:
             self._handle_write()
             if self._write_buffer_size:
@@ -445,9 +449,8 @@ class BaseIOStream(object):
             if self._read_future is not None:
                 futures.append(self._read_future)
                 self._read_future = None
-            if self._write_future is not None:
-                futures.append(self._write_future)
-                self._write_future = None
+            futures += [future for _, future in self._write_futures]
+            self._write_futures.clear()
             if self._connect_future is not None:
                 futures.append(self._connect_future)
                 self._connect_future = None
@@ -866,6 +869,7 @@ class BaseIOStream(object):
                     self._write_buffer_pos = 0
                 if self._write_buffer_frozen:
                     self._unfreeze_write_buffer()
+                self._total_write_done_index += num_bytes
             except (socket.error, IOError, OSError) as e:
                 if e.args[0] in _ERRNO_WOULDBLOCK:
                     self._got_empty_write(size)
@@ -879,15 +883,19 @@ class BaseIOStream(object):
                                         self.fileno(), e)
                     self.close(exc_info=True)
                     return
+
+        while self._write_futures:
+            index, future = self._write_futures[0]
+            if index > self._total_write_done_index:
+                break
+            self._write_futures.popleft()
+            future.set_result(None)
+
         if not self._write_buffer_size:
             if self._write_callback:
                 callback = self._write_callback
                 self._write_callback = None
                 self._run_callback(callback)
-            if self._write_future:
-                future = self._write_future
-                self._write_future = None
-                future.set_result(None)
 
     def _consume(self, loc):
         # Consume loc bytes from the read buffer and return them
@@ -1152,7 +1160,7 @@ class IOStream(BaseIOStream):
            suitably-configured `ssl.SSLContext` to disable.
         """
         if (self._read_callback or self._read_future or
-                self._write_callback or self._write_future or
+                self._write_callback or self._write_futures or
                 self._connect_callback or self._connect_future or
                 self._pending_callbacks or self._closed or
                 self._read_buffer or self._write_buffer):
index f62b0f85f3dc86216b97a09c7f88c9d06d3495ae..91bc7bf6add6bcd948834a86053ca225ed918e97 100644 (file)
@@ -808,6 +808,40 @@ class TestIOStreamMixin(object):
             server.close()
             client.close()
 
+    def test_future_write(self):
+        """
+        Test that write() Futures are never orphaned.
+        """
+        # Run concurrent writers that will write enough bytes so as to
+        # clog the socket buffer and accumulate bytes in our write buffer.
+        m, n = 10000, 1000
+        nproducers = 10
+        total_bytes = m * n * nproducers
+        server, client = self.make_iostream_pair(max_buffer_size=total_bytes)
+
+        @gen.coroutine
+        def produce():
+            data = b'x' * m
+            for i in range(n):
+                yield server.write(data)
+
+        @gen.coroutine
+        def consume():
+            nread = 0
+            while nread < total_bytes:
+                res = yield client.read_bytes(m)
+                nread += len(res)
+
+        @gen.coroutine
+        def main():
+            yield [produce() for i in range(nproducers)] + [consume()]
+
+        try:
+            self.io_loop.run_sync(main)
+        finally:
+            server.close()
+            client.close()
+
 
 class TestIOStreamWebHTTP(TestIOStreamWebMixin, AsyncHTTPTestCase):
     def _make_client_iostream(self):