]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Tweak backport_start_tls implementation, and add 'nocover'. (#622)
authorTom Christie <tom@tomchristie.com>
Mon, 9 Dec 2019 18:30:00 +0000 (18:30 +0000)
committerGitHub <noreply@github.com>
Mon, 9 Dec 2019 18:30:00 +0000 (18:30 +0000)
httpx/concurrency/asyncio.py

index 39bccdc933254448bcde554ecd2f7a2f984e078b..e06dc2091fd2eaadac9433043a2b56d6ed7e1d29 100644 (file)
@@ -32,6 +32,45 @@ def ssl_monkey_patch() -> None:
     MonkeyPatch.write = _fixed_write
 
 
+async def backport_start_tls(
+    transport: asyncio.BaseTransport,
+    protocol: asyncio.BaseProtocol,
+    sslcontext: ssl.SSLContext = None,
+    *,
+    server_side: bool = False,
+    server_hostname: str = None,
+    ssl_handshake_timeout: float = None,
+) -> asyncio.Transport:  # pragma: nocover (Since it's not used on all Python versions.)
+    """
+    Python 3.6 asyncio doesn't have a start_tls() method on the loop
+    so we use this function in place of the loop's start_tls() method.
+
+    Adapted from this comment:
+
+    https://github.com/urllib3/urllib3/issues/1323#issuecomment-362494839
+    """
+    import asyncio.sslproto
+
+    loop = asyncio.get_event_loop()
+    waiter = loop.create_future()
+    ssl_protocol = asyncio.sslproto.SSLProtocol(
+        loop,
+        protocol,
+        sslcontext,
+        waiter,
+        server_side=False,
+        server_hostname=server_hostname,
+        call_connection_made=False,
+    )
+
+    transport.set_protocol(ssl_protocol)
+    loop.call_soon(ssl_protocol.connection_made, transport)
+    loop.call_soon(transport.resume_reading)  # type: ignore
+
+    await waiter
+    return ssl_protocol._app_transport
+
+
 class SocketStream(BaseSocketStream):
     def __init__(
         self, stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter,
@@ -51,43 +90,7 @@ class SocketStream(BaseSocketStream):
         protocol = asyncio.StreamReaderProtocol(stream_reader)
         transport = self.stream_writer.transport
 
-        if hasattr(loop, "start_tls"):
-            loop_start_tls = loop.start_tls  # type: ignore
-        else:
-
-            async def loop_start_tls(
-                transport: asyncio.BaseTransport,
-                protocol: asyncio.BaseProtocol,
-                sslcontext: ssl.SSLContext = None,
-                *,
-                server_side: bool = False,
-                server_hostname: str = None,
-                ssl_handshake_timeout: float = None,
-            ) -> asyncio.Transport:
-                """Python 3.6 asyncio doesn't have a start_tls() method on the loop
-                so we use this function in place of the loop's start_tls() method.
-                Adapted from this comment:
-                https://github.com/urllib3/urllib3/issues/1323#issuecomment-362494839
-                """
-                import asyncio.sslproto
-
-                waiter = loop.create_future()
-                ssl_protocol = asyncio.sslproto.SSLProtocol(
-                    loop,
-                    protocol,
-                    sslcontext,
-                    waiter,
-                    server_side=False,
-                    server_hostname=server_hostname,
-                    call_connection_made=False,
-                )
-
-                transport.set_protocol(ssl_protocol)
-                loop.call_soon(ssl_protocol.connection_made, transport)
-                loop.call_soon(transport.resume_reading)  # type: ignore
-
-                await waiter
-                return ssl_protocol._app_transport
+        loop_start_tls = getattr(loop, "start_tls", backport_start_tls)
 
         transport = await asyncio.wait_for(
             loop_start_tls(