From: Jamie Hewland Date: Sun, 20 Oct 2019 10:59:16 +0000 (+0200) Subject: Make start_tls a method on streams & return a new stream (#484) X-Git-Tag: 0.7.6~6 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=644e8fc5b6c5678fc1c1916293c1afab56d60bad;p=thirdparty%2Fhttpx.git Make start_tls a method on streams & return a new stream (#484) * Move start_tls to stream & return a new stream * asyncio: Keep a reference to the inner stream when upgrading to TLS --- diff --git a/httpx/concurrency/asyncio.py b/httpx/concurrency/asyncio.py index dfee4254..70833454 100644 --- a/httpx/concurrency/asyncio.py +++ b/httpx/concurrency/asyncio.py @@ -51,6 +51,44 @@ class TCPStream(BaseTCPStream): self.stream_writer = stream_writer self.timeout = timeout + self._inner: typing.Optional[TCPStream] = None + + async def start_tls( + self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig + ) -> BaseTCPStream: + loop = asyncio.get_event_loop() + if not hasattr(loop, "start_tls"): # pragma: no cover + raise NotImplementedError( + "asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+" + ) + + stream_reader = asyncio.StreamReader() + protocol = asyncio.StreamReaderProtocol(stream_reader) + transport = self.stream_writer.transport + + loop_start_tls = loop.start_tls # type: ignore + transport = await asyncio.wait_for( + loop_start_tls( + transport=transport, + protocol=protocol, + sslcontext=ssl_context, + server_hostname=hostname, + ), + timeout=timeout.connect_timeout, + ) + + stream_reader.set_transport(transport) + stream_writer = asyncio.StreamWriter( + transport=transport, protocol=protocol, reader=stream_reader, loop=loop + ) + + ssl_stream = TCPStream(stream_reader, stream_writer, self.timeout) + # When we return a new TCPStream with new StreamReader/StreamWriter instances, + # we need to keep references to the old StreamReader/StreamWriter so that they + # are not garbage collected and closed while we're still using them. + ssl_stream._inner = self + return ssl_stream + def get_http_version(self) -> str: ssl_object = self.stream_writer.get_extra_info("ssl_object") @@ -201,44 +239,6 @@ class AsyncioBackend(ConcurrencyBackend): stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout ) - async def start_tls( - self, - stream: BaseTCPStream, - hostname: str, - ssl_context: ssl.SSLContext, - timeout: TimeoutConfig, - ) -> BaseTCPStream: - - loop = self.loop - if not hasattr(loop, "start_tls"): # pragma: no cover - raise NotImplementedError( - "asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+" - ) - - assert isinstance(stream, TCPStream) - - stream_reader = asyncio.StreamReader() - protocol = asyncio.StreamReaderProtocol(stream_reader) - transport = stream.stream_writer.transport - - loop_start_tls = loop.start_tls # type: ignore - transport = await asyncio.wait_for( - loop_start_tls( - transport=transport, - protocol=protocol, - sslcontext=ssl_context, - server_hostname=hostname, - ), - timeout=timeout.connect_timeout, - ) - - stream_reader.set_transport(transport) - stream.stream_reader = stream_reader - stream.stream_writer = asyncio.StreamWriter( - transport=transport, protocol=protocol, reader=stream_reader, loop=loop - ) - return stream - async def run_in_threadpool( self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any ) -> typing.Any: diff --git a/httpx/concurrency/base.py b/httpx/concurrency/base.py index fc784b30..a23d89bd 100644 --- a/httpx/concurrency/base.py +++ b/httpx/concurrency/base.py @@ -47,6 +47,11 @@ class BaseTCPStream: def get_http_version(self) -> str: raise NotImplementedError() # pragma: no cover + async def start_tls( + self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig + ) -> "BaseTCPStream": + raise NotImplementedError() # pragma: no cover + async def read( self, n: int, timeout: TimeoutConfig = None, flag: typing.Any = None ) -> bytes: @@ -119,15 +124,6 @@ class ConcurrencyBackend: ) -> BaseTCPStream: raise NotImplementedError() # pragma: no cover - async def start_tls( - self, - stream: BaseTCPStream, - hostname: str, - ssl_context: ssl.SSLContext, - timeout: TimeoutConfig, - ) -> BaseTCPStream: - raise NotImplementedError() # pragma: no cover - def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: raise NotImplementedError() # pragma: no cover diff --git a/httpx/concurrency/trio.py b/httpx/concurrency/trio.py index 3de3d140..da8e38a0 100644 --- a/httpx/concurrency/trio.py +++ b/httpx/concurrency/trio.py @@ -34,6 +34,26 @@ class TCPStream(BaseTCPStream): self.write_buffer = b"" self.write_lock = trio.Lock() + async def start_tls( + self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig + ) -> BaseTCPStream: + # Check that the write buffer is empty. We should never start a TLS stream + # while there is still pending data to write. + assert self.write_buffer == b"" + + connect_timeout = _or_inf(timeout.connect_timeout) + ssl_stream = trio.SSLStream( + self.stream, ssl_context=ssl_context, server_hostname=hostname + ) + + with trio.move_on_after(connect_timeout) as cancel_scope: + await ssl_stream.do_handshake() + + if cancel_scope.cancelled_caught: + raise ConnectTimeout() + + return TCPStream(ssl_stream, self.timeout) + def get_http_version(self) -> str: if not isinstance(self.stream, trio.SSLStream): return "HTTP/1.1" @@ -171,30 +191,6 @@ class TrioBackend(ConcurrencyBackend): return TCPStream(stream=stream, timeout=timeout) - async def start_tls( - self, - stream: BaseTCPStream, - hostname: str, - ssl_context: ssl.SSLContext, - timeout: TimeoutConfig, - ) -> BaseTCPStream: - assert isinstance(stream, TCPStream) - - connect_timeout = _or_inf(timeout.connect_timeout) - ssl_stream = trio.SSLStream( - stream.stream, ssl_context=ssl_context, server_hostname=hostname - ) - - with trio.move_on_after(connect_timeout) as cancel_scope: - await ssl_stream.do_handshake() - - if cancel_scope.cancelled_caught: - raise ConnectTimeout() - - stream.stream = ssl_stream - - return stream - async def run_in_threadpool( self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any ) -> typing.Any: diff --git a/httpx/dispatch/proxy_http.py b/httpx/dispatch/proxy_http.py index be2e289f..e54a5b30 100644 --- a/httpx/dispatch/proxy_http.py +++ b/httpx/dispatch/proxy_http.py @@ -192,11 +192,8 @@ class HTTPProxy(ConnectionPool): f"proxy_url={self.proxy_url!r} " f"origin={origin!r}" ) - stream = await self.backend.start_tls( - stream=stream, - hostname=origin.host, - ssl_context=ssl_context, - timeout=timeout, + stream = await stream.start_tls( + hostname=origin.host, ssl_context=ssl_context, timeout=timeout ) http_version = stream.get_http_version() logger.debug( diff --git a/tests/dispatch/utils.py b/tests/dispatch/utils.py index 0798b31f..3b0d5340 100644 --- a/tests/dispatch/utils.py +++ b/tests/dispatch/utils.py @@ -184,16 +184,6 @@ class MockRawSocketBackend: ) return self.stream - async def start_tls( - self, - stream: BaseTCPStream, - hostname: str, - ssl_context: ssl.SSLContext, - timeout: TimeoutConfig, - ) -> BaseTCPStream: - self.received_data.append(b"--- START_TLS(%s) ---" % hostname.encode()) - return self.stream - # Defer all other attributes and methods to the underlying backend. def __getattr__(self, name: str) -> typing.Any: return getattr(self.backend, name) @@ -203,6 +193,12 @@ class MockRawSocketStream(BaseTCPStream): def __init__(self, backend: MockRawSocketBackend): self.backend = backend + async def start_tls( + self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig + ) -> BaseTCPStream: + self.backend.received_data.append(b"--- START_TLS(%s) ---" % hostname.encode()) + return MockRawSocketStream(self.backend) + def get_http_version(self) -> str: return "HTTP/1.1" diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index 3f9e8262..27bbeaf2 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -45,7 +45,7 @@ async def test_start_tls_on_socket_stream(https_server, backend, get_cipher): assert stream.is_connection_dropped() is False assert get_cipher(stream) is None - stream = await backend.start_tls(stream, https_server.url.host, ctx, timeout) + stream = await stream.start_tls(https_server.url.host, ctx, timeout) assert stream.is_connection_dropped() is False assert get_cipher(stream) is not None