transport.close()
remove_writer.assert_called_with(self.sock_fd)
+ def test_write_buffer_after_close(self):
+ # gh-115514: If the transport is closed while:
+ # * Transport write buffer is not empty
+ # * Transport is paused
+ # * Protocol has data in its buffer, like SSLProtocol in self._outgoing
+ # The data is still written out.
+
+ # Also tested with real SSL transport in
+ # test.test_asyncio.test_ssl.TestSSL.test_remote_shutdown_receives_trailing_data
+
+ data = memoryview(b'data')
+ self.sock.send.return_value = 2
+ self.sock.send.fileno.return_value = 7
+
+ def _resume_writing():
+ transport.write(b"data")
+ self.protocol.resume_writing.side_effect = None
+
+ self.protocol.resume_writing.side_effect = _resume_writing
+
+ transport = self.socket_transport()
+ transport._high_water = 1
+
+ transport.write(data)
+
+ self.assertTrue(transport._protocol_paused)
+ self.assertTrue(self.sock.send.called)
+ self.loop.assert_writer(7, transport._write_ready)
+
+ transport.close()
+
+ # not called, we still have data in write buffer
+ self.assertFalse(self.protocol.connection_lost.called)
+
+ self.loop.writers[7]._run()
+ # during this ^ run, the _resume_writing mock above was called and added more data
+
+ self.assertEqual(transport.get_write_buffer_size(), 2)
+ self.loop.writers[7]._run()
+
+ self.assertEqual(transport.get_write_buffer_size(), 0)
+ self.assertTrue(self.protocol.connection_lost.called)
class SelectorSocketTransportBufferedProtocolTests(test_utils.TestCase):
import tempfile
import threading
import time
+import unittest.mock
import weakref
import unittest
with self.tcp_server(run(eof_server)) as srv:
self.loop.run_until_complete(client(srv.addr))
+ def test_remote_shutdown_receives_trailing_data_on_slow_socket(self):
+ # This test is the same as test_remote_shutdown_receives_trailing_data,
+ # except it simulates a socket that is not able to write data in time,
+ # thus triggering different code path in _SelectorSocketTransport.
+ # This triggers bug gh-115514, also tested using mocks in
+ # test.test_asyncio.test_selector_events.SelectorSocketTransportTests.test_write_buffer_after_close
+ # The slow path is triggered here by setting SO_SNDBUF, see code and comment below.
+
+ CHUNK = 1024 * 128
+ SIZE = 32
+
+ sslctx = self._create_server_ssl_context(
+ test_utils.ONLYCERT,
+ test_utils.ONLYKEY
+ )
+ client_sslctx = self._create_client_ssl_context()
+ future = None
+
+ def server(sock):
+ incoming = ssl.MemoryBIO()
+ outgoing = ssl.MemoryBIO()
+ sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True)
+
+ while True:
+ try:
+ sslobj.do_handshake()
+ except ssl.SSLWantReadError:
+ if outgoing.pending:
+ sock.send(outgoing.read())
+ incoming.write(sock.recv(16384))
+ else:
+ if outgoing.pending:
+ sock.send(outgoing.read())
+ break
+
+ while True:
+ try:
+ data = sslobj.read(4)
+ except ssl.SSLWantReadError:
+ incoming.write(sock.recv(16384))
+ else:
+ break
+
+ self.assertEqual(data, b'ping')
+ sslobj.write(b'pong')
+ sock.send(outgoing.read())
+
+ time.sleep(0.2) # wait for the peer to fill its backlog
+
+ # send close_notify but don't wait for response
+ with self.assertRaises(ssl.SSLWantReadError):
+ sslobj.unwrap()
+ sock.send(outgoing.read())
+
+ # should receive all data
+ data_len = 0
+ while True:
+ try:
+ chunk = len(sslobj.read(16384))
+ data_len += chunk
+ except ssl.SSLWantReadError:
+ incoming.write(sock.recv(16384))
+ except ssl.SSLZeroReturnError:
+ break
+
+ self.assertEqual(data_len, CHUNK * SIZE*2)
+
+ # verify that close_notify is received
+ sslobj.unwrap()
+
+ sock.close()
+
+ def eof_server(sock):
+ sock.starttls(sslctx, server_side=True)
+ self.assertEqual(sock.recv_all(4), b'ping')
+ sock.send(b'pong')
+
+ time.sleep(0.2) # wait for the peer to fill its backlog
+
+ # send EOF
+ sock.shutdown(socket.SHUT_WR)
+
+ # should receive all data
+ data = sock.recv_all(CHUNK * SIZE)
+ self.assertEqual(len(data), CHUNK * SIZE)
+
+ sock.close()
+
+ async def client(addr):
+ nonlocal future
+ future = self.loop.create_future()
+
+ reader, writer = await asyncio.open_connection(
+ *addr,
+ ssl=client_sslctx,
+ server_hostname='')
+ writer.write(b'ping')
+ data = await reader.readexactly(4)
+ self.assertEqual(data, b'pong')
+
+ # fill write backlog in a hacky way - renegotiation won't help
+ for _ in range(SIZE*2):
+ writer.transport._test__append_write_backlog(b'x' * CHUNK)
+
+ try:
+ data = await reader.read()
+ self.assertEqual(data, b'')
+ except (BrokenPipeError, ConnectionResetError):
+ pass
+
+ # Make sure _SelectorSocketTransport enters the delayed write
+ # path in its `write` method by wrapping socket in a fake class
+ # that acts as if there is not enough space in socket buffer.
+ # This triggers bug gh-115514, also tested using mocks in
+ # test.test_asyncio.test_selector_events.SelectorSocketTransportTests.test_write_buffer_after_close
+ socket_transport = writer.transport._ssl_protocol._transport
+
+ class SocketWrapper:
+ def __init__(self, sock) -> None:
+ self.sock = sock
+
+ def __getattr__(self, name):
+ return getattr(self.sock, name)
+
+ def send(self, data):
+ # Fake that our write buffer is full, send only half
+ to_send = len(data)//2
+ return self.sock.send(data[:to_send])
+
+ def _fake_full_write_buffer(data):
+ if socket_transport._read_ready_cb is None and not isinstance(socket_transport._sock, SocketWrapper):
+ socket_transport._sock = SocketWrapper(socket_transport._sock)
+ return unittest.mock.DEFAULT
+
+ with unittest.mock.patch.object(
+ socket_transport, "write",
+ wraps=socket_transport.write,
+ side_effect=_fake_full_write_buffer
+ ):
+ await future
+
+ writer.close()
+ await self.wait_closed(writer)
+
+ def run(meth):
+ def wrapper(sock):
+ try:
+ meth(sock)
+ except Exception as ex:
+ self.loop.call_soon_threadsafe(future.set_exception, ex)
+ else:
+ self.loop.call_soon_threadsafe(future.set_result, None)
+ return wrapper
+
+ with self.tcp_server(run(server)) as srv:
+ self.loop.run_until_complete(client(srv.addr))
+
+ with self.tcp_server(run(eof_server)) as srv:
+ self.loop.run_until_complete(client(srv.addr))
+
def test_connect_timeout_warning(self):
s = socket.socket(socket.AF_INET)
s.bind(('127.0.0.1', 0))