From ba83e29aee162c16b3ba8e4ee0629e88bf4ecaeb Mon Sep 17 00:00:00 2001 From: Seth Michael Larson Date: Tue, 16 Jul 2019 04:08:57 -0500 Subject: [PATCH] Support client cert passwords, new TLS options (#118) * Support client cert passwords, new TLS options * Update test_config.py * Switch to try-except for post_handshake_auth=True SSLContext.post_handshake_auth raises AttributeError if the property is available but cannot be written to (needs OpenSSL 1.1.1+) * Also try-except for hostname_checks_common_name=False * Custom implementation of trustme.CA() that emits encrypted PKs * lint * Split name of test * Updates from review comments * Don't load default CAs yet --- http3/config.py | 69 ++++++++++++++++++++++++++++++-------------- requirements.txt | 1 + tests/conftest.py | 55 ++++++++++++++++++++++++++++++----- tests/test_config.py | 44 +++++++++++++++++++++++----- 4 files changed, 132 insertions(+), 37 deletions(-) diff --git a/http3/config.py b/http3/config.py index 4b0fe9e2..8ef35708 100644 --- a/http3/config.py +++ b/http3/config.py @@ -7,7 +7,7 @@ import certifi from .__version__ import __version__ -CertTypes = typing.Union[str, typing.Tuple[str, str]] +CertTypes = typing.Union[str, typing.Tuple[str, str], typing.Tuple[str, str, str]] VerifyTypes = typing.Union[str, bool] TimeoutTypes = typing.Union[float, typing.Tuple[float, float, float], "TimeoutConfig"] @@ -43,6 +43,8 @@ class SSLConfig: self.cert = cert self.verify = verify + self.ssl_context: typing.Optional[ssl.SSLContext] = None + def __eq__(self, other: typing.Any) -> bool: return ( isinstance(other, self.__class__) @@ -64,7 +66,7 @@ class SSLConfig: return SSLConfig(cert=cert, verify=verify) async def load_ssl_context(self) -> ssl.SSLContext: - if not hasattr(self, "ssl_context"): + if self.ssl_context is None: if not self.verify: self.ssl_context = self.load_ssl_context_no_verify() else: @@ -80,11 +82,9 @@ class SSLConfig: """ Return an SSL context for unverified connections. """ - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - context.options |= ssl.OP_NO_SSLv2 - context.options |= ssl.OP_NO_SSLv3 - context.options |= ssl.OP_NO_COMPRESSION - context.set_default_verify_paths() + context = self._create_default_ssl_context() + context.verify_mode = ssl.CERT_NONE + context.check_hostname = False return context def load_ssl_context_verify(self) -> ssl.SSLContext: @@ -101,20 +101,23 @@ class SSLConfig: "invalid path: {}".format(self.verify) ) - context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH) - + context = self._create_default_ssl_context() context.verify_mode = ssl.CERT_REQUIRED - - context.options |= ssl.OP_NO_SSLv2 - context.options |= ssl.OP_NO_SSLv3 - context.options |= ssl.OP_NO_COMPRESSION - - context.set_ciphers(DEFAULT_CIPHERS) - - if ssl.HAS_ALPN: - context.set_alpn_protocols(["h2", "http/1.1"]) - if ssl.HAS_NPN: - context.set_npn_protocols(["h2", "http/1.1"]) + context.check_hostname = True + + # Signal to server support for PHA in TLS 1.3. Raises an + # AttributeError if only read-only access is implemented. + try: + context.post_handshake_auth = True + except AttributeError: + pass + + # Disable using 'commonName' for SSLContext.check_hostname + # when the 'subjectAltName' extension isn't available. + try: + context.hostname_checks_common_name = False + except AttributeError: + pass if os.path.isfile(ca_bundle_path): context.load_verify_locations(cafile=ca_bundle_path) @@ -124,8 +127,30 @@ class SSLConfig: if self.cert is not None: if isinstance(self.cert, str): context.load_cert_chain(certfile=self.cert) - else: + elif isinstance(self.cert, tuple) and len(self.cert) == 2: context.load_cert_chain(certfile=self.cert[0], keyfile=self.cert[1]) + else: + context.load_cert_chain( + certfile=self.cert[0], keyfile=self.cert[1], password=self.cert[2] + ) + + return context + + def _create_default_ssl_context(self) -> ssl.SSLContext: + """ + Creates the default SSLContext object that's used for both verified + and unverified connections. + """ + context = ssl.SSLContext(ssl.PROTOCOL_TLS) + context.options |= ssl.OP_NO_SSLv2 + context.options |= ssl.OP_NO_SSLv3 + context.options |= ssl.OP_NO_COMPRESSION + context.set_ciphers(DEFAULT_CIPHERS) + + if ssl.HAS_ALPN: + context.set_alpn_protocols(["h2", "http/1.1"]) + if ssl.HAS_NPN: + context.set_npn_protocols(["h2", "http/1.1"]) return context @@ -175,7 +200,7 @@ class TimeoutConfig: def __repr__(self) -> str: class_name = self.__class__.__name__ - if len(set([self.connect_timeout, self.read_timeout, self.write_timeout])) == 1: + if len({self.connect_timeout, self.read_timeout, self.write_timeout}) == 1: return f"{class_name}(timeout={self.connect_timeout})" return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, write_timeout={self.write_timeout})" diff --git a/requirements.txt b/requirements.txt index 4dc74df3..5e147046 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,6 +15,7 @@ mkdocs-material # Testing autoflake black +cryptography isort mypy pytest diff --git a/tests/conftest.py b/tests/conftest.py index 92b05832..f8a6d0dc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,11 @@ import asyncio import pytest import trustme +from cryptography.hazmat.primitives.serialization import ( + BestAvailableEncryption, + Encoding, + PrivateFormat, +) from uvicorn.config import Config from uvicorn.main import Server @@ -72,12 +77,45 @@ async def echo_body(scope, receive, send): await send({"type": "http.response.body", "body": body}) +class CAWithPKEncryption(trustme.CA): + """Implementation of trustme.CA() that can emit + private keys that are encrypted with a password. + """ + + @property + def encrypted_private_key_pem(self): + return trustme.Blob( + self._private_key.private_bytes( + Encoding.PEM, + PrivateFormat.TraditionalOpenSSL, + BestAvailableEncryption(password=b"password"), + ) + ) + + @pytest.fixture -def cert_and_key_paths(): - ca = trustme.CA() +def example_cert(): + ca = CAWithPKEncryption() ca.issue_cert("example.org") - with ca.cert_pem.tempfile() as cert_temp_path, ca.private_key_pem.tempfile() as key_temp_path: - yield cert_temp_path, key_temp_path + return ca + + +@pytest.fixture +def cert_pem_file(example_cert): + with example_cert.cert_pem.tempfile() as tmp: + yield tmp + + +@pytest.fixture +def cert_private_key_file(example_cert): + with example_cert.private_key_pem.tempfile() as tmp: + yield tmp + + +@pytest.fixture +def cert_encrypted_private_key_file(example_cert): + with example_cert.encrypted_private_key_pem.tempfile() as tmp: + yield tmp @pytest.fixture @@ -95,10 +133,13 @@ async def server(): @pytest.fixture -async def https_server(cert_and_key_paths): - cert_path, key_path = cert_and_key_paths +async def https_server(cert_pem_file, cert_private_key_file): config = Config( - app=app, lifespan="off", ssl_certfile=cert_path, ssl_keyfile=key_path, port=8001 + app=app, + lifespan="off", + ssl_certfile=cert_pem_file, + ssl_keyfile=cert_private_key_file, + port=8001, ) server = Server(config=config) task = asyncio.ensure_future(server.serve()) diff --git a/tests/test_config.py b/tests/test_config.py index a77f2189..7a68b899 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -11,6 +11,7 @@ async def test_load_ssl_config(): ssl_config = http3.SSLConfig() context = await ssl_config.load_ssl_context() assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED + assert context.check_hostname is True @pytest.mark.asyncio @@ -25,6 +26,7 @@ async def test_load_ssl_config_verify_existing_file(): ssl_config = http3.SSLConfig(verify=http3.config.DEFAULT_CA_BUNDLE_PATH) context = await ssl_config.load_ssl_context() assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED + assert context.check_hostname is True @pytest.mark.asyncio @@ -33,29 +35,55 @@ async def test_load_ssl_config_verify_directory(): ssl_config = http3.SSLConfig(verify=path) context = await ssl_config.load_ssl_context() assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED + assert context.check_hostname is True @pytest.mark.asyncio -async def test_load_ssl_config_cert_and_key(cert_and_key_paths): - cert_path, key_path = cert_and_key_paths - ssl_config = http3.SSLConfig(cert=(cert_path, key_path)) +async def test_load_ssl_config_cert_and_key(cert_pem_file, cert_private_key_file): + ssl_config = http3.SSLConfig(cert=(cert_pem_file, cert_private_key_file)) context = await ssl_config.load_ssl_context() assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED + assert context.check_hostname is True @pytest.mark.asyncio -async def test_load_ssl_config_cert_without_key_raises(cert_and_key_paths): - cert_path, _ = cert_and_key_paths - ssl_config = http3.SSLConfig(cert=cert_path) +@pytest.mark.parametrize("password", [b"password", "password"]) +async def test_load_ssl_config_cert_and_encrypted_key( + cert_pem_file, cert_encrypted_private_key_file, password +): + ssl_config = http3.SSLConfig( + cert=(cert_pem_file, cert_encrypted_private_key_file, password) + ) + context = await ssl_config.load_ssl_context() + assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED + assert context.check_hostname is True + + +@pytest.mark.asyncio +async def test_load_ssl_config_cert_and_key_invalid_password( + cert_pem_file, cert_encrypted_private_key_file +): + ssl_config = http3.SSLConfig( + cert=(cert_pem_file, cert_encrypted_private_key_file, "password1") + ) + with pytest.raises(ssl.SSLError): await ssl_config.load_ssl_context() @pytest.mark.asyncio -async def test_load_ssl_config_no_verify(verify=False): +async def test_load_ssl_config_cert_without_key_raises(cert_pem_file): + ssl_config = http3.SSLConfig(cert=cert_pem_file) + with pytest.raises(ssl.SSLError): + await ssl_config.load_ssl_context() + + +@pytest.mark.asyncio +async def test_load_ssl_config_no_verify(): ssl_config = http3.SSLConfig(verify=False) context = await ssl_config.load_ssl_context() assert context.verify_mode == ssl.VerifyMode.CERT_NONE + assert context.check_hostname is False def test_ssl_repr(): @@ -102,5 +130,5 @@ def test_timeout_from_tuple(): def test_timeout_from_config_instance(): - timeout = http3.TimeoutConfig(timeout=(5.0)) + timeout = http3.TimeoutConfig(timeout=5.0) assert http3.TimeoutConfig(timeout) == http3.TimeoutConfig(timeout=5.0) -- 2.47.3