]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Load SSL Context on init (#709)
authorTom Christie <tom@tomchristie.com>
Thu, 2 Jan 2020 16:52:23 +0000 (16:52 +0000)
committerGitHub <noreply@github.com>
Thu, 2 Jan 2020 16:52:23 +0000 (16:52 +0000)
httpx/config.py
httpx/dispatch/connection.py
tests/test_config.py

index 38245a3e7fb9d86978077fe63ce21a23208c53dd..5d1825b6b004698bddb88d8c30ad1df7e654a8c7 100644 (file)
@@ -63,19 +63,10 @@ class SSLConfig:
         http2: bool = False,
     ):
         self.cert = cert
-
-        # Allow passing in our own SSLContext object that's pre-configured.
-        # If you do this we assume that you want verify=True as well.
-        ssl_context = None
-        if isinstance(verify, ssl.SSLContext):
-            ssl_context = verify
-            verify = True
-            self._load_client_certs(ssl_context)
-
-        self.ssl_context: typing.Optional[ssl.SSLContext] = ssl_context
-        self.verify: typing.Union[str, bool] = verify
+        self.verify = verify
         self.trust_env = trust_env
         self.http2 = http2
+        self.ssl_context = self.load_ssl_context()
 
     def __eq__(self, other: typing.Any) -> bool:
         return (
@@ -97,15 +88,9 @@ class SSLConfig:
             f"http2={self.http2!r}"
         )
 
-        if self.ssl_context is None:
-            self.ssl_context = (
-                self.load_ssl_context_verify()
-                if self.verify
-                else self.load_ssl_context_no_verify()
-            )
-
-        assert self.ssl_context is not None
-        return self.ssl_context
+        if self.verify:
+            return self.load_ssl_context_verify()
+        return self.load_ssl_context_no_verify()
 
     def load_ssl_context_no_verify(self) -> ssl.SSLContext:
         """
@@ -125,7 +110,12 @@ class SSLConfig:
             if ca_bundle is not None:
                 self.verify = ca_bundle  # type: ignore
 
-        if isinstance(self.verify, bool):
+        if isinstance(self.verify, ssl.SSLContext):
+            # Allow passing in our own SSLContext object that's pre-configured.
+            context = self.verify
+            self._load_client_certs(context)
+            return context
+        elif isinstance(self.verify, bool):
             ca_bundle_path = self.DEFAULT_CA_BUNDLE_PATH
         elif Path(self.verify).exists():
             ca_bundle_path = Path(self.verify)
index 358bdcf7a7af3ce68fd36580796e6ac8665a5d28..d9e104b0c88018bda745a36709cbd56abfcac3bf 100644 (file)
@@ -1,5 +1,4 @@
 import functools
-import ssl
 import typing
 
 import h11
@@ -49,7 +48,7 @@ class HTTPConnection(Dispatcher):
     ) -> typing.Union[HTTP11Connection, HTTP2Connection]:
         host = self.origin.host
         port = self.origin.port
-        ssl_context = self.get_ssl_context()
+        ssl_context = None if not self.origin.is_ssl else self.ssl.ssl_context
 
         if self.release_func is None:
             on_release = None
@@ -104,8 +103,7 @@ class HTTPConnection(Dispatcher):
         if origin.is_ssl:
             # Pull the socket stream off the internal HTTP connection object,
             # and run start_tls().
-            ssl_context = self.get_ssl_context()
-            assert ssl_context is not None
+            ssl_context = self.ssl.ssl_context
 
             logger.trace(f"tunnel_start_tls proxy_url={proxy_url!r} origin={origin!r}")
             socket = await socket.start_tls(
@@ -130,11 +128,6 @@ class HTTPConnection(Dispatcher):
         else:
             self.connection = HTTP11Connection(socket, on_release=on_release)
 
-    def get_ssl_context(self) -> typing.Optional[ssl.SSLContext]:
-        if not self.origin.is_ssl:
-            return None
-        return self.ssl.load_ssl_context()
-
     async def close(self) -> None:
         logger.trace("close_connection")
         if self.connection is not None:
index c2701c085ddabc9801b6487f399e9b166bb18746..1aee9ed4d8cb261cb760aa4aa28cc008670eac02 100644 (file)
@@ -13,20 +13,19 @@ from httpx.config import SSLConfig
 
 def test_load_ssl_config():
     ssl_config = SSLConfig()
-    context = ssl_config.load_ssl_context()
+    context = ssl_config.ssl_context
     assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
     assert context.check_hostname is True
 
 
 def test_load_ssl_config_verify_non_existing_path():
-    ssl_config = SSLConfig(verify="/path/to/nowhere")
     with pytest.raises(IOError):
-        ssl_config.load_ssl_context()
+        SSLConfig(verify="/path/to/nowhere")
 
 
 def test_load_ssl_config_verify_existing_file():
     ssl_config = SSLConfig(verify=certifi.where())
-    context = ssl_config.load_ssl_context()
+    context = ssl_config.ssl_context
     assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
     assert context.check_hostname is True
 
@@ -39,7 +38,7 @@ def test_load_ssl_config_verify_env_file(https_server, ca_cert_pem_file, config)
         else str(Path(ca_cert_pem_file).parent)
     )
     ssl_config = SSLConfig(trust_env=True)
-    context = ssl_config.load_ssl_context()
+    context = ssl_config.ssl_context
     assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
     assert context.check_hostname is True
     assert ssl_config.verify == os.environ[config]
@@ -58,14 +57,14 @@ def test_load_ssl_config_verify_env_file(https_server, ca_cert_pem_file, config)
 def test_load_ssl_config_verify_directory():
     path = Path(certifi.where()).parent
     ssl_config = SSLConfig(verify=path)
-    context = ssl_config.load_ssl_context()
+    context = ssl_config.ssl_context
     assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
     assert context.check_hostname is True
 
 
 def test_load_ssl_config_cert_and_key(cert_pem_file, cert_private_key_file):
     ssl_config = SSLConfig(cert=(cert_pem_file, cert_private_key_file))
-    context = ssl_config.load_ssl_context()
+    context = ssl_config.ssl_context
     assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
     assert context.check_hostname is True
 
@@ -77,7 +76,7 @@ def test_load_ssl_config_cert_and_encrypted_key(
     ssl_config = SSLConfig(
         cert=(cert_pem_file, cert_encrypted_private_key_file, password)
     )
-    context = ssl_config.load_ssl_context()
+    context = ssl_config.ssl_context
     assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
     assert context.check_hostname is True
 
@@ -85,23 +84,18 @@ def test_load_ssl_config_cert_and_encrypted_key(
 def test_load_ssl_config_cert_and_key_invalid_password(
     cert_pem_file, cert_encrypted_private_key_file
 ):
-    ssl_config = SSLConfig(
-        cert=(cert_pem_file, cert_encrypted_private_key_file, "password1")
-    )
-
     with pytest.raises(ssl.SSLError):
-        ssl_config.load_ssl_context()
+        SSLConfig(cert=(cert_pem_file, cert_encrypted_private_key_file, "password1"))
 
 
 def test_load_ssl_config_cert_without_key_raises(cert_pem_file):
-    ssl_config = SSLConfig(cert=cert_pem_file)
     with pytest.raises(ssl.SSLError):
-        ssl_config.load_ssl_context()
+        SSLConfig(cert=cert_pem_file)
 
 
 def test_load_ssl_config_no_verify():
     ssl_config = SSLConfig(verify=False)
-    context = ssl_config.load_ssl_context()
+    context = ssl_config.ssl_context
     assert context.verify_mode == ssl.VerifyMode.CERT_NONE
     assert context.check_hostname is False
 
@@ -110,9 +104,7 @@ def test_load_ssl_context():
     ssl_context = ssl.create_default_context()
     ssl_config = SSLConfig(verify=ssl_context)
 
-    assert ssl_config.verify is True
     assert ssl_config.ssl_context is ssl_context
-    assert repr(ssl_config) == "SSLConfig(cert=None, verify=True)"
 
 
 def test_ssl_repr():
@@ -199,7 +191,6 @@ def test_ssl_config_support_for_keylog_file(tmpdir, monkeypatch):  # pragma: noc
         m.delenv("SSLKEYLOGFILE", raising=False)
 
         ssl_config = SSLConfig(trust_env=True)
-        ssl_config.load_ssl_context()
 
         assert ssl_config.ssl_context.keylog_filename is None
 
@@ -209,11 +200,9 @@ def test_ssl_config_support_for_keylog_file(tmpdir, monkeypatch):  # pragma: noc
         m.setenv("SSLKEYLOGFILE", filename)
 
         ssl_config = SSLConfig(trust_env=True)
-        ssl_config.load_ssl_context()
 
         assert ssl_config.ssl_context.keylog_filename == filename
 
         ssl_config = SSLConfig(trust_env=False)
-        ssl_config.load_ssl_context()
 
         assert ssl_config.ssl_context.keylog_filename is None