]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Accept SSLContext into SSLConfig(verify=...) (#215)
authorSeth Michael Larson <sethmichaellarson@gmail.com>
Thu, 15 Aug 2019 03:30:02 +0000 (22:30 -0500)
committerGitHub <noreply@github.com>
Thu, 15 Aug 2019 03:30:02 +0000 (22:30 -0500)
httpx/config.py
tests/test_config.py

index 3d2fe1c7d6f8ea8e4a5fc9171c3a59751f51a792..6d427b85c38ac0fa03c76c636770ef8ae62d8a3d 100644 (file)
@@ -7,7 +7,7 @@ import certifi
 from .__version__ import __version__
 
 CertTypes = typing.Union[str, typing.Tuple[str, str], typing.Tuple[str, str, str]]
-VerifyTypes = typing.Union[str, bool]
+VerifyTypes = typing.Union[str, bool, ssl.SSLContext]
 TimeoutTypes = typing.Union[float, typing.Tuple[float, float, float], "TimeoutConfig"]
 
 
@@ -40,9 +40,17 @@ class SSLConfig:
 
     def __init__(self, *, cert: CertTypes = None, verify: VerifyTypes = True):
         self.cert = cert
-        self.verify = verify
 
-        self.ssl_context: typing.Optional[ssl.SSLContext] = None
+        # Allow passing in our own SSLContext object that's pre-configured.
+        # If you do this we assume that you want verify=True as well.
+        ssl_context = None
+        if isinstance(verify, ssl.SSLContext):
+            ssl_context = verify
+            verify = True
+            self._load_client_certs(ssl_context)
+
+        self.ssl_context: typing.Optional[ssl.SSLContext] = ssl_context
+        self.verify: typing.Union[str, bool] = verify
 
     def __eq__(self, other: typing.Any) -> bool:
         return (
@@ -121,17 +129,7 @@ class SSLConfig:
         elif ca_bundle_path.is_dir():
             context.load_verify_locations(capath=str(ca_bundle_path))
 
-        if self.cert is not None:
-            if isinstance(self.cert, str):
-                context.load_cert_chain(certfile=self.cert)
-            elif isinstance(self.cert, tuple) and len(self.cert) == 2:
-                context.load_cert_chain(certfile=self.cert[0], keyfile=self.cert[1])
-            elif isinstance(self.cert, tuple) and len(self.cert) == 3:
-                context.load_cert_chain(
-                    certfile=self.cert[0],
-                    keyfile=self.cert[1],
-                    password=self.cert[2],  # type: ignore
-                )
+        self._load_client_certs(context)
 
         return context
 
@@ -155,6 +153,22 @@ class SSLConfig:
 
         return context
 
+    def _load_client_certs(self, ssl_context: ssl.SSLContext) -> None:
+        """
+        Loads client certificates into our SSLContext object
+        """
+        if self.cert is not None:
+            if isinstance(self.cert, str):
+                ssl_context.load_cert_chain(certfile=self.cert)
+            elif isinstance(self.cert, tuple) and len(self.cert) == 2:
+                ssl_context.load_cert_chain(certfile=self.cert[0], keyfile=self.cert[1])
+            elif isinstance(self.cert, tuple) and len(self.cert) == 3:
+                ssl_context.load_cert_chain(
+                    certfile=self.cert[0],
+                    keyfile=self.cert[1],
+                    password=self.cert[2],  # type: ignore
+                )
+
 
 class TimeoutConfig:
     """
index b587e931a0c1cfb923b5bcc4efe190a1b177f2cd..20befa7c7ba4e682cedbcfdb200b9ce231bcf6a1 100644 (file)
@@ -76,6 +76,15 @@ def test_load_ssl_config_no_verify():
     assert context.check_hostname is False
 
 
+def test_load_ssl_context():
+    ssl_context = ssl.create_default_context()
+    ssl_config = httpx.SSLConfig(verify=ssl_context)
+
+    assert ssl_config.verify is True
+    assert ssl_config.ssl_context is ssl_context
+    assert repr(ssl_config) == "SSLConfig(cert=None, verify=True)"
+
+
 def test_ssl_repr():
     ssl = httpx.SSLConfig(verify=False)
     assert repr(ssl) == "SSLConfig(cert=None, verify=False)"