]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
[3.12] gh-115514: Fix incomplete writes after close while using ssl in asyncio(GH...
authorMiss Islington (bot) <31488909+miss-islington@users.noreply.github.com>
Sun, 2 Feb 2025 15:47:37 +0000 (16:47 +0100)
committerGitHub <noreply@github.com>
Sun, 2 Feb 2025 15:47:37 +0000 (15:47 +0000)
gh-115514: Fix incomplete writes after close while using ssl in asyncio(GH-128037)

(cherry picked from commit 4e38eeafe2ff3bfc686514731d6281fed34a435e)

Co-authored-by: Vojtěch Boček <vbocek@gmail.com>
Co-authored-by: Kumar Aditya <kumaraditya@python.org>
Lib/asyncio/selector_events.py
Lib/test/test_asyncio/test_selector_events.py
Lib/test/test_asyncio/test_ssl.py
Misc/ACKS
Misc/NEWS.d/next/Library/2024-12-17-16-48-02.gh-issue-115514.1yOJ7T.rst [new file with mode: 0644]

index dd79ad18df3b18a8fc8af40ca642dd5387ac3ec2..160ed6ca13eb4314c3ad58482d6361c367a4b34d 100644 (file)
@@ -1189,10 +1189,13 @@ class _SelectorSocketTransport(_SelectorTransport):
         return True
 
     def _call_connection_lost(self, exc):
-        super()._call_connection_lost(exc)
-        if self._empty_waiter is not None:
-            self._empty_waiter.set_exception(
-                ConnectionError("Connection is closed by peer"))
+        try:
+            super()._call_connection_lost(exc)
+        finally:
+            self._write_ready = None
+            if self._empty_waiter is not None:
+                self._empty_waiter.set_exception(
+                    ConnectionError("Connection is closed by peer"))
 
     def _make_empty_waiter(self):
         if self._empty_waiter is not None:
@@ -1207,7 +1210,6 @@ class _SelectorSocketTransport(_SelectorTransport):
 
     def close(self):
         self._read_ready_cb = None
-        self._write_ready = None
         super().close()
 
 
index 736c19796ef3fc7f33bd2d5cf0ed59805bb09903..0d35a1f87f3edd967b98d6a37ca282a43e548acc 100644 (file)
@@ -1026,6 +1026,48 @@ class SelectorSocketTransportTests(test_utils.TestCase):
         transport.close()
         remove_writer.assert_called_with(self.sock_fd)
 
+    def test_write_buffer_after_close(self):
+        # gh-115514: If the transport is closed while:
+        #  * Transport write buffer is not empty
+        #  * Transport is paused
+        #  * Protocol has data in its buffer, like SSLProtocol in self._outgoing
+        # The data is still written out.
+
+        # Also tested with real SSL transport in
+        # test.test_asyncio.test_ssl.TestSSL.test_remote_shutdown_receives_trailing_data
+
+        data = memoryview(b'data')
+        self.sock.send.return_value = 2
+        self.sock.send.fileno.return_value = 7
+
+        def _resume_writing():
+            transport.write(b"data")
+            self.protocol.resume_writing.side_effect = None
+
+        self.protocol.resume_writing.side_effect = _resume_writing
+
+        transport = self.socket_transport()
+        transport._high_water = 1
+
+        transport.write(data)
+
+        self.assertTrue(transport._protocol_paused)
+        self.assertTrue(self.sock.send.called)
+        self.loop.assert_writer(7, transport._write_ready)
+
+        transport.close()
+
+        # not called, we still have data in write buffer
+        self.assertFalse(self.protocol.connection_lost.called)
+
+        self.loop.writers[7]._run()
+        # during this ^ run, the _resume_writing mock above was called and added more data
+
+        self.assertEqual(transport.get_write_buffer_size(), 2)
+        self.loop.writers[7]._run()
+
+        self.assertEqual(transport.get_write_buffer_size(), 0)
+        self.assertTrue(self.protocol.connection_lost.called)
 
 class SelectorSocketTransportBufferedProtocolTests(test_utils.TestCase):
 
index e072ede29ee3c7622334ddef22fccc9a620b240f..e4ab5a9024c95679fa25340938b8b578468447f1 100644 (file)
@@ -12,6 +12,7 @@ import sys
 import tempfile
 import threading
 import time
+import unittest.mock
 import weakref
 import unittest
 
@@ -1431,6 +1432,166 @@ class TestSSL(test_utils.TestCase):
         with self.tcp_server(run(eof_server)) as srv:
             self.loop.run_until_complete(client(srv.addr))
 
+    def test_remote_shutdown_receives_trailing_data_on_slow_socket(self):
+        # This test is the same as test_remote_shutdown_receives_trailing_data,
+        # except it simulates a socket that is not able to write data in time,
+        # thus triggering different code path in _SelectorSocketTransport.
+        # This triggers bug gh-115514, also tested using mocks in
+        # test.test_asyncio.test_selector_events.SelectorSocketTransportTests.test_write_buffer_after_close
+        # The slow path is triggered here by setting SO_SNDBUF, see code and comment below.
+
+        CHUNK = 1024 * 128
+        SIZE = 32
+
+        sslctx = self._create_server_ssl_context(
+            test_utils.ONLYCERT,
+            test_utils.ONLYKEY
+        )
+        client_sslctx = self._create_client_ssl_context()
+        future = None
+
+        def server(sock):
+            incoming = ssl.MemoryBIO()
+            outgoing = ssl.MemoryBIO()
+            sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True)
+
+            while True:
+                try:
+                    sslobj.do_handshake()
+                except ssl.SSLWantReadError:
+                    if outgoing.pending:
+                        sock.send(outgoing.read())
+                    incoming.write(sock.recv(16384))
+                else:
+                    if outgoing.pending:
+                        sock.send(outgoing.read())
+                    break
+
+            while True:
+                try:
+                    data = sslobj.read(4)
+                except ssl.SSLWantReadError:
+                    incoming.write(sock.recv(16384))
+                else:
+                    break
+
+            self.assertEqual(data, b'ping')
+            sslobj.write(b'pong')
+            sock.send(outgoing.read())
+
+            time.sleep(0.2)  # wait for the peer to fill its backlog
+
+            # send close_notify but don't wait for response
+            with self.assertRaises(ssl.SSLWantReadError):
+                sslobj.unwrap()
+            sock.send(outgoing.read())
+
+            # should receive all data
+            data_len = 0
+            while True:
+                try:
+                    chunk = len(sslobj.read(16384))
+                    data_len += chunk
+                except ssl.SSLWantReadError:
+                    incoming.write(sock.recv(16384))
+                except ssl.SSLZeroReturnError:
+                    break
+
+            self.assertEqual(data_len, CHUNK * SIZE*2)
+
+            # verify that close_notify is received
+            sslobj.unwrap()
+
+            sock.close()
+
+        def eof_server(sock):
+            sock.starttls(sslctx, server_side=True)
+            self.assertEqual(sock.recv_all(4), b'ping')
+            sock.send(b'pong')
+
+            time.sleep(0.2)  # wait for the peer to fill its backlog
+
+            # send EOF
+            sock.shutdown(socket.SHUT_WR)
+
+            # should receive all data
+            data = sock.recv_all(CHUNK * SIZE)
+            self.assertEqual(len(data), CHUNK * SIZE)
+
+            sock.close()
+
+        async def client(addr):
+            nonlocal future
+            future = self.loop.create_future()
+
+            reader, writer = await asyncio.open_connection(
+                *addr,
+                ssl=client_sslctx,
+                server_hostname='')
+            writer.write(b'ping')
+            data = await reader.readexactly(4)
+            self.assertEqual(data, b'pong')
+
+            # fill write backlog in a hacky way - renegotiation won't help
+            for _ in range(SIZE*2):
+                writer.transport._test__append_write_backlog(b'x' * CHUNK)
+
+            try:
+                data = await reader.read()
+                self.assertEqual(data, b'')
+            except (BrokenPipeError, ConnectionResetError):
+                pass
+
+            # Make sure _SelectorSocketTransport enters the delayed write
+            # path in its `write` method by wrapping socket in a fake class
+            # that acts as if there is not enough space in socket buffer.
+            # This triggers bug gh-115514, also tested using mocks in
+            # test.test_asyncio.test_selector_events.SelectorSocketTransportTests.test_write_buffer_after_close
+            socket_transport = writer.transport._ssl_protocol._transport
+
+            class SocketWrapper:
+                def __init__(self, sock) -> None:
+                    self.sock = sock
+
+                def __getattr__(self, name):
+                    return getattr(self.sock, name)
+
+                def send(self, data):
+                    # Fake that our write buffer is full, send only half
+                    to_send = len(data)//2
+                    return self.sock.send(data[:to_send])
+
+            def _fake_full_write_buffer(data):
+                if socket_transport._read_ready_cb is None and not isinstance(socket_transport._sock, SocketWrapper):
+                    socket_transport._sock = SocketWrapper(socket_transport._sock)
+                return unittest.mock.DEFAULT
+
+            with unittest.mock.patch.object(
+                socket_transport, "write",
+                wraps=socket_transport.write,
+                side_effect=_fake_full_write_buffer
+            ):
+                await future
+
+                writer.close()
+                await self.wait_closed(writer)
+
+        def run(meth):
+            def wrapper(sock):
+                try:
+                    meth(sock)
+                except Exception as ex:
+                    self.loop.call_soon_threadsafe(future.set_exception, ex)
+                else:
+                    self.loop.call_soon_threadsafe(future.set_result, None)
+            return wrapper
+
+        with self.tcp_server(run(server)) as srv:
+            self.loop.run_until_complete(client(srv.addr))
+
+        with self.tcp_server(run(eof_server)) as srv:
+            self.loop.run_until_complete(client(srv.addr))
+
     def test_connect_timeout_warning(self):
         s = socket.socket(socket.AF_INET)
         s.bind(('127.0.0.1', 0))
index 14450325647590cbd3bd001e4599af89fd2150d2..febc2f7b365346968d5eec534c600e41dab22afa 100644 (file)
--- a/Misc/ACKS
+++ b/Misc/ACKS
@@ -188,6 +188,7 @@ Stéphane Blondon
 Eric Blossom
 Sergey Bobrov
 Finn Bock
+Vojtěch Boček
 Paul Boddie
 Matthew Boedicker
 Robin Boerdijk
diff --git a/Misc/NEWS.d/next/Library/2024-12-17-16-48-02.gh-issue-115514.1yOJ7T.rst b/Misc/NEWS.d/next/Library/2024-12-17-16-48-02.gh-issue-115514.1yOJ7T.rst
new file mode 100644 (file)
index 0000000..24e836a
--- /dev/null
@@ -0,0 +1,2 @@
+Fix exceptions and incomplete writes after :class:`!asyncio._SelectorTransport`
+is closed before writes are completed.