]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Use a certificate issued from the trustme CA in tests (#357)
authorJamie Hewland <jhewland@gmail.com>
Wed, 18 Sep 2019 23:59:27 +0000 (01:59 +0200)
committerSeth Michael Larson <sethmichaellarson@gmail.com>
Wed, 18 Sep 2019 23:59:27 +0000 (18:59 -0500)
tests/conftest.py
tests/dispatch/test_connections.py

index c7ca24bdd6894b9f4d6906c824f7a786fe5e8a1c..6b68649d6959e22fb8aad70c421132461cfb00d6 100644 (file)
@@ -7,7 +7,9 @@ import typing
 
 import pytest
 import trustme
+from cryptography.hazmat.backends import default_backend
 from cryptography.hazmat.primitives.serialization import (
+    load_pem_private_key,
     BestAvailableEncryption,
     Encoding,
     PrivateFormat,
@@ -139,47 +141,51 @@ async def echo_headers(scope, receive, send):
     await send({"type": "http.response.body", "body": json.dumps(body).encode()})
 
 
-class CAWithPKEncryption(trustme.CA):
-    """Implementation of trustme.CA() that can emit
-    private keys that are encrypted with a password.
-    """
+SERVER_SCOPE = "session"
 
-    @property
-    def encrypted_private_key_pem(self):
-        return trustme.Blob(
-            self._private_key.private_bytes(
-                Encoding.PEM,
-                PrivateFormat.TraditionalOpenSSL,
-                BestAvailableEncryption(password=b"password"),
-            )
-        )
 
+@pytest.fixture(scope=SERVER_SCOPE)
+def cert_authority():
+    return trustme.CA()
 
-SERVER_SCOPE = "session"
+
+@pytest.fixture(scope=SERVER_SCOPE)
+def ca_cert_pem_file(cert_authority):
+    with cert_authority.cert_pem.tempfile() as tmp:
+        yield tmp
 
 
 @pytest.fixture(scope=SERVER_SCOPE)
-def example_cert():
-    ca = CAWithPKEncryption()
-    ca.issue_cert("example.org")
-    return ca
+def localhost_cert(cert_authority):
+    return cert_authority.issue_cert("localhost")
 
 
 @pytest.fixture(scope=SERVER_SCOPE)
-def cert_pem_file(example_cert):
-    with example_cert.cert_pem.tempfile() as tmp:
+def cert_pem_file(localhost_cert):
+    with localhost_cert.cert_chain_pems[0].tempfile() as tmp:
         yield tmp
 
 
 @pytest.fixture(scope=SERVER_SCOPE)
-def cert_private_key_file(example_cert):
-    with example_cert.private_key_pem.tempfile() as tmp:
+def cert_private_key_file(localhost_cert):
+    with localhost_cert.private_key_pem.tempfile() as tmp:
         yield tmp
 
 
 @pytest.fixture(scope=SERVER_SCOPE)
-def cert_encrypted_private_key_file(example_cert):
-    with example_cert.encrypted_private_key_pem.tempfile() as tmp:
+def cert_encrypted_private_key_file(localhost_cert):
+    # Deserialize the private key and then reserialize with a password
+    private_key = load_pem_private_key(
+        localhost_cert.private_key_pem.bytes(), password=None, backend=default_backend()
+    )
+    encrypted_private_key_pem = trustme.Blob(
+        private_key.private_bytes(
+            Encoding.PEM,
+            PrivateFormat.TraditionalOpenSSL,
+            BestAvailableEncryption(password=b"password"),
+        )
+    )
+    with encrypted_private_key_pem.tempfile() as tmp:
         yield tmp
 
 
@@ -270,6 +276,7 @@ def https_server(cert_pem_file, cert_private_key_file):
         lifespan="off",
         ssl_certfile=cert_pem_file,
         ssl_keyfile=cert_private_key_file,
+        host="localhost",
         port=8001,
         loop="asyncio",
     )
index 14386aaa7d1f74b5ad2c4717b611263350b57284..b4aea17b08183043440ea2894810aca7444bf772 100644 (file)
@@ -15,12 +15,12 @@ async def test_post(server, backend):
         assert response.status_code == 200
 
 
-async def test_https_get_with_ssl_defaults(https_server, backend):
+async def test_https_get_with_ssl_defaults(https_server, ca_cert_pem_file, backend):
     """
     An HTTPS request, with default SSL configuration set on the client.
     """
     async with HTTPConnection(
-        origin=https_server.url, verify=False, backend=backend
+        origin=https_server.url, verify=ca_cert_pem_file, backend=backend
     ) as conn:
         response = await conn.request("GET", https_server.url)
         await response.read()
@@ -28,12 +28,12 @@ async def test_https_get_with_ssl_defaults(https_server, backend):
         assert response.content == b"Hello, world!"
 
 
-async def test_https_get_with_sll_overrides(https_server, backend):
+async def test_https_get_with_sll_overrides(https_server, ca_cert_pem_file, backend):
     """
     An HTTPS request, with SSL configuration set on the request.
     """
     async with HTTPConnection(origin=https_server.url, backend=backend) as conn:
-        response = await conn.request("GET", https_server.url, verify=False)
+        response = await conn.request("GET", https_server.url, verify=ca_cert_pem_file)
         await response.read()
         assert response.status_code == 200
         assert response.content == b"Hello, world!"