From: Seth Michael Larson Date: Thu, 15 Aug 2019 03:30:02 +0000 (-0500) Subject: Accept SSLContext into SSLConfig(verify=...) (#215) X-Git-Tag: 0.7.0~9 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=df8874b7339c9dcb7963689b304ce0179fc0d862;p=thirdparty%2Fhttpx.git Accept SSLContext into SSLConfig(verify=...) (#215) --- diff --git a/httpx/config.py b/httpx/config.py index 3d2fe1c7..6d427b85 100644 --- a/httpx/config.py +++ b/httpx/config.py @@ -7,7 +7,7 @@ import certifi from .__version__ import __version__ CertTypes = typing.Union[str, typing.Tuple[str, str], typing.Tuple[str, str, str]] -VerifyTypes = typing.Union[str, bool] +VerifyTypes = typing.Union[str, bool, ssl.SSLContext] TimeoutTypes = typing.Union[float, typing.Tuple[float, float, float], "TimeoutConfig"] @@ -40,9 +40,17 @@ class SSLConfig: def __init__(self, *, cert: CertTypes = None, verify: VerifyTypes = True): self.cert = cert - self.verify = verify - self.ssl_context: typing.Optional[ssl.SSLContext] = None + # Allow passing in our own SSLContext object that's pre-configured. + # If you do this we assume that you want verify=True as well. + ssl_context = None + if isinstance(verify, ssl.SSLContext): + ssl_context = verify + verify = True + self._load_client_certs(ssl_context) + + self.ssl_context: typing.Optional[ssl.SSLContext] = ssl_context + self.verify: typing.Union[str, bool] = verify def __eq__(self, other: typing.Any) -> bool: return ( @@ -121,17 +129,7 @@ class SSLConfig: elif ca_bundle_path.is_dir(): context.load_verify_locations(capath=str(ca_bundle_path)) - if self.cert is not None: - if isinstance(self.cert, str): - context.load_cert_chain(certfile=self.cert) - elif isinstance(self.cert, tuple) and len(self.cert) == 2: - context.load_cert_chain(certfile=self.cert[0], keyfile=self.cert[1]) - elif isinstance(self.cert, tuple) and len(self.cert) == 3: - context.load_cert_chain( - certfile=self.cert[0], - keyfile=self.cert[1], - password=self.cert[2], # type: ignore - ) + self._load_client_certs(context) return context @@ -155,6 +153,22 @@ class SSLConfig: return context + def _load_client_certs(self, ssl_context: ssl.SSLContext) -> None: + """ + Loads client certificates into our SSLContext object + """ + if self.cert is not None: + if isinstance(self.cert, str): + ssl_context.load_cert_chain(certfile=self.cert) + elif isinstance(self.cert, tuple) and len(self.cert) == 2: + ssl_context.load_cert_chain(certfile=self.cert[0], keyfile=self.cert[1]) + elif isinstance(self.cert, tuple) and len(self.cert) == 3: + ssl_context.load_cert_chain( + certfile=self.cert[0], + keyfile=self.cert[1], + password=self.cert[2], # type: ignore + ) + class TimeoutConfig: """ diff --git a/tests/test_config.py b/tests/test_config.py index b587e931..20befa7c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -76,6 +76,15 @@ def test_load_ssl_config_no_verify(): assert context.check_hostname is False +def test_load_ssl_context(): + ssl_context = ssl.create_default_context() + ssl_config = httpx.SSLConfig(verify=ssl_context) + + assert ssl_config.verify is True + assert ssl_config.ssl_context is ssl_context + assert repr(ssl_config) == "SSLConfig(cert=None, verify=True)" + + def test_ssl_repr(): ssl = httpx.SSLConfig(verify=False) assert repr(ssl) == "SSLConfig(cert=None, verify=False)"