]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Initial pass at configuring supported protocol versions
authorTom Christie <tom@tomchristie.com>
Mon, 19 Aug 2019 19:13:37 +0000 (20:13 +0100)
committerTom Christie <tom@tomchristie.com>
Mon, 19 Aug 2019 19:13:37 +0000 (20:13 +0100)
httpx/__init__.py
httpx/concurrency.py
httpx/config.py
httpx/dispatch/connection.py
httpx/interfaces.py
tests/dispatch/utils.py

index dec5cff19c749bf77a819e28ae86856234d50dcf..0f1a7b4002804b052ed4dda54e30cc981aeedc97 100644 (file)
@@ -6,6 +6,8 @@ from .config import (
     USER_AGENT,
     CertTypes,
     PoolLimits,
+    ProtocolConfig,
+    ProtocolTypes,
     SSLConfig,
     TimeoutConfig,
     TimeoutTypes,
index f1bf585448f27d2b3b3f8941d70b051e0585a5ad..307b5dcc1179c6d251edcc7211d98aadf90149d4 100644 (file)
@@ -14,7 +14,7 @@ import ssl
 import typing
 from types import TracebackType
 
-from .config import PoolLimits, TimeoutConfig
+from .config import PoolLimits, ProtocolConfig, TimeoutConfig
 from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
 from .interfaces import (
     BaseBackgroundManager,
@@ -202,6 +202,7 @@ class AsyncioBackend(ConcurrencyBackend):
         port: int,
         ssl_context: typing.Optional[ssl.SSLContext],
         timeout: TimeoutConfig,
+        protocols: ProtocolConfig
     ) -> typing.Tuple[BaseReader, BaseWriter, Protocol]:
         try:
             stream_reader, stream_writer = await asyncio.wait_for(  # type: ignore
index 796da35023a821f72a0ee7096525db80c7b08062..4cbe76b510c3c8388486e4271c99a273689a66f2 100644 (file)
@@ -9,6 +9,7 @@ from .__version__ import __version__
 CertTypes = typing.Union[str, typing.Tuple[str, str], typing.Tuple[str, str, str]]
 VerifyTypes = typing.Union[str, bool, ssl.SSLContext]
 TimeoutTypes = typing.Union[float, typing.Tuple[float, float, float], "TimeoutConfig"]
+ProtocolTypes = typing.Union[str, typing.List[str], typing.Tuple[str], "ProtocolConfig"]
 
 
 USER_AGENT = f"python-httpx/{__version__}"
@@ -72,27 +73,29 @@ class SSLConfig:
             return self
         return SSLConfig(cert=cert, verify=verify)
 
-    def load_ssl_context(self) -> ssl.SSLContext:
+    def load_ssl_context(self, protocols: 'ProtocolConfig'=None) -> ssl.SSLContext:
+        protocols = ProtocolConfig() if protocols is None else protocols
+
         if self.ssl_context is None:
             self.ssl_context = (
-                self.load_ssl_context_verify()
+                self.load_ssl_context_verify(protocols=protocols)
                 if self.verify
-                else self.load_ssl_context_no_verify()
+                else self.load_ssl_context_no_verify(protocols=protocols)
             )
 
         assert self.ssl_context is not None
         return self.ssl_context
 
-    def load_ssl_context_no_verify(self) -> ssl.SSLContext:
+    def load_ssl_context_no_verify(self, protocols: 'ProtocolConfig') -> ssl.SSLContext:
         """
         Return an SSL context for unverified connections.
         """
-        context = self._create_default_ssl_context()
+        context = self._create_default_ssl_context(protocols=protocols)
         context.verify_mode = ssl.CERT_NONE
         context.check_hostname = False
         return context
 
-    def load_ssl_context_verify(self) -> ssl.SSLContext:
+    def load_ssl_context_verify(self, protocols: 'ProtocolConfig') -> ssl.SSLContext:
         """
         Return an SSL context for verified connections.
         """
@@ -106,7 +109,7 @@ class SSLConfig:
                 "invalid path: {}".format(self.verify)
             )
 
-        context = self._create_default_ssl_context()
+        context = self._create_default_ssl_context(protocols=protocols)
         context.verify_mode = ssl.CERT_REQUIRED
         context.check_hostname = True
 
@@ -133,7 +136,7 @@ class SSLConfig:
 
         return context
 
-    def _create_default_ssl_context(self) -> ssl.SSLContext:
+    def _create_default_ssl_context(self, protocols: 'ProtocolConfig') -> ssl.SSLContext:
         """
         Creates the default SSLContext object that's used for both verified
         and unverified connections.
@@ -147,9 +150,9 @@ class SSLConfig:
         context.set_ciphers(DEFAULT_CIPHERS)
 
         if ssl.HAS_ALPN:
-            context.set_alpn_protocols(["h2", "http/1.1"])
+            context.set_alpn_protocols(protocols.protocol_ident_strings)
         if ssl.HAS_NPN:  # pragma: no cover
-            context.set_npn_protocols(["h2", "http/1.1"])
+            context.set_npn_protocols(protocols.protocol_ident_strings)
 
         return context
 
@@ -223,6 +226,40 @@ class TimeoutConfig:
         )
 
 
+class ProtocolConfig:
+    """
+    Configure which HTTP protocol versions are supported.
+    """
+
+    def __init__(self, protocols: ProtocolTypes = None):
+        if protocols is None:
+            protocols = ['HTTP/1.1', 'HTTP/2']
+
+        if isinstance(protocols, str):
+            self.protocols = set([protocol])
+        elif isinstance(protocols, ProtocolConfig):
+            self.protocols = protocols.protocols
+        else:
+            self.protocols = set(sorted(protocols))
+
+        for protocol in self.protocols:
+            if protocol not in ('HTTP/1.1', 'HTTP/2'):
+                raise ValueError(f"Unsupported protocol value {protocol!r}")
+
+    @property
+    def protocol_ident_strings(self) -> typing.List[str]:
+        mapping = {"HTTP/1.1": "http/1.1", "HTTP/2": "h2"}
+        return [mapping[protocol] for protocol in self.protocols]
+
+    def __repr__(self) -> str:
+        class_name = self.__class__.__name__
+        if len(self.protocols) == 1:
+            value = self.protocols[0]
+            return f"{class_name}(protocols={value!r})"
+        value = list(self.protocols)
+        return f"{class_name}(protocols={value!r})"
+
+
 class PoolLimits:
     """
     Limits on the number of connections in a connection pool.
index b51fec688b57ab78133d595d3c5f350706b65b08..48271763967f4742c2cea10ca6821b199498b214 100644 (file)
@@ -5,6 +5,8 @@ from ..concurrency import AsyncioBackend
 from ..config import (
     DEFAULT_TIMEOUT_CONFIG,
     CertTypes,
+    ProtocolTypes,
+    ProtocolConfig,
     SSLConfig,
     TimeoutConfig,
     TimeoutTypes,
@@ -28,10 +30,12 @@ class HTTPConnection(AsyncDispatcher):
         timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
         backend: ConcurrencyBackend = None,
         release_func: typing.Optional[ReleaseCallback] = None,
+        protocols: ProtocolTypes = None,
     ):
         self.origin = Origin(origin) if isinstance(origin, str) else origin
         self.ssl = SSLConfig(cert=cert, verify=verify)
         self.timeout = TimeoutConfig(timeout)
+        self.protocols = ProtocolConfig(protocols)
         self.backend = AsyncioBackend() if backend is None else backend
         self.release_func = release_func
         self.h11_connection = None  # type: typing.Optional[HTTP11Connection]
@@ -43,9 +47,10 @@ class HTTPConnection(AsyncDispatcher):
         verify: VerifyTypes = None,
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
+        protocols: ProtocolTypes = None,
     ) -> AsyncResponse:
         if self.h11_connection is None and self.h2_connection is None:
-            await self.connect(verify=verify, cert=cert, timeout=timeout)
+            await self.connect(verify=verify, cert=cert, timeout=timeout, protocols=protocols)
 
         if self.h2_connection is not None:
             response = await self.h2_connection.send(request, timeout=timeout)
@@ -60,9 +65,11 @@ class HTTPConnection(AsyncDispatcher):
         verify: VerifyTypes = None,
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
+        protocols: ProtocolTypes = None,
     ) -> None:
         ssl = self.ssl.with_overrides(verify=verify, cert=cert)
         timeout = self.timeout if timeout is None else TimeoutConfig(timeout)
+        protocols = self.protocols if protocols is None else ProtocolConfig(protocols)
 
         host = self.origin.host
         port = self.origin.port
@@ -79,7 +86,7 @@ class HTTPConnection(AsyncDispatcher):
             on_release = functools.partial(self.release_func, self)
 
         reader, writer, protocol = await self.backend.connect(
-            host, port, ssl_context, timeout
+            host, port, ssl_context, timeout, protocols
         )
         if protocol == Protocol.HTTP_2:
             self.h2_connection = HTTP2Connection(
index 2b4edf4d3cbda1452ea42932e0a87efc44f3100e..ca16cec7ca30528fc269ea38611abf563ad745e7 100644 (file)
@@ -3,7 +3,7 @@ import ssl
 import typing
 from types import TracebackType
 
-from .config import CertTypes, PoolLimits, TimeoutConfig, TimeoutTypes, VerifyTypes
+from .config import CertTypes, PoolLimits, ProtocolConfig, TimeoutConfig, TimeoutTypes, VerifyTypes
 from .models import (
     AsyncRequest,
     AsyncRequestData,
@@ -172,6 +172,7 @@ class ConcurrencyBackend:
         port: int,
         ssl_context: typing.Optional[ssl.SSLContext],
         timeout: TimeoutConfig,
+        protocols: ProtocolConfig
     ) -> typing.Tuple[BaseReader, BaseWriter, Protocol]:
         raise NotImplementedError()  # pragma: no cover
 
index c92fa7a310b32bdd3039364be03fdc07251f55d7..1bc70cffcbf042ca34ab6d93b2b9d763521fa02d 100644 (file)
@@ -13,6 +13,7 @@ from httpx import (
     Protocol,
     Request,
     TimeoutConfig,
+    ProtocolConfig
 )
 
 
@@ -27,6 +28,7 @@ class MockHTTP2Backend(AsyncioBackend):
         port: int,
         ssl_context: typing.Optional[ssl.SSLContext],
         timeout: TimeoutConfig,
+        protocols: ProtocolConfig
     ) -> typing.Tuple[BaseReader, BaseWriter, Protocol]:
         self.server = MockHTTP2Server(self.app)
         return self.server, self.server, Protocol.HTTP_2