]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
fix type annotation for MockTransport (#2581)
authorAdrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>
Thu, 9 Feb 2023 16:05:07 +0000 (08:05 -0800)
committerGitHub <noreply@github.com>
Thu, 9 Feb 2023 16:05:07 +0000 (16:05 +0000)
* fix type annotation for MockTransport

* add type ignore

* better type checks

* better type checks

* add pragma

---------

Co-authored-by: Tom Christie <tom@tomchristie.com>
httpx/_transports/mock.py
tests/client/test_async_client.py
tests/client/test_auth.py
tests/client/test_redirects.py

index 1434166dd5ef6670a0f4d81c9c44555e07f93d80..82043da2d908f7575097f14b08c1a8a60fa1f8a4 100644 (file)
@@ -1,12 +1,14 @@
-import asyncio
 import typing
 
 from .._models import Request, Response
 from .base import AsyncBaseTransport, BaseTransport
 
+SyncHandler = typing.Callable[[Request], Response]
+AsyncHandler = typing.Callable[[Request], typing.Coroutine[None, None, Response]]
+
 
 class MockTransport(AsyncBaseTransport, BaseTransport):
-    def __init__(self, handler: typing.Callable[[Request], Response]) -> None:
+    def __init__(self, handler: typing.Union[SyncHandler, AsyncHandler]) -> None:
         self.handler = handler
 
     def handle_request(
@@ -14,7 +16,10 @@ class MockTransport(AsyncBaseTransport, BaseTransport):
         request: Request,
     ) -> Response:
         request.read()
-        return self.handler(request)
+        response = self.handler(request)
+        if not isinstance(response, Response):  # pragma: no cover
+            raise TypeError("Cannot use an async handler in a sync Client")
+        return response
 
     async def handle_async_request(
         self,
@@ -27,8 +32,7 @@ class MockTransport(AsyncBaseTransport, BaseTransport):
         # If it is, then the `response` variable need to be awaited to actually
         # return the result.
 
-        # https://simonwillison.net/2020/Sep/2/await-me-maybe/
-        if asyncio.iscoroutine(response):
+        if not isinstance(response, Response):
             response = await response
 
         return response
index 61ad5101ddf946f48227b741192c0739d8ffec12..5be0de3b12d83ffdf3515165fa789d57134d8565 100644 (file)
@@ -313,7 +313,7 @@ async def test_mounted_transport():
 
 @pytest.mark.anyio
 async def test_async_mock_transport():
-    async def hello_world(request):
+    async def hello_world(request: httpx.Request) -> httpx.Response:
         return httpx.Response(200, text="Hello, world!")
 
     transport = httpx.MockTransport(hello_world)
index 6d49f845b9852e6e95574f64084557cc01967494..fee515058b92a3f09738feaffc4b3b511024fda8 100644 (file)
@@ -710,7 +710,7 @@ class ConsumeBodyTransport(httpx.MockTransport):
     async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
         assert isinstance(request.stream, httpx.AsyncByteStream)
         [_ async for _ in request.stream]
-        return self.handler(request)
+        return self.handler(request)  # type: ignore[return-value]
 
 
 @pytest.mark.anyio
index b83e66787939148dab099101d1c8aaed2fca53a4..6155df1447a293816fcac2fa2d71075f362933ec 100644 (file)
@@ -346,7 +346,7 @@ class ConsumeBodyTransport(httpx.MockTransport):
     def handle_request(self, request: httpx.Request) -> httpx.Response:
         assert isinstance(request.stream, httpx.SyncByteStream)
         [_ for _ in request.stream]
-        return self.handler(request)
+        return self.handler(request)  # type: ignore[return-value]
 
 
 def test_cannot_redirect_streaming_body():