From f0fd91925bc8663b4ec3635c302aa07fe5f8e72e Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 9 Feb 2023 08:05:07 -0800 Subject: [PATCH] fix type annotation for MockTransport (#2581) * fix type annotation for MockTransport * add type ignore * better type checks * better type checks * add pragma --------- Co-authored-by: Tom Christie --- httpx/_transports/mock.py | 14 +++++++++----- tests/client/test_async_client.py | 2 +- tests/client/test_auth.py | 2 +- tests/client/test_redirects.py | 2 +- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/httpx/_transports/mock.py b/httpx/_transports/mock.py index 1434166d..82043da2 100644 --- a/httpx/_transports/mock.py +++ b/httpx/_transports/mock.py @@ -1,12 +1,14 @@ -import asyncio import typing from .._models import Request, Response from .base import AsyncBaseTransport, BaseTransport +SyncHandler = typing.Callable[[Request], Response] +AsyncHandler = typing.Callable[[Request], typing.Coroutine[None, None, Response]] + class MockTransport(AsyncBaseTransport, BaseTransport): - def __init__(self, handler: typing.Callable[[Request], Response]) -> None: + def __init__(self, handler: typing.Union[SyncHandler, AsyncHandler]) -> None: self.handler = handler def handle_request( @@ -14,7 +16,10 @@ class MockTransport(AsyncBaseTransport, BaseTransport): request: Request, ) -> Response: request.read() - return self.handler(request) + response = self.handler(request) + if not isinstance(response, Response): # pragma: no cover + raise TypeError("Cannot use an async handler in a sync Client") + return response async def handle_async_request( self, @@ -27,8 +32,7 @@ class MockTransport(AsyncBaseTransport, BaseTransport): # If it is, then the `response` variable need to be awaited to actually # return the result. - # https://simonwillison.net/2020/Sep/2/await-me-maybe/ - if asyncio.iscoroutine(response): + if not isinstance(response, Response): response = await response return response diff --git a/tests/client/test_async_client.py b/tests/client/test_async_client.py index 61ad5101..5be0de3b 100644 --- a/tests/client/test_async_client.py +++ b/tests/client/test_async_client.py @@ -313,7 +313,7 @@ async def test_mounted_transport(): @pytest.mark.anyio async def test_async_mock_transport(): - async def hello_world(request): + async def hello_world(request: httpx.Request) -> httpx.Response: return httpx.Response(200, text="Hello, world!") transport = httpx.MockTransport(hello_world) diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 6d49f845..fee51505 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -710,7 +710,7 @@ class ConsumeBodyTransport(httpx.MockTransport): async def handle_async_request(self, request: httpx.Request) -> httpx.Response: assert isinstance(request.stream, httpx.AsyncByteStream) [_ async for _ in request.stream] - return self.handler(request) + return self.handler(request) # type: ignore[return-value] @pytest.mark.anyio diff --git a/tests/client/test_redirects.py b/tests/client/test_redirects.py index b83e6678..6155df14 100644 --- a/tests/client/test_redirects.py +++ b/tests/client/test_redirects.py @@ -346,7 +346,7 @@ class ConsumeBodyTransport(httpx.MockTransport): def handle_request(self, request: httpx.Request) -> httpx.Response: assert isinstance(request.stream, httpx.SyncByteStream) [_ for _ in request.stream] - return self.handler(request) + return self.handler(request) # type: ignore[return-value] def test_cannot_redirect_streaming_body(): -- 2.47.3