From: Tom Christie Date: Tue, 20 Aug 2019 11:21:07 +0000 (+0100) Subject: 'Protocols' -> 'HTTPVersions' X-Git-Tag: 0.7.2~19^2~6 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f72fa688027ca3312d52ffd809eb2312778633a6;p=thirdparty%2Fhttpx.git 'Protocols' -> 'HTTPVersions' --- diff --git a/httpx/__init__.py b/httpx/__init__.py index 0f1a7b40..5ef0f638 100644 --- a/httpx/__init__.py +++ b/httpx/__init__.py @@ -6,8 +6,8 @@ from .config import ( USER_AGENT, CertTypes, PoolLimits, - ProtocolConfig, - ProtocolTypes, + HTTPVersionConfig, + HTTPVersionTypes, SSLConfig, TimeoutConfig, TimeoutTypes, diff --git a/httpx/concurrency.py b/httpx/concurrency.py index 307b5dcc..a87c9a1d 100644 --- a/httpx/concurrency.py +++ b/httpx/concurrency.py @@ -14,7 +14,7 @@ import ssl import typing from types import TracebackType -from .config import PoolLimits, ProtocolConfig, TimeoutConfig +from .config import PoolLimits, HTTPVersionConfig, TimeoutConfig from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout from .interfaces import ( BaseBackgroundManager, @@ -202,7 +202,6 @@ class AsyncioBackend(ConcurrencyBackend): port: int, ssl_context: typing.Optional[ssl.SSLContext], timeout: TimeoutConfig, - protocols: ProtocolConfig ) -> typing.Tuple[BaseReader, BaseWriter, Protocol]: try: stream_reader, stream_writer = await asyncio.wait_for( # type: ignore diff --git a/httpx/config.py b/httpx/config.py index 4cbe76b5..3804effc 100644 --- a/httpx/config.py +++ b/httpx/config.py @@ -9,7 +9,7 @@ from .__version__ import __version__ CertTypes = typing.Union[str, typing.Tuple[str, str], typing.Tuple[str, str, str]] VerifyTypes = typing.Union[str, bool, ssl.SSLContext] TimeoutTypes = typing.Union[float, typing.Tuple[float, float, float], "TimeoutConfig"] -ProtocolTypes = typing.Union[str, typing.List[str], typing.Tuple[str], "ProtocolConfig"] +HTTPVersionTypes = typing.Union[str, typing.List[str], typing.Tuple[str], "HTTPVersionConfig"] USER_AGENT = f"python-httpx/{__version__}" @@ -73,29 +73,29 @@ class SSLConfig: return self return SSLConfig(cert=cert, verify=verify) - def load_ssl_context(self, protocols: 'ProtocolConfig'=None) -> ssl.SSLContext: - protocols = ProtocolConfig() if protocols is None else protocols + def load_ssl_context(self, http_versions: 'HTTPVersionConfig'=None) -> ssl.SSLContext: + http_versions = HTTPVersionConfig() if http_versions is None else http_versions if self.ssl_context is None: self.ssl_context = ( - self.load_ssl_context_verify(protocols=protocols) + self.load_ssl_context_verify(http_versions=http_versions) if self.verify - else self.load_ssl_context_no_verify(protocols=protocols) + else self.load_ssl_context_no_verify(http_versions=http_versions) ) assert self.ssl_context is not None return self.ssl_context - def load_ssl_context_no_verify(self, protocols: 'ProtocolConfig') -> ssl.SSLContext: + def load_ssl_context_no_verify(self, http_versions: 'HTTPVersionConfig') -> ssl.SSLContext: """ Return an SSL context for unverified connections. """ - context = self._create_default_ssl_context(protocols=protocols) + context = self._create_default_ssl_context(http_versions=http_versions) context.verify_mode = ssl.CERT_NONE context.check_hostname = False return context - def load_ssl_context_verify(self, protocols: 'ProtocolConfig') -> ssl.SSLContext: + def load_ssl_context_verify(self, http_versions: 'HTTPVersionConfig') -> ssl.SSLContext: """ Return an SSL context for verified connections. """ @@ -109,7 +109,7 @@ class SSLConfig: "invalid path: {}".format(self.verify) ) - context = self._create_default_ssl_context(protocols=protocols) + context = self._create_default_ssl_context(http_versions=http_versions) context.verify_mode = ssl.CERT_REQUIRED context.check_hostname = True @@ -136,7 +136,7 @@ class SSLConfig: return context - def _create_default_ssl_context(self, protocols: 'ProtocolConfig') -> ssl.SSLContext: + def _create_default_ssl_context(self, http_versions: 'HTTPVersionConfig') -> ssl.SSLContext: """ Creates the default SSLContext object that's used for both verified and unverified connections. @@ -150,9 +150,9 @@ class SSLConfig: context.set_ciphers(DEFAULT_CIPHERS) if ssl.HAS_ALPN: - context.set_alpn_protocols(protocols.protocol_ident_strings) + context.set_alpn_protocols(http_versions.alpn_strings) if ssl.HAS_NPN: # pragma: no cover - context.set_npn_protocols(protocols.protocol_ident_strings) + context.set_npn_protocols(http_versions.alpn_strings) return context @@ -226,38 +226,38 @@ class TimeoutConfig: ) -class ProtocolConfig: +class HTTPVersionConfig: """ Configure which HTTP protocol versions are supported. """ - def __init__(self, protocols: ProtocolTypes = None): - if protocols is None: - protocols = ['HTTP/1.1', 'HTTP/2'] + def __init__(self, http_versions: HTTPVersionTypes = None): + if http_versions is None: + http_versions = ['HTTP/1.1', 'HTTP/2'] - if isinstance(protocols, str): - self.protocols = set([protocol]) - elif isinstance(protocols, ProtocolConfig): - self.protocols = protocols.protocols + if isinstance(http_versions, str): + self.http_versions = set([http_versions]) + elif isinstance(http_versions, HTTPVersionConfig): + self.http_versions = http_versions.http_versions else: - self.protocols = set(sorted(protocols)) + self.http_versions = set(sorted(http_versions)) - for protocol in self.protocols: - if protocol not in ('HTTP/1.1', 'HTTP/2'): + for version in self.http_versions: + if version not in ('HTTP/1.1', 'HTTP/2'): raise ValueError(f"Unsupported protocol value {protocol!r}") @property - def protocol_ident_strings(self) -> typing.List[str]: + def alpn_strings(self) -> typing.List[str]: + """ + Returns a list of supported ALPN identifiers. (One or more of "http/1.1", "h2"). + """ mapping = {"HTTP/1.1": "http/1.1", "HTTP/2": "h2"} - return [mapping[protocol] for protocol in self.protocols] + return [mapping[version] for version in self.http_versions] def __repr__(self) -> str: class_name = self.__class__.__name__ - if len(self.protocols) == 1: - value = self.protocols[0] - return f"{class_name}(protocols={value!r})" - value = list(self.protocols) - return f"{class_name}(protocols={value!r})" + value = list(self.http_versions) + return f"{class_name}({value!r})" class PoolLimits: diff --git a/httpx/dispatch/connection.py b/httpx/dispatch/connection.py index ac06d62d..b8e263aa 100644 --- a/httpx/dispatch/connection.py +++ b/httpx/dispatch/connection.py @@ -6,8 +6,8 @@ from ..concurrency import AsyncioBackend from ..config import ( DEFAULT_TIMEOUT_CONFIG, CertTypes, - ProtocolTypes, - ProtocolConfig, + HTTPVersionTypes, + HTTPVersionConfig, SSLConfig, TimeoutConfig, TimeoutTypes, @@ -31,12 +31,12 @@ class HTTPConnection(AsyncDispatcher): timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, backend: ConcurrencyBackend = None, release_func: typing.Optional[ReleaseCallback] = None, - protocols: ProtocolTypes = None, + http_versions: HTTPVersionTypes = None, ): self.origin = Origin(origin) if isinstance(origin, str) else origin self.ssl = SSLConfig(cert=cert, verify=verify) self.timeout = TimeoutConfig(timeout) - self.protocols = ProtocolConfig(protocols) + self.http_versions = HTTPVersionConfig(http_versions) self.backend = AsyncioBackend() if backend is None else backend self.release_func = release_func self.h11_connection = None # type: typing.Optional[HTTP11Connection] @@ -48,10 +48,10 @@ class HTTPConnection(AsyncDispatcher): verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, - protocols: ProtocolTypes = None, + http_versions: HTTPVersionTypes = None, ) -> AsyncResponse: if self.h11_connection is None and self.h2_connection is None: - await self.connect(verify=verify, cert=cert, timeout=timeout, protocols=protocols) + await self.connect(verify=verify, cert=cert, timeout=timeout, http_versions=http_versions) if self.h2_connection is not None: response = await self.h2_connection.send(request, timeout=timeout) @@ -66,15 +66,15 @@ class HTTPConnection(AsyncDispatcher): verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, - protocols: ProtocolTypes = None, + http_versions: HTTPVersionTypes = None, ) -> None: ssl = self.ssl.with_overrides(verify=verify, cert=cert) timeout = self.timeout if timeout is None else TimeoutConfig(timeout) - protocols = self.protocols if protocols is None else ProtocolConfig(protocols) + http_versions = self.http_versions if http_versions is None else HTTPVersionConfig(http_versions) host = self.origin.host port = self.origin.port - ssl_context = await self.get_ssl_context(ssl, protocols) + ssl_context = await self.get_ssl_context(ssl, http_versions) if self.release_func is None: on_release = None @@ -82,7 +82,7 @@ class HTTPConnection(AsyncDispatcher): on_release = functools.partial(self.release_func, self) reader, writer, protocol = await self.backend.connect( - host, port, ssl_context, timeout, protocols + host, port, ssl_context, timeout ) if protocol == Protocol.HTTP_2: self.h2_connection = HTTP2Connection( @@ -93,12 +93,12 @@ class HTTPConnection(AsyncDispatcher): reader, writer, self.backend, on_release=on_release ) - async def get_ssl_context(self, ssl: SSLConfig, protocols: ProtocolConfig) -> typing.Optional[ssl.SSLContext]: + async def get_ssl_context(self, ssl: SSLConfig, http_versions: HTTPVersionConfig) -> typing.Optional[ssl.SSLContext]: if not self.origin.is_ssl: return None # Run the SSL loading in a threadpool, since it may makes disk accesses. - return await self.backend.run_in_threadpool(ssl.load_ssl_context, protocols) + return await self.backend.run_in_threadpool(ssl.load_ssl_context, http_versions) async def close(self) -> None: if self.h2_connection is not None: diff --git a/httpx/interfaces.py b/httpx/interfaces.py index ca16cec7..a8758536 100644 --- a/httpx/interfaces.py +++ b/httpx/interfaces.py @@ -3,7 +3,7 @@ import ssl import typing from types import TracebackType -from .config import CertTypes, PoolLimits, ProtocolConfig, TimeoutConfig, TimeoutTypes, VerifyTypes +from .config import CertTypes, PoolLimits, HTTPVersionConfig, TimeoutConfig, TimeoutTypes, VerifyTypes from .models import ( AsyncRequest, AsyncRequestData, @@ -171,8 +171,7 @@ class ConcurrencyBackend: hostname: str, port: int, ssl_context: typing.Optional[ssl.SSLContext], - timeout: TimeoutConfig, - protocols: ProtocolConfig + timeout: TimeoutConfig ) -> typing.Tuple[BaseReader, BaseWriter, Protocol]: raise NotImplementedError() # pragma: no cover diff --git a/tests/dispatch/utils.py b/tests/dispatch/utils.py index 1bc70cff..fb2b913a 100644 --- a/tests/dispatch/utils.py +++ b/tests/dispatch/utils.py @@ -13,7 +13,7 @@ from httpx import ( Protocol, Request, TimeoutConfig, - ProtocolConfig + HTTPVersionConfig ) @@ -27,8 +27,7 @@ class MockHTTP2Backend(AsyncioBackend): hostname: str, port: int, ssl_context: typing.Optional[ssl.SSLContext], - timeout: TimeoutConfig, - protocols: ProtocolConfig + timeout: TimeoutConfig ) -> typing.Tuple[BaseReader, BaseWriter, Protocol]: self.server = MockHTTP2Server(self.app) return self.server, self.server, Protocol.HTTP_2