From: Tom Christie Date: Mon, 19 Aug 2019 19:13:37 +0000 (+0100) Subject: Initial pass at configuring supported protocol versions X-Git-Tag: 0.7.2~19^2~7^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=0e49f23b7537bd8fef2917a739c522aea166a918;p=thirdparty%2Fhttpx.git Initial pass at configuring supported protocol versions --- diff --git a/httpx/__init__.py b/httpx/__init__.py index dec5cff1..0f1a7b40 100644 --- a/httpx/__init__.py +++ b/httpx/__init__.py @@ -6,6 +6,8 @@ from .config import ( USER_AGENT, CertTypes, PoolLimits, + ProtocolConfig, + ProtocolTypes, SSLConfig, TimeoutConfig, TimeoutTypes, diff --git a/httpx/concurrency.py b/httpx/concurrency.py index f1bf5854..307b5dcc 100644 --- a/httpx/concurrency.py +++ b/httpx/concurrency.py @@ -14,7 +14,7 @@ import ssl import typing from types import TracebackType -from .config import PoolLimits, TimeoutConfig +from .config import PoolLimits, ProtocolConfig, TimeoutConfig from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout from .interfaces import ( BaseBackgroundManager, @@ -202,6 +202,7 @@ class AsyncioBackend(ConcurrencyBackend): port: int, ssl_context: typing.Optional[ssl.SSLContext], timeout: TimeoutConfig, + protocols: ProtocolConfig ) -> typing.Tuple[BaseReader, BaseWriter, Protocol]: try: stream_reader, stream_writer = await asyncio.wait_for( # type: ignore diff --git a/httpx/config.py b/httpx/config.py index 796da350..4cbe76b5 100644 --- a/httpx/config.py +++ b/httpx/config.py @@ -9,6 +9,7 @@ from .__version__ import __version__ CertTypes = typing.Union[str, typing.Tuple[str, str], typing.Tuple[str, str, str]] VerifyTypes = typing.Union[str, bool, ssl.SSLContext] TimeoutTypes = typing.Union[float, typing.Tuple[float, float, float], "TimeoutConfig"] +ProtocolTypes = typing.Union[str, typing.List[str], typing.Tuple[str], "ProtocolConfig"] USER_AGENT = f"python-httpx/{__version__}" @@ -72,27 +73,29 @@ class SSLConfig: return self return SSLConfig(cert=cert, verify=verify) - def load_ssl_context(self) -> ssl.SSLContext: + def load_ssl_context(self, protocols: 'ProtocolConfig'=None) -> ssl.SSLContext: + protocols = ProtocolConfig() if protocols is None else protocols + if self.ssl_context is None: self.ssl_context = ( - self.load_ssl_context_verify() + self.load_ssl_context_verify(protocols=protocols) if self.verify - else self.load_ssl_context_no_verify() + else self.load_ssl_context_no_verify(protocols=protocols) ) assert self.ssl_context is not None return self.ssl_context - def load_ssl_context_no_verify(self) -> ssl.SSLContext: + def load_ssl_context_no_verify(self, protocols: 'ProtocolConfig') -> ssl.SSLContext: """ Return an SSL context for unverified connections. """ - context = self._create_default_ssl_context() + context = self._create_default_ssl_context(protocols=protocols) context.verify_mode = ssl.CERT_NONE context.check_hostname = False return context - def load_ssl_context_verify(self) -> ssl.SSLContext: + def load_ssl_context_verify(self, protocols: 'ProtocolConfig') -> ssl.SSLContext: """ Return an SSL context for verified connections. """ @@ -106,7 +109,7 @@ class SSLConfig: "invalid path: {}".format(self.verify) ) - context = self._create_default_ssl_context() + context = self._create_default_ssl_context(protocols=protocols) context.verify_mode = ssl.CERT_REQUIRED context.check_hostname = True @@ -133,7 +136,7 @@ class SSLConfig: return context - def _create_default_ssl_context(self) -> ssl.SSLContext: + def _create_default_ssl_context(self, protocols: 'ProtocolConfig') -> ssl.SSLContext: """ Creates the default SSLContext object that's used for both verified and unverified connections. @@ -147,9 +150,9 @@ class SSLConfig: context.set_ciphers(DEFAULT_CIPHERS) if ssl.HAS_ALPN: - context.set_alpn_protocols(["h2", "http/1.1"]) + context.set_alpn_protocols(protocols.protocol_ident_strings) if ssl.HAS_NPN: # pragma: no cover - context.set_npn_protocols(["h2", "http/1.1"]) + context.set_npn_protocols(protocols.protocol_ident_strings) return context @@ -223,6 +226,40 @@ class TimeoutConfig: ) +class ProtocolConfig: + """ + Configure which HTTP protocol versions are supported. + """ + + def __init__(self, protocols: ProtocolTypes = None): + if protocols is None: + protocols = ['HTTP/1.1', 'HTTP/2'] + + if isinstance(protocols, str): + self.protocols = set([protocol]) + elif isinstance(protocols, ProtocolConfig): + self.protocols = protocols.protocols + else: + self.protocols = set(sorted(protocols)) + + for protocol in self.protocols: + if protocol not in ('HTTP/1.1', 'HTTP/2'): + raise ValueError(f"Unsupported protocol value {protocol!r}") + + @property + def protocol_ident_strings(self) -> typing.List[str]: + mapping = {"HTTP/1.1": "http/1.1", "HTTP/2": "h2"} + return [mapping[protocol] for protocol in self.protocols] + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + if len(self.protocols) == 1: + value = self.protocols[0] + return f"{class_name}(protocols={value!r})" + value = list(self.protocols) + return f"{class_name}(protocols={value!r})" + + class PoolLimits: """ Limits on the number of connections in a connection pool. diff --git a/httpx/dispatch/connection.py b/httpx/dispatch/connection.py index b51fec68..48271763 100644 --- a/httpx/dispatch/connection.py +++ b/httpx/dispatch/connection.py @@ -5,6 +5,8 @@ from ..concurrency import AsyncioBackend from ..config import ( DEFAULT_TIMEOUT_CONFIG, CertTypes, + ProtocolTypes, + ProtocolConfig, SSLConfig, TimeoutConfig, TimeoutTypes, @@ -28,10 +30,12 @@ class HTTPConnection(AsyncDispatcher): timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, backend: ConcurrencyBackend = None, release_func: typing.Optional[ReleaseCallback] = None, + protocols: ProtocolTypes = None, ): self.origin = Origin(origin) if isinstance(origin, str) else origin self.ssl = SSLConfig(cert=cert, verify=verify) self.timeout = TimeoutConfig(timeout) + self.protocols = ProtocolConfig(protocols) self.backend = AsyncioBackend() if backend is None else backend self.release_func = release_func self.h11_connection = None # type: typing.Optional[HTTP11Connection] @@ -43,9 +47,10 @@ class HTTPConnection(AsyncDispatcher): verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, + protocols: ProtocolTypes = None, ) -> AsyncResponse: if self.h11_connection is None and self.h2_connection is None: - await self.connect(verify=verify, cert=cert, timeout=timeout) + await self.connect(verify=verify, cert=cert, timeout=timeout, protocols=protocols) if self.h2_connection is not None: response = await self.h2_connection.send(request, timeout=timeout) @@ -60,9 +65,11 @@ class HTTPConnection(AsyncDispatcher): verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, + protocols: ProtocolTypes = None, ) -> None: ssl = self.ssl.with_overrides(verify=verify, cert=cert) timeout = self.timeout if timeout is None else TimeoutConfig(timeout) + protocols = self.protocols if protocols is None else ProtocolConfig(protocols) host = self.origin.host port = self.origin.port @@ -79,7 +86,7 @@ class HTTPConnection(AsyncDispatcher): on_release = functools.partial(self.release_func, self) reader, writer, protocol = await self.backend.connect( - host, port, ssl_context, timeout + host, port, ssl_context, timeout, protocols ) if protocol == Protocol.HTTP_2: self.h2_connection = HTTP2Connection( diff --git a/httpx/interfaces.py b/httpx/interfaces.py index 2b4edf4d..ca16cec7 100644 --- a/httpx/interfaces.py +++ b/httpx/interfaces.py @@ -3,7 +3,7 @@ import ssl import typing from types import TracebackType -from .config import CertTypes, PoolLimits, TimeoutConfig, TimeoutTypes, VerifyTypes +from .config import CertTypes, PoolLimits, ProtocolConfig, TimeoutConfig, TimeoutTypes, VerifyTypes from .models import ( AsyncRequest, AsyncRequestData, @@ -172,6 +172,7 @@ class ConcurrencyBackend: port: int, ssl_context: typing.Optional[ssl.SSLContext], timeout: TimeoutConfig, + protocols: ProtocolConfig ) -> typing.Tuple[BaseReader, BaseWriter, Protocol]: raise NotImplementedError() # pragma: no cover diff --git a/tests/dispatch/utils.py b/tests/dispatch/utils.py index c92fa7a3..1bc70cff 100644 --- a/tests/dispatch/utils.py +++ b/tests/dispatch/utils.py @@ -13,6 +13,7 @@ from httpx import ( Protocol, Request, TimeoutConfig, + ProtocolConfig ) @@ -27,6 +28,7 @@ class MockHTTP2Backend(AsyncioBackend): port: int, ssl_context: typing.Optional[ssl.SSLContext], timeout: TimeoutConfig, + protocols: ProtocolConfig ) -> typing.Tuple[BaseReader, BaseWriter, Protocol]: self.server = MockHTTP2Server(self.app) return self.server, self.server, Protocol.HTTP_2