from .__version__ import __version__
-CertTypes = typing.Union[str, typing.Tuple[str, str]]
+CertTypes = typing.Union[str, typing.Tuple[str, str], typing.Tuple[str, str, str]]
VerifyTypes = typing.Union[str, bool]
TimeoutTypes = typing.Union[float, typing.Tuple[float, float, float], "TimeoutConfig"]
self.cert = cert
self.verify = verify
+ self.ssl_context: typing.Optional[ssl.SSLContext] = None
+
def __eq__(self, other: typing.Any) -> bool:
return (
isinstance(other, self.__class__)
return SSLConfig(cert=cert, verify=verify)
async def load_ssl_context(self) -> ssl.SSLContext:
- if not hasattr(self, "ssl_context"):
+ if self.ssl_context is None:
if not self.verify:
self.ssl_context = self.load_ssl_context_no_verify()
else:
"""
Return an SSL context for unverified connections.
"""
- context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
- context.options |= ssl.OP_NO_SSLv2
- context.options |= ssl.OP_NO_SSLv3
- context.options |= ssl.OP_NO_COMPRESSION
- context.set_default_verify_paths()
+ context = self._create_default_ssl_context()
+ context.verify_mode = ssl.CERT_NONE
+ context.check_hostname = False
return context
def load_ssl_context_verify(self) -> ssl.SSLContext:
"invalid path: {}".format(self.verify)
)
- context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
-
+ context = self._create_default_ssl_context()
context.verify_mode = ssl.CERT_REQUIRED
-
- context.options |= ssl.OP_NO_SSLv2
- context.options |= ssl.OP_NO_SSLv3
- context.options |= ssl.OP_NO_COMPRESSION
-
- context.set_ciphers(DEFAULT_CIPHERS)
-
- if ssl.HAS_ALPN:
- context.set_alpn_protocols(["h2", "http/1.1"])
- if ssl.HAS_NPN:
- context.set_npn_protocols(["h2", "http/1.1"])
+ context.check_hostname = True
+
+ # Signal to server support for PHA in TLS 1.3. Raises an
+ # AttributeError if only read-only access is implemented.
+ try:
+ context.post_handshake_auth = True
+ except AttributeError:
+ pass
+
+ # Disable using 'commonName' for SSLContext.check_hostname
+ # when the 'subjectAltName' extension isn't available.
+ try:
+ context.hostname_checks_common_name = False
+ except AttributeError:
+ pass
if os.path.isfile(ca_bundle_path):
context.load_verify_locations(cafile=ca_bundle_path)
if self.cert is not None:
if isinstance(self.cert, str):
context.load_cert_chain(certfile=self.cert)
- else:
+ elif isinstance(self.cert, tuple) and len(self.cert) == 2:
context.load_cert_chain(certfile=self.cert[0], keyfile=self.cert[1])
+ else:
+ context.load_cert_chain(
+ certfile=self.cert[0], keyfile=self.cert[1], password=self.cert[2]
+ )
+
+ return context
+
+ def _create_default_ssl_context(self) -> ssl.SSLContext:
+ """
+ Creates the default SSLContext object that's used for both verified
+ and unverified connections.
+ """
+ context = ssl.SSLContext(ssl.PROTOCOL_TLS)
+ context.options |= ssl.OP_NO_SSLv2
+ context.options |= ssl.OP_NO_SSLv3
+ context.options |= ssl.OP_NO_COMPRESSION
+ context.set_ciphers(DEFAULT_CIPHERS)
+
+ if ssl.HAS_ALPN:
+ context.set_alpn_protocols(["h2", "http/1.1"])
+ if ssl.HAS_NPN:
+ context.set_npn_protocols(["h2", "http/1.1"])
return context
def __repr__(self) -> str:
class_name = self.__class__.__name__
- if len(set([self.connect_timeout, self.read_timeout, self.write_timeout])) == 1:
+ if len({self.connect_timeout, self.read_timeout, self.write_timeout}) == 1:
return f"{class_name}(timeout={self.connect_timeout})"
return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, write_timeout={self.write_timeout})"
# Testing
autoflake
black
+cryptography
isort
mypy
pytest
import pytest
import trustme
+from cryptography.hazmat.primitives.serialization import (
+ BestAvailableEncryption,
+ Encoding,
+ PrivateFormat,
+)
from uvicorn.config import Config
from uvicorn.main import Server
await send({"type": "http.response.body", "body": body})
+class CAWithPKEncryption(trustme.CA):
+ """Implementation of trustme.CA() that can emit
+ private keys that are encrypted with a password.
+ """
+
+ @property
+ def encrypted_private_key_pem(self):
+ return trustme.Blob(
+ self._private_key.private_bytes(
+ Encoding.PEM,
+ PrivateFormat.TraditionalOpenSSL,
+ BestAvailableEncryption(password=b"password"),
+ )
+ )
+
+
@pytest.fixture
-def cert_and_key_paths():
- ca = trustme.CA()
+def example_cert():
+ ca = CAWithPKEncryption()
ca.issue_cert("example.org")
- with ca.cert_pem.tempfile() as cert_temp_path, ca.private_key_pem.tempfile() as key_temp_path:
- yield cert_temp_path, key_temp_path
+ return ca
+
+
+@pytest.fixture
+def cert_pem_file(example_cert):
+ with example_cert.cert_pem.tempfile() as tmp:
+ yield tmp
+
+
+@pytest.fixture
+def cert_private_key_file(example_cert):
+ with example_cert.private_key_pem.tempfile() as tmp:
+ yield tmp
+
+
+@pytest.fixture
+def cert_encrypted_private_key_file(example_cert):
+ with example_cert.encrypted_private_key_pem.tempfile() as tmp:
+ yield tmp
@pytest.fixture
@pytest.fixture
-async def https_server(cert_and_key_paths):
- cert_path, key_path = cert_and_key_paths
+async def https_server(cert_pem_file, cert_private_key_file):
config = Config(
- app=app, lifespan="off", ssl_certfile=cert_path, ssl_keyfile=key_path, port=8001
+ app=app,
+ lifespan="off",
+ ssl_certfile=cert_pem_file,
+ ssl_keyfile=cert_private_key_file,
+ port=8001,
)
server = Server(config=config)
task = asyncio.ensure_future(server.serve())
ssl_config = http3.SSLConfig()
context = await ssl_config.load_ssl_context()
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
+ assert context.check_hostname is True
@pytest.mark.asyncio
ssl_config = http3.SSLConfig(verify=http3.config.DEFAULT_CA_BUNDLE_PATH)
context = await ssl_config.load_ssl_context()
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
+ assert context.check_hostname is True
@pytest.mark.asyncio
ssl_config = http3.SSLConfig(verify=path)
context = await ssl_config.load_ssl_context()
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
+ assert context.check_hostname is True
@pytest.mark.asyncio
-async def test_load_ssl_config_cert_and_key(cert_and_key_paths):
- cert_path, key_path = cert_and_key_paths
- ssl_config = http3.SSLConfig(cert=(cert_path, key_path))
+async def test_load_ssl_config_cert_and_key(cert_pem_file, cert_private_key_file):
+ ssl_config = http3.SSLConfig(cert=(cert_pem_file, cert_private_key_file))
context = await ssl_config.load_ssl_context()
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
+ assert context.check_hostname is True
@pytest.mark.asyncio
-async def test_load_ssl_config_cert_without_key_raises(cert_and_key_paths):
- cert_path, _ = cert_and_key_paths
- ssl_config = http3.SSLConfig(cert=cert_path)
+@pytest.mark.parametrize("password", [b"password", "password"])
+async def test_load_ssl_config_cert_and_encrypted_key(
+ cert_pem_file, cert_encrypted_private_key_file, password
+):
+ ssl_config = http3.SSLConfig(
+ cert=(cert_pem_file, cert_encrypted_private_key_file, password)
+ )
+ context = await ssl_config.load_ssl_context()
+ assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
+ assert context.check_hostname is True
+
+
+@pytest.mark.asyncio
+async def test_load_ssl_config_cert_and_key_invalid_password(
+ cert_pem_file, cert_encrypted_private_key_file
+):
+ ssl_config = http3.SSLConfig(
+ cert=(cert_pem_file, cert_encrypted_private_key_file, "password1")
+ )
+
with pytest.raises(ssl.SSLError):
await ssl_config.load_ssl_context()
@pytest.mark.asyncio
-async def test_load_ssl_config_no_verify(verify=False):
+async def test_load_ssl_config_cert_without_key_raises(cert_pem_file):
+ ssl_config = http3.SSLConfig(cert=cert_pem_file)
+ with pytest.raises(ssl.SSLError):
+ await ssl_config.load_ssl_context()
+
+
+@pytest.mark.asyncio
+async def test_load_ssl_config_no_verify():
ssl_config = http3.SSLConfig(verify=False)
context = await ssl_config.load_ssl_context()
assert context.verify_mode == ssl.VerifyMode.CERT_NONE
+ assert context.check_hostname is False
def test_ssl_repr():
def test_timeout_from_config_instance():
- timeout = http3.TimeoutConfig(timeout=(5.0))
+ timeout = http3.TimeoutConfig(timeout=5.0)
assert http3.TimeoutConfig(timeout) == http3.TimeoutConfig(timeout=5.0)