]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
iostream: Add read_into method. 2251/head
authorBen Darnell <ben@bendarnell.com>
Sun, 14 Jan 2018 20:45:48 +0000 (15:45 -0500)
committerBen Darnell <ben@bendarnell.com>
Sun, 21 Jan 2018 03:19:20 +0000 (22:19 -0500)
Tests come from Antoine Pitrou's #2193

Fixes #2176

tornado/iostream.py
tornado/test/iostream_test.py

index 56b4002cba877ece34e4262b84e7b092f0018a21..ad50be552f8c8eec8cb4e43df85e091d4ba5e3aa 100644 (file)
@@ -254,6 +254,8 @@ class BaseIOStream(object):
         self._read_buffer = bytearray()
         self._read_buffer_pos = 0
         self._read_buffer_size = 0
+        self._user_read_buffer = False
+        self._after_user_read_buffer = None
         self._write_buffer = _StreamBuffer()
         self._total_write_index = 0
         self._total_write_done_index = 0
@@ -420,6 +422,50 @@ class BaseIOStream(object):
             raise
         return future
 
+    def read_into(self, buf, callback=None, partial=False):
+        """Asynchronously read a number of bytes.
+
+        ``buf`` must be a writable buffer into which data will be read.
+        If a callback is given, it will be run with the number of read
+        bytes as an argument; if not, this method returns a `.Future`.
+
+        If ``partial`` is true, the callback is run as soon as any bytes
+        have been read.  Otherwise, it is run when the ``buf`` has been
+        entirely filled with read data.
+
+        .. versionadded:: 5.0
+        """
+        future = self._set_read_callback(callback)
+
+        # First copy data already in read buffer
+        available_bytes = self._read_buffer_size
+        n = len(buf)
+        if available_bytes >= n:
+            end = self._read_buffer_pos + n
+            buf[:] = memoryview(self._read_buffer)[self._read_buffer_pos:end]
+            del self._read_buffer[:end]
+            self._after_user_read_buffer = self._read_buffer
+        elif available_bytes > 0:
+            buf[:available_bytes] = memoryview(self._read_buffer)[self._read_buffer_pos:]
+
+        # Set up the supplied buffer as our temporary read buffer.
+        # The original (if it had any data remaining) has been
+        # saved for later.
+        self._user_read_buffer = True
+        self._read_buffer = buf
+        self._read_buffer_pos = 0
+        self._read_buffer_size = available_bytes
+        self._read_bytes = n
+        self._read_partial = partial
+
+        try:
+            self._try_inline_read()
+        except:
+            if future is not None:
+                future.add_done_callback(lambda f: f.exception())
+            raise
+        return future
+
     def read_until_close(self, callback=None, streaming_callback=None):
         """Asynchronously reads all data from the socket until it is closed.
 
@@ -767,6 +813,15 @@ class BaseIOStream(object):
         return self._read_future
 
     def _run_read_callback(self, size, streaming):
+        if self._user_read_buffer:
+            self._read_buffer = self._after_user_read_buffer or bytearray()
+            self._after_user_read_buffer = None
+            self._read_buffer_pos = 0
+            self._read_buffer_size = len(self._read_buffer)
+            self._user_read_buffer = False
+            result = size
+        else:
+            result = self._consume(size)
         if streaming:
             callback = self._streaming_callback
         else:
@@ -776,10 +831,11 @@ class BaseIOStream(object):
                 assert callback is None
                 future = self._read_future
                 self._read_future = None
-                future.set_result(self._consume(size))
+
+                future.set_result(result)
         if callback is not None:
             assert (self._read_future is None) or streaming
-            self._run_callback(callback, self._consume(size))
+            self._run_callback(callback, result)
         else:
             # If we scheduled a callback, we will add the error listener
             # afterwards.  If we didn't, we have to do it now.
@@ -828,7 +884,10 @@ class BaseIOStream(object):
         try:
             while True:
                 try:
-                    buf = bytearray(self.read_chunk_size)
+                    if self._user_read_buffer:
+                        buf = memoryview(self._read_buffer)[self._read_buffer_size:]
+                    else:
+                        buf = bytearray(self.read_chunk_size)
                     bytes_read = self.read_from_fd(buf)
                 except (socket.error, IOError, OSError) as e:
                     if errno_from_exception(e) == errno.EINTR:
@@ -848,7 +907,8 @@ class BaseIOStream(object):
             elif bytes_read == 0:
                 self.close()
                 return 0
-            self._read_buffer += memoryview(buf)[:bytes_read]
+            if not self._user_read_buffer:
+                self._read_buffer += memoryview(buf)[:bytes_read]
             self._read_buffer_size += bytes_read
         finally:
             # Break the reference to buf so we don't waste a chunk's worth of
index 1bfb0b34cb689fa194303ddf7214dd2d58c8a8c5..45799db2ff7e1ddfbe9511c888b4d60c82f1d024 100644 (file)
@@ -12,6 +12,7 @@ from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase
 from tornado.test.util import unittest, skipIfNonUnix, refusing_port, skipPypy3V58
 from tornado.web import RequestHandler, Application
 import errno
+import hashlib
 import logging
 import os
 import platform
@@ -674,6 +675,143 @@ class TestReadWriteMixin(object):
             rs.close()
             ws.close()
 
+    def test_read_into(self):
+        rs, ws = self.make_iostream_pair()
+
+        def sleep_some():
+            self.io_loop.run_sync(lambda: gen.sleep(0.05))
+        try:
+            buf = bytearray(10)
+            rs.read_into(buf, callback=self.stop)
+            ws.write(b"hello")
+            sleep_some()
+            self.assertTrue(rs.reading())
+            ws.write(b"world!!")
+            data = self.wait()
+            self.assertFalse(rs.reading())
+            self.assertEqual(data, 10)
+            self.assertEqual(bytes(buf), b"helloworld")
+
+            # Existing buffer is fed into user buffer
+            rs.read_into(buf, callback=self.stop)
+            sleep_some()
+            self.assertTrue(rs.reading())
+            ws.write(b"1234567890")
+            data = self.wait()
+            self.assertFalse(rs.reading())
+            self.assertEqual(data, 10)
+            self.assertEqual(bytes(buf), b"!!12345678")
+
+            # Existing buffer can satisfy read immediately
+            buf = bytearray(4)
+            ws.write(b"abcdefghi")
+            rs.read_into(buf, callback=self.stop)
+            data = self.wait()
+            self.assertEqual(data, 4)
+            self.assertEqual(bytes(buf), b"90ab")
+
+            rs.read_bytes(7, self.stop)
+            data = self.wait()
+            self.assertEqual(data, b"cdefghi")
+        finally:
+            ws.close()
+            rs.close()
+
+    def test_read_into_partial(self):
+        rs, ws = self.make_iostream_pair()
+
+        def sleep_some():
+            self.io_loop.run_sync(lambda: gen.sleep(0.05))
+        try:
+            # Partial read
+            buf = bytearray(10)
+            rs.read_into(buf, callback=self.stop, partial=True)
+            ws.write(b"hello")
+            data = self.wait()
+            self.assertFalse(rs.reading())
+            self.assertEqual(data, 5)
+            self.assertEqual(bytes(buf), b"hello\0\0\0\0\0")
+
+            # Full read despite partial=True
+            ws.write(b"world!1234567890")
+            rs.read_into(buf, callback=self.stop, partial=True)
+            data = self.wait()
+            self.assertEqual(data, 10)
+            self.assertEqual(bytes(buf), b"world!1234")
+
+            # Existing buffer can satisfy read immediately
+            rs.read_into(buf, callback=self.stop, partial=True)
+            data = self.wait()
+            self.assertEqual(data, 6)
+            self.assertEqual(bytes(buf), b"5678901234")
+
+        finally:
+            ws.close()
+            rs.close()
+
+    def test_read_into_zero_bytes(self):
+        rs, ws = self.make_iostream_pair()
+        try:
+            buf = bytearray()
+            fut = rs.read_into(buf)
+            self.assertEqual(fut.result(), 0)
+        finally:
+            ws.close()
+            rs.close()
+
+    def test_many_mixed_reads(self):
+        # Stress buffer handling when going back and forth between
+        # read_bytes() (using an internal buffer) and read_into()
+        # (using a user-allocated buffer).
+        r = random.Random(42)
+        nbytes = 1000000
+        rs, ws = self.make_iostream_pair()
+
+        produce_hash = hashlib.sha1()
+        consume_hash = hashlib.sha1()
+
+        @gen.coroutine
+        def produce():
+            remaining = nbytes
+            while remaining > 0:
+                size = r.randint(1, min(1000, remaining))
+                data = os.urandom(size)
+                produce_hash.update(data)
+                yield ws.write(data)
+                remaining -= size
+            assert remaining == 0
+
+        @gen.coroutine
+        def consume():
+            remaining = nbytes
+            while remaining > 0:
+                if r.random() > 0.5:
+                    # read_bytes()
+                    size = r.randint(1, min(1000, remaining))
+                    data = yield rs.read_bytes(size)
+                    consume_hash.update(data)
+                    remaining -= size
+                else:
+                    # read_into()
+                    size = r.randint(1, min(1000, remaining))
+                    buf = bytearray(size)
+                    n = yield rs.read_into(buf)
+                    assert n == size
+                    consume_hash.update(buf)
+                    remaining -= size
+            assert remaining == 0
+
+        @gen.coroutine
+        def main():
+            yield [produce(), consume()]
+            assert produce_hash.hexdigest() == consume_hash.hexdigest()
+
+        try:
+            self.io_loop.run_sync(main)
+        finally:
+            ws.close()
+            rs.close()
+
 
 class TestIOStreamMixin(TestReadWriteMixin):
     def _make_server_iostream(self, connection, **kwargs):