MonkeyPatch.write = _fixed_write
+async def backport_start_tls(
+ transport: asyncio.BaseTransport,
+ protocol: asyncio.BaseProtocol,
+ sslcontext: ssl.SSLContext = None,
+ *,
+ server_side: bool = False,
+ server_hostname: str = None,
+ ssl_handshake_timeout: float = None,
+) -> asyncio.Transport: # pragma: nocover (Since it's not used on all Python versions.)
+ """
+ Python 3.6 asyncio doesn't have a start_tls() method on the loop
+ so we use this function in place of the loop's start_tls() method.
+
+ Adapted from this comment:
+
+ https://github.com/urllib3/urllib3/issues/1323#issuecomment-362494839
+ """
+ import asyncio.sslproto
+
+ loop = asyncio.get_event_loop()
+ waiter = loop.create_future()
+ ssl_protocol = asyncio.sslproto.SSLProtocol(
+ loop,
+ protocol,
+ sslcontext,
+ waiter,
+ server_side=False,
+ server_hostname=server_hostname,
+ call_connection_made=False,
+ )
+
+ transport.set_protocol(ssl_protocol)
+ loop.call_soon(ssl_protocol.connection_made, transport)
+ loop.call_soon(transport.resume_reading) # type: ignore
+
+ await waiter
+ return ssl_protocol._app_transport
+
+
class SocketStream(BaseSocketStream):
def __init__(
self, stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter,
protocol = asyncio.StreamReaderProtocol(stream_reader)
transport = self.stream_writer.transport
- if hasattr(loop, "start_tls"):
- loop_start_tls = loop.start_tls # type: ignore
- else:
-
- async def loop_start_tls(
- transport: asyncio.BaseTransport,
- protocol: asyncio.BaseProtocol,
- sslcontext: ssl.SSLContext = None,
- *,
- server_side: bool = False,
- server_hostname: str = None,
- ssl_handshake_timeout: float = None,
- ) -> asyncio.Transport:
- """Python 3.6 asyncio doesn't have a start_tls() method on the loop
- so we use this function in place of the loop's start_tls() method.
- Adapted from this comment:
- https://github.com/urllib3/urllib3/issues/1323#issuecomment-362494839
- """
- import asyncio.sslproto
-
- waiter = loop.create_future()
- ssl_protocol = asyncio.sslproto.SSLProtocol(
- loop,
- protocol,
- sslcontext,
- waiter,
- server_side=False,
- server_hostname=server_hostname,
- call_connection_made=False,
- )
-
- transport.set_protocol(ssl_protocol)
- loop.call_soon(ssl_protocol.connection_made, transport)
- loop.call_soon(transport.resume_reading) # type: ignore
-
- await waiter
- return ssl_protocol._app_transport
+ loop_start_tls = getattr(loop, "start_tls", backport_start_tls)
transport = await asyncio.wait_for(
loop_start_tls(