]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-79156: Add start_tls() method to streams API (#91453)
authorOleg Iarygin <oleg@arhadthedev.net>
Fri, 15 Apr 2022 12:23:14 +0000 (15:23 +0300)
committerGitHub <noreply@github.com>
Fri, 15 Apr 2022 12:23:14 +0000 (14:23 +0200)
The existing event loop `start_tls()` method is not sufficient for
connections using the streams API. The existing StreamReader works
because the new transport passes received data to the original protocol.
The StreamWriter must then write data to the new transport, and the
StreamReaderProtocol must be updated to close the new transport
correctly.

The new StreamWriter `start_tls()` updates itself and the reader
protocol to the new SSL transport.

Co-authored-by: Ian Good <icgood@gmail.com>
Doc/library/asyncio-stream.rst
Doc/whatsnew/3.11.rst
Lib/asyncio/streams.py
Lib/test/test_asyncio/test_streams.py
Misc/NEWS.d/next/Library/2019-05-06-23-36-34.bpo-34975.eb49jr.rst [new file with mode: 0644]

index ba534f9903fb49170bc59afdda6efe8ebdad3942..72355d356f20520201ebdb470afa2c9c3879e84d 100644 (file)
@@ -295,6 +295,24 @@ StreamWriter
       be resumed.  When there is nothing to wait for, the :meth:`drain`
       returns immediately.
 
+   .. coroutinemethod:: start_tls(sslcontext, \*, server_hostname=None, \
+                          ssl_handshake_timeout=None)
+
+      Upgrade an existing stream-based connection to TLS.
+
+      Parameters:
+
+      * *sslcontext*: a configured instance of :class:`~ssl.SSLContext`.
+
+      * *server_hostname*: sets or overrides the host name that the target
+        server's certificate will be matched against.
+
+      * *ssl_handshake_timeout* is the time in seconds to wait for the TLS
+        handshake to complete before aborting the connection.  ``60.0`` seconds
+        if ``None`` (default).
+
+      .. versionadded:: 3.8
+
    .. method:: is_closing()
 
       Return ``True`` if the stream is closed or in the process of
index dba554cc834eca1f617ee69a692b06f2d4a39dfe..9f7f6f52a8e9e02c84e5079dcd6751771d472946 100644 (file)
@@ -246,6 +246,10 @@ asyncio
   :meth:`~asyncio.AbstractEventLoop.sock_recvfrom_into`.
   (Contributed by Alex Grönholm in :issue:`46805`.)
 
+* Add :meth:`~asyncio.streams.StreamWriter.start_tls` method for upgrading
+  existing stream-based connections to TLS. (Contributed by Ian Good in
+  :issue:`34975`.)
+
 fractions
 ---------
 
index 080d8a62cde1e219fa0866f8c3e7aa1751d0bc9b..a568c4e4b295f0f11919604f8182e192efe3bf53 100644 (file)
@@ -217,6 +217,13 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
             return None
         return self._stream_reader_wr()
 
+    def _replace_writer(self, writer):
+        loop = self._loop
+        transport = writer.transport
+        self._stream_writer = writer
+        self._transport = transport
+        self._over_ssl = transport.get_extra_info('sslcontext') is not None
+
     def connection_made(self, transport):
         if self._reject_connection:
             context = {
@@ -371,6 +378,20 @@ class StreamWriter:
             await sleep(0)
         await self._protocol._drain_helper()
 
+    async def start_tls(self, sslcontext, *,
+                        server_hostname=None,
+                        ssl_handshake_timeout=None):
+        """Upgrade an existing stream-based connection to TLS."""
+        server_side = self._protocol._client_connected_cb is not None
+        protocol = self._protocol
+        await self.drain()
+        new_transport = await self._loop.start_tls(  # type: ignore
+            self._transport, protocol, sslcontext,
+            server_side=server_side, server_hostname=server_hostname,
+            ssl_handshake_timeout=ssl_handshake_timeout)
+        self._transport = new_transport
+        protocol._replace_writer(self)
+
 
 class StreamReader:
 
index 227b2279e172c8ed0838360271b38b15684237da..a7d17894e1c5265f091a083eefcbf5782ec27550 100644 (file)
@@ -706,6 +706,69 @@ class StreamTests(test_utils.TestCase):
 
         self.assertEqual(messages, [])
 
+    @unittest.skipIf(ssl is None, 'No ssl module')
+    def test_start_tls(self):
+
+        class MyServer:
+
+            def __init__(self, loop):
+                self.server = None
+                self.loop = loop
+
+            async def handle_client(self, client_reader, client_writer):
+                data1 = await client_reader.readline()
+                client_writer.write(data1)
+                await client_writer.drain()
+                assert client_writer.get_extra_info('sslcontext') is None
+                await client_writer.start_tls(
+                    test_utils.simple_server_sslcontext())
+                assert client_writer.get_extra_info('sslcontext') is not None
+                data2 = await client_reader.readline()
+                client_writer.write(data2)
+                await client_writer.drain()
+                client_writer.close()
+                await client_writer.wait_closed()
+
+            def start(self):
+                sock = socket.create_server(('127.0.0.1', 0))
+                self.server = self.loop.run_until_complete(
+                    asyncio.start_server(self.handle_client,
+                                         sock=sock))
+                return sock.getsockname()
+
+            def stop(self):
+                if self.server is not None:
+                    self.server.close()
+                    self.loop.run_until_complete(self.server.wait_closed())
+                    self.server = None
+
+        async def client(addr):
+            reader, writer = await asyncio.open_connection(*addr)
+            writer.write(b"hello world 1!\n")
+            await writer.drain()
+            msgback1 = await reader.readline()
+            assert writer.get_extra_info('sslcontext') is None
+            await writer.start_tls(test_utils.simple_client_sslcontext())
+            assert writer.get_extra_info('sslcontext') is not None
+            writer.write(b"hello world 2!\n")
+            await writer.drain()
+            msgback2 = await reader.readline()
+            writer.close()
+            await writer.wait_closed()
+            return msgback1, msgback2
+
+        messages = []
+        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+
+        server = MyServer(self.loop)
+        addr = server.start()
+        msg1, msg2 = self.loop.run_until_complete(client(addr))
+        server.stop()
+
+        self.assertEqual(messages, [])
+        self.assertEqual(msg1, b"hello world 1!\n")
+        self.assertEqual(msg2, b"hello world 2!\n")
+
     @unittest.skipIf(sys.platform == 'win32', "Don't have pipes")
     def test_read_all_from_pipe_reader(self):
         # See asyncio issue 168.  This test is derived from the example
diff --git a/Misc/NEWS.d/next/Library/2019-05-06-23-36-34.bpo-34975.eb49jr.rst b/Misc/NEWS.d/next/Library/2019-05-06-23-36-34.bpo-34975.eb49jr.rst
new file mode 100644 (file)
index 0000000..1576269
--- /dev/null
@@ -0,0 +1,3 @@
+Adds a ``start_tls()`` method to :class:`~asyncio.streams.StreamWriter`,
+which upgrades the connection with TLS using the given
+:class:`~ssl.SSLContext`.