From: Tom Christie Date: Mon, 9 Dec 2019 18:30:00 +0000 (+0000) Subject: Tweak backport_start_tls implementation, and add 'nocover'. (#622) X-Git-Tag: 0.9.4~7 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=66e0f4b6f8df5c744681cb8306f416f48b0a0e35;p=thirdparty%2Fhttpx.git Tweak backport_start_tls implementation, and add 'nocover'. (#622) --- diff --git a/httpx/concurrency/asyncio.py b/httpx/concurrency/asyncio.py index 39bccdc9..e06dc209 100644 --- a/httpx/concurrency/asyncio.py +++ b/httpx/concurrency/asyncio.py @@ -32,6 +32,45 @@ def ssl_monkey_patch() -> None: MonkeyPatch.write = _fixed_write +async def backport_start_tls( + transport: asyncio.BaseTransport, + protocol: asyncio.BaseProtocol, + sslcontext: ssl.SSLContext = None, + *, + server_side: bool = False, + server_hostname: str = None, + ssl_handshake_timeout: float = None, +) -> asyncio.Transport: # pragma: nocover (Since it's not used on all Python versions.) + """ + Python 3.6 asyncio doesn't have a start_tls() method on the loop + so we use this function in place of the loop's start_tls() method. + + Adapted from this comment: + + https://github.com/urllib3/urllib3/issues/1323#issuecomment-362494839 + """ + import asyncio.sslproto + + loop = asyncio.get_event_loop() + waiter = loop.create_future() + ssl_protocol = asyncio.sslproto.SSLProtocol( + loop, + protocol, + sslcontext, + waiter, + server_side=False, + server_hostname=server_hostname, + call_connection_made=False, + ) + + transport.set_protocol(ssl_protocol) + loop.call_soon(ssl_protocol.connection_made, transport) + loop.call_soon(transport.resume_reading) # type: ignore + + await waiter + return ssl_protocol._app_transport + + class SocketStream(BaseSocketStream): def __init__( self, stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter, @@ -51,43 +90,7 @@ class SocketStream(BaseSocketStream): protocol = asyncio.StreamReaderProtocol(stream_reader) transport = self.stream_writer.transport - if hasattr(loop, "start_tls"): - loop_start_tls = loop.start_tls # type: ignore - else: - - async def loop_start_tls( - transport: asyncio.BaseTransport, - protocol: asyncio.BaseProtocol, - sslcontext: ssl.SSLContext = None, - *, - server_side: bool = False, - server_hostname: str = None, - ssl_handshake_timeout: float = None, - ) -> asyncio.Transport: - """Python 3.6 asyncio doesn't have a start_tls() method on the loop - so we use this function in place of the loop's start_tls() method. - Adapted from this comment: - https://github.com/urllib3/urllib3/issues/1323#issuecomment-362494839 - """ - import asyncio.sslproto - - waiter = loop.create_future() - ssl_protocol = asyncio.sslproto.SSLProtocol( - loop, - protocol, - sslcontext, - waiter, - server_side=False, - server_hostname=server_hostname, - call_connection_made=False, - ) - - transport.set_protocol(ssl_protocol) - loop.call_soon(ssl_protocol.connection_made, transport) - loop.call_soon(transport.resume_reading) # type: ignore - - await waiter - return ssl_protocol._app_transport + loop_start_tls = getattr(loop, "start_tls", backport_start_tls) transport = await asyncio.wait_for( loop_start_tls(