]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add type hints to `test_testclient.py` (#2493)
authorScirlat Danut <danut.scirlat@gmail.com>
Fri, 9 Feb 2024 09:13:38 +0000 (11:13 +0200)
committerGitHub <noreply@github.com>
Fri, 9 Feb 2024 09:13:38 +0000 (09:13 +0000)
* Add type hints to test_testclient.py

* Fix check errors

* Apply suggestions from code review

* Use ASGIInstance instead

---------

Co-authored-by: Scirlat Danut <scirlatdanut@scirlats-mini.lan>
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
tests/test_testclient.py

index eeccf25a7f029fb96665ccde41c93cac4afaf8e5..e8956cd30a407ca22f4a52616a56a01d6883ef1f 100644 (file)
@@ -1,8 +1,10 @@
+from __future__ import annotations
+
 import itertools
 import sys
-from asyncio import current_task as asyncio_current_task
+from asyncio import Task, current_task as asyncio_current_task
 from contextlib import asynccontextmanager
-from typing import Callable
+from typing import Any, AsyncGenerator, Callable
 
 import anyio
 import anyio.lowlevel
@@ -15,19 +17,21 @@ from starlette.middleware import Middleware
 from starlette.requests import Request
 from starlette.responses import JSONResponse, RedirectResponse, Response
 from starlette.routing import Route
-from starlette.testclient import TestClient
+from starlette.testclient import ASGIInstance, TestClient
 from starlette.types import ASGIApp, Receive, Scope, Send
 from starlette.websockets import WebSocket, WebSocketDisconnect
 
+TestClientFactory = Callable[..., TestClient]
+
 
-def mock_service_endpoint(request: Request):
+def mock_service_endpoint(request: Request) -> JSONResponse:
     return JSONResponse({"mock": "example"})
 
 
 mock_service = Starlette(routes=[Route("/", endpoint=mock_service_endpoint)])
 
 
-def current_task():
+def current_task() -> Task[Any] | trio.lowlevel.Task:
     # anyio's TaskInfo comparisons are invalid after their associated native
     # task object is GC'd https://github.com/agronholm/anyio/issues/324
     asynclib_name = sniffio.current_async_library()
@@ -42,11 +46,11 @@ def current_task():
     raise RuntimeError(f"unsupported asynclib={asynclib_name}")  # pragma: no cover
 
 
-def startup():
+def startup() -> None:
     raise RuntimeError()
 
 
-def test_use_testclient_in_endpoint(test_client_factory: Callable[..., TestClient]):
+def test_use_testclient_in_endpoint(test_client_factory: TestClientFactory) -> None:
     """
     We should be able to use the test client within applications.
 
@@ -54,7 +58,7 @@ def test_use_testclient_in_endpoint(test_client_factory: Callable[..., TestClien
     during tests or in development.
     """
 
-    def homepage(request: Request):
+    def homepage(request: Request) -> JSONResponse:
         client = test_client_factory(mock_service)
         response = client.get("/")
         return JSONResponse(response.json())
@@ -66,7 +70,7 @@ def test_use_testclient_in_endpoint(test_client_factory: Callable[..., TestClien
     assert response.json() == {"mock": "example"}
 
 
-def test_testclient_headers_behavior():
+def test_testclient_headers_behavior() -> None:
     """
     We should be able to use the test client with user defined headers.
 
@@ -86,8 +90,8 @@ def test_testclient_headers_behavior():
 
 
 def test_use_testclient_as_contextmanager(
-    test_client_factory: Callable[..., TestClient], anyio_backend_name: str
-):
+    test_client_factory: TestClientFactory, anyio_backend_name: str
+) -> None:
     """
     This test asserts a number of properties that are important for an
     app level task_group
@@ -95,7 +99,7 @@ def test_use_testclient_as_contextmanager(
     counter = itertools.count()
     identity_runvar = anyio.lowlevel.RunVar[int]("identity_runvar")
 
-    def get_identity():
+    def get_identity() -> int:
         try:
             return identity_runvar.get()
         except LookupError:
@@ -109,7 +113,7 @@ def test_use_testclient_as_contextmanager(
     shutdown_loop = None
 
     @asynccontextmanager
-    async def lifespan_context(app: Starlette):
+    async def lifespan_context(app: Starlette) -> AsyncGenerator[None, None]:
         nonlocal startup_task, startup_loop, shutdown_task, shutdown_loop
 
         startup_task = current_task()
@@ -119,7 +123,7 @@ def test_use_testclient_as_contextmanager(
         shutdown_task = current_task()
         shutdown_loop = get_identity()
 
-    async def loop_id(request: Request):
+    async def loop_id(request: Request) -> JSONResponse:
         return JSONResponse(get_identity())
 
     app = Starlette(
@@ -143,7 +147,7 @@ def test_use_testclient_as_contextmanager(
     assert startup_task is shutdown_task
 
     # outside the TestClient context, new requests continue to spawn in new
-    # eventloops in new threads
+    # event loops in new threads
     assert client.get("/loop_id").json() == 1
     assert client.get("/loop_id").json() == 2
 
@@ -165,7 +169,7 @@ def test_use_testclient_as_contextmanager(
     assert first_task is not startup_task
 
 
-def test_error_on_startup(test_client_factory: Callable[..., TestClient]):
+def test_error_on_startup(test_client_factory: TestClientFactory) -> None:
     with pytest.deprecated_call(
         match="The on_startup and on_shutdown parameters are deprecated"
     ):
@@ -176,7 +180,7 @@ def test_error_on_startup(test_client_factory: Callable[..., TestClient]):
             pass  # pragma: no cover
 
 
-def test_exception_in_middleware(test_client_factory: Callable[..., TestClient]):
+def test_exception_in_middleware(test_client_factory: TestClientFactory) -> None:
     class MiddlewareException(Exception):
         pass
 
@@ -184,7 +188,7 @@ def test_exception_in_middleware(test_client_factory: Callable[..., TestClient])
         def __init__(self, app: ASGIApp):
             self.app = app
 
-        async def __call__(self, scope: Scope, receive: Receive, send: Send):
+        async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
             raise MiddlewareException()
 
     broken_middleware = Starlette(middleware=[Middleware(BrokenMiddleware)])
@@ -194,9 +198,9 @@ def test_exception_in_middleware(test_client_factory: Callable[..., TestClient])
             pass  # pragma: no cover
 
 
-def test_testclient_asgi2(test_client_factory: Callable[..., TestClient]):
-    def app(scope: Scope):
-        async def inner(receive: Receive, send: Send):
+def test_testclient_asgi2(test_client_factory: TestClientFactory) -> None:
+    def app(scope: Scope) -> ASGIInstance:
+        async def inner(receive: Receive, send: Send) -> None:
             await send(
                 {
                     "type": "http.response.start",
@@ -213,8 +217,8 @@ def test_testclient_asgi2(test_client_factory: Callable[..., TestClient]):
     assert response.text == "Hello, world!"
 
 
-def test_testclient_asgi3(test_client_factory: Callable[..., TestClient]):
-    async def app(scope: Scope, receive: Receive, send: Send):
+def test_testclient_asgi3(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         await send(
             {
                 "type": "http.response.start",
@@ -229,12 +233,12 @@ def test_testclient_asgi3(test_client_factory: Callable[..., TestClient]):
     assert response.text == "Hello, world!"
 
 
-def test_websocket_blocking_receive(test_client_factory: Callable[..., TestClient]):
-    def app(scope: Scope):
-        async def respond(websocket: WebSocket):
+def test_websocket_blocking_receive(test_client_factory: TestClientFactory) -> None:
+    def app(scope: Scope) -> ASGIInstance:
+        async def respond(websocket: WebSocket) -> None:
             await websocket.send_json({"message": "test"})
 
-        async def asgi(receive: Receive, send: Send):
+        async def asgi(receive: Receive, send: Send) -> None:
             websocket = WebSocket(scope, receive=receive, send=send)
             await websocket.accept()
             async with anyio.create_task_group() as task_group:
@@ -254,9 +258,9 @@ def test_websocket_blocking_receive(test_client_factory: Callable[..., TestClien
         assert data == {"message": "test"}
 
 
-def test_websocket_not_block_on_close(test_client_factory: Callable[..., TestClient]):
-    def app(scope: Scope):
-        async def asgi(receive: Receive, send: Send):
+def test_websocket_not_block_on_close(test_client_factory: TestClientFactory) -> None:
+    def app(scope: Scope) -> ASGIInstance:
+        async def asgi(receive: Receive, send: Send) -> None:
             websocket = WebSocket(scope, receive=receive, send=send)
             await websocket.accept()
             while True:
@@ -271,8 +275,8 @@ def test_websocket_not_block_on_close(test_client_factory: Callable[..., TestCli
 
 
 @pytest.mark.parametrize("param", ("2020-07-14T00:00:00+00:00", "España", "voilà"))
-def test_query_params(test_client_factory: Callable[..., TestClient], param: str):
-    def homepage(request: Request):
+def test_query_params(test_client_factory: TestClientFactory, param: str) -> None:
+    def homepage(request: Request) -> Response:
         return Response(request.query_params["param"])
 
     app = Starlette(routes=[Route("/", endpoint=homepage)])
@@ -301,8 +305,8 @@ def test_query_params(test_client_factory: Callable[..., TestClient], param: str
     ],
 )
 def test_domain_restricted_cookies(
-    test_client_factory: Callable[..., TestClient], domain: str, ok: bool
-):
+    test_client_factory: TestClientFactory, domain: str, ok: bool
+) -> None:
     """
     Test that test client discards domain restricted cookies which do not match the
     base_url of the testclient (`http://testserver` by default).
@@ -312,7 +316,7 @@ def test_domain_restricted_cookies(
     in accordance with RFC 2965.
     """
 
-    async def app(scope: Scope, receive: Receive, send: Send):
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         response = Response("Hello, world!", media_type="text/plain")
         response.set_cookie(
             "mycookie",
@@ -328,8 +332,8 @@ def test_domain_restricted_cookies(
     assert cookie_set == ok
 
 
-def test_forward_follow_redirects(test_client_factory: Callable[..., TestClient]):
-    async def app(scope: Scope, receive: Receive, send: Send):
+def test_forward_follow_redirects(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         if "/ok" in scope["path"]:
             response = Response("ok")
         else:
@@ -341,8 +345,8 @@ def test_forward_follow_redirects(test_client_factory: Callable[..., TestClient]
     assert response.status_code == 200
 
 
-def test_forward_nofollow_redirects(test_client_factory: Callable[..., TestClient]):
-    async def app(scope: Scope, receive: Receive, send: Send):
+def test_forward_nofollow_redirects(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         response = RedirectResponse("/ok")
         await response(scope, receive, send)
 
@@ -351,7 +355,7 @@ def test_forward_nofollow_redirects(test_client_factory: Callable[..., TestClien
     assert response.status_code == 307
 
 
-def test_with_duplicate_headers(test_client_factory: Callable[[Starlette], TestClient]):
+def test_with_duplicate_headers(test_client_factory: TestClientFactory) -> None:
     def homepage(request: Request) -> JSONResponse:
         return JSONResponse({"x-token": request.headers.getlist("x-token")})
 
@@ -361,7 +365,7 @@ def test_with_duplicate_headers(test_client_factory: Callable[[Starlette], TestC
     assert response.json() == {"x-token": ["foo", "bar"]}
 
 
-def test_merge_url(test_client_factory: Callable[..., TestClient]):
+def test_merge_url(test_client_factory: TestClientFactory) -> None:
     def homepage(request: Request) -> Response:
         return Response(request.url.path)