From: Tom Christie Date: Tue, 20 Aug 2019 11:56:19 +0000 (+0100) Subject: Added http_versions to dispatch interface X-Git-Tag: 0.7.2~19^2~4 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a770102fd2d6c2410c028522ab04fe34c2662a66;p=thirdparty%2Fhttpx.git Added http_versions to dispatch interface --- diff --git a/httpx/config.py b/httpx/config.py index 73a5eb34..db55d3cb 100644 --- a/httpx/config.py +++ b/httpx/config.py @@ -9,7 +9,9 @@ from .__version__ import __version__ CertTypes = typing.Union[str, typing.Tuple[str, str], typing.Tuple[str, str, str]] VerifyTypes = typing.Union[str, bool, ssl.SSLContext] TimeoutTypes = typing.Union[float, typing.Tuple[float, float, float], "TimeoutConfig"] -HTTPVersionTypes = typing.Union[str, typing.List[str], typing.Tuple[str], "HTTPVersionConfig"] +HTTPVersionTypes = typing.Union[ + str, typing.List[str], typing.Tuple[str], "HTTPVersionConfig" +] USER_AGENT = f"python-httpx/{__version__}" @@ -73,7 +75,9 @@ class SSLConfig: return self return SSLConfig(cert=cert, verify=verify) - def load_ssl_context(self, http_versions: 'HTTPVersionConfig'=None) -> ssl.SSLContext: + def load_ssl_context( + self, http_versions: "HTTPVersionConfig" = None + ) -> ssl.SSLContext: http_versions = HTTPVersionConfig() if http_versions is None else http_versions if self.ssl_context is None: @@ -86,7 +90,9 @@ class SSLConfig: assert self.ssl_context is not None return self.ssl_context - def load_ssl_context_no_verify(self, http_versions: 'HTTPVersionConfig') -> ssl.SSLContext: + def load_ssl_context_no_verify( + self, http_versions: "HTTPVersionConfig" + ) -> ssl.SSLContext: """ Return an SSL context for unverified connections. """ @@ -95,7 +101,9 @@ class SSLConfig: context.check_hostname = False return context - def load_ssl_context_verify(self, http_versions: 'HTTPVersionConfig') -> ssl.SSLContext: + def load_ssl_context_verify( + self, http_versions: "HTTPVersionConfig" + ) -> ssl.SSLContext: """ Return an SSL context for verified connections. """ @@ -136,7 +144,9 @@ class SSLConfig: return context - def _create_default_ssl_context(self, http_versions: 'HTTPVersionConfig') -> ssl.SSLContext: + def _create_default_ssl_context( + self, http_versions: "HTTPVersionConfig" + ) -> ssl.SSLContext: """ Creates the default SSLContext object that's used for both verified and unverified connections. @@ -233,17 +243,19 @@ class HTTPVersionConfig: def __init__(self, http_versions: HTTPVersionTypes = None): if http_versions is None: - http_versions = ['HTTP/1.1', 'HTTP/2'] + http_versions = ["HTTP/1.1", "HTTP/2"] if isinstance(http_versions, str): self.http_versions = set([http_versions.upper()]) elif isinstance(http_versions, HTTPVersionConfig): self.http_versions = http_versions.http_versions else: - self.http_versions = set(sorted([version.upper() for version in http_versions])) + self.http_versions = set( + sorted([version.upper() for version in http_versions]) + ) for version in self.http_versions: - if version not in ('HTTP/1.1', 'HTTP/2'): + if version not in ("HTTP/1.1", "HTTP/2"): raise ValueError(f"Unsupported HTTP version {version!r}.") if not self.http_versions: diff --git a/httpx/dispatch/asgi.py b/httpx/dispatch/asgi.py index e3b4b0a0..e3a38ad8 100644 --- a/httpx/dispatch/asgi.py +++ b/httpx/dispatch/asgi.py @@ -1,7 +1,7 @@ import asyncio import typing -from ..config import CertTypes, TimeoutTypes, VerifyTypes +from ..config import CertTypes, HTTPVersionTypes, TimeoutTypes, VerifyTypes from ..interfaces import AsyncDispatcher from ..models import AsyncRequest, AsyncResponse @@ -59,6 +59,7 @@ class ASGIDispatch(AsyncDispatcher): verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, + http_versions: HTTPVersionTypes = None, ) -> AsyncResponse: scope = { diff --git a/httpx/dispatch/connection.py b/httpx/dispatch/connection.py index b8e263aa..230afc06 100644 --- a/httpx/dispatch/connection.py +++ b/httpx/dispatch/connection.py @@ -29,9 +29,9 @@ class HTTPConnection(AsyncDispatcher): verify: VerifyTypes = True, cert: CertTypes = None, timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + http_versions: HTTPVersionTypes = None, backend: ConcurrencyBackend = None, release_func: typing.Optional[ReleaseCallback] = None, - http_versions: HTTPVersionTypes = None, ): self.origin = Origin(origin) if isinstance(origin, str) else origin self.ssl = SSLConfig(cert=cert, verify=verify) @@ -51,7 +51,9 @@ class HTTPConnection(AsyncDispatcher): http_versions: HTTPVersionTypes = None, ) -> AsyncResponse: if self.h11_connection is None and self.h2_connection is None: - await self.connect(verify=verify, cert=cert, timeout=timeout, http_versions=http_versions) + await self.connect( + verify=verify, cert=cert, timeout=timeout, http_versions=http_versions + ) if self.h2_connection is not None: response = await self.h2_connection.send(request, timeout=timeout) @@ -70,7 +72,11 @@ class HTTPConnection(AsyncDispatcher): ) -> None: ssl = self.ssl.with_overrides(verify=verify, cert=cert) timeout = self.timeout if timeout is None else TimeoutConfig(timeout) - http_versions = self.http_versions if http_versions is None else HTTPVersionConfig(http_versions) + http_versions = ( + self.http_versions + if http_versions is None + else HTTPVersionConfig(http_versions) + ) host = self.origin.host port = self.origin.port @@ -93,7 +99,9 @@ class HTTPConnection(AsyncDispatcher): reader, writer, self.backend, on_release=on_release ) - async def get_ssl_context(self, ssl: SSLConfig, http_versions: HTTPVersionConfig) -> typing.Optional[ssl.SSLContext]: + async def get_ssl_context( + self, ssl: SSLConfig, http_versions: HTTPVersionConfig + ) -> typing.Optional[ssl.SSLContext]: if not self.origin.is_ssl: return None diff --git a/httpx/dispatch/connection_pool.py b/httpx/dispatch/connection_pool.py index 3090a9a5..a646d2a3 100644 --- a/httpx/dispatch/connection_pool.py +++ b/httpx/dispatch/connection_pool.py @@ -5,6 +5,7 @@ from ..config import ( DEFAULT_POOL_LIMITS, DEFAULT_TIMEOUT_CONFIG, CertTypes, + HTTPVersionTypes, PoolLimits, TimeoutTypes, VerifyTypes, @@ -80,12 +81,14 @@ class ConnectionPool(AsyncDispatcher): cert: CertTypes = None, timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, pool_limits: PoolLimits = DEFAULT_POOL_LIMITS, + http_versions: HTTPVersionTypes = None, backend: ConcurrencyBackend = None, ): self.verify = verify self.cert = cert self.timeout = timeout self.pool_limits = pool_limits + self.http_versions = http_versions self.is_closed = False self.keepalive_connections = ConnectionStore() @@ -104,11 +107,16 @@ class ConnectionPool(AsyncDispatcher): verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, + http_versions: HTTPVersionTypes = None, ) -> AsyncResponse: connection = await self.acquire_connection(origin=request.url.origin) try: response = await connection.send( - request, verify=verify, cert=cert, timeout=timeout + request, + verify=verify, + cert=cert, + timeout=timeout, + http_versions=http_versions, ) except BaseException as exc: self.active_connections.remove(connection) @@ -133,6 +141,7 @@ class ConnectionPool(AsyncDispatcher): verify=self.verify, cert=self.cert, timeout=self.timeout, + http_versions=self.http_versions, backend=self.backend, release_func=self.release_connection, ) diff --git a/httpx/dispatch/threaded.py b/httpx/dispatch/threaded.py index e4fdfd5e..902e0652 100644 --- a/httpx/dispatch/threaded.py +++ b/httpx/dispatch/threaded.py @@ -1,4 +1,4 @@ -from ..config import CertTypes, TimeoutTypes, VerifyTypes +from ..config import CertTypes, HTTPVersionTypes, TimeoutTypes, VerifyTypes from ..interfaces import AsyncDispatcher, ConcurrencyBackend, Dispatcher from ..models import ( AsyncRequest, @@ -29,6 +29,7 @@ class ThreadedDispatcher(AsyncDispatcher): verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, + http_versions: HTTPVersionTypes = None, ) -> AsyncResponse: concurrency_backend = self.backend @@ -48,6 +49,7 @@ class ThreadedDispatcher(AsyncDispatcher): "verify": verify, "cert": cert, "timeout": timeout, + "http_versions": http_versions, } sync_response = await self.backend.run_in_threadpool(func, **kwargs) assert isinstance(sync_response, Response) diff --git a/httpx/dispatch/wsgi.py b/httpx/dispatch/wsgi.py index 60e0a18c..d8f56d22 100644 --- a/httpx/dispatch/wsgi.py +++ b/httpx/dispatch/wsgi.py @@ -1,7 +1,7 @@ import io import typing -from ..config import CertTypes, TimeoutTypes, VerifyTypes +from ..config import CertTypes, HTTPVersionTypes, TimeoutTypes, VerifyTypes from ..interfaces import Dispatcher from ..models import Request, Response @@ -60,6 +60,7 @@ class WSGIDispatch(Dispatcher): verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, + http_versions: HTTPVersionTypes = None, ) -> Response: environ = { "wsgi.version": (1, 0), diff --git a/httpx/interfaces.py b/httpx/interfaces.py index a8758536..db6edaf4 100644 --- a/httpx/interfaces.py +++ b/httpx/interfaces.py @@ -3,7 +3,15 @@ import ssl import typing from types import TracebackType -from .config import CertTypes, PoolLimits, HTTPVersionConfig, TimeoutConfig, TimeoutTypes, VerifyTypes +from .config import ( + CertTypes, + PoolLimits, + HTTPVersionConfig, + HTTPVersionTypes, + TimeoutConfig, + TimeoutTypes, + VerifyTypes, +) from .models import ( AsyncRequest, AsyncRequestData, @@ -42,9 +50,16 @@ class AsyncDispatcher: verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, + http_versions: HTTPVersionTypes = None, ) -> AsyncResponse: request = AsyncRequest(method, url, data=data, params=params, headers=headers) - return await self.send(request, verify=verify, cert=cert, timeout=timeout) + return await self.send( + request, + verify=verify, + cert=cert, + timeout=timeout, + http_versions=http_versions, + ) async def send( self, @@ -52,6 +67,7 @@ class AsyncDispatcher: verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, + http_versions: HTTPVersionTypes = None, ) -> AsyncResponse: raise NotImplementedError() # pragma: nocover @@ -90,9 +106,16 @@ class Dispatcher: verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, + http_versions: HTTPVersionTypes = None, ) -> Response: request = Request(method, url, data=data, params=params, headers=headers) - return self.send(request, verify=verify, cert=cert, timeout=timeout) + return self.send( + request, + verify=verify, + cert=cert, + timeout=timeout, + http_versions=http_versions, + ) def send( self, @@ -100,6 +123,7 @@ class Dispatcher: verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, + http_versions: HTTPVersionTypes = None, ) -> Response: raise NotImplementedError() # pragma: nocover @@ -171,7 +195,7 @@ class ConcurrencyBackend: hostname: str, port: int, ssl_context: typing.Optional[ssl.SSLContext], - timeout: TimeoutConfig + timeout: TimeoutConfig, ) -> typing.Tuple[BaseReader, BaseWriter, Protocol]: raise NotImplementedError() # pragma: no cover diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 725ea56c..ff5b8e36 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -8,6 +8,7 @@ from httpx import ( AsyncResponse, CertTypes, Client, + HTTPVersionTypes, TimeoutTypes, VerifyTypes, ) @@ -20,6 +21,7 @@ class MockDispatch(AsyncDispatcher): verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, + http_versions: HTTPVersionTypes = None, ) -> AsyncResponse: body = json.dumps({"auth": request.headers.get("Authorization")}).encode() return AsyncResponse(200, content=body, request=request) diff --git a/tests/client/test_cookies.py b/tests/client/test_cookies.py index 280f44bc..9585d21a 100644 --- a/tests/client/test_cookies.py +++ b/tests/client/test_cookies.py @@ -8,6 +8,7 @@ from httpx import ( CertTypes, Client, Cookies, + HTTPVersionTypes, TimeoutTypes, VerifyTypes, ) @@ -20,6 +21,7 @@ class MockDispatch(AsyncDispatcher): verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, + http_versions: HTTPVersionTypes = None, ) -> AsyncResponse: if request.url.path.startswith("/echo_cookies"): body = json.dumps({"cookies": request.headers.get("Cookie")}).encode() diff --git a/tests/client/test_headers.py b/tests/client/test_headers.py index 2d17c125..5410e397 100755 --- a/tests/client/test_headers.py +++ b/tests/client/test_headers.py @@ -8,6 +8,7 @@ from httpx import ( AsyncResponse, CertTypes, Client, + HTTPVersionTypes, TimeoutTypes, VerifyTypes, __version__, @@ -21,6 +22,7 @@ class MockDispatch(AsyncDispatcher): verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, + http_versions: HTTPVersionTypes = None, ) -> AsyncResponse: if request.url.path.startswith("/echo_headers"): request_headers = dict(request.headers.items()) diff --git a/tests/client/test_redirects.py b/tests/client/test_redirects.py index 3062733a..3792ee85 100644 --- a/tests/client/test_redirects.py +++ b/tests/client/test_redirects.py @@ -12,6 +12,7 @@ from httpx import ( CertTypes, RedirectBodyUnavailable, RedirectLoop, + HTTPVersionTypes, TimeoutTypes, TooManyRedirects, VerifyTypes, @@ -26,6 +27,7 @@ class MockDispatch(AsyncDispatcher): verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, + http_versions: HTTPVersionTypes = None, ) -> AsyncResponse: if request.url.path == "/redirect_301": status_code = codes.MOVED_PERMANENTLY diff --git a/tests/dispatch/test_threaded.py b/tests/dispatch/test_threaded.py index ac90bdd2..488e13d9 100644 --- a/tests/dispatch/test_threaded.py +++ b/tests/dispatch/test_threaded.py @@ -4,6 +4,7 @@ from httpx import ( CertTypes, Client, Dispatcher, + HTTPVersionTypes, Request, Response, TimeoutTypes, @@ -23,6 +24,7 @@ class MockDispatch(Dispatcher): verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, + http_versions: HTTPVersionTypes = None, ) -> Response: if request.url.path == "/streaming_response": return Response(200, content=streaming_body(), request=request) diff --git a/tests/dispatch/utils.py b/tests/dispatch/utils.py index fb2b913a..de44f711 100644 --- a/tests/dispatch/utils.py +++ b/tests/dispatch/utils.py @@ -13,7 +13,7 @@ from httpx import ( Protocol, Request, TimeoutConfig, - HTTPVersionConfig + HTTPVersionConfig, ) @@ -27,7 +27,7 @@ class MockHTTP2Backend(AsyncioBackend): hostname: str, port: int, ssl_context: typing.Optional[ssl.SSLContext], - timeout: TimeoutConfig + timeout: TimeoutConfig, ) -> typing.Tuple[BaseReader, BaseWriter, Protocol]: self.server = MockHTTP2Server(self.app) return self.server, self.server, Protocol.HTTP_2 diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 097adbdb..ba346c8a 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -12,6 +12,7 @@ from httpx import ( Dispatcher, Request, Response, + HTTPVersionTypes, TimeoutTypes, VerifyTypes, multipart, @@ -25,6 +26,7 @@ class MockDispatch(Dispatcher): verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, + http_versions: HTTPVersionTypes = None, ) -> Response: return Response(200, content=request.read())