return TCPStream(stream=stream, timeout=timeout)
+ async def start_tls(
+ self,
+ stream: BaseTCPStream,
+ hostname: str,
+ ssl_context: ssl.SSLContext,
+ timeout: TimeoutConfig,
+ ) -> BaseTCPStream:
+ assert isinstance(stream, TCPStream)
+
+ connect_timeout = _or_inf(timeout.connect_timeout)
+ ssl_stream = trio.SSLStream(
+ stream.stream, ssl_context=ssl_context, server_hostname=hostname
+ )
+
+ with trio.move_on_after(connect_timeout) as cancel_scope:
+ await ssl_stream.do_handshake()
+
+ if cancel_scope.cancelled_caught:
+ raise ConnectTimeout()
+
+ stream.stream = ssl_stream
+
+ return stream
+
async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
import sys
import pytest
+import trio
from httpx import AsyncioBackend, HTTPVersionConfig, SSLConfig, TimeoutConfig
-
-
-@pytest.mark.xfail(
- sys.version_info < (3, 7),
- reason="Requires Python 3.7+ for AbstractEventLoop.start_tls()",
+from httpx.concurrency.trio import TrioBackend
+
+
+@pytest.mark.parametrize(
+ "backend, get_cipher",
+ [
+ pytest.param(
+ AsyncioBackend(),
+ lambda stream: stream.stream_writer.get_extra_info("cipher", default=None),
+ marks=pytest.mark.asyncio,
+ ),
+ pytest.param(
+ TrioBackend(),
+ lambda stream: (
+ stream.stream.cipher()
+ if isinstance(stream.stream, trio.SSLStream)
+ else None
+ ),
+ marks=pytest.mark.trio,
+ ),
+ ],
)
-@pytest.mark.asyncio
-async def test_start_tls_on_socket_stream(https_server):
+async def test_start_tls_on_socket_stream(https_server, backend, get_cipher):
"""
- See that the backend can make a connection without TLS then
+ See that the concurrency backend can make a connection without TLS then
start TLS on an existing connection.
"""
- backend = AsyncioBackend()
+ if isinstance(backend, AsyncioBackend) and sys.version_info < (3, 7):
+ pytest.xfail(reason="Requires Python 3.7+ for AbstractEventLoop.start_tls()")
+
ctx = SSLConfig().load_ssl_context_no_verify(HTTPVersionConfig())
timeout = TimeoutConfig(5)
try:
assert stream.is_connection_dropped() is False
- assert stream.stream_writer.get_extra_info("cipher", default=None) is None
+ assert get_cipher(stream) is None
stream = await backend.start_tls(stream, https_server.url.host, ctx, timeout)
assert stream.is_connection_dropped() is False
- assert stream.stream_writer.get_extra_info("cipher", default=None) is not None
+ assert get_cipher(stream) is not None
await stream.write(b"GET / HTTP/1.1\r\n\r\n")
assert (await stream.read(8192, timeout)).startswith(b"HTTP/1.1 200 OK\r\n")