self.stream_writer = stream_writer
self.timeout = timeout
+ self._inner: typing.Optional[TCPStream] = None
+
+ async def start_tls(
+ self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
+ ) -> BaseTCPStream:
+ loop = asyncio.get_event_loop()
+ if not hasattr(loop, "start_tls"): # pragma: no cover
+ raise NotImplementedError(
+ "asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
+ )
+
+ stream_reader = asyncio.StreamReader()
+ protocol = asyncio.StreamReaderProtocol(stream_reader)
+ transport = self.stream_writer.transport
+
+ loop_start_tls = loop.start_tls # type: ignore
+ transport = await asyncio.wait_for(
+ loop_start_tls(
+ transport=transport,
+ protocol=protocol,
+ sslcontext=ssl_context,
+ server_hostname=hostname,
+ ),
+ timeout=timeout.connect_timeout,
+ )
+
+ stream_reader.set_transport(transport)
+ stream_writer = asyncio.StreamWriter(
+ transport=transport, protocol=protocol, reader=stream_reader, loop=loop
+ )
+
+ ssl_stream = TCPStream(stream_reader, stream_writer, self.timeout)
+ # When we return a new TCPStream with new StreamReader/StreamWriter instances,
+ # we need to keep references to the old StreamReader/StreamWriter so that they
+ # are not garbage collected and closed while we're still using them.
+ ssl_stream._inner = self
+ return ssl_stream
+
def get_http_version(self) -> str:
ssl_object = self.stream_writer.get_extra_info("ssl_object")
stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
)
- async def start_tls(
- self,
- stream: BaseTCPStream,
- hostname: str,
- ssl_context: ssl.SSLContext,
- timeout: TimeoutConfig,
- ) -> BaseTCPStream:
-
- loop = self.loop
- if not hasattr(loop, "start_tls"): # pragma: no cover
- raise NotImplementedError(
- "asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
- )
-
- assert isinstance(stream, TCPStream)
-
- stream_reader = asyncio.StreamReader()
- protocol = asyncio.StreamReaderProtocol(stream_reader)
- transport = stream.stream_writer.transport
-
- loop_start_tls = loop.start_tls # type: ignore
- transport = await asyncio.wait_for(
- loop_start_tls(
- transport=transport,
- protocol=protocol,
- sslcontext=ssl_context,
- server_hostname=hostname,
- ),
- timeout=timeout.connect_timeout,
- )
-
- stream_reader.set_transport(transport)
- stream.stream_reader = stream_reader
- stream.stream_writer = asyncio.StreamWriter(
- transport=transport, protocol=protocol, reader=stream_reader, loop=loop
- )
- return stream
-
async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
def get_http_version(self) -> str:
raise NotImplementedError() # pragma: no cover
+ async def start_tls(
+ self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
+ ) -> "BaseTCPStream":
+ raise NotImplementedError() # pragma: no cover
+
async def read(
self, n: int, timeout: TimeoutConfig = None, flag: typing.Any = None
) -> bytes:
) -> BaseTCPStream:
raise NotImplementedError() # pragma: no cover
- async def start_tls(
- self,
- stream: BaseTCPStream,
- hostname: str,
- ssl_context: ssl.SSLContext,
- timeout: TimeoutConfig,
- ) -> BaseTCPStream:
- raise NotImplementedError() # pragma: no cover
-
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
raise NotImplementedError() # pragma: no cover
self.write_buffer = b""
self.write_lock = trio.Lock()
+ async def start_tls(
+ self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
+ ) -> BaseTCPStream:
+ # Check that the write buffer is empty. We should never start a TLS stream
+ # while there is still pending data to write.
+ assert self.write_buffer == b""
+
+ connect_timeout = _or_inf(timeout.connect_timeout)
+ ssl_stream = trio.SSLStream(
+ self.stream, ssl_context=ssl_context, server_hostname=hostname
+ )
+
+ with trio.move_on_after(connect_timeout) as cancel_scope:
+ await ssl_stream.do_handshake()
+
+ if cancel_scope.cancelled_caught:
+ raise ConnectTimeout()
+
+ return TCPStream(ssl_stream, self.timeout)
+
def get_http_version(self) -> str:
if not isinstance(self.stream, trio.SSLStream):
return "HTTP/1.1"
return TCPStream(stream=stream, timeout=timeout)
- async def start_tls(
- self,
- stream: BaseTCPStream,
- hostname: str,
- ssl_context: ssl.SSLContext,
- timeout: TimeoutConfig,
- ) -> BaseTCPStream:
- assert isinstance(stream, TCPStream)
-
- connect_timeout = _or_inf(timeout.connect_timeout)
- ssl_stream = trio.SSLStream(
- stream.stream, ssl_context=ssl_context, server_hostname=hostname
- )
-
- with trio.move_on_after(connect_timeout) as cancel_scope:
- await ssl_stream.do_handshake()
-
- if cancel_scope.cancelled_caught:
- raise ConnectTimeout()
-
- stream.stream = ssl_stream
-
- return stream
-
async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
f"proxy_url={self.proxy_url!r} "
f"origin={origin!r}"
)
- stream = await self.backend.start_tls(
- stream=stream,
- hostname=origin.host,
- ssl_context=ssl_context,
- timeout=timeout,
+ stream = await stream.start_tls(
+ hostname=origin.host, ssl_context=ssl_context, timeout=timeout
)
http_version = stream.get_http_version()
logger.debug(
)
return self.stream
- async def start_tls(
- self,
- stream: BaseTCPStream,
- hostname: str,
- ssl_context: ssl.SSLContext,
- timeout: TimeoutConfig,
- ) -> BaseTCPStream:
- self.received_data.append(b"--- START_TLS(%s) ---" % hostname.encode())
- return self.stream
-
# Defer all other attributes and methods to the underlying backend.
def __getattr__(self, name: str) -> typing.Any:
return getattr(self.backend, name)
def __init__(self, backend: MockRawSocketBackend):
self.backend = backend
+ async def start_tls(
+ self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
+ ) -> BaseTCPStream:
+ self.backend.received_data.append(b"--- START_TLS(%s) ---" % hostname.encode())
+ return MockRawSocketStream(self.backend)
+
def get_http_version(self) -> str:
return "HTTP/1.1"
assert stream.is_connection_dropped() is False
assert get_cipher(stream) is None
- stream = await backend.start_tls(stream, https_server.url.host, ctx, timeout)
+ stream = await stream.start_tls(https_server.url.host, ctx, timeout)
assert stream.is_connection_dropped() is False
assert get_cipher(stream) is not None