]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Issue #2147: avoid copies on non-small writes (#2169)
authorAntoine Pitrou <pitrou@free.fr>
Thu, 16 Nov 2017 03:18:06 +0000 (04:18 +0100)
committerBen Darnell <ben@bendarnell.com>
Thu, 16 Nov 2017 03:18:06 +0000 (22:18 -0500)
tornado/iostream.py
tornado/test/iostream_test.py

index fa0a27736b0460c0fa3fbbbf6a156025a6e1a45f..bfd8dae789d61ffeba33adcf5d6cf8f614ee45ee 100644 (file)
@@ -117,6 +117,96 @@ class StreamBufferFullError(Exception):
     """
 
 
+class _StreamBuffer(object):
+    """
+    A specialized buffer that tries to avoid copies when large pieces
+    of data are encountered.
+    """
+
+    def __init__(self):
+        # A sequence of (False, bytearray) and (True, memoryview) objects
+        self._buffers = collections.deque()
+        # Position in the first buffer
+        self._first_pos = 0
+        self._size = 0
+
+    def __len__(self):
+        return self._size
+
+    # Data above this size will be appended separately instead
+    # of extending an existing bytearray
+    _large_buf_threshold = 2048
+
+    def append(self, data):
+        """
+        Append the given piece of data (should be a buffer-compatible object).
+        """
+        size = len(data)
+        if size > self._large_buf_threshold:
+            if not isinstance(data, memoryview):
+                data = memoryview(data)
+            self._buffers.append((True, data))
+        elif size > 0:
+            if self._buffers:
+                is_memview, b = self._buffers[-1]
+                new_buf = is_memview or len(b) >= self._large_buf_threshold
+            else:
+                new_buf = True
+            if new_buf:
+                self._buffers.append((False, bytearray(data)))
+            else:
+                b += data
+
+        self._size += size
+
+    def peek(self, size):
+        """
+        Get a view over at most ``size`` bytes (possibly fewer) at the
+        current buffer position.
+        """
+        assert size > 0
+        try:
+            is_memview, b = self._buffers[0]
+        except IndexError:
+            return memoryview(b'')
+
+        pos = self._first_pos
+        if is_memview:
+            return b[pos:pos + size]
+        else:
+            return memoryview(b)[pos:pos + size]
+
+    def advance(self, size):
+        """
+        Advance the current buffer position by ``size`` bytes.
+        """
+        assert 0 < size <= self._size
+        self._size -= size
+        pos = self._first_pos
+
+        buffers = self._buffers
+        while buffers and size > 0:
+            is_large, b = buffers[0]
+            b_remain = len(b) - size - pos
+            if b_remain <= 0:
+                buffers.popleft()
+                size -= len(b) - pos
+                pos = 0
+            elif is_large:
+                pos += size
+                size = 0
+            else:
+                # Amortized O(1) shrink for Python 2
+                pos += size
+                if len(b) <= 2 * pos:
+                    del b[:pos]
+                    pos = 0
+                size = 0
+
+        assert size == 0
+        self._first_pos = pos
+
+
 class BaseIOStream(object):
     """A utility class to write to and read from a non-blocking file or socket.
 
@@ -164,9 +254,7 @@ class BaseIOStream(object):
         self._read_buffer = bytearray()
         self._read_buffer_pos = 0
         self._read_buffer_size = 0
-        self._write_buffer = bytearray()
-        self._write_buffer_pos = 0
-        self._write_buffer_size = 0
+        self._write_buffer = _StreamBuffer()
         self._total_write_index = 0
         self._total_write_done_index = 0
         self._read_delimiter = None
@@ -386,10 +474,9 @@ class BaseIOStream(object):
         self._check_closed()
         if data:
             if (self.max_write_buffer_size is not None and
-                    self._write_buffer_size + len(data) > self.max_write_buffer_size):
+                    len(self._write_buffer) + len(data) > self.max_write_buffer_size):
                 raise StreamBufferFullError("Reached maximum write buffer size")
-            self._write_buffer += data
-            self._write_buffer_size += len(data)
+            self._write_buffer.append(data)
             self._total_write_index += len(data)
         if callback is not None:
             self._write_callback = stack_context.wrap(callback)
@@ -400,7 +487,7 @@ class BaseIOStream(object):
             self._write_futures.append((self._total_write_index, future))
         if not self._connecting:
             self._handle_write()
-            if self._write_buffer_size:
+            if self._write_buffer:
                 self._add_io_state(self.io_loop.WRITE)
             self._maybe_add_error_listener()
         return future
@@ -474,7 +561,6 @@ class BaseIOStream(object):
             # if the IOStream object is kept alive by a reference cycle.
             # TODO: Clear the read buffer too; it currently breaks some tests.
             self._write_buffer = None
-            self._write_buffer_size = 0
 
     def reading(self):
         """Returns true if we are currently reading from the stream."""
@@ -482,7 +568,7 @@ class BaseIOStream(object):
 
     def writing(self):
         """Returns true if we are currently writing to the stream."""
-        return self._write_buffer_size > 0
+        return bool(self._write_buffer)
 
     def closed(self):
         """Returns true if the stream has been closed."""
@@ -829,10 +915,12 @@ class BaseIOStream(object):
                     delimiter, self._read_max_bytes))
 
     def _handle_write(self):
-        while self._write_buffer_size:
-            assert self._write_buffer_size >= 0
+        while True:
+            size = len(self._write_buffer)
+            if not size:
+                break
+            assert size > 0
             try:
-                start = self._write_buffer_pos
                 if _WINDOWS:
                     # On windows, socket.send blows up if given a
                     # write buffer that's too large, instead of just
@@ -840,20 +928,11 @@ class BaseIOStream(object):
                     # process.  Therefore we must not call socket.send
                     # with more than 128KB at a time.
                     size = 128 * 1024
-                else:
-                    size = self._write_buffer_size
-                num_bytes = self.write_to_fd(
-                    memoryview(self._write_buffer)[start:start + size])
+
+                num_bytes = self.write_to_fd(self._write_buffer.peek(size))
                 if num_bytes == 0:
                     break
-                self._write_buffer_pos += num_bytes
-                self._write_buffer_size -= num_bytes
-                # Amortized O(1) shrink
-                # (this heuristic is implemented natively in Python 3.4+
-                #  but is replicated here for Python 2)
-                if self._write_buffer_pos > self._write_buffer_size:
-                    del self._write_buffer[:self._write_buffer_pos]
-                    self._write_buffer_pos = 0
+                self._write_buffer.advance(num_bytes)
                 self._total_write_done_index += num_bytes
             except (socket.error, IOError, OSError) as e:
                 if e.args[0] in _ERRNO_WOULDBLOCK:
@@ -875,7 +954,7 @@ class BaseIOStream(object):
             self._write_futures.popleft()
             future.set_result(None)
 
-        if not self._write_buffer_size:
+        if not len(self._write_buffer):
             if self._write_callback:
                 callback = self._write_callback
                 self._write_callback = None
index 4df23dd2bd5bd0a301ef541a61e3634000e8a680..9f3273ab56ef6b890e71fa8cfaf5609c3fa9a868 100644 (file)
@@ -2,7 +2,7 @@ from __future__ import absolute_import, division, print_function
 from tornado.concurrent import Future
 from tornado import gen
 from tornado import netutil
-from tornado.iostream import IOStream, SSLIOStream, PipeIOStream, StreamClosedError
+from tornado.iostream import IOStream, SSLIOStream, PipeIOStream, StreamClosedError, _StreamBuffer
 from tornado.httputil import HTTPHeaders
 from tornado.log import gen_log, app_log
 from tornado.netutil import ssl_wrap_socket
@@ -12,9 +12,11 @@ from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase
 from tornado.test.util import unittest, skipIfNonUnix, refusing_port, skipPypy3V58
 from tornado.web import RequestHandler, Application
 import errno
+import io
 import logging
 import os
 import platform
+import random
 import socket
 import ssl
 import sys
@@ -1136,3 +1138,117 @@ class TestPipeIOStream(AsyncTestCase):
 
         ws.close()
         rs.close()
+
+
+class TestStreamBuffer(unittest.TestCase):
+    """
+    Unit tests for the private _StreamBuffer class.
+    """
+
+    def setUp(self):
+        self.random = random.Random(42)
+
+    def to_bytes(self, b):
+        if isinstance(b, (bytes, bytearray)):
+            return bytes(b)
+        elif isinstance(b, memoryview):
+            return b.tobytes()   # For py2
+        else:
+            raise TypeError(b)
+
+    def make_streambuffer(self, large_buf_threshold=10):
+        buf = _StreamBuffer()
+        assert buf._large_buf_threshold
+        buf._large_buf_threshold = large_buf_threshold
+        return buf
+
+    def check_peek(self, buf, expected):
+        size = 1
+        while size < 2 * len(expected):
+            got = self.to_bytes(buf.peek(size))
+            self.assertTrue(got)   # Not empty
+            self.assertLessEqual(len(got), size)
+            self.assertTrue(expected.startswith(got), (expected, got))
+            size = (size * 3 + 1) // 2
+
+    def check_append_all_then_skip_all(self, buf, objs, input_type):
+        self.assertEqual(len(buf), 0)
+
+        total_expected = b''.join(objs)
+        expected = b''
+
+        for o in objs:
+            expected += o
+            buf.append(input_type(o))
+            self.assertEqual(len(buf), len(expected))
+            self.check_peek(buf, expected)
+
+        total_size = len(expected)
+
+        while expected:
+            n = self.random.randrange(1, len(expected) + 1)
+            expected = expected[n:]
+            buf.advance(n)
+            self.assertEqual(len(buf), len(expected))
+            self.check_peek(buf, expected)
+
+        self.assertEqual(len(buf), 0)
+
+    def test_small(self):
+        objs = [b'12', b'345', b'67', b'89a', b'bcde', b'fgh', b'ijklmn']
+
+        buf = self.make_streambuffer()
+        self.check_append_all_then_skip_all(buf, objs, bytes)
+
+        buf = self.make_streambuffer()
+        self.check_append_all_then_skip_all(buf, objs, bytearray)
+
+        buf = self.make_streambuffer()
+        self.check_append_all_then_skip_all(buf, objs, memoryview)
+
+        # Test internal algorithm
+        buf = self.make_streambuffer(10)
+        for i in range(9):
+            buf.append(b'x')
+        self.assertEqual(len(buf._buffers), 1)
+        for i in range(9):
+            buf.append(b'x')
+        self.assertEqual(len(buf._buffers), 2)
+        buf.advance(10)
+        self.assertEqual(len(buf._buffers), 1)
+        buf.advance(8)
+        self.assertEqual(len(buf._buffers), 0)
+        self.assertEqual(len(buf), 0)
+
+    def test_large(self):
+        objs = [b'12' * 5,
+                b'345' * 2,
+                b'67' * 20,
+                b'89a' * 12,
+                b'bcde' * 1,
+                b'fgh' * 7,
+                b'ijklmn' * 2]
+
+        buf = self.make_streambuffer()
+        self.check_append_all_then_skip_all(buf, objs, bytes)
+
+        buf = self.make_streambuffer()
+        self.check_append_all_then_skip_all(buf, objs, bytearray)
+
+        buf = self.make_streambuffer()
+        self.check_append_all_then_skip_all(buf, objs, memoryview)
+
+        # Test internal algorithm
+        buf = self.make_streambuffer(10)
+        for i in range(3):
+            buf.append(b'x' * 11)
+        self.assertEqual(len(buf._buffers), 3)
+        buf.append(b'y')
+        self.assertEqual(len(buf._buffers), 4)
+        buf.append(b'z')
+        self.assertEqual(len(buf._buffers), 4)
+        buf.advance(33)
+        self.assertEqual(len(buf._buffers), 1)
+        buf.advance(2)
+        self.assertEqual(len(buf._buffers), 0)
+        self.assertEqual(len(buf), 0)