]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Add support for proxy tunnels for Python 3.6 + asyncio. (#521)
authorSeth Michael Larson <sethmichaellarson@gmail.com>
Sun, 17 Nov 2019 11:50:54 +0000 (05:50 -0600)
committerFlorimond Manca <florimond.manca@gmail.com>
Sun, 17 Nov 2019 11:50:54 +0000 (12:50 +0100)
* Backport start_tls() support for Python 3.6

* Remove version check in start_tls test

httpx/concurrency/asyncio.py
tests/test_concurrency.py

index 010d8215a8b15a3df80f76eed24ead276b970a91..019876e43e883cdf659fbdf5350458d8c0710bd2 100644 (file)
@@ -58,16 +58,49 @@ class SocketStream(BaseSocketStream):
         self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
     ) -> "SocketStream":
         loop = asyncio.get_event_loop()
-        if not hasattr(loop, "start_tls"):  # pragma: no cover
-            raise NotImplementedError(
-                "asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
-            )
 
         stream_reader = asyncio.StreamReader()
         protocol = asyncio.StreamReaderProtocol(stream_reader)
         transport = self.stream_writer.transport
 
-        loop_start_tls = loop.start_tls  # type: ignore
+        if hasattr(loop, "start_tls"):
+            loop_start_tls = loop.start_tls  # type: ignore
+        else:
+
+            async def loop_start_tls(
+                transport: asyncio.BaseTransport,
+                protocol: asyncio.BaseProtocol,
+                sslcontext: ssl.SSLContext = None,
+                *,
+                server_side: bool = False,
+                server_hostname: str = None,
+                ssl_handshake_timeout: float = None,
+            ) -> asyncio.Transport:
+                """Python 3.6 asyncio doesn't have a start_tls() method on the loop
+                so we use this function in place of the loop's start_tls() method.
+                Adapted from this comment:
+                https://github.com/urllib3/urllib3/issues/1323#issuecomment-362494839
+                """
+                import asyncio.sslproto
+
+                waiter = loop.create_future()
+                ssl_protocol = asyncio.sslproto.SSLProtocol(
+                    loop,
+                    protocol,
+                    sslcontext,
+                    waiter,
+                    server_side=False,
+                    server_hostname=server_hostname,
+                    call_connection_made=False,
+                )
+
+                transport.set_protocol(ssl_protocol)
+                loop.call_soon(ssl_protocol.connection_made, transport)
+                loop.call_soon(transport.resume_reading)  # type: ignore
+
+                await waiter
+                return ssl_protocol._app_transport
+
         transport = await asyncio.wait_for(
             loop_start_tls(
                 transport=transport,
index 8bb933b6975b946849502991ff0fb5c97acf5bef..7477ea3b96ba8c208576c3f0c5ca5f281a4f6850 100644 (file)
@@ -1,5 +1,3 @@
-import sys
-
 import pytest
 import trio
 
@@ -31,9 +29,6 @@ async def test_start_tls_on_socket_stream(https_server, backend, get_cipher):
     See that the concurrency backend can make a connection without TLS then
     start TLS on an existing connection.
     """
-    if isinstance(backend, AsyncioBackend) and sys.version_info < (3, 7):
-        pytest.xfail(reason="Requires Python 3.7+ for AbstractEventLoop.start_tls()")
-
     ctx = SSLConfig().load_ssl_context_no_verify(HTTPVersionConfig())
     timeout = TimeoutConfig(5)