]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Move tunnel_start_tls() to HTTPConnection (#609)
authorFlorimond Manca <florimond.manca@gmail.com>
Fri, 6 Dec 2019 13:00:38 +0000 (14:00 +0100)
committerTom Christie <tom@tomchristie.com>
Fri, 6 Dec 2019 13:00:38 +0000 (13:00 +0000)
httpx/dispatch/connection.py
httpx/dispatch/proxy_http.py

index ecb84a748bd1616ce3e46a3d905163ee621295f0..6d7790a4f1813cab6761199f2c510ffa87bd13ef 100644 (file)
@@ -2,9 +2,11 @@ import functools
 import ssl
 import typing
 
+import h11
+
 from ..concurrency.base import ConcurrencyBackend, lookup_backend
 from ..config import CertTypes, SSLConfig, Timeout, VerifyTypes
-from ..models import Origin, Request, Response
+from ..models import URL, Origin, Request, Response
 from ..utils import get_logger
 from .base import Dispatcher
 from .http2 import HTTP2Connection
@@ -98,6 +100,64 @@ class HTTPConnection(Dispatcher):
             assert http_version == "HTTP/1.1"
             self.h11_connection = HTTP11Connection(stream, on_release=on_release)
 
+    async def tunnel_start_tls(
+        self,
+        origin: Origin,
+        proxy_url: URL,
+        timeout: Timeout = None,
+        cert: CertTypes = None,
+        verify: VerifyTypes = True,
+    ) -> None:
+        """
+        Upgrade this connection to use TLS, assuming it represents a TCP tunnel.
+        """
+        timeout = Timeout() if timeout is None else timeout
+
+        # First, check that we are in the correct state to start TLS, i.e. we've
+        # just agreed to switch protocols with the server via HTTP/1.1.
+        h11_connection = self.h11_connection
+        assert h11_connection is not None
+        assert h11_connection.h11_state.our_state == h11.SWITCHED_PROTOCOL
+
+        # Store this information here so that we can transfer
+        # it to the new internal connection object after
+        # the old one goes to 'SWITCHED_PROTOCOL'.
+        # Note that the negotiated 'http_version' may change after the TLS upgrade.
+        http_version = "HTTP/1.1"
+        socket = h11_connection.socket
+        on_release = h11_connection.on_release
+
+        if origin.is_ssl:
+            # Pull the socket stream off the internal HTTP connection object,
+            # and run start_tls().
+            ssl_config = SSLConfig(cert=cert, verify=verify)
+            ssl_context = await self.get_ssl_context(ssl_config)
+            assert ssl_context is not None
+
+            logger.trace(f"tunnel_start_tls proxy_url={proxy_url!r} origin={origin!r}")
+            socket = await socket.start_tls(
+                hostname=origin.host, ssl_context=ssl_context, timeout=timeout
+            )
+            http_version = socket.get_http_version()
+            logger.trace(
+                f"tunnel_tls_complete "
+                f"proxy_url={proxy_url!r} "
+                f"origin={origin!r} "
+                f"http_version={http_version!r}"
+            )
+        else:
+            # User requested the use of a tunnel, but they're performing a plain-text
+            # HTTP request. Don't try to upgrade to TLS in this case.
+            pass
+
+        if http_version == "HTTP/2":
+            self.h2_connection = HTTP2Connection(
+                socket, self.backend, on_release=on_release
+            )
+        else:
+            assert http_version == "HTTP/1.1"
+            self.h11_connection = HTTP11Connection(socket, on_release=on_release)
+
     async def get_ssl_context(self, ssl: SSLConfig) -> typing.Optional[ssl.SSLContext]:
         if not self.origin.is_ssl:
             return None
index aec042aa6e7b53d710114a2eae188422beef887a..a6e4973661d3502ec8f047c12d8402f8f1dc62a0 100644 (file)
@@ -2,24 +2,13 @@ import enum
 import typing
 from base64 import b64encode
 
-import h11
-
 from ..concurrency.base import ConcurrencyBackend
-from ..config import (
-    DEFAULT_POOL_LIMITS,
-    CertTypes,
-    PoolLimits,
-    SSLConfig,
-    Timeout,
-    VerifyTypes,
-)
+from ..config import DEFAULT_POOL_LIMITS, CertTypes, PoolLimits, Timeout, VerifyTypes
 from ..exceptions import ProxyError
 from ..models import URL, Headers, HeaderTypes, Origin, Request, Response, URLTypes
 from ..utils import get_logger
 from .connection import HTTPConnection
 from .connection_pool import ConnectionPool
-from .http2 import HTTP2Connection
-from .http11 import HTTP11Connection
 
 logger = get_logger(__name__)
 
@@ -110,8 +99,13 @@ class HTTPProxy(ConnectionPool):
             connection.origin = origin
             self.active_connections.add(connection)
 
-            timeout = Timeout() if timeout is None else timeout
-            await self.tunnel_start_tls(origin, connection, timeout)
+            await connection.tunnel_start_tls(
+                origin=origin,
+                proxy_url=self.proxy_url,
+                timeout=timeout,
+                cert=self.cert,
+                verify=self.verify,
+            )
         else:
             self.active_connections.add(connection)
 
@@ -161,54 +155,6 @@ class HTTPProxy(ConnectionPool):
 
         return connection
 
-    async def tunnel_start_tls(
-        self, origin: Origin, connection: HTTPConnection, timeout: Timeout = None
-    ) -> None:
-        """Runs start_tls() on a TCP-tunneled connection"""
-        timeout = Timeout() if timeout is None else timeout
-
-        # Store this information here so that we can transfer
-        # it to the new internal connection object after
-        # the old one goes to 'SWITCHED_PROTOCOL'.
-        http_version = "HTTP/1.1"
-        http_connection = connection.h11_connection
-        assert http_connection is not None
-        assert http_connection.h11_state.our_state == h11.SWITCHED_PROTOCOL
-        on_release = http_connection.on_release
-        socket = http_connection.socket
-
-        # If we need to start TLS again for the target server
-        # we need to pull the socket stream off the internal
-        # HTTP connection object and run start_tls()
-        if origin.is_ssl:
-            ssl_config = SSLConfig(cert=self.cert, verify=self.verify)
-            ssl_context = await connection.get_ssl_context(ssl_config)
-            assert ssl_context is not None
-
-            logger.trace(
-                f"tunnel_start_tls "
-                f"proxy_url={self.proxy_url!r} "
-                f"origin={origin!r}"
-            )
-            socket = await socket.start_tls(
-                hostname=origin.host, ssl_context=ssl_context, timeout=timeout
-            )
-            http_version = socket.get_http_version()
-            logger.trace(
-                f"tunnel_tls_complete "
-                f"proxy_url={self.proxy_url!r} "
-                f"origin={origin!r} "
-                f"http_version={http_version!r}"
-            )
-
-        if http_version == "HTTP/2":
-            connection.h2_connection = HTTP2Connection(
-                socket, self.backend, on_release=on_release
-            )
-        else:
-            assert http_version == "HTTP/1.1"
-            connection.h11_connection = HTTP11Connection(socket, on_release=on_release)
-
     def should_forward_origin(self, origin: Origin) -> bool:
         """Determines if the given origin should
         be forwarded or tunneled. If 'proxy_mode' is 'DEFAULT'