]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Make start_tls a method on streams & return a new stream (#484)
authorJamie Hewland <jhewland@gmail.com>
Sun, 20 Oct 2019 10:59:16 +0000 (12:59 +0200)
committerFlorimond Manca <florimond.manca@gmail.com>
Sun, 20 Oct 2019 10:59:16 +0000 (12:59 +0200)
* Move start_tls to stream & return a new stream

* asyncio: Keep a reference to the inner stream when upgrading to TLS

httpx/concurrency/asyncio.py
httpx/concurrency/base.py
httpx/concurrency/trio.py
httpx/dispatch/proxy_http.py
tests/dispatch/utils.py
tests/test_concurrency.py

index dfee42545688c7517f4b48a8a29d9af0b10662ce..7083345426c1805406fe44e6dbeb54ea26c74d81 100644 (file)
@@ -51,6 +51,44 @@ class TCPStream(BaseTCPStream):
         self.stream_writer = stream_writer
         self.timeout = timeout
 
+        self._inner: typing.Optional[TCPStream] = None
+
+    async def start_tls(
+        self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
+    ) -> BaseTCPStream:
+        loop = asyncio.get_event_loop()
+        if not hasattr(loop, "start_tls"):  # pragma: no cover
+            raise NotImplementedError(
+                "asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
+            )
+
+        stream_reader = asyncio.StreamReader()
+        protocol = asyncio.StreamReaderProtocol(stream_reader)
+        transport = self.stream_writer.transport
+
+        loop_start_tls = loop.start_tls  # type: ignore
+        transport = await asyncio.wait_for(
+            loop_start_tls(
+                transport=transport,
+                protocol=protocol,
+                sslcontext=ssl_context,
+                server_hostname=hostname,
+            ),
+            timeout=timeout.connect_timeout,
+        )
+
+        stream_reader.set_transport(transport)
+        stream_writer = asyncio.StreamWriter(
+            transport=transport, protocol=protocol, reader=stream_reader, loop=loop
+        )
+
+        ssl_stream = TCPStream(stream_reader, stream_writer, self.timeout)
+        # When we return a new TCPStream with new StreamReader/StreamWriter instances,
+        # we need to keep references to the old StreamReader/StreamWriter so that they
+        # are not garbage collected and closed while we're still using them.
+        ssl_stream._inner = self
+        return ssl_stream
+
     def get_http_version(self) -> str:
         ssl_object = self.stream_writer.get_extra_info("ssl_object")
 
@@ -201,44 +239,6 @@ class AsyncioBackend(ConcurrencyBackend):
             stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
         )
 
-    async def start_tls(
-        self,
-        stream: BaseTCPStream,
-        hostname: str,
-        ssl_context: ssl.SSLContext,
-        timeout: TimeoutConfig,
-    ) -> BaseTCPStream:
-
-        loop = self.loop
-        if not hasattr(loop, "start_tls"):  # pragma: no cover
-            raise NotImplementedError(
-                "asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
-            )
-
-        assert isinstance(stream, TCPStream)
-
-        stream_reader = asyncio.StreamReader()
-        protocol = asyncio.StreamReaderProtocol(stream_reader)
-        transport = stream.stream_writer.transport
-
-        loop_start_tls = loop.start_tls  # type: ignore
-        transport = await asyncio.wait_for(
-            loop_start_tls(
-                transport=transport,
-                protocol=protocol,
-                sslcontext=ssl_context,
-                server_hostname=hostname,
-            ),
-            timeout=timeout.connect_timeout,
-        )
-
-        stream_reader.set_transport(transport)
-        stream.stream_reader = stream_reader
-        stream.stream_writer = asyncio.StreamWriter(
-            transport=transport, protocol=protocol, reader=stream_reader, loop=loop
-        )
-        return stream
-
     async def run_in_threadpool(
         self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
     ) -> typing.Any:
index fc784b30f8f823bebdbcc4d1b7030d88ae59f891..a23d89bd30a90113777afcf88e3dc223b7afb703 100644 (file)
@@ -47,6 +47,11 @@ class BaseTCPStream:
     def get_http_version(self) -> str:
         raise NotImplementedError()  # pragma: no cover
 
+    async def start_tls(
+        self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
+    ) -> "BaseTCPStream":
+        raise NotImplementedError()  # pragma: no cover
+
     async def read(
         self, n: int, timeout: TimeoutConfig = None, flag: typing.Any = None
     ) -> bytes:
@@ -119,15 +124,6 @@ class ConcurrencyBackend:
     ) -> BaseTCPStream:
         raise NotImplementedError()  # pragma: no cover
 
-    async def start_tls(
-        self,
-        stream: BaseTCPStream,
-        hostname: str,
-        ssl_context: ssl.SSLContext,
-        timeout: TimeoutConfig,
-    ) -> BaseTCPStream:
-        raise NotImplementedError()  # pragma: no cover
-
     def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
         raise NotImplementedError()  # pragma: no cover
 
index 3de3d1407812f2df1cd1a1abeb5bb3d66e75c0e8..da8e38a0ef768509c5f8cc4f72383a10051d4f01 100644 (file)
@@ -34,6 +34,26 @@ class TCPStream(BaseTCPStream):
         self.write_buffer = b""
         self.write_lock = trio.Lock()
 
+    async def start_tls(
+        self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
+    ) -> BaseTCPStream:
+        # Check that the write buffer is empty. We should never start a TLS stream
+        # while there is still pending data to write.
+        assert self.write_buffer == b""
+
+        connect_timeout = _or_inf(timeout.connect_timeout)
+        ssl_stream = trio.SSLStream(
+            self.stream, ssl_context=ssl_context, server_hostname=hostname
+        )
+
+        with trio.move_on_after(connect_timeout) as cancel_scope:
+            await ssl_stream.do_handshake()
+
+        if cancel_scope.cancelled_caught:
+            raise ConnectTimeout()
+
+        return TCPStream(ssl_stream, self.timeout)
+
     def get_http_version(self) -> str:
         if not isinstance(self.stream, trio.SSLStream):
             return "HTTP/1.1"
@@ -171,30 +191,6 @@ class TrioBackend(ConcurrencyBackend):
 
         return TCPStream(stream=stream, timeout=timeout)
 
-    async def start_tls(
-        self,
-        stream: BaseTCPStream,
-        hostname: str,
-        ssl_context: ssl.SSLContext,
-        timeout: TimeoutConfig,
-    ) -> BaseTCPStream:
-        assert isinstance(stream, TCPStream)
-
-        connect_timeout = _or_inf(timeout.connect_timeout)
-        ssl_stream = trio.SSLStream(
-            stream.stream, ssl_context=ssl_context, server_hostname=hostname
-        )
-
-        with trio.move_on_after(connect_timeout) as cancel_scope:
-            await ssl_stream.do_handshake()
-
-        if cancel_scope.cancelled_caught:
-            raise ConnectTimeout()
-
-        stream.stream = ssl_stream
-
-        return stream
-
     async def run_in_threadpool(
         self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
     ) -> typing.Any:
index be2e289ffcc94796c202ad6d93763654ff8e9dd4..e54a5b30731415bd9adf2f84bc5deb59ad2e6ffd 100644 (file)
@@ -192,11 +192,8 @@ class HTTPProxy(ConnectionPool):
                 f"proxy_url={self.proxy_url!r} "
                 f"origin={origin!r}"
             )
-            stream = await self.backend.start_tls(
-                stream=stream,
-                hostname=origin.host,
-                ssl_context=ssl_context,
-                timeout=timeout,
+            stream = await stream.start_tls(
+                hostname=origin.host, ssl_context=ssl_context, timeout=timeout
             )
             http_version = stream.get_http_version()
             logger.debug(
index 0798b31ff2b0d957bd3debd582df274ce7687e43..3b0d5340000c99e9b6a4240707af3f893846c5aa 100644 (file)
@@ -184,16 +184,6 @@ class MockRawSocketBackend:
         )
         return self.stream
 
-    async def start_tls(
-        self,
-        stream: BaseTCPStream,
-        hostname: str,
-        ssl_context: ssl.SSLContext,
-        timeout: TimeoutConfig,
-    ) -> BaseTCPStream:
-        self.received_data.append(b"--- START_TLS(%s) ---" % hostname.encode())
-        return self.stream
-
     # Defer all other attributes and methods to the underlying backend.
     def __getattr__(self, name: str) -> typing.Any:
         return getattr(self.backend, name)
@@ -203,6 +193,12 @@ class MockRawSocketStream(BaseTCPStream):
     def __init__(self, backend: MockRawSocketBackend):
         self.backend = backend
 
+    async def start_tls(
+        self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
+    ) -> BaseTCPStream:
+        self.backend.received_data.append(b"--- START_TLS(%s) ---" % hostname.encode())
+        return MockRawSocketStream(self.backend)
+
     def get_http_version(self) -> str:
         return "HTTP/1.1"
 
index 3f9e8262351b1a01d209043a45b9763a9dc7d3a9..27bbeaf28048286f9466f1498f56402807ec495f 100644 (file)
@@ -45,7 +45,7 @@ async def test_start_tls_on_socket_stream(https_server, backend, get_cipher):
         assert stream.is_connection_dropped() is False
         assert get_cipher(stream) is None
 
-        stream = await backend.start_tls(stream, https_server.url.host, ctx, timeout)
+        stream = await stream.start_tls(https_server.url.host, ctx, timeout)
         assert stream.is_connection_dropped() is False
         assert get_cipher(stream) is not None