# have a chance to get called before "ssl_protocol.connection_made()".
transport.pause_reading()
+ # gh-142352: move buffered StreamReader data to SSLProtocol
+ if server_side:
+ from .streams import StreamReaderProtocol
+ if isinstance(protocol, StreamReaderProtocol):
+ stream_reader = getattr(protocol, '_stream_reader', None)
+ if stream_reader is not None:
+ buffer = stream_reader._buffer
+ if buffer:
+ ssl_protocol._incoming.write(buffer)
+ buffer.clear()
+
transport.set_protocol(ssl_protocol)
conmade_cb = self.call_soon(ssl_protocol.connection_made, transport)
resume_cb = self.call_soon(transport.resume_reading)
self.assertEqual(msg1, b"hello world 1!\n")
self.assertEqual(msg2, b"hello world 2!\n")
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ def test_start_tls_buffered_data(self):
+ # gh-142352: test start_tls() with buffered data
+
+ async def server_handler(client_reader, client_writer):
+ # Wait for TLS ClientHello to be buffered before start_tls().
+ await client_reader._wait_for_data('test_start_tls_buffered_data'),
+ self.assertTrue(client_reader._buffer)
+ await client_writer.start_tls(test_utils.simple_server_sslcontext())
+
+ line = await client_reader.readline()
+ self.assertEqual(line, b"ping\n")
+ client_writer.write(b"pong\n")
+ await client_writer.drain()
+ client_writer.close()
+ await client_writer.wait_closed()
+
+ async def client(addr):
+ reader, writer = await asyncio.open_connection(*addr)
+ await writer.start_tls(test_utils.simple_client_sslcontext())
+
+ writer.write(b"ping\n")
+ await writer.drain()
+ line = await reader.readline()
+ self.assertEqual(line, b"pong\n")
+ writer.close()
+ await writer.wait_closed()
+
+ async def run_test():
+ server = await asyncio.start_server(
+ server_handler, socket_helper.HOSTv4, 0)
+ server_addr = server.sockets[0].getsockname()
+
+ await client(server_addr)
+ server.close()
+ await server.wait_closed()
+
+ messages = []
+ self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+ self.loop.run_until_complete(run_test())
+ self.assertEqual(messages, [])
+
def test_streamreader_constructor_without_loop(self):
with self.assertRaisesRegex(RuntimeError, 'no current event loop'):
asyncio.StreamReader()