From: Ben Darnell Date: Sun, 14 Jan 2018 20:45:48 +0000 (-0500) Subject: iostream: Add read_into method. X-Git-Tag: v5.0.0~14^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=refs%2Fpull%2F2251%2Fhead;p=thirdparty%2Ftornado.git iostream: Add read_into method. Tests come from Antoine Pitrou's #2193 Fixes #2176 --- diff --git a/tornado/iostream.py b/tornado/iostream.py index 56b4002cb..ad50be552 100644 --- a/tornado/iostream.py +++ b/tornado/iostream.py @@ -254,6 +254,8 @@ class BaseIOStream(object): self._read_buffer = bytearray() self._read_buffer_pos = 0 self._read_buffer_size = 0 + self._user_read_buffer = False + self._after_user_read_buffer = None self._write_buffer = _StreamBuffer() self._total_write_index = 0 self._total_write_done_index = 0 @@ -420,6 +422,50 @@ class BaseIOStream(object): raise return future + def read_into(self, buf, callback=None, partial=False): + """Asynchronously read a number of bytes. + + ``buf`` must be a writable buffer into which data will be read. + If a callback is given, it will be run with the number of read + bytes as an argument; if not, this method returns a `.Future`. + + If ``partial`` is true, the callback is run as soon as any bytes + have been read. Otherwise, it is run when the ``buf`` has been + entirely filled with read data. + + .. versionadded:: 5.0 + """ + future = self._set_read_callback(callback) + + # First copy data already in read buffer + available_bytes = self._read_buffer_size + n = len(buf) + if available_bytes >= n: + end = self._read_buffer_pos + n + buf[:] = memoryview(self._read_buffer)[self._read_buffer_pos:end] + del self._read_buffer[:end] + self._after_user_read_buffer = self._read_buffer + elif available_bytes > 0: + buf[:available_bytes] = memoryview(self._read_buffer)[self._read_buffer_pos:] + + # Set up the supplied buffer as our temporary read buffer. + # The original (if it had any data remaining) has been + # saved for later. + self._user_read_buffer = True + self._read_buffer = buf + self._read_buffer_pos = 0 + self._read_buffer_size = available_bytes + self._read_bytes = n + self._read_partial = partial + + try: + self._try_inline_read() + except: + if future is not None: + future.add_done_callback(lambda f: f.exception()) + raise + return future + def read_until_close(self, callback=None, streaming_callback=None): """Asynchronously reads all data from the socket until it is closed. @@ -767,6 +813,15 @@ class BaseIOStream(object): return self._read_future def _run_read_callback(self, size, streaming): + if self._user_read_buffer: + self._read_buffer = self._after_user_read_buffer or bytearray() + self._after_user_read_buffer = None + self._read_buffer_pos = 0 + self._read_buffer_size = len(self._read_buffer) + self._user_read_buffer = False + result = size + else: + result = self._consume(size) if streaming: callback = self._streaming_callback else: @@ -776,10 +831,11 @@ class BaseIOStream(object): assert callback is None future = self._read_future self._read_future = None - future.set_result(self._consume(size)) + + future.set_result(result) if callback is not None: assert (self._read_future is None) or streaming - self._run_callback(callback, self._consume(size)) + self._run_callback(callback, result) else: # If we scheduled a callback, we will add the error listener # afterwards. If we didn't, we have to do it now. @@ -828,7 +884,10 @@ class BaseIOStream(object): try: while True: try: - buf = bytearray(self.read_chunk_size) + if self._user_read_buffer: + buf = memoryview(self._read_buffer)[self._read_buffer_size:] + else: + buf = bytearray(self.read_chunk_size) bytes_read = self.read_from_fd(buf) except (socket.error, IOError, OSError) as e: if errno_from_exception(e) == errno.EINTR: @@ -848,7 +907,8 @@ class BaseIOStream(object): elif bytes_read == 0: self.close() return 0 - self._read_buffer += memoryview(buf)[:bytes_read] + if not self._user_read_buffer: + self._read_buffer += memoryview(buf)[:bytes_read] self._read_buffer_size += bytes_read finally: # Break the reference to buf so we don't waste a chunk's worth of diff --git a/tornado/test/iostream_test.py b/tornado/test/iostream_test.py index 1bfb0b34c..45799db2f 100644 --- a/tornado/test/iostream_test.py +++ b/tornado/test/iostream_test.py @@ -12,6 +12,7 @@ from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase from tornado.test.util import unittest, skipIfNonUnix, refusing_port, skipPypy3V58 from tornado.web import RequestHandler, Application import errno +import hashlib import logging import os import platform @@ -674,6 +675,143 @@ class TestReadWriteMixin(object): rs.close() ws.close() + def test_read_into(self): + rs, ws = self.make_iostream_pair() + + def sleep_some(): + self.io_loop.run_sync(lambda: gen.sleep(0.05)) + try: + buf = bytearray(10) + rs.read_into(buf, callback=self.stop) + ws.write(b"hello") + sleep_some() + self.assertTrue(rs.reading()) + ws.write(b"world!!") + data = self.wait() + self.assertFalse(rs.reading()) + self.assertEqual(data, 10) + self.assertEqual(bytes(buf), b"helloworld") + + # Existing buffer is fed into user buffer + rs.read_into(buf, callback=self.stop) + sleep_some() + self.assertTrue(rs.reading()) + ws.write(b"1234567890") + data = self.wait() + self.assertFalse(rs.reading()) + self.assertEqual(data, 10) + self.assertEqual(bytes(buf), b"!!12345678") + + # Existing buffer can satisfy read immediately + buf = bytearray(4) + ws.write(b"abcdefghi") + rs.read_into(buf, callback=self.stop) + data = self.wait() + self.assertEqual(data, 4) + self.assertEqual(bytes(buf), b"90ab") + + rs.read_bytes(7, self.stop) + data = self.wait() + self.assertEqual(data, b"cdefghi") + finally: + ws.close() + rs.close() + + def test_read_into_partial(self): + rs, ws = self.make_iostream_pair() + + def sleep_some(): + self.io_loop.run_sync(lambda: gen.sleep(0.05)) + try: + # Partial read + buf = bytearray(10) + rs.read_into(buf, callback=self.stop, partial=True) + ws.write(b"hello") + data = self.wait() + self.assertFalse(rs.reading()) + self.assertEqual(data, 5) + self.assertEqual(bytes(buf), b"hello\0\0\0\0\0") + + # Full read despite partial=True + ws.write(b"world!1234567890") + rs.read_into(buf, callback=self.stop, partial=True) + data = self.wait() + self.assertEqual(data, 10) + self.assertEqual(bytes(buf), b"world!1234") + + # Existing buffer can satisfy read immediately + rs.read_into(buf, callback=self.stop, partial=True) + data = self.wait() + self.assertEqual(data, 6) + self.assertEqual(bytes(buf), b"5678901234") + + finally: + ws.close() + rs.close() + + def test_read_into_zero_bytes(self): + rs, ws = self.make_iostream_pair() + try: + buf = bytearray() + fut = rs.read_into(buf) + self.assertEqual(fut.result(), 0) + finally: + ws.close() + rs.close() + + def test_many_mixed_reads(self): + # Stress buffer handling when going back and forth between + # read_bytes() (using an internal buffer) and read_into() + # (using a user-allocated buffer). + r = random.Random(42) + nbytes = 1000000 + rs, ws = self.make_iostream_pair() + + produce_hash = hashlib.sha1() + consume_hash = hashlib.sha1() + + @gen.coroutine + def produce(): + remaining = nbytes + while remaining > 0: + size = r.randint(1, min(1000, remaining)) + data = os.urandom(size) + produce_hash.update(data) + yield ws.write(data) + remaining -= size + assert remaining == 0 + + @gen.coroutine + def consume(): + remaining = nbytes + while remaining > 0: + if r.random() > 0.5: + # read_bytes() + size = r.randint(1, min(1000, remaining)) + data = yield rs.read_bytes(size) + consume_hash.update(data) + remaining -= size + else: + # read_into() + size = r.randint(1, min(1000, remaining)) + buf = bytearray(size) + n = yield rs.read_into(buf) + assert n == size + consume_hash.update(buf) + remaining -= size + assert remaining == 0 + + @gen.coroutine + def main(): + yield [produce(), consume()] + assert produce_hash.hexdigest() == consume_hash.hexdigest() + + try: + self.io_loop.run_sync(main) + finally: + ws.close() + rs.close() + class TestIOStreamMixin(TestReadWriteMixin): def _make_server_iostream(self, connection, **kwargs):