]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add type hints to `test_base.py` (#2445)
authorScirlat Danut <danut.scirlat@gmail.com>
Sat, 3 Feb 2024 16:54:21 +0000 (18:54 +0200)
committerGitHub <noreply@github.com>
Sat, 3 Feb 2024 16:54:21 +0000 (09:54 -0700)
* added type annotations to test-base.py

* deleted unused imports

* fixed import order

* conditional import

* conditional import TestClient on types.py

* using string literals when importing in types

* added missing imports

* deleted starlette/types, refactored test_base types

* deleted types

---------

Co-authored-by: Scirlat Danut <scirlatdanut@scirlats-mini.lan>
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
tests/middleware/test_base.py

index 012e2c1ff65794d11de30d5a12a12acd44bf2112..6e5e42b94423bbbab8e708e410069fc8fb6ca398 100644 (file)
@@ -1,9 +1,18 @@
 import contextvars
 from contextlib import AsyncExitStack
-from typing import Any, AsyncGenerator, Awaitable, Callable, List, Type, Union
+from typing import (
+    Any,
+    AsyncGenerator,
+    Callable,
+    Generator,
+    List,
+    Type,
+    Union,
+)
 
 import anyio
 import pytest
+from anyio.abc import TaskStatus
 
 from starlette.applications import Starlette
 from starlette.background import BackgroundTask
@@ -14,44 +23,56 @@ from starlette.responses import PlainTextResponse, Response, StreamingResponse
 from starlette.routing import Route, WebSocketRoute
 from starlette.testclient import TestClient
 from starlette.types import ASGIApp, Message, Receive, Scope, Send
+from starlette.websockets import WebSocket
+
+TestClientFactory = Callable[[ASGIApp], TestClient]
 
 
 class CustomMiddleware(BaseHTTPMiddleware):
-    async def dispatch(self, request, call_next):
+    async def dispatch(
+        self,
+        request: Request,
+        call_next: RequestResponseEndpoint,
+    ) -> Response:
         response = await call_next(request)
         response.headers["Custom-Header"] = "Example"
         return response
 
 
-def homepage(request):
+def homepage(request: Request) -> PlainTextResponse:
     return PlainTextResponse("Homepage")
 
 
-def exc(request):
+def exc(request: Request) -> None:
     raise Exception("Exc")
 
 
-def exc_stream(request):
+def exc_stream(request: Request) -> StreamingResponse:
     return StreamingResponse(_generate_faulty_stream())
 
 
-def _generate_faulty_stream():
+def _generate_faulty_stream() -> Generator[bytes, None, None]:
     yield b"Ok"
     raise Exception("Faulty Stream")
 
 
 class NoResponse:
-    def __init__(self, scope, receive, send):
+    def __init__(
+        self,
+        scope: Scope,
+        receive: Receive,
+        send: Send,
+    ):
         pass
 
-    def __await__(self):
+    def __await__(self) -> Generator[Any, None, None]:
         return self.dispatch().__await__()
 
-    async def dispatch(self):
+    async def dispatch(self) -> None:
         pass
 
 
-async def websocket_endpoint(session):
+async def websocket_endpoint(session: WebSocket) -> None:
     await session.accept()
     await session.send_text("Hello, world!")
     await session.close()
@@ -69,7 +90,7 @@ app = Starlette(
 )
 
 
-def test_custom_middleware(test_client_factory):
+def test_custom_middleware(test_client_factory: TestClientFactory) -> None:
     client = test_client_factory(app)
     response = client.get("/")
     assert response.headers["Custom-Header"] == "Example"
@@ -90,30 +111,44 @@ def test_custom_middleware(test_client_factory):
         assert text == "Hello, world!"
 
 
-def test_state_data_across_multiple_middlewares(test_client_factory):
+def test_state_data_across_multiple_middlewares(
+    test_client_factory: TestClientFactory,
+) -> None:
     expected_value1 = "foo"
     expected_value2 = "bar"
 
     class aMiddleware(BaseHTTPMiddleware):
-        async def dispatch(self, request, call_next):
+        async def dispatch(
+            self,
+            request: Request,
+            call_next: RequestResponseEndpoint,
+        ) -> Response:
             request.state.foo = expected_value1
             response = await call_next(request)
             return response
 
     class bMiddleware(BaseHTTPMiddleware):
-        async def dispatch(self, request, call_next):
+        async def dispatch(
+            self,
+            request: Request,
+            call_next: RequestResponseEndpoint,
+        ) -> Response:
             request.state.bar = expected_value2
             response = await call_next(request)
             response.headers["X-State-Foo"] = request.state.foo
             return response
 
     class cMiddleware(BaseHTTPMiddleware):
-        async def dispatch(self, request, call_next):
+        async def dispatch(
+            self,
+            request: Request,
+            call_next: RequestResponseEndpoint,
+        ) -> Response:
             response = await call_next(request)
             response.headers["X-State-Bar"] = request.state.bar
             return response
 
-    def homepage(request):
+    def homepage(request: Request) -> PlainTextResponse:
         return PlainTextResponse("OK")
 
     app = Starlette(
@@ -132,8 +167,8 @@ def test_state_data_across_multiple_middlewares(test_client_factory):
     assert response.headers["X-State-Bar"] == expected_value2
 
 
-def test_app_middleware_argument(test_client_factory):
-    def homepage(request):
+def test_app_middleware_argument(test_client_factory: TestClientFactory) -> None:
+    def homepage(request: Request) -> PlainTextResponse:
         return PlainTextResponse("Homepage")
 
     app = Starlette(
@@ -145,10 +180,14 @@ def test_app_middleware_argument(test_client_factory):
     assert response.headers["Custom-Header"] == "Example"
 
 
-def test_fully_evaluated_response(test_client_factory):
+def test_fully_evaluated_response(test_client_factory: TestClientFactory) -> None:
     # Test for https://github.com/encode/starlette/issues/1022
     class CustomMiddleware(BaseHTTPMiddleware):
-        async def dispatch(self, request, call_next):
+        async def dispatch(
+            self,
+            request: Request,
+            call_next: RequestResponseEndpoint,
+        ) -> PlainTextResponse:
             await call_next(request)
             return PlainTextResponse("Custom")
 
@@ -173,7 +212,11 @@ class CustomMiddlewareWithoutBaseHTTPMiddleware:
 
 
 class CustomMiddlewareUsingBaseHTTPMiddleware(BaseHTTPMiddleware):
-    async def dispatch(self, request, call_next):
+    async def dispatch(
+        self,
+        request: Request,
+        call_next: RequestResponseEndpoint,
+    ) -> Response:
         ctxvar.set("set by middleware")
         resp = await call_next(request)
         assert ctxvar.get() == "set by endpoint"
@@ -196,11 +239,14 @@ class CustomMiddlewareUsingBaseHTTPMiddleware(BaseHTTPMiddleware):
         ),
     ],
 )
-def test_contextvars(test_client_factory, middleware_cls: Type[_MiddlewareClass[Any]]):
+def test_contextvars(
+    test_client_factory: TestClientFactory,
+    middleware_cls: Type[_MiddlewareClass[Any]],
+) -> None:
     # this has to be an async endpoint because Starlette calls run_in_threadpool
     # on sync endpoints which has it's own set of peculiarities w.r.t propagating
     # contextvars (it propagates them forwards but not backwards)
-    async def homepage(request):
+    async def homepage(request: Request) -> PlainTextResponse:
         assert ctxvar.get() == "set by middleware"
         ctxvar.set("set by endpoint")
         return PlainTextResponse("Homepage")
@@ -215,23 +261,26 @@ def test_contextvars(test_client_factory, middleware_cls: Type[_MiddlewareClass[
 
 
 @pytest.mark.anyio
-async def test_run_background_tasks_even_if_client_disconnects():
+async def test_run_background_tasks_even_if_client_disconnects() -> None:
     # test for https://github.com/encode/starlette/issues/1438
     request_body_sent = False
     response_complete = anyio.Event()
     background_task_run = anyio.Event()
 
-    async def sleep_and_set():
+    async def sleep_and_set() -> None:
         # small delay to give BaseHTTPMiddleware a chance to cancel us
         # this is required to make the test fail prior to fixing the issue
         # so do not be surprised if you remove it and the test still passes
         await anyio.sleep(0.1)
         background_task_run.set()
 
-    async def endpoint_with_background_task(_):
+    async def endpoint_with_background_task(_: Request) -> PlainTextResponse:
         return PlainTextResponse(background=BackgroundTask(sleep_and_set))
 
-    async def passthrough(request, call_next):
+    async def passthrough(
+        request: Request,
+        call_next: RequestResponseEndpoint,
+    ) -> Response:
         return await call_next(request)
 
     app = Starlette(
@@ -246,7 +295,7 @@ async def test_run_background_tasks_even_if_client_disconnects():
         "path": "/",
     }
 
-    async def receive():
+    async def receive() -> Message:
         nonlocal request_body_sent
         if not request_body_sent:
             request_body_sent = True
@@ -255,7 +304,7 @@ async def test_run_background_tasks_even_if_client_disconnects():
         await response_complete.wait()
         return {"type": "http.disconnect"}
 
-    async def send(message):
+    async def send(message: Message) -> None:
         if message["type"] == "http.response.body":
             if not message.get("more_body", False):
                 response_complete.set()
@@ -266,23 +315,23 @@ async def test_run_background_tasks_even_if_client_disconnects():
 
 
 @pytest.mark.anyio
-async def test_do_not_block_on_background_tasks():
+async def test_do_not_block_on_background_tasks() -> None:
     request_body_sent = False
     response_complete = anyio.Event()
     events: List[Union[str, Message]] = []
 
-    async def sleep_and_set():
+    async def sleep_and_set() -> None:
         events.append("Background task started")
         await anyio.sleep(0.1)
         events.append("Background task finished")
 
-    async def endpoint_with_background_task(_):
+    async def endpoint_with_background_task(_: Request) -> PlainTextResponse:
         return PlainTextResponse(
             content="Hello", background=BackgroundTask(sleep_and_set)
         )
 
     async def passthrough(
-        request: Request, call_next: Callable[[Request], Awaitable[Response]]
+        request: Request, call_next: RequestResponseEndpoint
     ) -> Response:
         return await call_next(request)
 
@@ -306,7 +355,7 @@ async def test_do_not_block_on_background_tasks():
         await response_complete.wait()
         return {"type": "http.disconnect"}
 
-    async def send(message: Message):
+    async def send(message: Message) -> None:
         if message["type"] == "http.response.body":
             events.append(message)
             if not message.get("more_body", False):
@@ -331,13 +380,13 @@ async def test_do_not_block_on_background_tasks():
 
 
 @pytest.mark.anyio
-async def test_run_context_manager_exit_even_if_client_disconnects():
+async def test_run_context_manager_exit_even_if_client_disconnects() -> None:
     # test for https://github.com/encode/starlette/issues/1678#issuecomment-1172916042
     request_body_sent = False
     response_complete = anyio.Event()
     context_manager_exited = anyio.Event()
 
-    async def sleep_and_set():
+    async def sleep_and_set() -> None:
         # small delay to give BaseHTTPMiddleware a chance to cancel us
         # this is required to make the test fail prior to fixing the issue
         # so do not be surprised if you remove it and the test still passes
@@ -345,18 +394,21 @@ async def test_run_context_manager_exit_even_if_client_disconnects():
         context_manager_exited.set()
 
     class ContextManagerMiddleware:
-        def __init__(self, app):
+        def __init__(self, app: ASGIApp):
             self.app = app
 
-        async def __call__(self, scope: Scope, receive: Receive, send: Send):
+        async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
             async with AsyncExitStack() as stack:
                 stack.push_async_callback(sleep_and_set)
                 await self.app(scope, receive, send)
 
-    async def simple_endpoint(_):
+    async def simple_endpoint(_: Request) -> PlainTextResponse:
         return PlainTextResponse(background=BackgroundTask(sleep_and_set))
 
-    async def passthrough(request, call_next):
+    async def passthrough(
+        request: Request,
+        call_next: RequestResponseEndpoint,
+    ) -> Response:
         return await call_next(request)
 
     app = Starlette(
@@ -374,7 +426,7 @@ async def test_run_context_manager_exit_even_if_client_disconnects():
         "path": "/",
     }
 
-    async def receive():
+    async def receive() -> Message:
         nonlocal request_body_sent
         if not request_body_sent:
             request_body_sent = True
@@ -383,7 +435,7 @@ async def test_run_context_manager_exit_even_if_client_disconnects():
         await response_complete.wait()
         return {"type": "http.disconnect"}
 
-    async def send(message):
+    async def send(message: Message) -> None:
         if message["type"] == "http.response.body":
             if not message.get("more_body", False):
                 response_complete.set()
@@ -393,9 +445,15 @@ async def test_run_context_manager_exit_even_if_client_disconnects():
     assert context_manager_exited.is_set()
 
 
-def test_app_receives_http_disconnect_while_sending_if_discarded(test_client_factory):
+def test_app_receives_http_disconnect_while_sending_if_discarded(
+    test_client_factory: TestClientFactory,
+) -> None:
     class DiscardingMiddleware(BaseHTTPMiddleware):
-        async def dispatch(self, request, call_next):
+        async def dispatch(
+            self,
+            request: Request,
+            call_next: Any,
+        ) -> PlainTextResponse:
             # As a matter of ordering, this test targets the case where the downstream
             # app response is discarded while it is sending a response body.
             # We need to wait for the downstream app to begin sending a response body
@@ -410,7 +468,11 @@ def test_app_receives_http_disconnect_while_sending_if_discarded(test_client_fac
 
             return PlainTextResponse("Custom")
 
-    async def downstream_app(scope, receive, send):
+    async def downstream_app(
+        scope: Scope,
+        receive: Receive,
+        send: Send,
+    ) -> None:
         await send(
             {
                 "type": "http.response.start",
@@ -422,7 +484,10 @@ def test_app_receives_http_disconnect_while_sending_if_discarded(test_client_fac
         )
         async with anyio.create_task_group() as task_group:
 
-            async def cancel_on_disconnect(*, task_status=anyio.TASK_STATUS_IGNORED):
+            async def cancel_on_disconnect(
+                *,
+                task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED,
+            ) -> None:
                 task_status.started()
                 while True:
                     message = await receive()
@@ -458,13 +523,23 @@ def test_app_receives_http_disconnect_while_sending_if_discarded(test_client_fac
     assert response.text == "Custom"
 
 
-def test_app_receives_http_disconnect_after_sending_if_discarded(test_client_factory):
+def test_app_receives_http_disconnect_after_sending_if_discarded(
+    test_client_factory: TestClientFactory,
+) -> None:
     class DiscardingMiddleware(BaseHTTPMiddleware):
-        async def dispatch(self, request, call_next):
+        async def dispatch(
+            self,
+            request: Request,
+            call_next: RequestResponseEndpoint,
+        ) -> PlainTextResponse:
             await call_next(request)
             return PlainTextResponse("Custom")
 
-    async def downstream_app(scope, receive, send):
+    async def downstream_app(
+        scope: Scope,
+        receive: Receive,
+        send: Send,
+    ) -> None:
         await send(
             {
                 "type": "http.response.start",
@@ -499,9 +574,9 @@ def test_app_receives_http_disconnect_after_sending_if_discarded(test_client_fac
 
 
 def test_read_request_stream_in_app_after_middleware_calls_stream(
-    test_client_factory: Callable[[ASGIApp], TestClient],
+    test_client_factory: TestClientFactory,
 ) -> None:
-    async def homepage(request: Request):
+    async def homepage(request: Request) -> PlainTextResponse:
         expected = [b""]
         async for chunk in request.stream():
             assert chunk == expected.pop(0)
@@ -509,7 +584,11 @@ def test_read_request_stream_in_app_after_middleware_calls_stream(
         return PlainTextResponse("Homepage")
 
     class ConsumingMiddleware(BaseHTTPMiddleware):
-        async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
+        async def dispatch(
+            self,
+            request: Request,
+            call_next: RequestResponseEndpoint,
+        ) -> Response:
             expected = [b"a", b""]
             async for chunk in request.stream():
                 assert chunk == expected.pop(0)
@@ -527,9 +606,9 @@ def test_read_request_stream_in_app_after_middleware_calls_stream(
 
 
 def test_read_request_stream_in_app_after_middleware_calls_body(
-    test_client_factory: Callable[[ASGIApp], TestClient],
+    test_client_factory: TestClientFactory,
 ) -> None:
-    async def homepage(request: Request):
+    async def homepage(request: Request) -> PlainTextResponse:
         expected = [b"a", b""]
         async for chunk in request.stream():
             assert chunk == expected.pop(0)
@@ -537,7 +616,11 @@ def test_read_request_stream_in_app_after_middleware_calls_body(
         return PlainTextResponse("Homepage")
 
     class ConsumingMiddleware(BaseHTTPMiddleware):
-        async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
+        async def dispatch(
+            self,
+            request: Request,
+            call_next: RequestResponseEndpoint,
+        ) -> Response:
             assert await request.body() == b"a"
             return await call_next(request)
 
@@ -552,14 +635,18 @@ def test_read_request_stream_in_app_after_middleware_calls_body(
 
 
 def test_read_request_body_in_app_after_middleware_calls_stream(
-    test_client_factory: Callable[[ASGIApp], TestClient],
+    test_client_factory: TestClientFactory,
 ) -> None:
-    async def homepage(request: Request):
+    async def homepage(request: Request) -> PlainTextResponse:
         assert await request.body() == b""
         return PlainTextResponse("Homepage")
 
     class ConsumingMiddleware(BaseHTTPMiddleware):
-        async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
+        async def dispatch(
+            self,
+            request: Request,
+            call_next: RequestResponseEndpoint,
+        ) -> Response:
             expected = [b"a", b""]
             async for chunk in request.stream():
                 assert chunk == expected.pop(0)
@@ -577,14 +664,18 @@ def test_read_request_body_in_app_after_middleware_calls_stream(
 
 
 def test_read_request_body_in_app_after_middleware_calls_body(
-    test_client_factory: Callable[[ASGIApp], TestClient],
+    test_client_factory: TestClientFactory,
 ) -> None:
-    async def homepage(request: Request):
+    async def homepage(request: Request) -> PlainTextResponse:
         assert await request.body() == b"a"
         return PlainTextResponse("Homepage")
 
     class ConsumingMiddleware(BaseHTTPMiddleware):
-        async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
+        async def dispatch(
+            self,
+            request: Request,
+            call_next: RequestResponseEndpoint,
+        ) -> Response:
             assert await request.body() == b"a"
             return await call_next(request)
 
@@ -599,9 +690,9 @@ def test_read_request_body_in_app_after_middleware_calls_body(
 
 
 def test_read_request_stream_in_dispatch_after_app_calls_stream(
-    test_client_factory: Callable[[ASGIApp], TestClient],
+    test_client_factory: TestClientFactory,
 ) -> None:
-    async def homepage(request: Request):
+    async def homepage(request: Request) -> PlainTextResponse:
         expected = [b"a", b""]
         async for chunk in request.stream():
             assert chunk == expected.pop(0)
@@ -609,7 +700,11 @@ def test_read_request_stream_in_dispatch_after_app_calls_stream(
         return PlainTextResponse("Homepage")
 
     class ConsumingMiddleware(BaseHTTPMiddleware):
-        async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
+        async def dispatch(
+            self,
+            request: Request,
+            call_next: RequestResponseEndpoint,
+        ) -> Response:
             resp = await call_next(request)
             with pytest.raises(RuntimeError, match="Stream consumed"):
                 async for _ in request.stream():
@@ -627,14 +722,18 @@ def test_read_request_stream_in_dispatch_after_app_calls_stream(
 
 
 def test_read_request_stream_in_dispatch_after_app_calls_body(
-    test_client_factory: Callable[[ASGIApp], TestClient],
+    test_client_factory: TestClientFactory,
 ) -> None:
-    async def homepage(request: Request):
+    async def homepage(request: Request) -> PlainTextResponse:
         assert await request.body() == b"a"
         return PlainTextResponse("Homepage")
 
     class ConsumingMiddleware(BaseHTTPMiddleware):
-        async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
+        async def dispatch(
+            self,
+            request: Request,
+            call_next: RequestResponseEndpoint,
+        ) -> Response:
             resp = await call_next(request)
             with pytest.raises(RuntimeError, match="Stream consumed"):
                 async for _ in request.stream():
@@ -661,7 +760,11 @@ async def test_read_request_stream_in_dispatch_wrapping_app_calls_body() -> None
         await Response()(scope, receive, send)
 
     class ConsumingMiddleware(BaseHTTPMiddleware):
-        async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
+        async def dispatch(
+            self,
+            request: Request,
+            call_next: RequestResponseEndpoint,
+        ) -> Response:
             expected = b"1"
             response: Union[Response, None] = None
             async for chunk in request.stream():
@@ -705,14 +808,18 @@ async def test_read_request_stream_in_dispatch_wrapping_app_calls_body() -> None
 
 
 def test_read_request_stream_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next(  # noqa: E501
-    test_client_factory: Callable[[ASGIApp], TestClient],
+    test_client_factory: TestClientFactory,
 ) -> None:
-    async def homepage(request: Request):
+    async def homepage(request: Request) -> PlainTextResponse:
         assert await request.body() == b"a"
         return PlainTextResponse("Homepage")
 
     class ConsumingMiddleware(BaseHTTPMiddleware):
-        async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
+        async def dispatch(
+            self,
+            request: Request,
+            call_next: RequestResponseEndpoint,
+        ) -> Response:
             assert (
                 await request.body() == b"a"
             )  # this buffers the request body in memory
@@ -733,14 +840,18 @@ def test_read_request_stream_in_dispatch_after_app_calls_body_with_middleware_ca
 
 
 def test_read_request_body_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next(  # noqa: E501
-    test_client_factory: Callable[[ASGIApp], TestClient],
+    test_client_factory: TestClientFactory,
 ) -> None:
-    async def homepage(request: Request):
+    async def homepage(request: Request) -> PlainTextResponse:
         assert await request.body() == b"a"
         return PlainTextResponse("Homepage")
 
     class ConsumingMiddleware(BaseHTTPMiddleware):
-        async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
+        async def dispatch(
+            self,
+            request: Request,
+            call_next: RequestResponseEndpoint,
+        ) -> Response:
             assert (
                 await request.body() == b"a"
             )  # this buffers the request body in memory
@@ -773,7 +884,11 @@ async def test_read_request_disconnected_client() -> None:
         await Response()(scope, receive, send)
 
     class ConsumingMiddleware(BaseHTTPMiddleware):
-        async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
+        async def dispatch(
+            self,
+            request: Request,
+            call_next: RequestResponseEndpoint,
+        ) -> Response:
             response = await call_next(request)
             disconnected = await request.is_disconnected()
             assert disconnected is True
@@ -785,7 +900,7 @@ async def test_read_request_disconnected_client() -> None:
         yield {"type": "http.disconnect"}
         raise AssertionError("Should not be called, would hang")  # pragma: no cover
 
-    async def send(msg: Message):
+    async def send(msg: Message) -> None:
         if msg["type"] == "http.response.start":
             assert msg["status"] == 200
 
@@ -809,7 +924,11 @@ async def test_read_request_disconnected_after_consuming_steam() -> None:
         await Response()(scope, receive, send)
 
     class ConsumingMiddleware(BaseHTTPMiddleware):
-        async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
+        async def dispatch(
+            self,
+            request: Request,
+            call_next: RequestResponseEndpoint,
+        ) -> Response:
             await request.body()
             disconnected = await request.is_disconnected()
             assert disconnected is True
@@ -823,7 +942,7 @@ async def test_read_request_disconnected_after_consuming_steam() -> None:
         yield {"type": "http.disconnect"}
         raise AssertionError("Should not be called, would hang")  # pragma: no cover
 
-    async def send(msg: Message):
+    async def send(msg: Message) -> None:
         if msg["type"] == "http.response.start":
             assert msg["status"] == 200
 
@@ -837,7 +956,7 @@ async def test_read_request_disconnected_after_consuming_steam() -> None:
 
 
 def test_downstream_middleware_modifies_receive(
-    test_client_factory: Callable[[ASGIApp], TestClient],
+    test_client_factory: TestClientFactory,
 ) -> None:
     """If a downstream middleware modifies receive() the final ASGI app
     should see the modified version.
@@ -850,7 +969,11 @@ def test_downstream_middleware_modifies_receive(
         await Response()(scope, receive, send)
 
     class ConsumingMiddleware(BaseHTTPMiddleware):
-        async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
+        async def dispatch(
+            self,
+            request: Request,
+            call_next: RequestResponseEndpoint,
+        ) -> Response:
             body = await request.body()
             assert body == b"foo "
             return await call_next(request)
@@ -873,9 +996,6 @@ def test_downstream_middleware_modifies_receive(
     assert resp.status_code == 200
 
 
-CallNext = Callable[[Request], Awaitable[Response]]
-
-
 def test_pr_1519_comment_1236166180_example() -> None:
     """
     https://github.com/encode/starlette/pull/1519#issuecomment-1236166180
@@ -883,7 +1003,11 @@ def test_pr_1519_comment_1236166180_example() -> None:
     bodies: List[bytes] = []
 
     class LogRequestBodySize(BaseHTTPMiddleware):
-        async def dispatch(self, request: Request, call_next: CallNext) -> Response:
+        async def dispatch(
+            self,
+            request: Request,
+            call_next: RequestResponseEndpoint,
+        ) -> Response:
             print(len(await request.body()))
             return await call_next(request)