From: Florimond Manca Date: Fri, 6 Dec 2019 13:00:38 +0000 (+0100) Subject: Move tunnel_start_tls() to HTTPConnection (#609) X-Git-Tag: 0.9.0~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e1f5b8ba574e1fea4a9a79042f688f143c546f9d;p=thirdparty%2Fhttpx.git Move tunnel_start_tls() to HTTPConnection (#609) --- diff --git a/httpx/dispatch/connection.py b/httpx/dispatch/connection.py index ecb84a74..6d7790a4 100644 --- a/httpx/dispatch/connection.py +++ b/httpx/dispatch/connection.py @@ -2,9 +2,11 @@ import functools import ssl import typing +import h11 + from ..concurrency.base import ConcurrencyBackend, lookup_backend from ..config import CertTypes, SSLConfig, Timeout, VerifyTypes -from ..models import Origin, Request, Response +from ..models import URL, Origin, Request, Response from ..utils import get_logger from .base import Dispatcher from .http2 import HTTP2Connection @@ -98,6 +100,64 @@ class HTTPConnection(Dispatcher): assert http_version == "HTTP/1.1" self.h11_connection = HTTP11Connection(stream, on_release=on_release) + async def tunnel_start_tls( + self, + origin: Origin, + proxy_url: URL, + timeout: Timeout = None, + cert: CertTypes = None, + verify: VerifyTypes = True, + ) -> None: + """ + Upgrade this connection to use TLS, assuming it represents a TCP tunnel. + """ + timeout = Timeout() if timeout is None else timeout + + # First, check that we are in the correct state to start TLS, i.e. we've + # just agreed to switch protocols with the server via HTTP/1.1. + h11_connection = self.h11_connection + assert h11_connection is not None + assert h11_connection.h11_state.our_state == h11.SWITCHED_PROTOCOL + + # Store this information here so that we can transfer + # it to the new internal connection object after + # the old one goes to 'SWITCHED_PROTOCOL'. + # Note that the negotiated 'http_version' may change after the TLS upgrade. + http_version = "HTTP/1.1" + socket = h11_connection.socket + on_release = h11_connection.on_release + + if origin.is_ssl: + # Pull the socket stream off the internal HTTP connection object, + # and run start_tls(). + ssl_config = SSLConfig(cert=cert, verify=verify) + ssl_context = await self.get_ssl_context(ssl_config) + assert ssl_context is not None + + logger.trace(f"tunnel_start_tls proxy_url={proxy_url!r} origin={origin!r}") + socket = await socket.start_tls( + hostname=origin.host, ssl_context=ssl_context, timeout=timeout + ) + http_version = socket.get_http_version() + logger.trace( + f"tunnel_tls_complete " + f"proxy_url={proxy_url!r} " + f"origin={origin!r} " + f"http_version={http_version!r}" + ) + else: + # User requested the use of a tunnel, but they're performing a plain-text + # HTTP request. Don't try to upgrade to TLS in this case. + pass + + if http_version == "HTTP/2": + self.h2_connection = HTTP2Connection( + socket, self.backend, on_release=on_release + ) + else: + assert http_version == "HTTP/1.1" + self.h11_connection = HTTP11Connection(socket, on_release=on_release) + async def get_ssl_context(self, ssl: SSLConfig) -> typing.Optional[ssl.SSLContext]: if not self.origin.is_ssl: return None diff --git a/httpx/dispatch/proxy_http.py b/httpx/dispatch/proxy_http.py index aec042aa..a6e49736 100644 --- a/httpx/dispatch/proxy_http.py +++ b/httpx/dispatch/proxy_http.py @@ -2,24 +2,13 @@ import enum import typing from base64 import b64encode -import h11 - from ..concurrency.base import ConcurrencyBackend -from ..config import ( - DEFAULT_POOL_LIMITS, - CertTypes, - PoolLimits, - SSLConfig, - Timeout, - VerifyTypes, -) +from ..config import DEFAULT_POOL_LIMITS, CertTypes, PoolLimits, Timeout, VerifyTypes from ..exceptions import ProxyError from ..models import URL, Headers, HeaderTypes, Origin, Request, Response, URLTypes from ..utils import get_logger from .connection import HTTPConnection from .connection_pool import ConnectionPool -from .http2 import HTTP2Connection -from .http11 import HTTP11Connection logger = get_logger(__name__) @@ -110,8 +99,13 @@ class HTTPProxy(ConnectionPool): connection.origin = origin self.active_connections.add(connection) - timeout = Timeout() if timeout is None else timeout - await self.tunnel_start_tls(origin, connection, timeout) + await connection.tunnel_start_tls( + origin=origin, + proxy_url=self.proxy_url, + timeout=timeout, + cert=self.cert, + verify=self.verify, + ) else: self.active_connections.add(connection) @@ -161,54 +155,6 @@ class HTTPProxy(ConnectionPool): return connection - async def tunnel_start_tls( - self, origin: Origin, connection: HTTPConnection, timeout: Timeout = None - ) -> None: - """Runs start_tls() on a TCP-tunneled connection""" - timeout = Timeout() if timeout is None else timeout - - # Store this information here so that we can transfer - # it to the new internal connection object after - # the old one goes to 'SWITCHED_PROTOCOL'. - http_version = "HTTP/1.1" - http_connection = connection.h11_connection - assert http_connection is not None - assert http_connection.h11_state.our_state == h11.SWITCHED_PROTOCOL - on_release = http_connection.on_release - socket = http_connection.socket - - # If we need to start TLS again for the target server - # we need to pull the socket stream off the internal - # HTTP connection object and run start_tls() - if origin.is_ssl: - ssl_config = SSLConfig(cert=self.cert, verify=self.verify) - ssl_context = await connection.get_ssl_context(ssl_config) - assert ssl_context is not None - - logger.trace( - f"tunnel_start_tls " - f"proxy_url={self.proxy_url!r} " - f"origin={origin!r}" - ) - socket = await socket.start_tls( - hostname=origin.host, ssl_context=ssl_context, timeout=timeout - ) - http_version = socket.get_http_version() - logger.trace( - f"tunnel_tls_complete " - f"proxy_url={self.proxy_url!r} " - f"origin={origin!r} " - f"http_version={http_version!r}" - ) - - if http_version == "HTTP/2": - connection.h2_connection = HTTP2Connection( - socket, self.backend, on_release=on_release - ) - else: - assert http_version == "HTTP/1.1" - connection.h11_connection = HTTP11Connection(socket, on_release=on_release) - def should_forward_origin(self, origin: Origin) -> bool: """Determines if the given origin should be forwarded or tunneled. If 'proxy_mode' is 'DEFAULT'