]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
'Protocols' -> 'HTTPVersions'
authorTom Christie <tom@tomchristie.com>
Tue, 20 Aug 2019 11:21:07 +0000 (12:21 +0100)
committerTom Christie <tom@tomchristie.com>
Tue, 20 Aug 2019 11:21:07 +0000 (12:21 +0100)
httpx/__init__.py
httpx/concurrency.py
httpx/config.py
httpx/dispatch/connection.py
httpx/interfaces.py
tests/dispatch/utils.py

index 0f1a7b4002804b052ed4dda54e30cc981aeedc97..5ef0f6386f8145dcf2e6bb8716f134aaf6bd72c1 100644 (file)
@@ -6,8 +6,8 @@ from .config import (
     USER_AGENT,
     CertTypes,
     PoolLimits,
-    ProtocolConfig,
-    ProtocolTypes,
+    HTTPVersionConfig,
+    HTTPVersionTypes,
     SSLConfig,
     TimeoutConfig,
     TimeoutTypes,
index 307b5dcc1179c6d251edcc7211d98aadf90149d4..a87c9a1d8b8a341020ecdc6e85ca01b8f17cd4cb 100644 (file)
@@ -14,7 +14,7 @@ import ssl
 import typing
 from types import TracebackType
 
-from .config import PoolLimits, ProtocolConfig, TimeoutConfig
+from .config import PoolLimits, HTTPVersionConfig, TimeoutConfig
 from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
 from .interfaces import (
     BaseBackgroundManager,
@@ -202,7 +202,6 @@ class AsyncioBackend(ConcurrencyBackend):
         port: int,
         ssl_context: typing.Optional[ssl.SSLContext],
         timeout: TimeoutConfig,
-        protocols: ProtocolConfig
     ) -> typing.Tuple[BaseReader, BaseWriter, Protocol]:
         try:
             stream_reader, stream_writer = await asyncio.wait_for(  # type: ignore
index 4cbe76b510c3c8388486e4271c99a273689a66f2..3804effc09bb6fed7a0f2b3f5fef8a69e1e98100 100644 (file)
@@ -9,7 +9,7 @@ from .__version__ import __version__
 CertTypes = typing.Union[str, typing.Tuple[str, str], typing.Tuple[str, str, str]]
 VerifyTypes = typing.Union[str, bool, ssl.SSLContext]
 TimeoutTypes = typing.Union[float, typing.Tuple[float, float, float], "TimeoutConfig"]
-ProtocolTypes = typing.Union[str, typing.List[str], typing.Tuple[str], "ProtocolConfig"]
+HTTPVersionTypes = typing.Union[str, typing.List[str], typing.Tuple[str], "HTTPVersionConfig"]
 
 
 USER_AGENT = f"python-httpx/{__version__}"
@@ -73,29 +73,29 @@ class SSLConfig:
             return self
         return SSLConfig(cert=cert, verify=verify)
 
-    def load_ssl_context(self, protocols: 'ProtocolConfig'=None) -> ssl.SSLContext:
-        protocols = ProtocolConfig() if protocols is None else protocols
+    def load_ssl_context(self, http_versions: 'HTTPVersionConfig'=None) -> ssl.SSLContext:
+        http_versions = HTTPVersionConfig() if http_versions is None else http_versions
 
         if self.ssl_context is None:
             self.ssl_context = (
-                self.load_ssl_context_verify(protocols=protocols)
+                self.load_ssl_context_verify(http_versions=http_versions)
                 if self.verify
-                else self.load_ssl_context_no_verify(protocols=protocols)
+                else self.load_ssl_context_no_verify(http_versions=http_versions)
             )
 
         assert self.ssl_context is not None
         return self.ssl_context
 
-    def load_ssl_context_no_verify(self, protocols: 'ProtocolConfig') -> ssl.SSLContext:
+    def load_ssl_context_no_verify(self, http_versions: 'HTTPVersionConfig') -> ssl.SSLContext:
         """
         Return an SSL context for unverified connections.
         """
-        context = self._create_default_ssl_context(protocols=protocols)
+        context = self._create_default_ssl_context(http_versions=http_versions)
         context.verify_mode = ssl.CERT_NONE
         context.check_hostname = False
         return context
 
-    def load_ssl_context_verify(self, protocols: 'ProtocolConfig') -> ssl.SSLContext:
+    def load_ssl_context_verify(self, http_versions: 'HTTPVersionConfig') -> ssl.SSLContext:
         """
         Return an SSL context for verified connections.
         """
@@ -109,7 +109,7 @@ class SSLConfig:
                 "invalid path: {}".format(self.verify)
             )
 
-        context = self._create_default_ssl_context(protocols=protocols)
+        context = self._create_default_ssl_context(http_versions=http_versions)
         context.verify_mode = ssl.CERT_REQUIRED
         context.check_hostname = True
 
@@ -136,7 +136,7 @@ class SSLConfig:
 
         return context
 
-    def _create_default_ssl_context(self, protocols: 'ProtocolConfig') -> ssl.SSLContext:
+    def _create_default_ssl_context(self, http_versions: 'HTTPVersionConfig') -> ssl.SSLContext:
         """
         Creates the default SSLContext object that's used for both verified
         and unverified connections.
@@ -150,9 +150,9 @@ class SSLConfig:
         context.set_ciphers(DEFAULT_CIPHERS)
 
         if ssl.HAS_ALPN:
-            context.set_alpn_protocols(protocols.protocol_ident_strings)
+            context.set_alpn_protocols(http_versions.alpn_strings)
         if ssl.HAS_NPN:  # pragma: no cover
-            context.set_npn_protocols(protocols.protocol_ident_strings)
+            context.set_npn_protocols(http_versions.alpn_strings)
 
         return context
 
@@ -226,38 +226,38 @@ class TimeoutConfig:
         )
 
 
-class ProtocolConfig:
+class HTTPVersionConfig:
     """
     Configure which HTTP protocol versions are supported.
     """
 
-    def __init__(self, protocols: ProtocolTypes = None):
-        if protocols is None:
-            protocols = ['HTTP/1.1', 'HTTP/2']
+    def __init__(self, http_versions: HTTPVersionTypes = None):
+        if http_versions is None:
+            http_versions = ['HTTP/1.1', 'HTTP/2']
 
-        if isinstance(protocols, str):
-            self.protocols = set([protocol])
-        elif isinstance(protocols, ProtocolConfig):
-            self.protocols = protocols.protocols
+        if isinstance(http_versions, str):
+            self.http_versions = set([http_versions])
+        elif isinstance(http_versions, HTTPVersionConfig):
+            self.http_versions = http_versions.http_versions
         else:
-            self.protocols = set(sorted(protocols))
+            self.http_versions = set(sorted(http_versions))
 
-        for protocol in self.protocols:
-            if protocol not in ('HTTP/1.1', 'HTTP/2'):
+        for version in self.http_versions:
+            if version not in ('HTTP/1.1', 'HTTP/2'):
                 raise ValueError(f"Unsupported protocol value {protocol!r}")
 
     @property
-    def protocol_ident_strings(self) -> typing.List[str]:
+    def alpn_strings(self) -> typing.List[str]:
+        """
+        Returns a list of supported ALPN identifiers. (One or more of "http/1.1", "h2").
+        """
         mapping = {"HTTP/1.1": "http/1.1", "HTTP/2": "h2"}
-        return [mapping[protocol] for protocol in self.protocols]
+        return [mapping[version] for version in self.http_versions]
 
     def __repr__(self) -> str:
         class_name = self.__class__.__name__
-        if len(self.protocols) == 1:
-            value = self.protocols[0]
-            return f"{class_name}(protocols={value!r})"
-        value = list(self.protocols)
-        return f"{class_name}(protocols={value!r})"
+        value = list(self.http_versions)
+        return f"{class_name}({value!r})"
 
 
 class PoolLimits:
index ac06d62d60f8862de939d3bd181e034b34405771..b8e263aa685eda091eaf44e274a8c2f8a82c7f22 100644 (file)
@@ -6,8 +6,8 @@ from ..concurrency import AsyncioBackend
 from ..config import (
     DEFAULT_TIMEOUT_CONFIG,
     CertTypes,
-    ProtocolTypes,
-    ProtocolConfig,
+    HTTPVersionTypes,
+    HTTPVersionConfig,
     SSLConfig,
     TimeoutConfig,
     TimeoutTypes,
@@ -31,12 +31,12 @@ class HTTPConnection(AsyncDispatcher):
         timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
         backend: ConcurrencyBackend = None,
         release_func: typing.Optional[ReleaseCallback] = None,
-        protocols: ProtocolTypes = None,
+        http_versions: HTTPVersionTypes = None,
     ):
         self.origin = Origin(origin) if isinstance(origin, str) else origin
         self.ssl = SSLConfig(cert=cert, verify=verify)
         self.timeout = TimeoutConfig(timeout)
-        self.protocols = ProtocolConfig(protocols)
+        self.http_versions = HTTPVersionConfig(http_versions)
         self.backend = AsyncioBackend() if backend is None else backend
         self.release_func = release_func
         self.h11_connection = None  # type: typing.Optional[HTTP11Connection]
@@ -48,10 +48,10 @@ class HTTPConnection(AsyncDispatcher):
         verify: VerifyTypes = None,
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
-        protocols: ProtocolTypes = None,
+        http_versions: HTTPVersionTypes = None,
     ) -> AsyncResponse:
         if self.h11_connection is None and self.h2_connection is None:
-            await self.connect(verify=verify, cert=cert, timeout=timeout, protocols=protocols)
+            await self.connect(verify=verify, cert=cert, timeout=timeout, http_versions=http_versions)
 
         if self.h2_connection is not None:
             response = await self.h2_connection.send(request, timeout=timeout)
@@ -66,15 +66,15 @@ class HTTPConnection(AsyncDispatcher):
         verify: VerifyTypes = None,
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
-        protocols: ProtocolTypes = None,
+        http_versions: HTTPVersionTypes = None,
     ) -> None:
         ssl = self.ssl.with_overrides(verify=verify, cert=cert)
         timeout = self.timeout if timeout is None else TimeoutConfig(timeout)
-        protocols = self.protocols if protocols is None else ProtocolConfig(protocols)
+        http_versions = self.http_versions if http_versions is None else HTTPVersionConfig(http_versions)
 
         host = self.origin.host
         port = self.origin.port
-        ssl_context = await self.get_ssl_context(ssl, protocols)
+        ssl_context = await self.get_ssl_context(ssl, http_versions)
 
         if self.release_func is None:
             on_release = None
@@ -82,7 +82,7 @@ class HTTPConnection(AsyncDispatcher):
             on_release = functools.partial(self.release_func, self)
 
         reader, writer, protocol = await self.backend.connect(
-            host, port, ssl_context, timeout, protocols
+            host, port, ssl_context, timeout
         )
         if protocol == Protocol.HTTP_2:
             self.h2_connection = HTTP2Connection(
@@ -93,12 +93,12 @@ class HTTPConnection(AsyncDispatcher):
                 reader, writer, self.backend, on_release=on_release
             )
 
-    async def get_ssl_context(self, ssl: SSLConfig, protocols: ProtocolConfig) -> typing.Optional[ssl.SSLContext]:
+    async def get_ssl_context(self, ssl: SSLConfig, http_versions: HTTPVersionConfig) -> typing.Optional[ssl.SSLContext]:
         if not self.origin.is_ssl:
             return None
 
         # Run the SSL loading in a threadpool, since it may makes disk accesses.
-        return await self.backend.run_in_threadpool(ssl.load_ssl_context, protocols)
+        return await self.backend.run_in_threadpool(ssl.load_ssl_context, http_versions)
 
     async def close(self) -> None:
         if self.h2_connection is not None:
index ca16cec7ca30528fc269ea38611abf563ad745e7..a8758536027e262a296e6686061e7a98e3d82630 100644 (file)
@@ -3,7 +3,7 @@ import ssl
 import typing
 from types import TracebackType
 
-from .config import CertTypes, PoolLimits, ProtocolConfig, TimeoutConfig, TimeoutTypes, VerifyTypes
+from .config import CertTypes, PoolLimits, HTTPVersionConfig, TimeoutConfig, TimeoutTypes, VerifyTypes
 from .models import (
     AsyncRequest,
     AsyncRequestData,
@@ -171,8 +171,7 @@ class ConcurrencyBackend:
         hostname: str,
         port: int,
         ssl_context: typing.Optional[ssl.SSLContext],
-        timeout: TimeoutConfig,
-        protocols: ProtocolConfig
+        timeout: TimeoutConfig
     ) -> typing.Tuple[BaseReader, BaseWriter, Protocol]:
         raise NotImplementedError()  # pragma: no cover
 
index 1bc70cffcbf042ca34ab6d93b2b9d763521fa02d..fb2b913a9fabbb409bc16ee2aeb2a3ac5cceddd1 100644 (file)
@@ -13,7 +13,7 @@ from httpx import (
     Protocol,
     Request,
     TimeoutConfig,
-    ProtocolConfig
+    HTTPVersionConfig
 )
 
 
@@ -27,8 +27,7 @@ class MockHTTP2Backend(AsyncioBackend):
         hostname: str,
         port: int,
         ssl_context: typing.Optional[ssl.SSLContext],
-        timeout: TimeoutConfig,
-        protocols: ProtocolConfig
+        timeout: TimeoutConfig
     ) -> typing.Tuple[BaseReader, BaseWriter, Protocol]:
         self.server = MockHTTP2Server(self.app)
         return self.server, self.server, Protocol.HTTP_2