]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
SSLConfig refactor (#706)
authorTom Christie <tom@tomchristie.com>
Thu, 2 Jan 2020 10:54:04 +0000 (10:54 +0000)
committerGitHub <noreply@github.com>
Thu, 2 Jan 2020 10:54:04 +0000 (10:54 +0000)
* SSLConfig includes 'http2' argument on init.

* Pass SSL config to HTTPConnection as a single argument

* Don't run SSL context loading in threadpool

httpx/config.py
httpx/dispatch/connection.py
httpx/dispatch/connection_pool.py
httpx/dispatch/proxy_http.py
tests/client/test_proxies.py
tests/dispatch/test_connections.py

index 088642606799fc52aef33296f3db2ce02cf2e4e7..38245a3e7fb9d86978077fe63ce21a23208c53dd 100644 (file)
@@ -60,6 +60,7 @@ class SSLConfig:
         cert: CertTypes = None,
         verify: VerifyTypes = True,
         trust_env: bool = None,
+        http2: bool = False,
     ):
         self.cert = cert
 
@@ -74,6 +75,7 @@ class SSLConfig:
         self.ssl_context: typing.Optional[ssl.SSLContext] = ssl_context
         self.verify: typing.Union[str, bool] = verify
         self.trust_env = trust_env
+        self.http2 = http2
 
     def __eq__(self, other: typing.Any) -> bool:
         return (
@@ -86,35 +88,35 @@ class SSLConfig:
         class_name = self.__class__.__name__
         return f"{class_name}(cert={self.cert}, verify={self.verify})"
 
-    def load_ssl_context(self, http2: bool = False) -> ssl.SSLContext:
+    def load_ssl_context(self) -> ssl.SSLContext:
         logger.trace(
             f"load_ssl_context "
             f"verify={self.verify!r} "
             f"cert={self.cert!r} "
             f"trust_env={self.trust_env!r} "
-            f"http2={http2!r}"
+            f"http2={self.http2!r}"
         )
 
         if self.ssl_context is None:
             self.ssl_context = (
-                self.load_ssl_context_verify(http2=http2)
+                self.load_ssl_context_verify()
                 if self.verify
-                else self.load_ssl_context_no_verify(http2=http2)
+                else self.load_ssl_context_no_verify()
             )
 
         assert self.ssl_context is not None
         return self.ssl_context
 
-    def load_ssl_context_no_verify(self, http2: bool = False) -> ssl.SSLContext:
+    def load_ssl_context_no_verify(self) -> ssl.SSLContext:
         """
         Return an SSL context for unverified connections.
         """
-        context = self._create_default_ssl_context(http2=http2)
+        context = self._create_default_ssl_context()
         context.verify_mode = ssl.CERT_NONE
         context.check_hostname = False
         return context
 
-    def load_ssl_context_verify(self, http2: bool = False) -> ssl.SSLContext:
+    def load_ssl_context_verify(self) -> ssl.SSLContext:
         """
         Return an SSL context for verified connections.
         """
@@ -133,7 +135,7 @@ class SSLConfig:
                 "invalid path: {}".format(self.verify)
             )
 
-        context = self._create_default_ssl_context(http2=http2)
+        context = self._create_default_ssl_context()
         context.verify_mode = ssl.CERT_REQUIRED
         context.check_hostname = True
 
@@ -162,7 +164,7 @@ class SSLConfig:
 
         return context
 
-    def _create_default_ssl_context(self, http2: bool) -> ssl.SSLContext:
+    def _create_default_ssl_context(self) -> ssl.SSLContext:
         """
         Creates the default SSLContext object that's used for both verified
         and unverified connections.
@@ -176,7 +178,7 @@ class SSLConfig:
         context.set_ciphers(DEFAULT_CIPHERS)
 
         if ssl.HAS_ALPN:
-            alpn_idents = ["http/1.1", "h2"] if http2 else ["http/1.1"]
+            alpn_idents = ["http/1.1", "h2"] if self.http2 else ["http/1.1"]
             context.set_alpn_protocols(alpn_idents)
 
         if hasattr(context, "keylog_filename"):  # pragma: nocover (Available in 3.8+)
index 43dbd963cfc5a4045710c7082c88dcc125671537..358bdcf7a7af3ce68fd36580796e6ac8665a5d28 100644 (file)
@@ -5,7 +5,7 @@ import typing
 import h11
 
 from ..backends.base import ConcurrencyBackend, lookup_backend
-from ..config import CertTypes, SSLConfig, Timeout, VerifyTypes
+from ..config import SSLConfig, Timeout
 from ..models import URL, Origin, Request, Response
 from ..utils import get_logger
 from .base import Dispatcher
@@ -23,17 +23,13 @@ class HTTPConnection(Dispatcher):
     def __init__(
         self,
         origin: typing.Union[str, Origin],
-        verify: VerifyTypes = True,
-        cert: CertTypes = None,
-        trust_env: bool = None,
-        http2: bool = False,
+        ssl: SSLConfig = None,
         backend: typing.Union[str, ConcurrencyBackend] = "auto",
         release_func: typing.Optional[ReleaseCallback] = None,
         uds: typing.Optional[str] = None,
     ):
         self.origin = Origin(origin) if isinstance(origin, str) else origin
-        self.ssl = SSLConfig(cert=cert, verify=verify, trust_env=trust_env)
-        self.http2 = http2
+        self.ssl = SSLConfig() if ssl is None else ssl
         self.backend = lookup_backend(backend)
         self.release_func = release_func
         self.uds = uds
@@ -53,7 +49,7 @@ class HTTPConnection(Dispatcher):
     ) -> typing.Union[HTTP11Connection, HTTP2Connection]:
         host = self.origin.host
         port = self.origin.port
-        ssl_context = await self.get_ssl_context(self.ssl)
+        ssl_context = self.get_ssl_context()
 
         if self.release_func is None:
             on_release = None
@@ -108,7 +104,7 @@ class HTTPConnection(Dispatcher):
         if origin.is_ssl:
             # Pull the socket stream off the internal HTTP connection object,
             # and run start_tls().
-            ssl_context = await self.get_ssl_context(self.ssl)
+            ssl_context = self.get_ssl_context()
             assert ssl_context is not None
 
             logger.trace(f"tunnel_start_tls proxy_url={proxy_url!r} origin={origin!r}")
@@ -134,12 +130,10 @@ class HTTPConnection(Dispatcher):
         else:
             self.connection = HTTP11Connection(socket, on_release=on_release)
 
-    async def get_ssl_context(self, ssl: SSLConfig) -> typing.Optional[ssl.SSLContext]:
+    def get_ssl_context(self) -> typing.Optional[ssl.SSLContext]:
         if not self.origin.is_ssl:
             return None
-
-        # Run the SSL loading in a threadpool, since it may make disk accesses.
-        return await self.backend.run_in_threadpool(ssl.load_ssl_context, self.http2)
+        return self.ssl.load_ssl_context()
 
     async def close(self) -> None:
         logger.trace("close_connection")
index bfe147de28f702f4e0623560643900f7035e31c1..545e20452147a0a0fd2a494cfa3ab63ecfa064a3 100644 (file)
@@ -1,7 +1,14 @@
 import typing
 
 from ..backends.base import BaseSemaphore, ConcurrencyBackend, lookup_backend
-from ..config import DEFAULT_POOL_LIMITS, CertTypes, PoolLimits, Timeout, VerifyTypes
+from ..config import (
+    DEFAULT_POOL_LIMITS,
+    CertTypes,
+    PoolLimits,
+    SSLConfig,
+    Timeout,
+    VerifyTypes,
+)
 from ..exceptions import PoolTimeout
 from ..models import Origin, Request, Response
 from ..utils import get_logger
@@ -92,12 +99,9 @@ class ConnectionPool(Dispatcher):
         backend: typing.Union[str, ConcurrencyBackend] = "auto",
         uds: typing.Optional[str] = None,
     ):
-        self.verify = verify
-        self.cert = cert
+        self.ssl = SSLConfig(verify=verify, cert=cert, trust_env=trust_env, http2=http2)
         self.pool_limits = pool_limits
-        self.http2 = http2
         self.is_closed = False
-        self.trust_env = trust_env
         self.uds = uds
 
         self.keepalive_connections = ConnectionStore()
@@ -166,12 +170,9 @@ class ConnectionPool(Dispatcher):
             await self.max_connections.acquire(timeout=pool_timeout)
             connection = HTTPConnection(
                 origin,
-                verify=self.verify,
-                cert=self.cert,
-                http2=self.http2,
+                ssl=self.ssl,
                 backend=self.backend,
                 release_func=self.release_connection,
-                trust_env=self.trust_env,
                 uds=self.uds,
             )
             logger.trace(f"new_connection connection={connection!r}")
index 748cffa5c1f2404ada13282599e5f8f635634b6e..915983601f24357921539119eec8da95248ffbbf 100644 (file)
@@ -4,7 +4,14 @@ import warnings
 from base64 import b64encode
 
 from ..backends.base import ConcurrencyBackend
-from ..config import DEFAULT_POOL_LIMITS, CertTypes, PoolLimits, Timeout, VerifyTypes
+from ..config import (
+    DEFAULT_POOL_LIMITS,
+    CertTypes,
+    PoolLimits,
+    SSLConfig,
+    Timeout,
+    VerifyTypes,
+)
 from ..exceptions import ProxyError
 from ..models import URL, Headers, HeaderTypes, Origin, Request, Response, URLTypes
 from ..utils import get_logger
@@ -55,6 +62,10 @@ class HTTPProxy(ConnectionPool):
             proxy_mode = proxy_mode.value
         assert proxy_mode in ("DEFAULT", "FORWARD_ONLY", "TUNNEL_ONLY")
 
+        self.tunnel_ssl = SSLConfig(
+            verify=verify, cert=cert, trust_env=trust_env, http2=False
+        )
+
         super(HTTPProxy, self).__init__(
             verify=verify,
             cert=cert,
@@ -137,11 +148,8 @@ class HTTPProxy(ConnectionPool):
 
         connection = HTTPConnection(
             self.proxy_url.origin,
-            verify=self.verify,
-            cert=self.cert,
+            ssl=self.tunnel_ssl,
             backend=self.backend,
-            http2=False,  # Short-lived 'connection'
-            trust_env=self.trust_env,
             release_func=self.release_connection,
         )
         self.active_connections.add(connection)
index ff117147573e8b0be638fa55d566b3b0c3fe690a..b94774e6e448997fdf76adca38721378f02fb60f 100644 (file)
@@ -29,27 +29,6 @@ def test_proxies_parameter(proxies, expected_proxies):
     assert len(expected_proxies) == len(client.proxies)
 
 
-def test_proxies_has_same_properties_as_dispatch():
-    client = httpx.AsyncClient(
-        proxies="http://127.0.0.1",
-        verify="/path/to/verify",
-        cert="/path/to/cert",
-        trust_env=False,
-        timeout=30,
-    )
-    pool = client.dispatch
-    proxy = client.proxies["all"]
-
-    assert isinstance(proxy, httpx.HTTPProxy)
-
-    for prop in [
-        "verify",
-        "cert",
-        "pool_limits",
-    ]:
-        assert getattr(pool, prop) == getattr(proxy, prop)
-
-
 PROXY_URL = "http://[::1]"
 
 
index 65eb6c51ddea2d764441ad373a2214381cd8eb49..80dd6a46b371c74480da0a3aeb3a943e39cda5a7 100644 (file)
@@ -1,6 +1,7 @@
 import pytest
 
 import httpx
+from httpx.config import SSLConfig
 from httpx.dispatch.connection import HTTPConnection
 
 
@@ -35,7 +36,8 @@ async def test_https_get_with_ssl(https_server, ca_cert_pem_file):
     """
     An HTTPS request, with SSL configuration set on the client.
     """
-    async with HTTPConnection(origin=https_server.url, verify=ca_cert_pem_file) as conn:
+    ssl = SSLConfig(verify=ca_cert_pem_file)
+    async with HTTPConnection(origin=https_server.url, ssl=ssl) as conn:
         response = await conn.request("GET", https_server.url)
         await response.aread()
         assert response.status_code == 200