From: Tom Christie Date: Mon, 2 Dec 2019 19:26:16 +0000 (+0000) Subject: Concurrency autodetection (#585) X-Git-Tag: 0.9.0~26 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=3cbe7315e8e768371eff398ebb15a73a31fe601d;p=thirdparty%2Fhttpx.git Concurrency autodetection (#585) * Simplify HTTP version config, and switch HTTP/2 off by default * HTTP/2 docs * HTTP/2 interlinking in docs * Add concurrency auto-detection * Add sniffio --- diff --git a/README.md b/README.md index 66393ff9..431a9594 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,7 @@ The httpx project relies on these excellent libraries: * `hstspreload` - determines whether IDNA-encoded host should be only accessed via HTTPS. * `idna` - Internationalized domain name support. * `rfc3986` - URL parsing & normalization. +* `sniffio` - Async library autodetection. * `brotlipy` - Decoding for "brotli" compressed responses. *(Optional)* A huge amount of credit is due to `requests` for the API layout that diff --git a/httpx/client.py b/httpx/client.py index 28aa1767..2593eddf 100644 --- a/httpx/client.py +++ b/httpx/client.py @@ -5,7 +5,6 @@ from types import TracebackType import hstspreload from .auth import BasicAuth -from .concurrency.asyncio import AsyncioBackend from .concurrency.base import ConcurrencyBackend from .config import ( DEFAULT_MAX_REDIRECTS, @@ -95,7 +94,8 @@ class Client: * **app** - *(optional)* An ASGI application to send requests to, rather than sending actual network requests. * **backend** - *(optional)* A concurrency backend to use when issuing - async requests. + async requests. Either 'auto', 'asyncio', 'trio', or a `ConcurrencyBackend` + instance. Defaults to 'auto', for autodetection. * **trust_env** - *(optional)* Enables or disables usage of environment variables for configuration. * **uds** - *(optional)* A path to a Unix domain socket to connect through. @@ -118,15 +118,12 @@ class Client: base_url: URLTypes = None, dispatch: Dispatcher = None, app: typing.Callable = None, - backend: ConcurrencyBackend = None, + backend: typing.Union[str, ConcurrencyBackend] = "auto", trust_env: bool = True, uds: str = None, ): - if backend is None: - backend = AsyncioBackend() - if app is not None: - dispatch = ASGIDispatch(app=app, backend=backend) + dispatch = ASGIDispatch(app=app) if dispatch is None: dispatch = ConnectionPool( @@ -155,7 +152,6 @@ class Client: self.max_redirects = max_redirects self.trust_env = trust_env self.dispatch = dispatch - self.concurrency_backend = backend self.netrc = NetRCInfo() if proxies is None and trust_env: @@ -834,7 +830,7 @@ def _proxies_to_dispatchers( timeout: TimeoutTypes, http_2: bool, pool_limits: PoolLimits, - backend: ConcurrencyBackend, + backend: typing.Union[str, ConcurrencyBackend], trust_env: bool, ) -> typing.Dict[str, Dispatcher]: def _proxy_from_url(url: URLTypes) -> Dispatcher: diff --git a/httpx/concurrency/auto.py b/httpx/concurrency/auto.py new file mode 100644 index 00000000..9b3518de --- /dev/null +++ b/httpx/concurrency/auto.py @@ -0,0 +1,59 @@ +import ssl +import typing + +import sniffio + +from ..config import PoolLimits, TimeoutConfig +from .base import ( + BaseBackgroundManager, + BaseEvent, + BasePoolSemaphore, + BaseSocketStream, + ConcurrencyBackend, + lookup_backend, +) + + +class AutoBackend(ConcurrencyBackend): + @property + def backend(self) -> ConcurrencyBackend: + if not hasattr(self, "_backend_implementation"): + backend = sniffio.current_async_library() + if backend not in ("asyncio", "trio"): + raise RuntimeError(f"Unsupported concurrency backend {backend!r}") + self._backend_implementation = lookup_backend(backend) + return self._backend_implementation + + async def open_tcp_stream( + self, + hostname: str, + port: int, + ssl_context: typing.Optional[ssl.SSLContext], + timeout: TimeoutConfig, + ) -> BaseSocketStream: + return await self.backend.open_tcp_stream(hostname, port, ssl_context, timeout) + + async def open_uds_stream( + self, + path: str, + hostname: typing.Optional[str], + ssl_context: typing.Optional[ssl.SSLContext], + timeout: TimeoutConfig, + ) -> BaseSocketStream: + return await self.backend.open_uds_stream(path, hostname, ssl_context, timeout) + + def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: + return self.backend.get_semaphore(limits) + + async def run_in_threadpool( + self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any + ) -> typing.Any: + return await self.backend.run_in_threadpool(func, *args, **kwargs) + + def create_event(self) -> BaseEvent: + return self.backend.create_event() + + def background_manager( + self, coroutine: typing.Callable, *args: typing.Any + ) -> BaseBackgroundManager: + return self.backend.background_manager(coroutine, *args) diff --git a/httpx/concurrency/base.py b/httpx/concurrency/base.py index 6bbeb071..33aca070 100644 --- a/httpx/concurrency/base.py +++ b/httpx/concurrency/base.py @@ -5,6 +5,28 @@ from types import TracebackType from ..config import PoolLimits, TimeoutConfig +def lookup_backend( + backend: typing.Union[str, "ConcurrencyBackend"] = "auto" +) -> "ConcurrencyBackend": + if not isinstance(backend, str): + return backend + + if backend == "auto": + from .auto import AutoBackend + + return AutoBackend() + elif backend == "asyncio": + from .asyncio import AsyncioBackend + + return AsyncioBackend() + elif backend == "trio": + from .trio import TrioBackend + + return TrioBackend() + + raise RuntimeError(f"Unknown or unsupported concurrency backend {backend!r}") + + class TimeoutFlag: """ A timeout flag holds a state of either read-timeout or write-timeout mode. diff --git a/httpx/dispatch/asgi.py b/httpx/dispatch/asgi.py index ee75debb..a1e83255 100644 --- a/httpx/dispatch/asgi.py +++ b/httpx/dispatch/asgi.py @@ -1,7 +1,5 @@ import typing -from ..concurrency.asyncio import AsyncioBackend -from ..concurrency.base import ConcurrencyBackend from ..config import CertTypes, TimeoutTypes, VerifyTypes from ..models import Request, Response from .base import Dispatcher @@ -49,13 +47,11 @@ class ASGIDispatch(Dispatcher): raise_app_exceptions: bool = True, root_path: str = "", client: typing.Tuple[str, int] = ("127.0.0.1", 123), - backend: ConcurrencyBackend = None, ) -> None: self.app = app self.raise_app_exceptions = raise_app_exceptions self.root_path = root_path self.client = client - self.backend = AsyncioBackend() if backend is None else backend async def send( self, diff --git a/httpx/dispatch/connection.py b/httpx/dispatch/connection.py index 0f05b2e8..e030770d 100644 --- a/httpx/dispatch/connection.py +++ b/httpx/dispatch/connection.py @@ -2,8 +2,7 @@ import functools import ssl import typing -from ..concurrency.asyncio import AsyncioBackend -from ..concurrency.base import ConcurrencyBackend +from ..concurrency.base import ConcurrencyBackend, lookup_backend from ..config import ( DEFAULT_TIMEOUT_CONFIG, CertTypes, @@ -34,7 +33,7 @@ class HTTPConnection(Dispatcher): trust_env: bool = None, timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, http_2: bool = False, - backend: ConcurrencyBackend = None, + backend: typing.Union[str, ConcurrencyBackend] = "auto", release_func: typing.Optional[ReleaseCallback] = None, uds: typing.Optional[str] = None, ): @@ -42,7 +41,7 @@ class HTTPConnection(Dispatcher): self.ssl = SSLConfig(cert=cert, verify=verify, trust_env=trust_env) self.timeout = TimeoutConfig(timeout) self.http_2 = http_2 - self.backend = AsyncioBackend() if backend is None else backend + self.backend = lookup_backend(backend) self.release_func = release_func self.uds = uds self.h11_connection = None # type: typing.Optional[HTTP11Connection] @@ -104,13 +103,11 @@ class HTTPConnection(Dispatcher): if http_version == "HTTP/2": self.h2_connection = HTTP2Connection( - stream, self.backend, on_release=on_release + stream, backend=self.backend, on_release=on_release ) else: assert http_version == "HTTP/1.1" - self.h11_connection = HTTP11Connection( - stream, self.backend, on_release=on_release - ) + self.h11_connection = HTTP11Connection(stream, on_release=on_release) async def get_ssl_context(self, ssl: SSLConfig) -> typing.Optional[ssl.SSLContext]: if not self.origin.is_ssl: diff --git a/httpx/dispatch/connection_pool.py b/httpx/dispatch/connection_pool.py index fc44c413..3069568b 100644 --- a/httpx/dispatch/connection_pool.py +++ b/httpx/dispatch/connection_pool.py @@ -1,7 +1,6 @@ import typing -from ..concurrency.asyncio import AsyncioBackend -from ..concurrency.base import ConcurrencyBackend +from ..concurrency.base import BasePoolSemaphore, ConcurrencyBackend, lookup_backend from ..config import ( DEFAULT_POOL_LIMITS, DEFAULT_TIMEOUT_CONFIG, @@ -88,7 +87,7 @@ class ConnectionPool(Dispatcher): timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, pool_limits: PoolLimits = DEFAULT_POOL_LIMITS, http_2: bool = False, - backend: ConcurrencyBackend = None, + backend: typing.Union[str, ConcurrencyBackend] = "auto", uds: typing.Optional[str] = None, ): self.verify = verify @@ -103,8 +102,15 @@ class ConnectionPool(Dispatcher): self.keepalive_connections = ConnectionStore() self.active_connections = ConnectionStore() - self.backend = AsyncioBackend() if backend is None else backend - self.max_connections = self.backend.get_semaphore(pool_limits) + self.backend = lookup_backend(backend) + + @property + def max_connections(self) -> BasePoolSemaphore: + # We do this lazily, to make sure backend autodetection always + # runs within an async context. + if not hasattr(self, "_max_connections"): + self._max_connections = self.backend.get_semaphore(self.pool_limits) + return self._max_connections @property def num_connections(self) -> int: diff --git a/httpx/dispatch/http11.py b/httpx/dispatch/http11.py index 8202781b..e4426295 100644 --- a/httpx/dispatch/http11.py +++ b/httpx/dispatch/http11.py @@ -2,7 +2,7 @@ import typing import h11 -from ..concurrency.base import BaseSocketStream, ConcurrencyBackend, TimeoutFlag +from ..concurrency.base import BaseSocketStream, TimeoutFlag from ..config import TimeoutConfig, TimeoutTypes from ..exceptions import ConnectionClosed, ProtocolError from ..models import Request, Response @@ -33,11 +33,9 @@ class HTTP11Connection: def __init__( self, stream: BaseSocketStream, - backend: ConcurrencyBackend, on_release: typing.Optional[OnReleaseCallback] = None, ): self.stream = stream - self.backend = backend self.on_release = on_release self.h11_state = h11.Connection(our_role=h11.CLIENT) self.timeout_flag = TimeoutFlag() diff --git a/httpx/dispatch/http2.py b/httpx/dispatch/http2.py index 9947155a..7bfd519c 100644 --- a/httpx/dispatch/http2.py +++ b/httpx/dispatch/http2.py @@ -10,6 +10,7 @@ from ..concurrency.base import ( BaseSocketStream, ConcurrencyBackend, TimeoutFlag, + lookup_backend, ) from ..config import TimeoutConfig, TimeoutTypes from ..exceptions import ProtocolError @@ -25,11 +26,11 @@ class HTTP2Connection: def __init__( self, stream: BaseSocketStream, - backend: ConcurrencyBackend, + backend: typing.Union[str, ConcurrencyBackend] = "auto", on_release: typing.Callable = None, ): self.stream = stream - self.backend = backend + self.backend = lookup_backend(backend) self.on_release = on_release self.h2_state = h2.connection.H2Connection() self.events = {} # type: typing.Dict[int, typing.List[h2.events.Event]] diff --git a/httpx/dispatch/proxy_http.py b/httpx/dispatch/proxy_http.py index c51c009e..effbd79e 100644 --- a/httpx/dispatch/proxy_http.py +++ b/httpx/dispatch/proxy_http.py @@ -1,4 +1,5 @@ import enum +import typing from base64 import b64encode import h11 @@ -47,7 +48,7 @@ class HTTPProxy(ConnectionPool): timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, pool_limits: PoolLimits = DEFAULT_POOL_LIMITS, http_2: bool = False, - backend: ConcurrencyBackend = None, + backend: typing.Union[str, ConcurrencyBackend] = "auto", ): super(HTTPProxy, self).__init__( @@ -207,9 +208,7 @@ class HTTPProxy(ConnectionPool): ) else: assert http_version == "HTTP/1.1" - connection.h11_connection = HTTP11Connection( - stream, self.backend, on_release=on_release - ) + connection.h11_connection = HTTP11Connection(stream, on_release=on_release) def should_forward_origin(self, origin: Origin) -> bool: """Determines if the given origin should diff --git a/setup.cfg b/setup.cfg index d75cf544..a1064ada 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,7 +14,7 @@ combine_as_imports = True force_grid_wrap = 0 include_trailing_comma = True known_first_party = httpx,httpxprof,tests -known_third_party = brotli,certifi,chardet,click,cryptography,h11,h2,hstspreload,pytest,rfc3986,setuptools,tqdm,trio,trustme,uvicorn +known_third_party = brotli,certifi,chardet,click,cryptography,h11,h2,hstspreload,pytest,rfc3986,setuptools,sniffio,tqdm,trio,trustme,uvicorn line_length = 88 multi_line_output = 3 diff --git a/setup.py b/setup.py index bb16e963..e58fd8ce 100644 --- a/setup.py +++ b/setup.py @@ -52,12 +52,13 @@ setup( zip_safe=False, install_requires=[ "certifi", + "hstspreload", "chardet==3.*", "h11==0.8.*", "h2==3.*", - "hstspreload>=2019.8.27", "idna==2.*", "rfc3986==1.*", + "sniffio==1.*", ], classifiers=[ "Development Status :: 3 - Alpha", diff --git a/tests/client/test_proxies.py b/tests/client/test_proxies.py index ede15a8f..c0c1ae0e 100644 --- a/tests/client/test_proxies.py +++ b/tests/client/test_proxies.py @@ -48,7 +48,6 @@ def test_proxies_has_same_properties_as_dispatch(): "cert", "timeout", "pool_limits", - "backend", ]: assert getattr(pool, prop) == getattr(proxy, prop)