]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Enforce that sync client uses asyncio-based backend (#232)
authorFlorimond Manca <florimond.manca@gmail.com>
Sun, 18 Aug 2019 14:41:37 +0000 (16:41 +0200)
committerSeth Michael Larson <sethmichaellarson@gmail.com>
Sun, 18 Aug 2019 14:41:37 +0000 (09:41 -0500)
httpx/client.py
tests/client/test_client.py

index 4a2fb88ed3fddc6795b100872a219104fb7e1968..40bf0f18a9ded5a2332997c6acdbe721edd07a7d 100644 (file)
@@ -70,6 +70,8 @@ class BaseClient:
         if backend is None:
             backend = AsyncioBackend()
 
+        self.check_concurrency_backend(backend)
+
         if app is not None:
             param_count = len(inspect.signature(app).parameters)
             assert param_count in (2, 3)
@@ -108,6 +110,9 @@ class BaseClient:
         self.concurrency_backend = backend
         self.trust_env = True if trust_env is None else trust_env
 
+    def check_concurrency_backend(self, backend: ConcurrencyBackend) -> None:
+        pass  # pragma: no cover
+
     def merge_url(self, url: URLTypes) -> URL:
         url = self.base_url.join(relative_url=url)
         if url.scheme == "http" and hstspreload.in_hsts_preload(url.host):
@@ -623,6 +628,19 @@ class AsyncClient(BaseClient):
 
 
 class Client(BaseClient):
+    def check_concurrency_backend(self, backend: ConcurrencyBackend) -> None:
+        # Iterating over response content allocates an async environment on each step.
+        # This is relatively cheap on asyncio, but cannot be guaranteed for all
+        # concurrency backends.
+        # The sync client performs I/O on its own, so it doesn't need to support
+        # arbitrary concurrency backends.
+        # Therefore, we kept the `backend` parameter (for testing/mocking), but enforce
+        # that the concurrency backend derives from the asyncio one.
+        if not isinstance(backend, AsyncioBackend):
+            raise ValueError(
+                "'Client' only supports asyncio-based concurrency backends"
+            )
+
     def _async_request_data(
         self, data: RequestData = None
     ) -> typing.Optional[AsyncRequestData]:
index 97ae0277dd5d6f462f191a1686eb6ddbbb75a199..6ae4633a349e0503bdd0a1347aac01a7250b912a 100644 (file)
@@ -158,3 +158,19 @@ def test_merge_url():
 
     assert url.scheme == "https"
     assert url.is_ssl
+
+
+class DerivedFromAsyncioBackend(httpx.AsyncioBackend):
+    pass
+
+
+class AnyBackend:
+    pass
+
+
+def test_client_backend_must_be_asyncio_based():
+    httpx.Client(backend=httpx.AsyncioBackend())
+    httpx.Client(backend=DerivedFromAsyncioBackend())
+
+    with pytest.raises(ValueError):
+        httpx.Client(backend=AnyBackend())