]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Concurrency autodetection (#585)
authorTom Christie <tom@tomchristie.com>
Mon, 2 Dec 2019 19:26:16 +0000 (19:26 +0000)
committerGitHub <noreply@github.com>
Mon, 2 Dec 2019 19:26:16 +0000 (19:26 +0000)
* Simplify HTTP version config, and switch HTTP/2 off by default

* HTTP/2 docs

* HTTP/2 interlinking in docs

* Add concurrency auto-detection

* Add sniffio

13 files changed:
README.md
httpx/client.py
httpx/concurrency/auto.py [new file with mode: 0644]
httpx/concurrency/base.py
httpx/dispatch/asgi.py
httpx/dispatch/connection.py
httpx/dispatch/connection_pool.py
httpx/dispatch/http11.py
httpx/dispatch/http2.py
httpx/dispatch/proxy_http.py
setup.cfg
setup.py
tests/client/test_proxies.py

index 66393ff909f04090034ebe17292b9689ee37b2d8..431a959498d2b55be6c6fb6c2bea1183c5b5afa3 100644 (file)
--- a/README.md
+++ b/README.md
@@ -113,6 +113,7 @@ The httpx project relies on these excellent libraries:
 * `hstspreload` - determines whether IDNA-encoded host should be only accessed via HTTPS.
 * `idna` - Internationalized domain name support.
 * `rfc3986` - URL parsing & normalization.
+* `sniffio` - Async library autodetection.
 * `brotlipy` - Decoding for "brotli" compressed responses. *(Optional)*
 
 A huge amount of credit is due to `requests` for the API layout that
index 28aa1767ecaf2f8e27b4839061598a0cde6b795d..2593eddf9d17ca2c12ce771d594c9bcd23323d14 100644 (file)
@@ -5,7 +5,6 @@ from types import TracebackType
 import hstspreload
 
 from .auth import BasicAuth
-from .concurrency.asyncio import AsyncioBackend
 from .concurrency.base import ConcurrencyBackend
 from .config import (
     DEFAULT_MAX_REDIRECTS,
@@ -95,7 +94,8 @@ class Client:
     * **app** - *(optional)* An ASGI application to send requests to,
     rather than sending actual network requests.
     * **backend** - *(optional)* A concurrency backend to use when issuing
-    async requests.
+    async requests. Either 'auto', 'asyncio', 'trio', or a `ConcurrencyBackend`
+    instance. Defaults to 'auto', for autodetection.
     * **trust_env** - *(optional)* Enables or disables usage of environment
     variables for configuration.
     * **uds** - *(optional)* A path to a Unix domain socket to connect through.
@@ -118,15 +118,12 @@ class Client:
         base_url: URLTypes = None,
         dispatch: Dispatcher = None,
         app: typing.Callable = None,
-        backend: ConcurrencyBackend = None,
+        backend: typing.Union[str, ConcurrencyBackend] = "auto",
         trust_env: bool = True,
         uds: str = None,
     ):
-        if backend is None:
-            backend = AsyncioBackend()
-
         if app is not None:
-            dispatch = ASGIDispatch(app=app, backend=backend)
+            dispatch = ASGIDispatch(app=app)
 
         if dispatch is None:
             dispatch = ConnectionPool(
@@ -155,7 +152,6 @@ class Client:
         self.max_redirects = max_redirects
         self.trust_env = trust_env
         self.dispatch = dispatch
-        self.concurrency_backend = backend
         self.netrc = NetRCInfo()
 
         if proxies is None and trust_env:
@@ -834,7 +830,7 @@ def _proxies_to_dispatchers(
     timeout: TimeoutTypes,
     http_2: bool,
     pool_limits: PoolLimits,
-    backend: ConcurrencyBackend,
+    backend: typing.Union[str, ConcurrencyBackend],
     trust_env: bool,
 ) -> typing.Dict[str, Dispatcher]:
     def _proxy_from_url(url: URLTypes) -> Dispatcher:
diff --git a/httpx/concurrency/auto.py b/httpx/concurrency/auto.py
new file mode 100644 (file)
index 0000000..9b3518d
--- /dev/null
@@ -0,0 +1,59 @@
+import ssl
+import typing
+
+import sniffio
+
+from ..config import PoolLimits, TimeoutConfig
+from .base import (
+    BaseBackgroundManager,
+    BaseEvent,
+    BasePoolSemaphore,
+    BaseSocketStream,
+    ConcurrencyBackend,
+    lookup_backend,
+)
+
+
+class AutoBackend(ConcurrencyBackend):
+    @property
+    def backend(self) -> ConcurrencyBackend:
+        if not hasattr(self, "_backend_implementation"):
+            backend = sniffio.current_async_library()
+            if backend not in ("asyncio", "trio"):
+                raise RuntimeError(f"Unsupported concurrency backend {backend!r}")
+            self._backend_implementation = lookup_backend(backend)
+        return self._backend_implementation
+
+    async def open_tcp_stream(
+        self,
+        hostname: str,
+        port: int,
+        ssl_context: typing.Optional[ssl.SSLContext],
+        timeout: TimeoutConfig,
+    ) -> BaseSocketStream:
+        return await self.backend.open_tcp_stream(hostname, port, ssl_context, timeout)
+
+    async def open_uds_stream(
+        self,
+        path: str,
+        hostname: typing.Optional[str],
+        ssl_context: typing.Optional[ssl.SSLContext],
+        timeout: TimeoutConfig,
+    ) -> BaseSocketStream:
+        return await self.backend.open_uds_stream(path, hostname, ssl_context, timeout)
+
+    def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
+        return self.backend.get_semaphore(limits)
+
+    async def run_in_threadpool(
+        self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
+    ) -> typing.Any:
+        return await self.backend.run_in_threadpool(func, *args, **kwargs)
+
+    def create_event(self) -> BaseEvent:
+        return self.backend.create_event()
+
+    def background_manager(
+        self, coroutine: typing.Callable, *args: typing.Any
+    ) -> BaseBackgroundManager:
+        return self.backend.background_manager(coroutine, *args)
index 6bbeb071ed541d651b8dd967ab4e9eb3cd7d78e5..33aca070f12d53ec62b8a53172ff97a3c4264d05 100644 (file)
@@ -5,6 +5,28 @@ from types import TracebackType
 from ..config import PoolLimits, TimeoutConfig
 
 
+def lookup_backend(
+    backend: typing.Union[str, "ConcurrencyBackend"] = "auto"
+) -> "ConcurrencyBackend":
+    if not isinstance(backend, str):
+        return backend
+
+    if backend == "auto":
+        from .auto import AutoBackend
+
+        return AutoBackend()
+    elif backend == "asyncio":
+        from .asyncio import AsyncioBackend
+
+        return AsyncioBackend()
+    elif backend == "trio":
+        from .trio import TrioBackend
+
+        return TrioBackend()
+
+    raise RuntimeError(f"Unknown or unsupported concurrency backend {backend!r}")
+
+
 class TimeoutFlag:
     """
     A timeout flag holds a state of either read-timeout or write-timeout mode.
index ee75debb5f40aa9f5171ac049e4408422a5963dc..a1e83255951dce4b427ec529e7119474d6ac87af 100644 (file)
@@ -1,7 +1,5 @@
 import typing
 
-from ..concurrency.asyncio import AsyncioBackend
-from ..concurrency.base import ConcurrencyBackend
 from ..config import CertTypes, TimeoutTypes, VerifyTypes
 from ..models import Request, Response
 from .base import Dispatcher
@@ -49,13 +47,11 @@ class ASGIDispatch(Dispatcher):
         raise_app_exceptions: bool = True,
         root_path: str = "",
         client: typing.Tuple[str, int] = ("127.0.0.1", 123),
-        backend: ConcurrencyBackend = None,
     ) -> None:
         self.app = app
         self.raise_app_exceptions = raise_app_exceptions
         self.root_path = root_path
         self.client = client
-        self.backend = AsyncioBackend() if backend is None else backend
 
     async def send(
         self,
index 0f05b2e8b211994c42c60ba61d138f4bf5166920..e030770d36b2d29334fb9487dfbb9f891ee0c468 100644 (file)
@@ -2,8 +2,7 @@ import functools
 import ssl
 import typing
 
-from ..concurrency.asyncio import AsyncioBackend
-from ..concurrency.base import ConcurrencyBackend
+from ..concurrency.base import ConcurrencyBackend, lookup_backend
 from ..config import (
     DEFAULT_TIMEOUT_CONFIG,
     CertTypes,
@@ -34,7 +33,7 @@ class HTTPConnection(Dispatcher):
         trust_env: bool = None,
         timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
         http_2: bool = False,
-        backend: ConcurrencyBackend = None,
+        backend: typing.Union[str, ConcurrencyBackend] = "auto",
         release_func: typing.Optional[ReleaseCallback] = None,
         uds: typing.Optional[str] = None,
     ):
@@ -42,7 +41,7 @@ class HTTPConnection(Dispatcher):
         self.ssl = SSLConfig(cert=cert, verify=verify, trust_env=trust_env)
         self.timeout = TimeoutConfig(timeout)
         self.http_2 = http_2
-        self.backend = AsyncioBackend() if backend is None else backend
+        self.backend = lookup_backend(backend)
         self.release_func = release_func
         self.uds = uds
         self.h11_connection = None  # type: typing.Optional[HTTP11Connection]
@@ -104,13 +103,11 @@ class HTTPConnection(Dispatcher):
 
         if http_version == "HTTP/2":
             self.h2_connection = HTTP2Connection(
-                stream, self.backend, on_release=on_release
+                stream, backend=self.backend, on_release=on_release
             )
         else:
             assert http_version == "HTTP/1.1"
-            self.h11_connection = HTTP11Connection(
-                stream, self.backend, on_release=on_release
-            )
+            self.h11_connection = HTTP11Connection(stream, on_release=on_release)
 
     async def get_ssl_context(self, ssl: SSLConfig) -> typing.Optional[ssl.SSLContext]:
         if not self.origin.is_ssl:
index fc44c413016e1f69579c768f0906172c1edcd79d..3069568b0ba0fda2f0d233a142ca83bc0b82d279 100644 (file)
@@ -1,7 +1,6 @@
 import typing
 
-from ..concurrency.asyncio import AsyncioBackend
-from ..concurrency.base import ConcurrencyBackend
+from ..concurrency.base import BasePoolSemaphore, ConcurrencyBackend, lookup_backend
 from ..config import (
     DEFAULT_POOL_LIMITS,
     DEFAULT_TIMEOUT_CONFIG,
@@ -88,7 +87,7 @@ class ConnectionPool(Dispatcher):
         timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
         pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
         http_2: bool = False,
-        backend: ConcurrencyBackend = None,
+        backend: typing.Union[str, ConcurrencyBackend] = "auto",
         uds: typing.Optional[str] = None,
     ):
         self.verify = verify
@@ -103,8 +102,15 @@ class ConnectionPool(Dispatcher):
         self.keepalive_connections = ConnectionStore()
         self.active_connections = ConnectionStore()
 
-        self.backend = AsyncioBackend() if backend is None else backend
-        self.max_connections = self.backend.get_semaphore(pool_limits)
+        self.backend = lookup_backend(backend)
+
+    @property
+    def max_connections(self) -> BasePoolSemaphore:
+        # We do this lazily, to make sure backend autodetection always
+        # runs within an async context.
+        if not hasattr(self, "_max_connections"):
+            self._max_connections = self.backend.get_semaphore(self.pool_limits)
+        return self._max_connections
 
     @property
     def num_connections(self) -> int:
index 8202781b6cffe699c2c5fef521d815f0d6eb0edf..e4426295010a9e94d2b1b9b27967f50c1281683c 100644 (file)
@@ -2,7 +2,7 @@ import typing
 
 import h11
 
-from ..concurrency.base import BaseSocketStream, ConcurrencyBackend, TimeoutFlag
+from ..concurrency.base import BaseSocketStream, TimeoutFlag
 from ..config import TimeoutConfig, TimeoutTypes
 from ..exceptions import ConnectionClosed, ProtocolError
 from ..models import Request, Response
@@ -33,11 +33,9 @@ class HTTP11Connection:
     def __init__(
         self,
         stream: BaseSocketStream,
-        backend: ConcurrencyBackend,
         on_release: typing.Optional[OnReleaseCallback] = None,
     ):
         self.stream = stream
-        self.backend = backend
         self.on_release = on_release
         self.h11_state = h11.Connection(our_role=h11.CLIENT)
         self.timeout_flag = TimeoutFlag()
index 9947155a0c4bc9bcf6c7b34b8b8f8e13f12924e6..7bfd519cb9aa0d509211c19a71c4e750f46843a0 100644 (file)
@@ -10,6 +10,7 @@ from ..concurrency.base import (
     BaseSocketStream,
     ConcurrencyBackend,
     TimeoutFlag,
+    lookup_backend,
 )
 from ..config import TimeoutConfig, TimeoutTypes
 from ..exceptions import ProtocolError
@@ -25,11 +26,11 @@ class HTTP2Connection:
     def __init__(
         self,
         stream: BaseSocketStream,
-        backend: ConcurrencyBackend,
+        backend: typing.Union[str, ConcurrencyBackend] = "auto",
         on_release: typing.Callable = None,
     ):
         self.stream = stream
-        self.backend = backend
+        self.backend = lookup_backend(backend)
         self.on_release = on_release
         self.h2_state = h2.connection.H2Connection()
         self.events = {}  # type: typing.Dict[int, typing.List[h2.events.Event]]
index c51c009e120f807c0500770b4cf89866c23fb609..effbd79ee2375a0df5635fc3ed60f18e984fc40f 100644 (file)
@@ -1,4 +1,5 @@
 import enum
+import typing
 from base64 import b64encode
 
 import h11
@@ -47,7 +48,7 @@ class HTTPProxy(ConnectionPool):
         timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
         pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
         http_2: bool = False,
-        backend: ConcurrencyBackend = None,
+        backend: typing.Union[str, ConcurrencyBackend] = "auto",
     ):
 
         super(HTTPProxy, self).__init__(
@@ -207,9 +208,7 @@ class HTTPProxy(ConnectionPool):
             )
         else:
             assert http_version == "HTTP/1.1"
-            connection.h11_connection = HTTP11Connection(
-                stream, self.backend, on_release=on_release
-            )
+            connection.h11_connection = HTTP11Connection(stream, on_release=on_release)
 
     def should_forward_origin(self, origin: Origin) -> bool:
         """Determines if the given origin should
index d75cf544004dbb4260226b40f608e849b15f0c14..a1064adaea0ee8e1f9e0d1c2f66837d166f0e393 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -14,7 +14,7 @@ combine_as_imports = True
 force_grid_wrap = 0
 include_trailing_comma = True
 known_first_party = httpx,httpxprof,tests
-known_third_party = brotli,certifi,chardet,click,cryptography,h11,h2,hstspreload,pytest,rfc3986,setuptools,tqdm,trio,trustme,uvicorn
+known_third_party = brotli,certifi,chardet,click,cryptography,h11,h2,hstspreload,pytest,rfc3986,setuptools,sniffio,tqdm,trio,trustme,uvicorn
 line_length = 88
 multi_line_output = 3
 
index bb16e9631b403e3feaa0dcc01cd4e1dd9cba79ba..e58fd8ce68b322cf0024476923b491e88ed782af 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -52,12 +52,13 @@ setup(
     zip_safe=False,
     install_requires=[
         "certifi",
+        "hstspreload",
         "chardet==3.*",
         "h11==0.8.*",
         "h2==3.*",
-        "hstspreload>=2019.8.27",
         "idna==2.*",
         "rfc3986==1.*",
+        "sniffio==1.*",
     ],
     classifiers=[
         "Development Status :: 3 - Alpha",
index ede15a8fce1ced4d304f0d20c12153c980a34400..c0c1ae0e83a169b73e32cee6ec34299b8f510a42 100644 (file)
@@ -48,7 +48,6 @@ def test_proxies_has_same_properties_as_dispatch():
         "cert",
         "timeout",
         "pool_limits",
-        "backend",
     ]:
         assert getattr(pool, prop) == getattr(proxy, prop)