CertTypes = typing.Union[str, typing.Tuple[str, str], typing.Tuple[str, str, str]]
VerifyTypes = typing.Union[str, bool, ssl.SSLContext]
TimeoutTypes = typing.Union[float, typing.Tuple[float, float, float], "TimeoutConfig"]
+ProtocolTypes = typing.Union[str, typing.List[str], typing.Tuple[str], "ProtocolConfig"]
USER_AGENT = f"python-httpx/{__version__}"
return self
return SSLConfig(cert=cert, verify=verify)
- def load_ssl_context(self) -> ssl.SSLContext:
+ def load_ssl_context(self, protocols: 'ProtocolConfig'=None) -> ssl.SSLContext:
+ protocols = ProtocolConfig() if protocols is None else protocols
+
if self.ssl_context is None:
self.ssl_context = (
- self.load_ssl_context_verify()
+ self.load_ssl_context_verify(protocols=protocols)
if self.verify
- else self.load_ssl_context_no_verify()
+ else self.load_ssl_context_no_verify(protocols=protocols)
)
assert self.ssl_context is not None
return self.ssl_context
- def load_ssl_context_no_verify(self) -> ssl.SSLContext:
+ def load_ssl_context_no_verify(self, protocols: 'ProtocolConfig') -> ssl.SSLContext:
"""
Return an SSL context for unverified connections.
"""
- context = self._create_default_ssl_context()
+ context = self._create_default_ssl_context(protocols=protocols)
context.verify_mode = ssl.CERT_NONE
context.check_hostname = False
return context
- def load_ssl_context_verify(self) -> ssl.SSLContext:
+ def load_ssl_context_verify(self, protocols: 'ProtocolConfig') -> ssl.SSLContext:
"""
Return an SSL context for verified connections.
"""
"invalid path: {}".format(self.verify)
)
- context = self._create_default_ssl_context()
+ context = self._create_default_ssl_context(protocols=protocols)
context.verify_mode = ssl.CERT_REQUIRED
context.check_hostname = True
return context
- def _create_default_ssl_context(self) -> ssl.SSLContext:
+ def _create_default_ssl_context(self, protocols: 'ProtocolConfig') -> ssl.SSLContext:
"""
Creates the default SSLContext object that's used for both verified
and unverified connections.
context.set_ciphers(DEFAULT_CIPHERS)
if ssl.HAS_ALPN:
- context.set_alpn_protocols(["h2", "http/1.1"])
+ context.set_alpn_protocols(protocols.protocol_ident_strings)
if ssl.HAS_NPN: # pragma: no cover
- context.set_npn_protocols(["h2", "http/1.1"])
+ context.set_npn_protocols(protocols.protocol_ident_strings)
return context
)
+class ProtocolConfig:
+ """
+ Configure which HTTP protocol versions are supported.
+ """
+
+ def __init__(self, protocols: ProtocolTypes = None):
+ if protocols is None:
+ protocols = ['HTTP/1.1', 'HTTP/2']
+
+ if isinstance(protocols, str):
+ self.protocols = set([protocol])
+ elif isinstance(protocols, ProtocolConfig):
+ self.protocols = protocols.protocols
+ else:
+ self.protocols = set(sorted(protocols))
+
+ for protocol in self.protocols:
+ if protocol not in ('HTTP/1.1', 'HTTP/2'):
+ raise ValueError(f"Unsupported protocol value {protocol!r}")
+
+ @property
+ def protocol_ident_strings(self) -> typing.List[str]:
+ mapping = {"HTTP/1.1": "http/1.1", "HTTP/2": "h2"}
+ return [mapping[protocol] for protocol in self.protocols]
+
+ def __repr__(self) -> str:
+ class_name = self.__class__.__name__
+ if len(self.protocols) == 1:
+ value = self.protocols[0]
+ return f"{class_name}(protocols={value!r})"
+ value = list(self.protocols)
+ return f"{class_name}(protocols={value!r})"
+
+
class PoolLimits:
"""
Limits on the number of connections in a connection pool.
from ..config import (
DEFAULT_TIMEOUT_CONFIG,
CertTypes,
+ ProtocolTypes,
+ ProtocolConfig,
SSLConfig,
TimeoutConfig,
TimeoutTypes,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
backend: ConcurrencyBackend = None,
release_func: typing.Optional[ReleaseCallback] = None,
+ protocols: ProtocolTypes = None,
):
self.origin = Origin(origin) if isinstance(origin, str) else origin
self.ssl = SSLConfig(cert=cert, verify=verify)
self.timeout = TimeoutConfig(timeout)
+ self.protocols = ProtocolConfig(protocols)
self.backend = AsyncioBackend() if backend is None else backend
self.release_func = release_func
self.h11_connection = None # type: typing.Optional[HTTP11Connection]
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
+ protocols: ProtocolTypes = None,
) -> AsyncResponse:
if self.h11_connection is None and self.h2_connection is None:
- await self.connect(verify=verify, cert=cert, timeout=timeout)
+ await self.connect(verify=verify, cert=cert, timeout=timeout, protocols=protocols)
if self.h2_connection is not None:
response = await self.h2_connection.send(request, timeout=timeout)
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
+ protocols: ProtocolTypes = None,
) -> None:
ssl = self.ssl.with_overrides(verify=verify, cert=cert)
timeout = self.timeout if timeout is None else TimeoutConfig(timeout)
+ protocols = self.protocols if protocols is None else ProtocolConfig(protocols)
host = self.origin.host
port = self.origin.port
on_release = functools.partial(self.release_func, self)
reader, writer, protocol = await self.backend.connect(
- host, port, ssl_context, timeout
+ host, port, ssl_context, timeout, protocols
)
if protocol == Protocol.HTTP_2:
self.h2_connection = HTTP2Connection(