From: Florimond Manca Date: Thu, 10 Oct 2019 12:01:23 +0000 (+0200) Subject: Add start_tls to Trio backend (#467) X-Git-Tag: 0.7.5~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=38a136833f7c7f3a17f362b1223ef7cc7e38253d;p=thirdparty%2Fhttpx.git Add start_tls to Trio backend (#467) --- diff --git a/httpx/concurrency/trio.py b/httpx/concurrency/trio.py index 2c36f186..3de3d140 100644 --- a/httpx/concurrency/trio.py +++ b/httpx/concurrency/trio.py @@ -171,6 +171,30 @@ class TrioBackend(ConcurrencyBackend): return TCPStream(stream=stream, timeout=timeout) + async def start_tls( + self, + stream: BaseTCPStream, + hostname: str, + ssl_context: ssl.SSLContext, + timeout: TimeoutConfig, + ) -> BaseTCPStream: + assert isinstance(stream, TCPStream) + + connect_timeout = _or_inf(timeout.connect_timeout) + ssl_stream = trio.SSLStream( + stream.stream, ssl_context=ssl_context, server_hostname=hostname + ) + + with trio.move_on_after(connect_timeout) as cancel_scope: + await ssl_stream.do_handshake() + + if cancel_scope.cancelled_caught: + raise ConnectTimeout() + + stream.stream = ssl_stream + + return stream + async def run_in_threadpool( self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any ) -> typing.Any: diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index ab93b302..3f9e8262 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -1,21 +1,39 @@ import sys import pytest +import trio from httpx import AsyncioBackend, HTTPVersionConfig, SSLConfig, TimeoutConfig - - -@pytest.mark.xfail( - sys.version_info < (3, 7), - reason="Requires Python 3.7+ for AbstractEventLoop.start_tls()", +from httpx.concurrency.trio import TrioBackend + + +@pytest.mark.parametrize( + "backend, get_cipher", + [ + pytest.param( + AsyncioBackend(), + lambda stream: stream.stream_writer.get_extra_info("cipher", default=None), + marks=pytest.mark.asyncio, + ), + pytest.param( + TrioBackend(), + lambda stream: ( + stream.stream.cipher() + if isinstance(stream.stream, trio.SSLStream) + else None + ), + marks=pytest.mark.trio, + ), + ], ) -@pytest.mark.asyncio -async def test_start_tls_on_socket_stream(https_server): +async def test_start_tls_on_socket_stream(https_server, backend, get_cipher): """ - See that the backend can make a connection without TLS then + See that the concurrency backend can make a connection without TLS then start TLS on an existing connection. """ - backend = AsyncioBackend() + if isinstance(backend, AsyncioBackend) and sys.version_info < (3, 7): + pytest.xfail(reason="Requires Python 3.7+ for AbstractEventLoop.start_tls()") + ctx = SSLConfig().load_ssl_context_no_verify(HTTPVersionConfig()) timeout = TimeoutConfig(5) @@ -25,11 +43,11 @@ async def test_start_tls_on_socket_stream(https_server): try: assert stream.is_connection_dropped() is False - assert stream.stream_writer.get_extra_info("cipher", default=None) is None + assert get_cipher(stream) is None stream = await backend.start_tls(stream, https_server.url.host, ctx, timeout) assert stream.is_connection_dropped() is False - assert stream.stream_writer.get_extra_info("cipher", default=None) is not None + assert get_cipher(stream) is not None await stream.write(b"GET / HTTP/1.1\r\n\r\n") assert (await stream.read(8192, timeout)).startswith(b"HTTP/1.1 200 OK\r\n")