]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
GH-91166: Implement zero copy writes for `SelectorSocketTransport` in asyncio (#31871)
authorKumar Aditya <59607654+kumaraditya303@users.noreply.github.com>
Sat, 24 Dec 2022 05:51:11 +0000 (11:21 +0530)
committerGitHub <noreply@github.com>
Sat, 24 Dec 2022 05:51:11 +0000 (11:21 +0530)
Co-authored-by: Guido van Rossum <gvanrossum@gmail.com>
Lib/asyncio/selector_events.py
Lib/test/test_asyncio/test_selector_events.py
Misc/NEWS.d/next/Library/2022-10-24-07-31-11.gh-issue-91166.-IG06R.rst [new file with mode: 0644]

index 74f289f0e6f81165a6f18c1aeb22cd8874e0f0d5..de5076a96218e0f4f0e60635258a8be04a51f1c0 100644 (file)
@@ -9,6 +9,8 @@ __all__ = 'BaseSelectorEventLoop',
 import collections
 import errno
 import functools
+import itertools
+import os
 import selectors
 import socket
 import warnings
@@ -28,6 +30,14 @@ from . import transports
 from . import trsock
 from .log import logger
 
+_HAS_SENDMSG = hasattr(socket.socket, 'sendmsg')
+
+if _HAS_SENDMSG:
+    try:
+        SC_IOV_MAX = os.sysconf('SC_IOV_MAX')
+    except OSError:
+        # Fallback to send
+        _HAS_SENDMSG = False
 
 def _test_selector_event(selector, fd, event):
     # Test if the selector is monitoring 'event' events
@@ -757,8 +767,6 @@ class _SelectorTransport(transports._FlowControlMixin,
 
     max_size = 256 * 1024  # Buffer size passed to recv().
 
-    _buffer_factory = bytearray  # Constructs initial value for self._buffer.
-
     # Attribute used in the destructor: it must be set even if the constructor
     # is not called (see _SelectorSslTransport which may start by raising an
     # exception)
@@ -783,7 +791,7 @@ class _SelectorTransport(transports._FlowControlMixin,
         self.set_protocol(protocol)
 
         self._server = server
-        self._buffer = self._buffer_factory()
+        self._buffer = collections.deque()
         self._conn_lost = 0  # Set when call to connection_lost scheduled.
         self._closing = False  # Set when close() called.
         if self._server is not None:
@@ -887,7 +895,7 @@ class _SelectorTransport(transports._FlowControlMixin,
                 self._server = None
 
     def get_write_buffer_size(self):
-        return len(self._buffer)
+        return sum(map(len, self._buffer))
 
     def _add_reader(self, fd, callback, *args):
         if self._closing:
@@ -909,7 +917,10 @@ class _SelectorSocketTransport(_SelectorTransport):
         self._eof = False
         self._paused = False
         self._empty_waiter = None
-
+        if _HAS_SENDMSG:
+            self._write_ready = self._write_sendmsg
+        else:
+            self._write_ready = self._write_send
         # Disable the Nagle algorithm -- small writes will be
         # sent without waiting for the TCP ACK.  This generally
         # decreases the latency (in some cases significantly.)
@@ -1066,23 +1077,68 @@ class _SelectorSocketTransport(_SelectorTransport):
                 self._fatal_error(exc, 'Fatal write error on socket transport')
                 return
             else:
-                data = data[n:]
+                data = memoryview(data)[n:]
                 if not data:
                     return
             # Not all was written; register write handler.
             self._loop._add_writer(self._sock_fd, self._write_ready)
 
         # Add it to the buffer.
-        self._buffer.extend(data)
+        self._buffer.append(data)
         self._maybe_pause_protocol()
 
-    def _write_ready(self):
+    def _get_sendmsg_buffer(self):
+        return itertools.islice(self._buffer, SC_IOV_MAX)
+
+    def _write_sendmsg(self):
         assert self._buffer, 'Data should not be empty'
+        if self._conn_lost:
+            return
+        try:
+            nbytes = self._sock.sendmsg(self._get_sendmsg_buffer())
+            self._adjust_leftover_buffer(nbytes)
+        except (BlockingIOError, InterruptedError):
+            pass
+        except (SystemExit, KeyboardInterrupt):
+            raise
+        except BaseException as exc:
+            self._loop._remove_writer(self._sock_fd)
+            self._buffer.clear()
+            self._fatal_error(exc, 'Fatal write error on socket transport')
+            if self._empty_waiter is not None:
+                self._empty_waiter.set_exception(exc)
+        else:
+            self._maybe_resume_protocol()  # May append to buffer.
+            if not self._buffer:
+                self._loop._remove_writer(self._sock_fd)
+                if self._empty_waiter is not None:
+                    self._empty_waiter.set_result(None)
+                if self._closing:
+                    self._call_connection_lost(None)
+                elif self._eof:
+                    self._sock.shutdown(socket.SHUT_WR)
 
+    def _adjust_leftover_buffer(self, nbytes: int) -> None:
+        buffer = self._buffer
+        while nbytes:
+            b = buffer.popleft()
+            b_len = len(b)
+            if b_len <= nbytes:
+                nbytes -= b_len
+            else:
+                buffer.appendleft(b[nbytes:])
+                break
+
+    def _write_send(self):
+        assert self._buffer, 'Data should not be empty'
         if self._conn_lost:
             return
         try:
-            n = self._sock.send(self._buffer)
+            buffer = self._buffer.popleft()
+            n = self._sock.send(buffer)
+            if n != len(buffer):
+                # Not all data was written
+                self._buffer.appendleft(buffer[n:])
         except (BlockingIOError, InterruptedError):
             pass
         except (SystemExit, KeyboardInterrupt):
@@ -1094,8 +1150,6 @@ class _SelectorSocketTransport(_SelectorTransport):
             if self._empty_waiter is not None:
                 self._empty_waiter.set_exception(exc)
         else:
-            if n:
-                del self._buffer[:n]
             self._maybe_resume_protocol()  # May append to buffer.
             if not self._buffer:
                 self._loop._remove_writer(self._sock_fd)
@@ -1113,6 +1167,16 @@ class _SelectorSocketTransport(_SelectorTransport):
         if not self._buffer:
             self._sock.shutdown(socket.SHUT_WR)
 
+    def writelines(self, list_of_data):
+        if self._eof:
+            raise RuntimeError('Cannot call writelines() after write_eof()')
+        if self._empty_waiter is not None:
+            raise RuntimeError('unable to writelines; sendfile is in progress')
+        if not list_of_data:
+            return
+        self._buffer.extend([memoryview(data) for data in list_of_data])
+        self._write_ready()
+
     def can_write_eof(self):
         return True
 
index ca555387dd2493d28bb7fcbdf686b1fb0a9887fa..921c98a2702d76576a4a634174f85ed4ca718bbb 100644 (file)
@@ -1,23 +1,25 @@
 """Tests for selector_events.py"""
 
-import sys
+import collections
 import selectors
 import socket
+import sys
 import unittest
+from asyncio import selector_events
 from unittest import mock
+
 try:
     import ssl
 except ImportError:
     ssl = None
 
 import asyncio
-from asyncio.selector_events import BaseSelectorEventLoop
-from asyncio.selector_events import _SelectorTransport
-from asyncio.selector_events import _SelectorSocketTransport
-from asyncio.selector_events import _SelectorDatagramTransport
+from asyncio.selector_events import (BaseSelectorEventLoop,
+                                     _SelectorDatagramTransport,
+                                     _SelectorSocketTransport,
+                                     _SelectorTransport)
 from test.test_asyncio import utils as test_utils
 
-
 MOCK_ANY = mock.ANY
 
 
@@ -37,7 +39,10 @@ class TestBaseSelectorEventLoop(BaseSelectorEventLoop):
 
 
 def list_to_buffer(l=()):
-    return bytearray().join(l)
+    buffer = collections.deque()
+    buffer.extend((memoryview(i) for i in l))
+    return buffer
+
 
 
 def close_transport(transport):
@@ -493,9 +498,13 @@ class SelectorSocketTransportTests(test_utils.TestCase):
         self.sock = mock.Mock(socket.socket)
         self.sock_fd = self.sock.fileno.return_value = 7
 
-    def socket_transport(self, waiter=None):
+    def socket_transport(self, waiter=None, sendmsg=False):
         transport = _SelectorSocketTransport(self.loop, self.sock,
                                              self.protocol, waiter=waiter)
+        if sendmsg:
+            transport._write_ready = transport._write_sendmsg
+        else:
+            transport._write_ready = transport._write_send
         self.addCleanup(close_transport, transport)
         return transport
 
@@ -664,14 +673,14 @@ class SelectorSocketTransportTests(test_utils.TestCase):
 
     def test_write_no_data(self):
         transport = self.socket_transport()
-        transport._buffer.extend(b'data')
+        transport._buffer.append(memoryview(b'data'))
         transport.write(b'')
         self.assertFalse(self.sock.send.called)
         self.assertEqual(list_to_buffer([b'data']), transport._buffer)
 
     def test_write_buffer(self):
         transport = self.socket_transport()
-        transport._buffer.extend(b'data1')
+        transport._buffer.append(b'data1')
         transport.write(b'data2')
         self.assertFalse(self.sock.send.called)
         self.assertEqual(list_to_buffer([b'data1', b'data2']),
@@ -729,6 +738,77 @@ class SelectorSocketTransportTests(test_utils.TestCase):
         self.loop.assert_writer(7, transport._write_ready)
         self.assertEqual(list_to_buffer([b'data']), transport._buffer)
 
+    def test_write_sendmsg_no_data(self):
+        self.sock.sendmsg = mock.Mock()
+        self.sock.sendmsg.return_value = 0
+        transport = self.socket_transport(sendmsg=True)
+        transport._buffer.append(memoryview(b'data'))
+        transport.write(b'')
+        self.assertFalse(self.sock.sendmsg.called)
+        self.assertEqual(list_to_buffer([b'data']), transport._buffer)
+
+    @unittest.skipUnless(selector_events._HAS_SENDMSG, 'no sendmsg')
+    def test_write_sendmsg_full(self):
+        data = memoryview(b'data')
+        self.sock.sendmsg = mock.Mock()
+        self.sock.sendmsg.return_value = len(data)
+
+        transport = self.socket_transport(sendmsg=True)
+        transport._buffer.append(data)
+        self.loop._add_writer(7, transport._write_ready)
+        transport._write_ready()
+        self.assertTrue(self.sock.sendmsg.called)
+        self.assertFalse(self.loop.writers)
+
+    @unittest.skipUnless(selector_events._HAS_SENDMSG, 'no sendmsg')
+    def test_write_sendmsg_partial(self):
+
+        data = memoryview(b'data')
+        self.sock.sendmsg = mock.Mock()
+        # Sent partial data
+        self.sock.sendmsg.return_value = 2
+
+        transport = self.socket_transport(sendmsg=True)
+        transport._buffer.append(data)
+        self.loop._add_writer(7, transport._write_ready)
+        transport._write_ready()
+        self.assertTrue(self.sock.sendmsg.called)
+        self.assertTrue(self.loop.writers)
+        self.assertEqual(list_to_buffer([b'ta']), transport._buffer)
+
+    @unittest.skipUnless(selector_events._HAS_SENDMSG, 'no sendmsg')
+    def test_write_sendmsg_half_buffer(self):
+        data = [memoryview(b'data1'), memoryview(b'data2')]
+        self.sock.sendmsg = mock.Mock()
+        # Sent partial data
+        self.sock.sendmsg.return_value = 2
+
+        transport = self.socket_transport(sendmsg=True)
+        transport._buffer.extend(data)
+        self.loop._add_writer(7, transport._write_ready)
+        transport._write_ready()
+        self.assertTrue(self.sock.sendmsg.called)
+        self.assertTrue(self.loop.writers)
+        self.assertEqual(list_to_buffer([b'ta1', b'data2']), transport._buffer)
+
+    @unittest.skipUnless(selector_events._HAS_SENDMSG, 'no sendmsg')
+    def test_write_sendmsg_OSError(self):
+        data = memoryview(b'data')
+        self.sock.sendmsg = mock.Mock()
+        err = self.sock.sendmsg.side_effect = OSError()
+
+        transport = self.socket_transport(sendmsg=True)
+        transport._fatal_error = mock.Mock()
+        transport._buffer.extend(data)
+        # Calls _fatal_error and clears the buffer
+        transport._write_ready()
+        self.assertTrue(self.sock.sendmsg.called)
+        self.assertFalse(self.loop.writers)
+        self.assertEqual(list_to_buffer([]), transport._buffer)
+        transport._fatal_error.assert_called_with(
+                                   err,
+                                   'Fatal write error on socket transport')
+
     @mock.patch('asyncio.selector_events.logger')
     def test_write_exception(self, m_log):
         err = self.sock.send.side_effect = OSError()
@@ -768,19 +848,19 @@ class SelectorSocketTransportTests(test_utils.TestCase):
         self.sock.send.return_value = len(data)
 
         transport = self.socket_transport()
-        transport._buffer.extend(data)
+        transport._buffer.append(data)
         self.loop._add_writer(7, transport._write_ready)
         transport._write_ready()
         self.assertTrue(self.sock.send.called)
         self.assertFalse(self.loop.writers)
 
     def test_write_ready_closing(self):
-        data = b'data'
+        data = memoryview(b'data')
         self.sock.send.return_value = len(data)
 
         transport = self.socket_transport()
         transport._closing = True
-        transport._buffer.extend(data)
+        transport._buffer.append(data)
         self.loop._add_writer(7, transport._write_ready)
         transport._write_ready()
         self.assertTrue(self.sock.send.called)
@@ -795,11 +875,11 @@ class SelectorSocketTransportTests(test_utils.TestCase):
         self.assertRaises(AssertionError, transport._write_ready)
 
     def test_write_ready_partial(self):
-        data = b'data'
+        data = memoryview(b'data')
         self.sock.send.return_value = 2
 
         transport = self.socket_transport()
-        transport._buffer.extend(data)
+        transport._buffer.append(data)
         self.loop._add_writer(7, transport._write_ready)
         transport._write_ready()
         self.loop.assert_writer(7, transport._write_ready)
@@ -810,7 +890,7 @@ class SelectorSocketTransportTests(test_utils.TestCase):
         self.sock.send.return_value = 0
 
         transport = self.socket_transport()
-        transport._buffer.extend(data)
+        transport._buffer.append(data)
         self.loop._add_writer(7, transport._write_ready)
         transport._write_ready()
         self.loop.assert_writer(7, transport._write_ready)
@@ -820,12 +900,13 @@ class SelectorSocketTransportTests(test_utils.TestCase):
         self.sock.send.side_effect = BlockingIOError
 
         transport = self.socket_transport()
-        transport._buffer = list_to_buffer([b'data1', b'data2'])
+        buffer = list_to_buffer([b'data1', b'data2'])
+        transport._buffer = buffer
         self.loop._add_writer(7, transport._write_ready)
         transport._write_ready()
 
         self.loop.assert_writer(7, transport._write_ready)
-        self.assertEqual(list_to_buffer([b'data1data2']), transport._buffer)
+        self.assertEqual(buffer, transport._buffer)
 
     def test_write_ready_exception(self):
         err = self.sock.send.side_effect = OSError()
diff --git a/Misc/NEWS.d/next/Library/2022-10-24-07-31-11.gh-issue-91166.-IG06R.rst b/Misc/NEWS.d/next/Library/2022-10-24-07-31-11.gh-issue-91166.-IG06R.rst
new file mode 100644 (file)
index 0000000..5ee08ec
--- /dev/null
@@ -0,0 +1 @@
+:mod:`asyncio` is optimized to avoid excessive copying when writing to socket and use :meth:`~socket.socket.sendmsg` if the platform supports it. Patch by Kumar Aditya.