http2: bool = False,
):
self.cert = cert
-
- # Allow passing in our own SSLContext object that's pre-configured.
- # If you do this we assume that you want verify=True as well.
- ssl_context = None
- if isinstance(verify, ssl.SSLContext):
- ssl_context = verify
- verify = True
- self._load_client_certs(ssl_context)
-
- self.ssl_context: typing.Optional[ssl.SSLContext] = ssl_context
- self.verify: typing.Union[str, bool] = verify
+ self.verify = verify
self.trust_env = trust_env
self.http2 = http2
+ self.ssl_context = self.load_ssl_context()
def __eq__(self, other: typing.Any) -> bool:
return (
f"http2={self.http2!r}"
)
- if self.ssl_context is None:
- self.ssl_context = (
- self.load_ssl_context_verify()
- if self.verify
- else self.load_ssl_context_no_verify()
- )
-
- assert self.ssl_context is not None
- return self.ssl_context
+ if self.verify:
+ return self.load_ssl_context_verify()
+ return self.load_ssl_context_no_verify()
def load_ssl_context_no_verify(self) -> ssl.SSLContext:
"""
if ca_bundle is not None:
self.verify = ca_bundle # type: ignore
- if isinstance(self.verify, bool):
+ if isinstance(self.verify, ssl.SSLContext):
+ # Allow passing in our own SSLContext object that's pre-configured.
+ context = self.verify
+ self._load_client_certs(context)
+ return context
+ elif isinstance(self.verify, bool):
ca_bundle_path = self.DEFAULT_CA_BUNDLE_PATH
elif Path(self.verify).exists():
ca_bundle_path = Path(self.verify)
import functools
-import ssl
import typing
import h11
) -> typing.Union[HTTP11Connection, HTTP2Connection]:
host = self.origin.host
port = self.origin.port
- ssl_context = self.get_ssl_context()
+ ssl_context = None if not self.origin.is_ssl else self.ssl.ssl_context
if self.release_func is None:
on_release = None
if origin.is_ssl:
# Pull the socket stream off the internal HTTP connection object,
# and run start_tls().
- ssl_context = self.get_ssl_context()
- assert ssl_context is not None
+ ssl_context = self.ssl.ssl_context
logger.trace(f"tunnel_start_tls proxy_url={proxy_url!r} origin={origin!r}")
socket = await socket.start_tls(
else:
self.connection = HTTP11Connection(socket, on_release=on_release)
- def get_ssl_context(self) -> typing.Optional[ssl.SSLContext]:
- if not self.origin.is_ssl:
- return None
- return self.ssl.load_ssl_context()
-
async def close(self) -> None:
logger.trace("close_connection")
if self.connection is not None:
def test_load_ssl_config():
ssl_config = SSLConfig()
- context = ssl_config.load_ssl_context()
+ context = ssl_config.ssl_context
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
assert context.check_hostname is True
def test_load_ssl_config_verify_non_existing_path():
- ssl_config = SSLConfig(verify="/path/to/nowhere")
with pytest.raises(IOError):
- ssl_config.load_ssl_context()
+ SSLConfig(verify="/path/to/nowhere")
def test_load_ssl_config_verify_existing_file():
ssl_config = SSLConfig(verify=certifi.where())
- context = ssl_config.load_ssl_context()
+ context = ssl_config.ssl_context
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
assert context.check_hostname is True
else str(Path(ca_cert_pem_file).parent)
)
ssl_config = SSLConfig(trust_env=True)
- context = ssl_config.load_ssl_context()
+ context = ssl_config.ssl_context
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
assert context.check_hostname is True
assert ssl_config.verify == os.environ[config]
def test_load_ssl_config_verify_directory():
path = Path(certifi.where()).parent
ssl_config = SSLConfig(verify=path)
- context = ssl_config.load_ssl_context()
+ context = ssl_config.ssl_context
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
assert context.check_hostname is True
def test_load_ssl_config_cert_and_key(cert_pem_file, cert_private_key_file):
ssl_config = SSLConfig(cert=(cert_pem_file, cert_private_key_file))
- context = ssl_config.load_ssl_context()
+ context = ssl_config.ssl_context
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
assert context.check_hostname is True
ssl_config = SSLConfig(
cert=(cert_pem_file, cert_encrypted_private_key_file, password)
)
- context = ssl_config.load_ssl_context()
+ context = ssl_config.ssl_context
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
assert context.check_hostname is True
def test_load_ssl_config_cert_and_key_invalid_password(
cert_pem_file, cert_encrypted_private_key_file
):
- ssl_config = SSLConfig(
- cert=(cert_pem_file, cert_encrypted_private_key_file, "password1")
- )
-
with pytest.raises(ssl.SSLError):
- ssl_config.load_ssl_context()
+ SSLConfig(cert=(cert_pem_file, cert_encrypted_private_key_file, "password1"))
def test_load_ssl_config_cert_without_key_raises(cert_pem_file):
- ssl_config = SSLConfig(cert=cert_pem_file)
with pytest.raises(ssl.SSLError):
- ssl_config.load_ssl_context()
+ SSLConfig(cert=cert_pem_file)
def test_load_ssl_config_no_verify():
ssl_config = SSLConfig(verify=False)
- context = ssl_config.load_ssl_context()
+ context = ssl_config.ssl_context
assert context.verify_mode == ssl.VerifyMode.CERT_NONE
assert context.check_hostname is False
ssl_context = ssl.create_default_context()
ssl_config = SSLConfig(verify=ssl_context)
- assert ssl_config.verify is True
assert ssl_config.ssl_context is ssl_context
- assert repr(ssl_config) == "SSLConfig(cert=None, verify=True)"
def test_ssl_repr():
m.delenv("SSLKEYLOGFILE", raising=False)
ssl_config = SSLConfig(trust_env=True)
- ssl_config.load_ssl_context()
assert ssl_config.ssl_context.keylog_filename is None
m.setenv("SSLKEYLOGFILE", filename)
ssl_config = SSLConfig(trust_env=True)
- ssl_config.load_ssl_context()
assert ssl_config.ssl_context.keylog_filename == filename
ssl_config = SSLConfig(trust_env=False)
- ssl_config.load_ssl_context()
assert ssl_config.ssl_context.keylog_filename is None