]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Support client cert passwords, new TLS options (#118)
authorSeth Michael Larson <sethmichaellarson@gmail.com>
Tue, 16 Jul 2019 09:08:57 +0000 (04:08 -0500)
committerTom Christie <tom@tomchristie.com>
Tue, 16 Jul 2019 09:08:57 +0000 (10:08 +0100)
* Support client cert passwords, new TLS options

* Update test_config.py

* Switch to try-except for post_handshake_auth=True

SSLContext.post_handshake_auth raises AttributeError if the property is available but cannot be written to (needs OpenSSL 1.1.1+)

* Also try-except for hostname_checks_common_name=False

* Custom implementation of trustme.CA() that emits encrypted PKs

* lint

* Split name of test

* Updates from review comments

* Don't load default CAs yet

http3/config.py
requirements.txt
tests/conftest.py
tests/test_config.py

index 4b0fe9e28bc3c12d0c61ec6040fba87f549bc110..8ef3570861ad8158e8c76f02d42851e36d6d6181 100644 (file)
@@ -7,7 +7,7 @@ import certifi
 
 from .__version__ import __version__
 
-CertTypes = typing.Union[str, typing.Tuple[str, str]]
+CertTypes = typing.Union[str, typing.Tuple[str, str], typing.Tuple[str, str, str]]
 VerifyTypes = typing.Union[str, bool]
 TimeoutTypes = typing.Union[float, typing.Tuple[float, float, float], "TimeoutConfig"]
 
@@ -43,6 +43,8 @@ class SSLConfig:
         self.cert = cert
         self.verify = verify
 
+        self.ssl_context: typing.Optional[ssl.SSLContext] = None
+
     def __eq__(self, other: typing.Any) -> bool:
         return (
             isinstance(other, self.__class__)
@@ -64,7 +66,7 @@ class SSLConfig:
         return SSLConfig(cert=cert, verify=verify)
 
     async def load_ssl_context(self) -> ssl.SSLContext:
-        if not hasattr(self, "ssl_context"):
+        if self.ssl_context is None:
             if not self.verify:
                 self.ssl_context = self.load_ssl_context_no_verify()
             else:
@@ -80,11 +82,9 @@ class SSLConfig:
         """
         Return an SSL context for unverified connections.
         """
-        context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
-        context.options |= ssl.OP_NO_SSLv2
-        context.options |= ssl.OP_NO_SSLv3
-        context.options |= ssl.OP_NO_COMPRESSION
-        context.set_default_verify_paths()
+        context = self._create_default_ssl_context()
+        context.verify_mode = ssl.CERT_NONE
+        context.check_hostname = False
         return context
 
     def load_ssl_context_verify(self) -> ssl.SSLContext:
@@ -101,20 +101,23 @@ class SSLConfig:
                 "invalid path: {}".format(self.verify)
             )
 
-        context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
-
+        context = self._create_default_ssl_context()
         context.verify_mode = ssl.CERT_REQUIRED
-
-        context.options |= ssl.OP_NO_SSLv2
-        context.options |= ssl.OP_NO_SSLv3
-        context.options |= ssl.OP_NO_COMPRESSION
-
-        context.set_ciphers(DEFAULT_CIPHERS)
-
-        if ssl.HAS_ALPN:
-            context.set_alpn_protocols(["h2", "http/1.1"])
-        if ssl.HAS_NPN:
-            context.set_npn_protocols(["h2", "http/1.1"])
+        context.check_hostname = True
+
+        # Signal to server support for PHA in TLS 1.3. Raises an
+        # AttributeError if only read-only access is implemented.
+        try:
+            context.post_handshake_auth = True
+        except AttributeError:
+            pass
+
+        # Disable using 'commonName' for SSLContext.check_hostname
+        # when the 'subjectAltName' extension isn't available.
+        try:
+            context.hostname_checks_common_name = False
+        except AttributeError:
+            pass
 
         if os.path.isfile(ca_bundle_path):
             context.load_verify_locations(cafile=ca_bundle_path)
@@ -124,8 +127,30 @@ class SSLConfig:
         if self.cert is not None:
             if isinstance(self.cert, str):
                 context.load_cert_chain(certfile=self.cert)
-            else:
+            elif isinstance(self.cert, tuple) and len(self.cert) == 2:
                 context.load_cert_chain(certfile=self.cert[0], keyfile=self.cert[1])
+            else:
+                context.load_cert_chain(
+                    certfile=self.cert[0], keyfile=self.cert[1], password=self.cert[2]
+                )
+
+        return context
+
+    def _create_default_ssl_context(self) -> ssl.SSLContext:
+        """
+        Creates the default SSLContext object that's used for both verified
+        and unverified connections.
+        """
+        context = ssl.SSLContext(ssl.PROTOCOL_TLS)
+        context.options |= ssl.OP_NO_SSLv2
+        context.options |= ssl.OP_NO_SSLv3
+        context.options |= ssl.OP_NO_COMPRESSION
+        context.set_ciphers(DEFAULT_CIPHERS)
+
+        if ssl.HAS_ALPN:
+            context.set_alpn_protocols(["h2", "http/1.1"])
+        if ssl.HAS_NPN:
+            context.set_npn_protocols(["h2", "http/1.1"])
 
         return context
 
@@ -175,7 +200,7 @@ class TimeoutConfig:
 
     def __repr__(self) -> str:
         class_name = self.__class__.__name__
-        if len(set([self.connect_timeout, self.read_timeout, self.write_timeout])) == 1:
+        if len({self.connect_timeout, self.read_timeout, self.write_timeout}) == 1:
             return f"{class_name}(timeout={self.connect_timeout})"
         return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, write_timeout={self.write_timeout})"
 
index 4dc74df3962e5b93c8c687b95dd31a58efbe561f..5e147046947be0c6abd7d84da25a127fcc7a98ee 100644 (file)
@@ -15,6 +15,7 @@ mkdocs-material
 # Testing
 autoflake
 black
+cryptography
 isort
 mypy
 pytest
index 92b058329330f4face8c34eaa03778e8848ad56e..f8a6d0dc04f386558f73c785bcb27b5dbdb9746e 100644 (file)
@@ -2,6 +2,11 @@ import asyncio
 
 import pytest
 import trustme
+from cryptography.hazmat.primitives.serialization import (
+    BestAvailableEncryption,
+    Encoding,
+    PrivateFormat,
+)
 from uvicorn.config import Config
 from uvicorn.main import Server
 
@@ -72,12 +77,45 @@ async def echo_body(scope, receive, send):
     await send({"type": "http.response.body", "body": body})
 
 
+class CAWithPKEncryption(trustme.CA):
+    """Implementation of trustme.CA() that can emit
+    private keys that are encrypted with a password.
+    """
+
+    @property
+    def encrypted_private_key_pem(self):
+        return trustme.Blob(
+            self._private_key.private_bytes(
+                Encoding.PEM,
+                PrivateFormat.TraditionalOpenSSL,
+                BestAvailableEncryption(password=b"password"),
+            )
+        )
+
+
 @pytest.fixture
-def cert_and_key_paths():
-    ca = trustme.CA()
+def example_cert():
+    ca = CAWithPKEncryption()
     ca.issue_cert("example.org")
-    with ca.cert_pem.tempfile() as cert_temp_path, ca.private_key_pem.tempfile() as key_temp_path:
-        yield cert_temp_path, key_temp_path
+    return ca
+
+
+@pytest.fixture
+def cert_pem_file(example_cert):
+    with example_cert.cert_pem.tempfile() as tmp:
+        yield tmp
+
+
+@pytest.fixture
+def cert_private_key_file(example_cert):
+    with example_cert.private_key_pem.tempfile() as tmp:
+        yield tmp
+
+
+@pytest.fixture
+def cert_encrypted_private_key_file(example_cert):
+    with example_cert.encrypted_private_key_pem.tempfile() as tmp:
+        yield tmp
 
 
 @pytest.fixture
@@ -95,10 +133,13 @@ async def server():
 
 
 @pytest.fixture
-async def https_server(cert_and_key_paths):
-    cert_path, key_path = cert_and_key_paths
+async def https_server(cert_pem_file, cert_private_key_file):
     config = Config(
-        app=app, lifespan="off", ssl_certfile=cert_path, ssl_keyfile=key_path, port=8001
+        app=app,
+        lifespan="off",
+        ssl_certfile=cert_pem_file,
+        ssl_keyfile=cert_private_key_file,
+        port=8001,
     )
     server = Server(config=config)
     task = asyncio.ensure_future(server.serve())
index a77f2189d09ca1ca9ed9ba45216e8e7088b6ba1a..7a68b899fe7c6414bad3ca6065d501c53451aa08 100644 (file)
@@ -11,6 +11,7 @@ async def test_load_ssl_config():
     ssl_config = http3.SSLConfig()
     context = await ssl_config.load_ssl_context()
     assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
+    assert context.check_hostname is True
 
 
 @pytest.mark.asyncio
@@ -25,6 +26,7 @@ async def test_load_ssl_config_verify_existing_file():
     ssl_config = http3.SSLConfig(verify=http3.config.DEFAULT_CA_BUNDLE_PATH)
     context = await ssl_config.load_ssl_context()
     assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
+    assert context.check_hostname is True
 
 
 @pytest.mark.asyncio
@@ -33,29 +35,55 @@ async def test_load_ssl_config_verify_directory():
     ssl_config = http3.SSLConfig(verify=path)
     context = await ssl_config.load_ssl_context()
     assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
+    assert context.check_hostname is True
 
 
 @pytest.mark.asyncio
-async def test_load_ssl_config_cert_and_key(cert_and_key_paths):
-    cert_path, key_path = cert_and_key_paths
-    ssl_config = http3.SSLConfig(cert=(cert_path, key_path))
+async def test_load_ssl_config_cert_and_key(cert_pem_file, cert_private_key_file):
+    ssl_config = http3.SSLConfig(cert=(cert_pem_file, cert_private_key_file))
     context = await ssl_config.load_ssl_context()
     assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
+    assert context.check_hostname is True
 
 
 @pytest.mark.asyncio
-async def test_load_ssl_config_cert_without_key_raises(cert_and_key_paths):
-    cert_path, _ = cert_and_key_paths
-    ssl_config = http3.SSLConfig(cert=cert_path)
+@pytest.mark.parametrize("password", [b"password", "password"])
+async def test_load_ssl_config_cert_and_encrypted_key(
+    cert_pem_file, cert_encrypted_private_key_file, password
+):
+    ssl_config = http3.SSLConfig(
+        cert=(cert_pem_file, cert_encrypted_private_key_file, password)
+    )
+    context = await ssl_config.load_ssl_context()
+    assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
+    assert context.check_hostname is True
+
+
+@pytest.mark.asyncio
+async def test_load_ssl_config_cert_and_key_invalid_password(
+    cert_pem_file, cert_encrypted_private_key_file
+):
+    ssl_config = http3.SSLConfig(
+        cert=(cert_pem_file, cert_encrypted_private_key_file, "password1")
+    )
+
     with pytest.raises(ssl.SSLError):
         await ssl_config.load_ssl_context()
 
 
 @pytest.mark.asyncio
-async def test_load_ssl_config_no_verify(verify=False):
+async def test_load_ssl_config_cert_without_key_raises(cert_pem_file):
+    ssl_config = http3.SSLConfig(cert=cert_pem_file)
+    with pytest.raises(ssl.SSLError):
+        await ssl_config.load_ssl_context()
+
+
+@pytest.mark.asyncio
+async def test_load_ssl_config_no_verify():
     ssl_config = http3.SSLConfig(verify=False)
     context = await ssl_config.load_ssl_context()
     assert context.verify_mode == ssl.VerifyMode.CERT_NONE
+    assert context.check_hostname is False
 
 
 def test_ssl_repr():
@@ -102,5 +130,5 @@ def test_timeout_from_tuple():
 
 
 def test_timeout_from_config_instance():
-    timeout = http3.TimeoutConfig(timeout=(5.0))
+    timeout = http3.TimeoutConfig(timeout=5.0)
     assert http3.TimeoutConfig(timeout) == http3.TimeoutConfig(timeout=5.0)