be resumed. When there is nothing to wait for, the :meth:`drain`
returns immediately.
+ .. coroutinemethod:: start_tls(sslcontext, \*, server_hostname=None, \
+ ssl_handshake_timeout=None)
+
+ Upgrade an existing stream-based connection to TLS.
+
+ Parameters:
+
+ * *sslcontext*: a configured instance of :class:`~ssl.SSLContext`.
+
+ * *server_hostname*: sets or overrides the host name that the target
+ server's certificate will be matched against.
+
+ * *ssl_handshake_timeout* is the time in seconds to wait for the TLS
+ handshake to complete before aborting the connection. ``60.0`` seconds
+ if ``None`` (default).
+
+ .. versionadded:: 3.8
+
.. method:: is_closing()
Return ``True`` if the stream is closed or in the process of
return None
return self._stream_reader_wr()
+ def _replace_writer(self, writer):
+ loop = self._loop
+ transport = writer.transport
+ self._stream_writer = writer
+ self._transport = transport
+ self._over_ssl = transport.get_extra_info('sslcontext') is not None
+
def connection_made(self, transport):
if self._reject_connection:
context = {
await sleep(0)
await self._protocol._drain_helper()
+ async def start_tls(self, sslcontext, *,
+ server_hostname=None,
+ ssl_handshake_timeout=None):
+ """Upgrade an existing stream-based connection to TLS."""
+ server_side = self._protocol._client_connected_cb is not None
+ protocol = self._protocol
+ await self.drain()
+ new_transport = await self._loop.start_tls( # type: ignore
+ self._transport, protocol, sslcontext,
+ server_side=server_side, server_hostname=server_hostname,
+ ssl_handshake_timeout=ssl_handshake_timeout)
+ self._transport = new_transport
+ protocol._replace_writer(self)
+
class StreamReader:
self.assertEqual(messages, [])
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ def test_start_tls(self):
+
+ class MyServer:
+
+ def __init__(self, loop):
+ self.server = None
+ self.loop = loop
+
+ async def handle_client(self, client_reader, client_writer):
+ data1 = await client_reader.readline()
+ client_writer.write(data1)
+ await client_writer.drain()
+ assert client_writer.get_extra_info('sslcontext') is None
+ await client_writer.start_tls(
+ test_utils.simple_server_sslcontext())
+ assert client_writer.get_extra_info('sslcontext') is not None
+ data2 = await client_reader.readline()
+ client_writer.write(data2)
+ await client_writer.drain()
+ client_writer.close()
+ await client_writer.wait_closed()
+
+ def start(self):
+ sock = socket.create_server(('127.0.0.1', 0))
+ self.server = self.loop.run_until_complete(
+ asyncio.start_server(self.handle_client,
+ sock=sock))
+ return sock.getsockname()
+
+ def stop(self):
+ if self.server is not None:
+ self.server.close()
+ self.loop.run_until_complete(self.server.wait_closed())
+ self.server = None
+
+ async def client(addr):
+ reader, writer = await asyncio.open_connection(*addr)
+ writer.write(b"hello world 1!\n")
+ await writer.drain()
+ msgback1 = await reader.readline()
+ assert writer.get_extra_info('sslcontext') is None
+ await writer.start_tls(test_utils.simple_client_sslcontext())
+ assert writer.get_extra_info('sslcontext') is not None
+ writer.write(b"hello world 2!\n")
+ await writer.drain()
+ msgback2 = await reader.readline()
+ writer.close()
+ await writer.wait_closed()
+ return msgback1, msgback2
+
+ messages = []
+ self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+
+ server = MyServer(self.loop)
+ addr = server.start()
+ msg1, msg2 = self.loop.run_until_complete(client(addr))
+ server.stop()
+
+ self.assertEqual(messages, [])
+ self.assertEqual(msg1, b"hello world 1!\n")
+ self.assertEqual(msg2, b"hello world 2!\n")
+
@unittest.skipIf(sys.platform == 'win32', "Don't have pipes")
def test_read_all_from_pipe_reader(self):
# See asyncio issue 168. This test is derived from the example