from .__version__ import __version__
CertTypes = typing.Union[str, typing.Tuple[str, str], typing.Tuple[str, str, str]]
-VerifyTypes = typing.Union[str, bool]
+VerifyTypes = typing.Union[str, bool, ssl.SSLContext]
TimeoutTypes = typing.Union[float, typing.Tuple[float, float, float], "TimeoutConfig"]
def __init__(self, *, cert: CertTypes = None, verify: VerifyTypes = True):
self.cert = cert
- self.verify = verify
- self.ssl_context: typing.Optional[ssl.SSLContext] = None
+ # Allow passing in our own SSLContext object that's pre-configured.
+ # If you do this we assume that you want verify=True as well.
+ ssl_context = None
+ if isinstance(verify, ssl.SSLContext):
+ ssl_context = verify
+ verify = True
+ self._load_client_certs(ssl_context)
+
+ self.ssl_context: typing.Optional[ssl.SSLContext] = ssl_context
+ self.verify: typing.Union[str, bool] = verify
def __eq__(self, other: typing.Any) -> bool:
return (
elif ca_bundle_path.is_dir():
context.load_verify_locations(capath=str(ca_bundle_path))
- if self.cert is not None:
- if isinstance(self.cert, str):
- context.load_cert_chain(certfile=self.cert)
- elif isinstance(self.cert, tuple) and len(self.cert) == 2:
- context.load_cert_chain(certfile=self.cert[0], keyfile=self.cert[1])
- elif isinstance(self.cert, tuple) and len(self.cert) == 3:
- context.load_cert_chain(
- certfile=self.cert[0],
- keyfile=self.cert[1],
- password=self.cert[2], # type: ignore
- )
+ self._load_client_certs(context)
return context
return context
+ def _load_client_certs(self, ssl_context: ssl.SSLContext) -> None:
+ """
+ Loads client certificates into our SSLContext object
+ """
+ if self.cert is not None:
+ if isinstance(self.cert, str):
+ ssl_context.load_cert_chain(certfile=self.cert)
+ elif isinstance(self.cert, tuple) and len(self.cert) == 2:
+ ssl_context.load_cert_chain(certfile=self.cert[0], keyfile=self.cert[1])
+ elif isinstance(self.cert, tuple) and len(self.cert) == 3:
+ ssl_context.load_cert_chain(
+ certfile=self.cert[0],
+ keyfile=self.cert[1],
+ password=self.cert[2], # type: ignore
+ )
+
class TimeoutConfig:
"""
assert context.check_hostname is False
+def test_load_ssl_context():
+ ssl_context = ssl.create_default_context()
+ ssl_config = httpx.SSLConfig(verify=ssl_context)
+
+ assert ssl_config.verify is True
+ assert ssl_config.ssl_context is ssl_context
+ assert repr(ssl_config) == "SSLConfig(cert=None, verify=True)"
+
+
def test_ssl_repr():
ssl = httpx.SSLConfig(verify=False)
assert repr(ssl) == "SSLConfig(cert=None, verify=False)"