import ssl
import typing
+import h11
+
from ..concurrency.base import ConcurrencyBackend, lookup_backend
from ..config import CertTypes, SSLConfig, Timeout, VerifyTypes
-from ..models import Origin, Request, Response
+from ..models import URL, Origin, Request, Response
from ..utils import get_logger
from .base import Dispatcher
from .http2 import HTTP2Connection
assert http_version == "HTTP/1.1"
self.h11_connection = HTTP11Connection(stream, on_release=on_release)
+ async def tunnel_start_tls(
+ self,
+ origin: Origin,
+ proxy_url: URL,
+ timeout: Timeout = None,
+ cert: CertTypes = None,
+ verify: VerifyTypes = True,
+ ) -> None:
+ """
+ Upgrade this connection to use TLS, assuming it represents a TCP tunnel.
+ """
+ timeout = Timeout() if timeout is None else timeout
+
+ # First, check that we are in the correct state to start TLS, i.e. we've
+ # just agreed to switch protocols with the server via HTTP/1.1.
+ h11_connection = self.h11_connection
+ assert h11_connection is not None
+ assert h11_connection.h11_state.our_state == h11.SWITCHED_PROTOCOL
+
+ # Store this information here so that we can transfer
+ # it to the new internal connection object after
+ # the old one goes to 'SWITCHED_PROTOCOL'.
+ # Note that the negotiated 'http_version' may change after the TLS upgrade.
+ http_version = "HTTP/1.1"
+ socket = h11_connection.socket
+ on_release = h11_connection.on_release
+
+ if origin.is_ssl:
+ # Pull the socket stream off the internal HTTP connection object,
+ # and run start_tls().
+ ssl_config = SSLConfig(cert=cert, verify=verify)
+ ssl_context = await self.get_ssl_context(ssl_config)
+ assert ssl_context is not None
+
+ logger.trace(f"tunnel_start_tls proxy_url={proxy_url!r} origin={origin!r}")
+ socket = await socket.start_tls(
+ hostname=origin.host, ssl_context=ssl_context, timeout=timeout
+ )
+ http_version = socket.get_http_version()
+ logger.trace(
+ f"tunnel_tls_complete "
+ f"proxy_url={proxy_url!r} "
+ f"origin={origin!r} "
+ f"http_version={http_version!r}"
+ )
+ else:
+ # User requested the use of a tunnel, but they're performing a plain-text
+ # HTTP request. Don't try to upgrade to TLS in this case.
+ pass
+
+ if http_version == "HTTP/2":
+ self.h2_connection = HTTP2Connection(
+ socket, self.backend, on_release=on_release
+ )
+ else:
+ assert http_version == "HTTP/1.1"
+ self.h11_connection = HTTP11Connection(socket, on_release=on_release)
+
async def get_ssl_context(self, ssl: SSLConfig) -> typing.Optional[ssl.SSLContext]:
if not self.origin.is_ssl:
return None
import typing
from base64 import b64encode
-import h11
-
from ..concurrency.base import ConcurrencyBackend
-from ..config import (
- DEFAULT_POOL_LIMITS,
- CertTypes,
- PoolLimits,
- SSLConfig,
- Timeout,
- VerifyTypes,
-)
+from ..config import DEFAULT_POOL_LIMITS, CertTypes, PoolLimits, Timeout, VerifyTypes
from ..exceptions import ProxyError
from ..models import URL, Headers, HeaderTypes, Origin, Request, Response, URLTypes
from ..utils import get_logger
from .connection import HTTPConnection
from .connection_pool import ConnectionPool
-from .http2 import HTTP2Connection
-from .http11 import HTTP11Connection
logger = get_logger(__name__)
connection.origin = origin
self.active_connections.add(connection)
- timeout = Timeout() if timeout is None else timeout
- await self.tunnel_start_tls(origin, connection, timeout)
+ await connection.tunnel_start_tls(
+ origin=origin,
+ proxy_url=self.proxy_url,
+ timeout=timeout,
+ cert=self.cert,
+ verify=self.verify,
+ )
else:
self.active_connections.add(connection)
return connection
- async def tunnel_start_tls(
- self, origin: Origin, connection: HTTPConnection, timeout: Timeout = None
- ) -> None:
- """Runs start_tls() on a TCP-tunneled connection"""
- timeout = Timeout() if timeout is None else timeout
-
- # Store this information here so that we can transfer
- # it to the new internal connection object after
- # the old one goes to 'SWITCHED_PROTOCOL'.
- http_version = "HTTP/1.1"
- http_connection = connection.h11_connection
- assert http_connection is not None
- assert http_connection.h11_state.our_state == h11.SWITCHED_PROTOCOL
- on_release = http_connection.on_release
- socket = http_connection.socket
-
- # If we need to start TLS again for the target server
- # we need to pull the socket stream off the internal
- # HTTP connection object and run start_tls()
- if origin.is_ssl:
- ssl_config = SSLConfig(cert=self.cert, verify=self.verify)
- ssl_context = await connection.get_ssl_context(ssl_config)
- assert ssl_context is not None
-
- logger.trace(
- f"tunnel_start_tls "
- f"proxy_url={self.proxy_url!r} "
- f"origin={origin!r}"
- )
- socket = await socket.start_tls(
- hostname=origin.host, ssl_context=ssl_context, timeout=timeout
- )
- http_version = socket.get_http_version()
- logger.trace(
- f"tunnel_tls_complete "
- f"proxy_url={self.proxy_url!r} "
- f"origin={origin!r} "
- f"http_version={http_version!r}"
- )
-
- if http_version == "HTTP/2":
- connection.h2_connection = HTTP2Connection(
- socket, self.backend, on_release=on_release
- )
- else:
- assert http_version == "HTTP/1.1"
- connection.h11_connection = HTTP11Connection(socket, on_release=on_release)
-
def should_forward_origin(self, origin: Origin) -> bool:
"""Determines if the given origin should
be forwarded or tunneled. If 'proxy_mode' is 'DEFAULT'