]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Add start_tls to Trio backend (#467)
authorFlorimond Manca <florimond.manca@gmail.com>
Thu, 10 Oct 2019 12:01:23 +0000 (14:01 +0200)
committerSeth Michael Larson <sethmichaellarson@gmail.com>
Thu, 10 Oct 2019 12:01:23 +0000 (07:01 -0500)
httpx/concurrency/trio.py
tests/test_concurrency.py

index 2c36f186f5bb05f54fcd8371c7dce578eb8649b1..3de3d1407812f2df1cd1a1abeb5bb3d66e75c0e8 100644 (file)
@@ -171,6 +171,30 @@ class TrioBackend(ConcurrencyBackend):
 
         return TCPStream(stream=stream, timeout=timeout)
 
+    async def start_tls(
+        self,
+        stream: BaseTCPStream,
+        hostname: str,
+        ssl_context: ssl.SSLContext,
+        timeout: TimeoutConfig,
+    ) -> BaseTCPStream:
+        assert isinstance(stream, TCPStream)
+
+        connect_timeout = _or_inf(timeout.connect_timeout)
+        ssl_stream = trio.SSLStream(
+            stream.stream, ssl_context=ssl_context, server_hostname=hostname
+        )
+
+        with trio.move_on_after(connect_timeout) as cancel_scope:
+            await ssl_stream.do_handshake()
+
+        if cancel_scope.cancelled_caught:
+            raise ConnectTimeout()
+
+        stream.stream = ssl_stream
+
+        return stream
+
     async def run_in_threadpool(
         self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
     ) -> typing.Any:
index ab93b302829d2293eb35845fd77276d69811de95..3f9e8262351b1a01d209043a45b9763a9dc7d3a9 100644 (file)
@@ -1,21 +1,39 @@
 import sys
 
 import pytest
+import trio
 
 from httpx import AsyncioBackend, HTTPVersionConfig, SSLConfig, TimeoutConfig
-
-
-@pytest.mark.xfail(
-    sys.version_info < (3, 7),
-    reason="Requires Python 3.7+ for AbstractEventLoop.start_tls()",
+from httpx.concurrency.trio import TrioBackend
+
+
+@pytest.mark.parametrize(
+    "backend, get_cipher",
+    [
+        pytest.param(
+            AsyncioBackend(),
+            lambda stream: stream.stream_writer.get_extra_info("cipher", default=None),
+            marks=pytest.mark.asyncio,
+        ),
+        pytest.param(
+            TrioBackend(),
+            lambda stream: (
+                stream.stream.cipher()
+                if isinstance(stream.stream, trio.SSLStream)
+                else None
+            ),
+            marks=pytest.mark.trio,
+        ),
+    ],
 )
-@pytest.mark.asyncio
-async def test_start_tls_on_socket_stream(https_server):
+async def test_start_tls_on_socket_stream(https_server, backend, get_cipher):
     """
-    See that the backend can make a connection without TLS then
+    See that the concurrency backend can make a connection without TLS then
     start TLS on an existing connection.
     """
-    backend = AsyncioBackend()
+    if isinstance(backend, AsyncioBackend) and sys.version_info < (3, 7):
+        pytest.xfail(reason="Requires Python 3.7+ for AbstractEventLoop.start_tls()")
+
     ctx = SSLConfig().load_ssl_context_no_verify(HTTPVersionConfig())
     timeout = TimeoutConfig(5)
 
@@ -25,11 +43,11 @@ async def test_start_tls_on_socket_stream(https_server):
 
     try:
         assert stream.is_connection_dropped() is False
-        assert stream.stream_writer.get_extra_info("cipher", default=None) is None
+        assert get_cipher(stream) is None
 
         stream = await backend.start_tls(stream, https_server.url.host, ctx, timeout)
         assert stream.is_connection_dropped() is False
-        assert stream.stream_writer.get_extra_info("cipher", default=None) is not None
+        assert get_cipher(stream) is not None
 
         await stream.write(b"GET / HTTP/1.1\r\n\r\n")
         assert (await stream.read(8192, timeout)).startswith(b"HTTP/1.1 200 OK\r\n")