self._read_buffer = bytearray()
self._read_buffer_pos = 0
self._read_buffer_size = 0
+ self._user_read_buffer = False
+ self._after_user_read_buffer = None
self._write_buffer = _StreamBuffer()
self._total_write_index = 0
self._total_write_done_index = 0
raise
return future
+ def read_into(self, buf, callback=None, partial=False):
+ """Asynchronously read a number of bytes.
+
+ ``buf`` must be a writable buffer into which data will be read.
+ If a callback is given, it will be run with the number of read
+ bytes as an argument; if not, this method returns a `.Future`.
+
+ If ``partial`` is true, the callback is run as soon as any bytes
+ have been read. Otherwise, it is run when the ``buf`` has been
+ entirely filled with read data.
+
+ .. versionadded:: 5.0
+ """
+ future = self._set_read_callback(callback)
+
+ # First copy data already in read buffer
+ available_bytes = self._read_buffer_size
+ n = len(buf)
+ if available_bytes >= n:
+ end = self._read_buffer_pos + n
+ buf[:] = memoryview(self._read_buffer)[self._read_buffer_pos:end]
+ del self._read_buffer[:end]
+ self._after_user_read_buffer = self._read_buffer
+ elif available_bytes > 0:
+ buf[:available_bytes] = memoryview(self._read_buffer)[self._read_buffer_pos:]
+
+ # Set up the supplied buffer as our temporary read buffer.
+ # The original (if it had any data remaining) has been
+ # saved for later.
+ self._user_read_buffer = True
+ self._read_buffer = buf
+ self._read_buffer_pos = 0
+ self._read_buffer_size = available_bytes
+ self._read_bytes = n
+ self._read_partial = partial
+
+ try:
+ self._try_inline_read()
+ except:
+ if future is not None:
+ future.add_done_callback(lambda f: f.exception())
+ raise
+ return future
+
def read_until_close(self, callback=None, streaming_callback=None):
"""Asynchronously reads all data from the socket until it is closed.
return self._read_future
def _run_read_callback(self, size, streaming):
+ if self._user_read_buffer:
+ self._read_buffer = self._after_user_read_buffer or bytearray()
+ self._after_user_read_buffer = None
+ self._read_buffer_pos = 0
+ self._read_buffer_size = len(self._read_buffer)
+ self._user_read_buffer = False
+ result = size
+ else:
+ result = self._consume(size)
if streaming:
callback = self._streaming_callback
else:
assert callback is None
future = self._read_future
self._read_future = None
- future.set_result(self._consume(size))
+
+ future.set_result(result)
if callback is not None:
assert (self._read_future is None) or streaming
- self._run_callback(callback, self._consume(size))
+ self._run_callback(callback, result)
else:
# If we scheduled a callback, we will add the error listener
# afterwards. If we didn't, we have to do it now.
try:
while True:
try:
- buf = bytearray(self.read_chunk_size)
+ if self._user_read_buffer:
+ buf = memoryview(self._read_buffer)[self._read_buffer_size:]
+ else:
+ buf = bytearray(self.read_chunk_size)
bytes_read = self.read_from_fd(buf)
except (socket.error, IOError, OSError) as e:
if errno_from_exception(e) == errno.EINTR:
elif bytes_read == 0:
self.close()
return 0
- self._read_buffer += memoryview(buf)[:bytes_read]
+ if not self._user_read_buffer:
+ self._read_buffer += memoryview(buf)[:bytes_read]
self._read_buffer_size += bytes_read
finally:
# Break the reference to buf so we don't waste a chunk's worth of
from tornado.test.util import unittest, skipIfNonUnix, refusing_port, skipPypy3V58
from tornado.web import RequestHandler, Application
import errno
+import hashlib
import logging
import os
import platform
rs.close()
ws.close()
+ def test_read_into(self):
+ rs, ws = self.make_iostream_pair()
+
+ def sleep_some():
+ self.io_loop.run_sync(lambda: gen.sleep(0.05))
+ try:
+ buf = bytearray(10)
+ rs.read_into(buf, callback=self.stop)
+ ws.write(b"hello")
+ sleep_some()
+ self.assertTrue(rs.reading())
+ ws.write(b"world!!")
+ data = self.wait()
+ self.assertFalse(rs.reading())
+ self.assertEqual(data, 10)
+ self.assertEqual(bytes(buf), b"helloworld")
+
+ # Existing buffer is fed into user buffer
+ rs.read_into(buf, callback=self.stop)
+ sleep_some()
+ self.assertTrue(rs.reading())
+ ws.write(b"1234567890")
+ data = self.wait()
+ self.assertFalse(rs.reading())
+ self.assertEqual(data, 10)
+ self.assertEqual(bytes(buf), b"!!12345678")
+
+ # Existing buffer can satisfy read immediately
+ buf = bytearray(4)
+ ws.write(b"abcdefghi")
+ rs.read_into(buf, callback=self.stop)
+ data = self.wait()
+ self.assertEqual(data, 4)
+ self.assertEqual(bytes(buf), b"90ab")
+
+ rs.read_bytes(7, self.stop)
+ data = self.wait()
+ self.assertEqual(data, b"cdefghi")
+ finally:
+ ws.close()
+ rs.close()
+
+ def test_read_into_partial(self):
+ rs, ws = self.make_iostream_pair()
+
+ def sleep_some():
+ self.io_loop.run_sync(lambda: gen.sleep(0.05))
+ try:
+ # Partial read
+ buf = bytearray(10)
+ rs.read_into(buf, callback=self.stop, partial=True)
+ ws.write(b"hello")
+ data = self.wait()
+ self.assertFalse(rs.reading())
+ self.assertEqual(data, 5)
+ self.assertEqual(bytes(buf), b"hello\0\0\0\0\0")
+
+ # Full read despite partial=True
+ ws.write(b"world!1234567890")
+ rs.read_into(buf, callback=self.stop, partial=True)
+ data = self.wait()
+ self.assertEqual(data, 10)
+ self.assertEqual(bytes(buf), b"world!1234")
+
+ # Existing buffer can satisfy read immediately
+ rs.read_into(buf, callback=self.stop, partial=True)
+ data = self.wait()
+ self.assertEqual(data, 6)
+ self.assertEqual(bytes(buf), b"5678901234")
+
+ finally:
+ ws.close()
+ rs.close()
+
+ def test_read_into_zero_bytes(self):
+ rs, ws = self.make_iostream_pair()
+ try:
+ buf = bytearray()
+ fut = rs.read_into(buf)
+ self.assertEqual(fut.result(), 0)
+ finally:
+ ws.close()
+ rs.close()
+
+ def test_many_mixed_reads(self):
+ # Stress buffer handling when going back and forth between
+ # read_bytes() (using an internal buffer) and read_into()
+ # (using a user-allocated buffer).
+ r = random.Random(42)
+ nbytes = 1000000
+ rs, ws = self.make_iostream_pair()
+
+ produce_hash = hashlib.sha1()
+ consume_hash = hashlib.sha1()
+
+ @gen.coroutine
+ def produce():
+ remaining = nbytes
+ while remaining > 0:
+ size = r.randint(1, min(1000, remaining))
+ data = os.urandom(size)
+ produce_hash.update(data)
+ yield ws.write(data)
+ remaining -= size
+ assert remaining == 0
+
+ @gen.coroutine
+ def consume():
+ remaining = nbytes
+ while remaining > 0:
+ if r.random() > 0.5:
+ # read_bytes()
+ size = r.randint(1, min(1000, remaining))
+ data = yield rs.read_bytes(size)
+ consume_hash.update(data)
+ remaining -= size
+ else:
+ # read_into()
+ size = r.randint(1, min(1000, remaining))
+ buf = bytearray(size)
+ n = yield rs.read_into(buf)
+ assert n == size
+ consume_hash.update(buf)
+ remaining -= size
+ assert remaining == 0
+
+ @gen.coroutine
+ def main():
+ yield [produce(), consume()]
+ assert produce_hash.hexdigest() == consume_hash.hexdigest()
+
+ try:
+ self.io_loop.run_sync(main)
+ finally:
+ ws.close()
+ rs.close()
+
class TestIOStreamMixin(TestReadWriteMixin):
def _make_server_iostream(self, connection, **kwargs):