]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Included create_ssl_context function to create the same context with SSLConfig and...
authorCan Sarıgöl <ertugrulsarigol@gmail.com>
Mon, 27 Jul 2020 18:46:46 +0000 (20:46 +0200)
committerGitHub <noreply@github.com>
Mon, 27 Jul 2020 18:46:46 +0000 (19:46 +0100)
* Included create_ssl_context function to create the same context with SSLConfig and serve as API.

* Changed create_ssl_context with SSLConfig into the client implementation and tests.

* Dropped the __repr__ and __eq__ methods from SSLConfig and removed SSLConfig using from tests

* Fixed test issue regarding cert_authority trust of ssl context

Co-authored-by: Tom Christie <tom@tomchristie.com>
Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
docs/advanced.md
httpx/__init__.py
httpx/_client.py
httpx/_config.py
httpx/_transports/urllib3.py
tests/test_config.py

index 9d416e6f01d25b38e9787705a67e7ff80f97b622..a55aacd810b1ef7969f3e7aeedcbbaf39748b495 100644 (file)
@@ -600,6 +600,23 @@ Alternatively, you can pass a standard library `ssl.SSLContext`.
 <Response [200 OK]>
 ```
 
+We also include a helper function for creating properly configured `SSLContext` instances.
+
+```python
+>>> context = httpx.create_ssl_context()
+```
+
+The `create_ssl_context` function accepts the same set of SSL configuration arguments 
+(`trust_env`, `verify`, `cert` and `http2` arguments) 
+as `httpx.Client` or `httpx.AsyncClient`
+
+```python
+>>> import httpx
+>>> context = httpx.create_ssl_context(verify="/tmp/client.pem")
+>>> httpx.get('https://example.org', verify=context)
+<Response [200 OK]>
+```
+
 Or you can also disable the SSL verification entirely, which is _not_ recommended.
 
 ```python
index 644656fd19bfb839675e46e0af80b2f60d3d1114..8aca4f7ceeba1e23030d7dd959a52fc0ef1bcf54 100644 (file)
@@ -2,7 +2,7 @@ from .__version__ import __description__, __title__, __version__
 from ._api import delete, get, head, options, patch, post, put, request, stream
 from ._auth import Auth, BasicAuth, DigestAuth
 from ._client import AsyncClient, Client
-from ._config import PoolLimits, Proxy, Timeout
+from ._config import PoolLimits, Proxy, Timeout, create_ssl_context
 from ._exceptions import (
     CloseError,
     ConnectError,
@@ -61,6 +61,7 @@ __all__ = [
     "PoolLimits",
     "Proxy",
     "Timeout",
+    "create_ssl_context",
     "CloseError",
     "ConnectError",
     "ConnectTimeout",
index 6aa576fe8a9f6d735e643e07f81382a1c60f2565..518531986d4e49345d75babb4261e194d41893fc 100644 (file)
@@ -13,9 +13,9 @@ from ._config import (
     UNSET,
     PoolLimits,
     Proxy,
-    SSLConfig,
     Timeout,
     UnsetType,
+    create_ssl_context,
 )
 from ._content_streams import ContentStream
 from ._exceptions import (
@@ -499,9 +499,7 @@ class Client(BaseClient):
         if app is not None:
             return WSGITransport(app=app)
 
-        ssl_context = SSLConfig(
-            verify=verify, cert=cert, trust_env=trust_env
-        ).ssl_context
+        ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
 
         return httpcore.SyncConnectionPool(
             ssl_context=ssl_context,
@@ -520,9 +518,7 @@ class Client(BaseClient):
         pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
         trust_env: bool = True,
     ) -> httpcore.SyncHTTPTransport:
-        ssl_context = SSLConfig(
-            verify=verify, cert=cert, trust_env=trust_env
-        ).ssl_context
+        ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
 
         return httpcore.SyncHTTPProxy(
             proxy_url=proxy.url.raw,
@@ -1032,9 +1028,7 @@ class AsyncClient(BaseClient):
         if app is not None:
             return ASGITransport(app=app)
 
-        ssl_context = SSLConfig(
-            verify=verify, cert=cert, trust_env=trust_env
-        ).ssl_context
+        ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
 
         return httpcore.AsyncConnectionPool(
             ssl_context=ssl_context,
@@ -1053,9 +1047,7 @@ class AsyncClient(BaseClient):
         pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
         trust_env: bool = True,
     ) -> httpcore.AsyncHTTPTransport:
-        ssl_context = SSLConfig(
-            verify=verify, cert=cert, trust_env=trust_env
-        ).ssl_context
+        ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
 
         return httpcore.AsyncHTTPProxy(
             proxy_url=proxy.url.raw,
index 9b2c8cba76d342ed0b7a644b3513ad00dd919b4e..80dda876d512bc8ece67e1e2bc6a7d34c264c173 100644 (file)
@@ -40,6 +40,17 @@ class UnsetType:
 UNSET = UnsetType()
 
 
+def create_ssl_context(
+    cert: CertTypes = None,
+    verify: VerifyTypes = True,
+    trust_env: bool = None,
+    http2: bool = False,
+) -> ssl.SSLContext:
+    return SSLConfig(
+        cert=cert, verify=verify, trust_env=trust_env, http2=http2
+    ).ssl_context
+
+
 class SSLConfig:
     """
     SSL Configuration.
@@ -61,17 +72,6 @@ class SSLConfig:
         self.http2 = http2
         self.ssl_context = self.load_ssl_context()
 
-    def __eq__(self, other: typing.Any) -> bool:
-        return (
-            isinstance(other, self.__class__)
-            and self.cert == other.cert
-            and self.verify == other.verify
-        )
-
-    def __repr__(self) -> str:
-        class_name = self.__class__.__name__
-        return f"{class_name}(cert={self.cert}, verify={self.verify})"
-
     def load_ssl_context(self) -> ssl.SSLContext:
         logger.trace(
             f"load_ssl_context "
index fab0a122b7feb0cf5ca87d78883aae3685f92520..7f076d2b38bfce04647fb4b7c4485215742c2a9a 100644 (file)
@@ -3,7 +3,7 @@ from typing import Dict, Iterator, List, Optional, Tuple
 
 import httpcore
 
-from .._config import SSLConfig
+from .._config import create_ssl_context
 from .._content_streams import ByteStream, IteratorStream
 from .._exceptions import NetworkError, map_exceptions
 from .._types import CertTypes, VerifyTypes
@@ -30,12 +30,10 @@ class URLLib3Transport(httpcore.SyncHTTPTransport):
             urllib3 is not None
         ), "urllib3 must be installed in order to use URLLib3Transport"
 
-        ssl_config = SSLConfig(
-            verify=verify, cert=cert, trust_env=trust_env, http2=False
-        )
-
         self.pool = urllib3.PoolManager(
-            ssl_context=ssl_config.ssl_context,
+            ssl_context=create_ssl_context(
+                verify=verify, cert=cert, trust_env=trust_env, http2=False
+            ),
             num_pools=pool_connections,
             maxsize=pool_maxsize,
             block=pool_block,
@@ -139,14 +137,12 @@ class URLLib3ProxyTransport(URLLib3Transport):
             urllib3 is not None
         ), "urllib3 must be installed in order to use URLLib3ProxyTransport"
 
-        ssl_config = SSLConfig(
-            verify=verify, cert=cert, trust_env=trust_env, http2=False
-        )
-
         self.pool = urllib3.ProxyManager(
             proxy_url=proxy_url,
             proxy_headers=proxy_headers,
-            ssl_context=ssl_config.ssl_context,
+            ssl_context=create_ssl_context(
+                verify=verify, cert=cert, trust_env=trust_env, http2=False
+            ),
             num_pools=pool_connections,
             maxsize=pool_maxsize,
             block=pool_block,
index 0c2c2f73ca28f868c650c7b0874e78a897165e7e..d9e68db37ae7a5988469f2bcd28e75ea85b6142c 100644 (file)
@@ -1,5 +1,4 @@
 import os
-import socket
 import ssl
 import sys
 from pathlib import Path
@@ -8,63 +7,51 @@ import certifi
 import pytest
 
 import httpx
-from httpx._config import SSLConfig
 
 
 def test_load_ssl_config():
-    ssl_config = SSLConfig()
-    context = ssl_config.ssl_context
+    context = httpx.create_ssl_context()
     assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
     assert context.check_hostname is True
 
 
 def test_load_ssl_config_verify_non_existing_path():
     with pytest.raises(IOError):
-        SSLConfig(verify="/path/to/nowhere")
+        httpx.create_ssl_context(verify="/path/to/nowhere")
 
 
 def test_load_ssl_config_verify_existing_file():
-    ssl_config = SSLConfig(verify=certifi.where())
-    context = ssl_config.ssl_context
+    context = httpx.create_ssl_context(verify=certifi.where())
     assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
     assert context.check_hostname is True
 
 
 @pytest.mark.parametrize("config", ("SSL_CERT_FILE", "SSL_CERT_DIR"))
-def test_load_ssl_config_verify_env_file(https_server, ca_cert_pem_file, config):
+def test_load_ssl_config_verify_env_file(
+    https_server, ca_cert_pem_file, config, cert_authority
+):
     os.environ[config] = (
         ca_cert_pem_file
         if config.endswith("_FILE")
         else str(Path(ca_cert_pem_file).parent)
     )
-    ssl_config = SSLConfig(trust_env=True)
-    context = ssl_config.ssl_context
+    context = httpx.create_ssl_context(trust_env=True)
+    cert_authority.configure_trust(context)
+
     assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
     assert context.check_hostname is True
-    assert ssl_config.verify == os.environ[config]
-
-    # Skipping 'SSL_CERT_DIR' functional test for now because
-    # we're unable to get the certificate within the directory to
-    # load into the SSLContext. :(
-    if config == "SSL_CERT_FILE":
-        host = https_server.url.host
-        port = https_server.url.port
-        conn = socket.create_connection((host, port))
-        context.wrap_socket(conn, server_hostname=host)
-        assert len(context.get_ca_certs()) == 1
+    assert len(context.get_ca_certs()) == 1
 
 
 def test_load_ssl_config_verify_directory():
     path = Path(certifi.where()).parent
-    ssl_config = SSLConfig(verify=str(path))
-    context = ssl_config.ssl_context
+    context = httpx.create_ssl_context(verify=str(path))
     assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
     assert context.check_hostname is True
 
 
 def test_load_ssl_config_cert_and_key(cert_pem_file, cert_private_key_file):
-    ssl_config = SSLConfig(cert=(cert_pem_file, cert_private_key_file))
-    context = ssl_config.ssl_context
+    context = httpx.create_ssl_context(cert=(cert_pem_file, cert_private_key_file))
     assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
     assert context.check_hostname is True
 
@@ -73,10 +60,9 @@ def test_load_ssl_config_cert_and_key(cert_pem_file, cert_private_key_file):
 def test_load_ssl_config_cert_and_encrypted_key(
     cert_pem_file, cert_encrypted_private_key_file, password
 ):
-    ssl_config = SSLConfig(
+    context = httpx.create_ssl_context(
         cert=(cert_pem_file, cert_encrypted_private_key_file, password)
     )
-    context = ssl_config.ssl_context
     assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
     assert context.check_hostname is True
 
@@ -85,36 +71,33 @@ def test_load_ssl_config_cert_and_key_invalid_password(
     cert_pem_file, cert_encrypted_private_key_file
 ):
     with pytest.raises(ssl.SSLError):
-        SSLConfig(cert=(cert_pem_file, cert_encrypted_private_key_file, "password1"))
+        httpx.create_ssl_context(
+            cert=(cert_pem_file, cert_encrypted_private_key_file, "password1")
+        )
 
 
 def test_load_ssl_config_cert_without_key_raises(cert_pem_file):
     with pytest.raises(ssl.SSLError):
-        SSLConfig(cert=cert_pem_file)
+        httpx.create_ssl_context(cert=cert_pem_file)
 
 
 def test_load_ssl_config_no_verify():
-    ssl_config = SSLConfig(verify=False)
-    context = ssl_config.ssl_context
+    context = httpx.create_ssl_context(verify=False)
     assert context.verify_mode == ssl.VerifyMode.CERT_NONE
     assert context.check_hostname is False
 
 
 def test_load_ssl_context():
     ssl_context = ssl.create_default_context()
-    ssl_config = SSLConfig(verify=ssl_context)
-
-    assert ssl_config.ssl_context is ssl_context
-
+    context = httpx.create_ssl_context(verify=ssl_context)
 
-def test_ssl_repr():
-    ssl = SSLConfig(verify=False)
-    assert repr(ssl) == "SSLConfig(cert=None, verify=False)"
+    assert context is ssl_context
 
 
-def test_ssl_eq():
-    ssl = SSLConfig(verify=False)
-    assert ssl == SSLConfig(verify=False)
+def test_create_ssl_context_with_get_request(server, cert_pem_file):
+    context = httpx.create_ssl_context(verify=cert_pem_file)
+    response = httpx.get(server.url, verify=context)
+    assert response.status_code == 200
 
 
 def test_limits_repr():
@@ -190,22 +173,22 @@ def test_ssl_config_support_for_keylog_file(tmpdir, monkeypatch):  # pragma: noc
     with monkeypatch.context() as m:
         m.delenv("SSLKEYLOGFILE", raising=False)
 
-        ssl_config = SSLConfig(trust_env=True)
+        context = httpx.create_ssl_context(trust_env=True)
 
-        assert ssl_config.ssl_context.keylog_filename is None  # type: ignore
+        assert context.keylog_filename is None  # type: ignore
 
     filename = str(tmpdir.join("test.log"))
 
     with monkeypatch.context() as m:
         m.setenv("SSLKEYLOGFILE", filename)
 
-        ssl_config = SSLConfig(trust_env=True)
+        context = httpx.create_ssl_context(trust_env=True)
 
-        assert ssl_config.ssl_context.keylog_filename == filename  # type: ignore
+        assert context.keylog_filename == filename  # type: ignore
 
-        ssl_config = SSLConfig(trust_env=False)
+        context = httpx.create_ssl_context(trust_env=False)
 
-        assert ssl_config.ssl_context.keylog_filename is None  # type: ignore
+        assert context.keylog_filename is None  # type: ignore
 
 
 @pytest.mark.parametrize(