From 7c40bedd82e804c3e25c051bf3b3bd2526909e49 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Mon, 14 Nov 2016 19:38:25 +0100 Subject: [PATCH] Fix IOStream.write() to never orphan Future The current behaviour is error-prone and makes it difficult to use the Future-returning variant of IOStream.write(). This change makes sure the returned Future is triggered as soon as the corresponding write is issued. --- tornado/iostream.py | 28 ++++++++++++++++++---------- tornado/test/iostream_test.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 10 deletions(-) diff --git a/tornado/iostream.py b/tornado/iostream.py index 0746e1d51..0f9d62856 100644 --- a/tornado/iostream.py +++ b/tornado/iostream.py @@ -167,6 +167,8 @@ class BaseIOStream(object): self._write_buffer_pos = 0 self._write_buffer_size = 0 self._write_buffer_frozen = False + self._total_write_index = 0 + self._total_write_done_index = 0 self._pending_writes_while_frozen = [] self._read_delimiter = None self._read_regex = None @@ -178,7 +180,7 @@ class BaseIOStream(object): self._read_future = None self._streaming_callback = None self._write_callback = None - self._write_future = None + self._write_futures = collections.deque() self._close_callback = None self._connect_callback = None self._connect_future = None @@ -388,12 +390,14 @@ class BaseIOStream(object): else: self._write_buffer += data self._write_buffer_size += len(data) + self._total_write_index += len(data) if callback is not None: self._write_callback = stack_context.wrap(callback) future = None else: - future = self._write_future = TracebackFuture() + future = TracebackFuture() future.add_done_callback(lambda f: f.exception()) + self._write_futures.append((self._total_write_index, future)) if not self._connecting: self._handle_write() if self._write_buffer_size: @@ -445,9 +449,8 @@ class BaseIOStream(object): if self._read_future is not None: futures.append(self._read_future) self._read_future = None - if self._write_future is not None: - futures.append(self._write_future) - self._write_future = None + futures += [future for _, future in self._write_futures] + self._write_futures.clear() if self._connect_future is not None: futures.append(self._connect_future) self._connect_future = None @@ -866,6 +869,7 @@ class BaseIOStream(object): self._write_buffer_pos = 0 if self._write_buffer_frozen: self._unfreeze_write_buffer() + self._total_write_done_index += num_bytes except (socket.error, IOError, OSError) as e: if e.args[0] in _ERRNO_WOULDBLOCK: self._got_empty_write(size) @@ -879,15 +883,19 @@ class BaseIOStream(object): self.fileno(), e) self.close(exc_info=True) return + + while self._write_futures: + index, future = self._write_futures[0] + if index > self._total_write_done_index: + break + self._write_futures.popleft() + future.set_result(None) + if not self._write_buffer_size: if self._write_callback: callback = self._write_callback self._write_callback = None self._run_callback(callback) - if self._write_future: - future = self._write_future - self._write_future = None - future.set_result(None) def _consume(self, loc): # Consume loc bytes from the read buffer and return them @@ -1152,7 +1160,7 @@ class IOStream(BaseIOStream): suitably-configured `ssl.SSLContext` to disable. """ if (self._read_callback or self._read_future or - self._write_callback or self._write_future or + self._write_callback or self._write_futures or self._connect_callback or self._connect_future or self._pending_callbacks or self._closed or self._read_buffer or self._write_buffer): diff --git a/tornado/test/iostream_test.py b/tornado/test/iostream_test.py index f62b0f85f..91bc7bf6a 100644 --- a/tornado/test/iostream_test.py +++ b/tornado/test/iostream_test.py @@ -808,6 +808,40 @@ class TestIOStreamMixin(object): server.close() client.close() + def test_future_write(self): + """ + Test that write() Futures are never orphaned. + """ + # Run concurrent writers that will write enough bytes so as to + # clog the socket buffer and accumulate bytes in our write buffer. + m, n = 10000, 1000 + nproducers = 10 + total_bytes = m * n * nproducers + server, client = self.make_iostream_pair(max_buffer_size=total_bytes) + + @gen.coroutine + def produce(): + data = b'x' * m + for i in range(n): + yield server.write(data) + + @gen.coroutine + def consume(): + nread = 0 + while nread < total_bytes: + res = yield client.read_bytes(m) + nread += len(res) + + @gen.coroutine + def main(): + yield [produce() for i in range(nproducers)] + [consume()] + + try: + self.io_loop.run_sync(main) + finally: + server.close() + client.close() + class TestIOStreamWebHTTP(TestIOStreamWebMixin, AsyncHTTPTestCase): def _make_client_iostream(self): -- 2.47.2