]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Added http_versions to dispatch interface
authorTom Christie <tom@tomchristie.com>
Tue, 20 Aug 2019 11:56:19 +0000 (12:56 +0100)
committerTom Christie <tom@tomchristie.com>
Tue, 20 Aug 2019 11:56:19 +0000 (12:56 +0100)
14 files changed:
httpx/config.py
httpx/dispatch/asgi.py
httpx/dispatch/connection.py
httpx/dispatch/connection_pool.py
httpx/dispatch/threaded.py
httpx/dispatch/wsgi.py
httpx/interfaces.py
tests/client/test_auth.py
tests/client/test_cookies.py
tests/client/test_headers.py
tests/client/test_redirects.py
tests/dispatch/test_threaded.py
tests/dispatch/utils.py
tests/test_multipart.py

index 73a5eb34392f491775586be9b75369c3de2c7072..db55d3cbf1683779d489bbef488fec0372b30176 100644 (file)
@@ -9,7 +9,9 @@ from .__version__ import __version__
 CertTypes = typing.Union[str, typing.Tuple[str, str], typing.Tuple[str, str, str]]
 VerifyTypes = typing.Union[str, bool, ssl.SSLContext]
 TimeoutTypes = typing.Union[float, typing.Tuple[float, float, float], "TimeoutConfig"]
-HTTPVersionTypes = typing.Union[str, typing.List[str], typing.Tuple[str], "HTTPVersionConfig"]
+HTTPVersionTypes = typing.Union[
+    str, typing.List[str], typing.Tuple[str], "HTTPVersionConfig"
+]
 
 
 USER_AGENT = f"python-httpx/{__version__}"
@@ -73,7 +75,9 @@ class SSLConfig:
             return self
         return SSLConfig(cert=cert, verify=verify)
 
-    def load_ssl_context(self, http_versions: 'HTTPVersionConfig'=None) -> ssl.SSLContext:
+    def load_ssl_context(
+        self, http_versions: "HTTPVersionConfig" = None
+    ) -> ssl.SSLContext:
         http_versions = HTTPVersionConfig() if http_versions is None else http_versions
 
         if self.ssl_context is None:
@@ -86,7 +90,9 @@ class SSLConfig:
         assert self.ssl_context is not None
         return self.ssl_context
 
-    def load_ssl_context_no_verify(self, http_versions: 'HTTPVersionConfig') -> ssl.SSLContext:
+    def load_ssl_context_no_verify(
+        self, http_versions: "HTTPVersionConfig"
+    ) -> ssl.SSLContext:
         """
         Return an SSL context for unverified connections.
         """
@@ -95,7 +101,9 @@ class SSLConfig:
         context.check_hostname = False
         return context
 
-    def load_ssl_context_verify(self, http_versions: 'HTTPVersionConfig') -> ssl.SSLContext:
+    def load_ssl_context_verify(
+        self, http_versions: "HTTPVersionConfig"
+    ) -> ssl.SSLContext:
         """
         Return an SSL context for verified connections.
         """
@@ -136,7 +144,9 @@ class SSLConfig:
 
         return context
 
-    def _create_default_ssl_context(self, http_versions: 'HTTPVersionConfig') -> ssl.SSLContext:
+    def _create_default_ssl_context(
+        self, http_versions: "HTTPVersionConfig"
+    ) -> ssl.SSLContext:
         """
         Creates the default SSLContext object that's used for both verified
         and unverified connections.
@@ -233,17 +243,19 @@ class HTTPVersionConfig:
 
     def __init__(self, http_versions: HTTPVersionTypes = None):
         if http_versions is None:
-            http_versions = ['HTTP/1.1', 'HTTP/2']
+            http_versions = ["HTTP/1.1", "HTTP/2"]
 
         if isinstance(http_versions, str):
             self.http_versions = set([http_versions.upper()])
         elif isinstance(http_versions, HTTPVersionConfig):
             self.http_versions = http_versions.http_versions
         else:
-            self.http_versions = set(sorted([version.upper() for version in http_versions]))
+            self.http_versions = set(
+                sorted([version.upper() for version in http_versions])
+            )
 
         for version in self.http_versions:
-            if version not in ('HTTP/1.1', 'HTTP/2'):
+            if version not in ("HTTP/1.1", "HTTP/2"):
                 raise ValueError(f"Unsupported HTTP version {version!r}.")
 
         if not self.http_versions:
index e3b4b0a0dfb4f9973f35891ba881699fc3f2a93f..e3a38ad88257caaed3d4329103a93108a67e9324 100644 (file)
@@ -1,7 +1,7 @@
 import asyncio
 import typing
 
-from ..config import CertTypes, TimeoutTypes, VerifyTypes
+from ..config import CertTypes, HTTPVersionTypes, TimeoutTypes, VerifyTypes
 from ..interfaces import AsyncDispatcher
 from ..models import AsyncRequest, AsyncResponse
 
@@ -59,6 +59,7 @@ class ASGIDispatch(AsyncDispatcher):
         verify: VerifyTypes = None,
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
+        http_versions: HTTPVersionTypes = None,
     ) -> AsyncResponse:
 
         scope = {
index b8e263aa685eda091eaf44e274a8c2f8a82c7f22..230afc06acea63f72f97b1be03fe9634594f016c 100644 (file)
@@ -29,9 +29,9 @@ class HTTPConnection(AsyncDispatcher):
         verify: VerifyTypes = True,
         cert: CertTypes = None,
         timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
+        http_versions: HTTPVersionTypes = None,
         backend: ConcurrencyBackend = None,
         release_func: typing.Optional[ReleaseCallback] = None,
-        http_versions: HTTPVersionTypes = None,
     ):
         self.origin = Origin(origin) if isinstance(origin, str) else origin
         self.ssl = SSLConfig(cert=cert, verify=verify)
@@ -51,7 +51,9 @@ class HTTPConnection(AsyncDispatcher):
         http_versions: HTTPVersionTypes = None,
     ) -> AsyncResponse:
         if self.h11_connection is None and self.h2_connection is None:
-            await self.connect(verify=verify, cert=cert, timeout=timeout, http_versions=http_versions)
+            await self.connect(
+                verify=verify, cert=cert, timeout=timeout, http_versions=http_versions
+            )
 
         if self.h2_connection is not None:
             response = await self.h2_connection.send(request, timeout=timeout)
@@ -70,7 +72,11 @@ class HTTPConnection(AsyncDispatcher):
     ) -> None:
         ssl = self.ssl.with_overrides(verify=verify, cert=cert)
         timeout = self.timeout if timeout is None else TimeoutConfig(timeout)
-        http_versions = self.http_versions if http_versions is None else HTTPVersionConfig(http_versions)
+        http_versions = (
+            self.http_versions
+            if http_versions is None
+            else HTTPVersionConfig(http_versions)
+        )
 
         host = self.origin.host
         port = self.origin.port
@@ -93,7 +99,9 @@ class HTTPConnection(AsyncDispatcher):
                 reader, writer, self.backend, on_release=on_release
             )
 
-    async def get_ssl_context(self, ssl: SSLConfig, http_versions: HTTPVersionConfig) -> typing.Optional[ssl.SSLContext]:
+    async def get_ssl_context(
+        self, ssl: SSLConfig, http_versions: HTTPVersionConfig
+    ) -> typing.Optional[ssl.SSLContext]:
         if not self.origin.is_ssl:
             return None
 
index 3090a9a514dddca61b247c66a148f859bfd5acdf..a646d2a3885feafcdabcb56ac82ebe444006daff 100644 (file)
@@ -5,6 +5,7 @@ from ..config import (
     DEFAULT_POOL_LIMITS,
     DEFAULT_TIMEOUT_CONFIG,
     CertTypes,
+    HTTPVersionTypes,
     PoolLimits,
     TimeoutTypes,
     VerifyTypes,
@@ -80,12 +81,14 @@ class ConnectionPool(AsyncDispatcher):
         cert: CertTypes = None,
         timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
         pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
+        http_versions: HTTPVersionTypes = None,
         backend: ConcurrencyBackend = None,
     ):
         self.verify = verify
         self.cert = cert
         self.timeout = timeout
         self.pool_limits = pool_limits
+        self.http_versions = http_versions
         self.is_closed = False
 
         self.keepalive_connections = ConnectionStore()
@@ -104,11 +107,16 @@ class ConnectionPool(AsyncDispatcher):
         verify: VerifyTypes = None,
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
+        http_versions: HTTPVersionTypes = None,
     ) -> AsyncResponse:
         connection = await self.acquire_connection(origin=request.url.origin)
         try:
             response = await connection.send(
-                request, verify=verify, cert=cert, timeout=timeout
+                request,
+                verify=verify,
+                cert=cert,
+                timeout=timeout,
+                http_versions=http_versions,
             )
         except BaseException as exc:
             self.active_connections.remove(connection)
@@ -133,6 +141,7 @@ class ConnectionPool(AsyncDispatcher):
                 verify=self.verify,
                 cert=self.cert,
                 timeout=self.timeout,
+                http_versions=self.http_versions,
                 backend=self.backend,
                 release_func=self.release_connection,
             )
index e4fdfd5ed152dcf1e2a98f6d4b75caff2f23733e..902e0652bdf8d57024793f8519f61aeb9c60f166 100644 (file)
@@ -1,4 +1,4 @@
-from ..config import CertTypes, TimeoutTypes, VerifyTypes
+from ..config import CertTypes, HTTPVersionTypes, TimeoutTypes, VerifyTypes
 from ..interfaces import AsyncDispatcher, ConcurrencyBackend, Dispatcher
 from ..models import (
     AsyncRequest,
@@ -29,6 +29,7 @@ class ThreadedDispatcher(AsyncDispatcher):
         verify: VerifyTypes = None,
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
+        http_versions: HTTPVersionTypes = None,
     ) -> AsyncResponse:
         concurrency_backend = self.backend
 
@@ -48,6 +49,7 @@ class ThreadedDispatcher(AsyncDispatcher):
             "verify": verify,
             "cert": cert,
             "timeout": timeout,
+            "http_versions": http_versions,
         }
         sync_response = await self.backend.run_in_threadpool(func, **kwargs)
         assert isinstance(sync_response, Response)
index 60e0a18c7942030a3dcaf19b690767c139450123..d8f56d22c9cdaebd8bf29277f1640b29ae53f2ae 100644 (file)
@@ -1,7 +1,7 @@
 import io
 import typing
 
-from ..config import CertTypes, TimeoutTypes, VerifyTypes
+from ..config import CertTypes, HTTPVersionTypes, TimeoutTypes, VerifyTypes
 from ..interfaces import Dispatcher
 from ..models import Request, Response
 
@@ -60,6 +60,7 @@ class WSGIDispatch(Dispatcher):
         verify: VerifyTypes = None,
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
+        http_versions: HTTPVersionTypes = None,
     ) -> Response:
         environ = {
             "wsgi.version": (1, 0),
index a8758536027e262a296e6686061e7a98e3d82630..db6edaf4cbf16e07ef577de64b927eca26170ac8 100644 (file)
@@ -3,7 +3,15 @@ import ssl
 import typing
 from types import TracebackType
 
-from .config import CertTypes, PoolLimits, HTTPVersionConfig, TimeoutConfig, TimeoutTypes, VerifyTypes
+from .config import (
+    CertTypes,
+    PoolLimits,
+    HTTPVersionConfig,
+    HTTPVersionTypes,
+    TimeoutConfig,
+    TimeoutTypes,
+    VerifyTypes,
+)
 from .models import (
     AsyncRequest,
     AsyncRequestData,
@@ -42,9 +50,16 @@ class AsyncDispatcher:
         verify: VerifyTypes = None,
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
+        http_versions: HTTPVersionTypes = None,
     ) -> AsyncResponse:
         request = AsyncRequest(method, url, data=data, params=params, headers=headers)
-        return await self.send(request, verify=verify, cert=cert, timeout=timeout)
+        return await self.send(
+            request,
+            verify=verify,
+            cert=cert,
+            timeout=timeout,
+            http_versions=http_versions,
+        )
 
     async def send(
         self,
@@ -52,6 +67,7 @@ class AsyncDispatcher:
         verify: VerifyTypes = None,
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
+        http_versions: HTTPVersionTypes = None,
     ) -> AsyncResponse:
         raise NotImplementedError()  # pragma: nocover
 
@@ -90,9 +106,16 @@ class Dispatcher:
         verify: VerifyTypes = None,
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
+        http_versions: HTTPVersionTypes = None,
     ) -> Response:
         request = Request(method, url, data=data, params=params, headers=headers)
-        return self.send(request, verify=verify, cert=cert, timeout=timeout)
+        return self.send(
+            request,
+            verify=verify,
+            cert=cert,
+            timeout=timeout,
+            http_versions=http_versions,
+        )
 
     def send(
         self,
@@ -100,6 +123,7 @@ class Dispatcher:
         verify: VerifyTypes = None,
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
+        http_versions: HTTPVersionTypes = None,
     ) -> Response:
         raise NotImplementedError()  # pragma: nocover
 
@@ -171,7 +195,7 @@ class ConcurrencyBackend:
         hostname: str,
         port: int,
         ssl_context: typing.Optional[ssl.SSLContext],
-        timeout: TimeoutConfig
+        timeout: TimeoutConfig,
     ) -> typing.Tuple[BaseReader, BaseWriter, Protocol]:
         raise NotImplementedError()  # pragma: no cover
 
index 725ea56c00491aec0819ea66d17be5acabbb28ac..ff5b8e36060de33d3c366b56b00e6c0325c5dbfc 100644 (file)
@@ -8,6 +8,7 @@ from httpx import (
     AsyncResponse,
     CertTypes,
     Client,
+    HTTPVersionTypes,
     TimeoutTypes,
     VerifyTypes,
 )
@@ -20,6 +21,7 @@ class MockDispatch(AsyncDispatcher):
         verify: VerifyTypes = None,
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
+        http_versions: HTTPVersionTypes = None,
     ) -> AsyncResponse:
         body = json.dumps({"auth": request.headers.get("Authorization")}).encode()
         return AsyncResponse(200, content=body, request=request)
index 280f44bc2b1176f3b6171f7e08ec38ebbb4de16b..9585d21a727e7e02ff33bf76752bbed0226a0d00 100644 (file)
@@ -8,6 +8,7 @@ from httpx import (
     CertTypes,
     Client,
     Cookies,
+    HTTPVersionTypes,
     TimeoutTypes,
     VerifyTypes,
 )
@@ -20,6 +21,7 @@ class MockDispatch(AsyncDispatcher):
         verify: VerifyTypes = None,
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
+        http_versions: HTTPVersionTypes = None,
     ) -> AsyncResponse:
         if request.url.path.startswith("/echo_cookies"):
             body = json.dumps({"cookies": request.headers.get("Cookie")}).encode()
index 2d17c1259e5aa270da010319037c75d84c64d576..5410e397c8e194ce820c5310e2c2c064a88d0c89 100755 (executable)
@@ -8,6 +8,7 @@ from httpx import (
     AsyncResponse,
     CertTypes,
     Client,
+    HTTPVersionTypes,
     TimeoutTypes,
     VerifyTypes,
     __version__,
@@ -21,6 +22,7 @@ class MockDispatch(AsyncDispatcher):
         verify: VerifyTypes = None,
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
+        http_versions: HTTPVersionTypes = None,
     ) -> AsyncResponse:
         if request.url.path.startswith("/echo_headers"):
             request_headers = dict(request.headers.items())
index 3062733a73b2d7146f0a6eb7813100cea8652e24..3792ee85af76e71ecc152e811db859666b52b48a 100644 (file)
@@ -12,6 +12,7 @@ from httpx import (
     CertTypes,
     RedirectBodyUnavailable,
     RedirectLoop,
+    HTTPVersionTypes,
     TimeoutTypes,
     TooManyRedirects,
     VerifyTypes,
@@ -26,6 +27,7 @@ class MockDispatch(AsyncDispatcher):
         verify: VerifyTypes = None,
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
+        http_versions: HTTPVersionTypes = None,
     ) -> AsyncResponse:
         if request.url.path == "/redirect_301":
             status_code = codes.MOVED_PERMANENTLY
index ac90bdd2d6ba861b6f0833a014e6267fd8443774..488e13d9a70e5e4a0e57da948146c3b9e2c38ddb 100644 (file)
@@ -4,6 +4,7 @@ from httpx import (
     CertTypes,
     Client,
     Dispatcher,
+    HTTPVersionTypes,
     Request,
     Response,
     TimeoutTypes,
@@ -23,6 +24,7 @@ class MockDispatch(Dispatcher):
         verify: VerifyTypes = None,
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
+        http_versions: HTTPVersionTypes = None,
     ) -> Response:
         if request.url.path == "/streaming_response":
             return Response(200, content=streaming_body(), request=request)
index fb2b913a9fabbb409bc16ee2aeb2a3ac5cceddd1..de44f71166e1c9b5c1f4cc4eb8404d893831fb01 100644 (file)
@@ -13,7 +13,7 @@ from httpx import (
     Protocol,
     Request,
     TimeoutConfig,
-    HTTPVersionConfig
+    HTTPVersionConfig,
 )
 
 
@@ -27,7 +27,7 @@ class MockHTTP2Backend(AsyncioBackend):
         hostname: str,
         port: int,
         ssl_context: typing.Optional[ssl.SSLContext],
-        timeout: TimeoutConfig
+        timeout: TimeoutConfig,
     ) -> typing.Tuple[BaseReader, BaseWriter, Protocol]:
         self.server = MockHTTP2Server(self.app)
         return self.server, self.server, Protocol.HTTP_2
index 097adbdb26e392876b012892b7e6ea7b6dc79811..ba346c8adb9e31547c0276e3ec1c0a00ad474760 100644 (file)
@@ -12,6 +12,7 @@ from httpx import (
     Dispatcher,
     Request,
     Response,
+    HTTPVersionTypes,
     TimeoutTypes,
     VerifyTypes,
     multipart,
@@ -25,6 +26,7 @@ class MockDispatch(Dispatcher):
         verify: VerifyTypes = None,
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
+        http_versions: HTTPVersionTypes = None,
     ) -> Response:
         return Response(200, content=request.read())