]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add type hints to `test_applications.py` (#2471)
authorScirlat Danut <danut.scirlat@gmail.com>
Tue, 6 Feb 2024 20:30:47 +0000 (22:30 +0200)
committerGitHub <noreply@github.com>
Tue, 6 Feb 2024 20:30:47 +0000 (20:30 +0000)
* added type annotations to test_applications.py

* requested changes

* Apply suggestions from code review

* Apply suggestions from code review

* Update tests/test_applications.py

---------

Co-authored-by: Scirlat Danut <scirlatdanut@scirlats-mini.lan>
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
tests/test_applications.py

index 6d0118b535a16de36c9dda8a6ad6e5ae5b2dc991..5b6c9d54564c16d422d29be7d70a1c16be805470 100644 (file)
@@ -1,9 +1,9 @@
 import os
 from contextlib import asynccontextmanager
-from typing import AsyncIterator, Callable
+from pathlib import Path
+from typing import AsyncGenerator, AsyncIterator, Callable, Generator
 
 import anyio
-import httpx
 import pytest
 
 from starlette import status
@@ -11,63 +11,68 @@ from starlette.applications import Starlette
 from starlette.endpoints import HTTPEndpoint
 from starlette.exceptions import HTTPException, WebSocketException
 from starlette.middleware import Middleware
+from starlette.middleware.base import RequestResponseEndpoint
 from starlette.middleware.trustedhost import TrustedHostMiddleware
+from starlette.requests import Request
 from starlette.responses import JSONResponse, PlainTextResponse
 from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
 from starlette.staticfiles import StaticFiles
+from starlette.testclient import TestClient
 from starlette.types import ASGIApp, Receive, Scope, Send
 from starlette.websockets import WebSocket
 
+TestClientFactory = Callable[..., TestClient]
 
-async def error_500(request, exc):
+
+async def error_500(request: Request, exc: HTTPException) -> JSONResponse:
     return JSONResponse({"detail": "Server Error"}, status_code=500)
 
 
-async def method_not_allowed(request, exc):
+async def method_not_allowed(request: Request, exc: HTTPException) -> JSONResponse:
     return JSONResponse({"detail": "Custom message"}, status_code=405)
 
 
-async def http_exception(request, exc):
+async def http_exception(request: Request, exc: HTTPException) -> JSONResponse:
     return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
 
 
-def func_homepage(request):
+def func_homepage(request: Request) -> PlainTextResponse:
     return PlainTextResponse("Hello, world!")
 
 
-async def async_homepage(request):
+async def async_homepage(request: Request) -> PlainTextResponse:
     return PlainTextResponse("Hello, world!")
 
 
 class Homepage(HTTPEndpoint):
-    def get(self, request):
+    def get(self, request: Request) -> PlainTextResponse:
         return PlainTextResponse("Hello, world!")
 
 
-def all_users_page(request):
+def all_users_page(request: Request) -> PlainTextResponse:
     return PlainTextResponse("Hello, everyone!")
 
 
-def user_page(request):
+def user_page(request: Request) -> PlainTextResponse:
     username = request.path_params["username"]
     return PlainTextResponse(f"Hello, {username}!")
 
 
-def custom_subdomain(request):
+def custom_subdomain(request: Request) -> PlainTextResponse:
     return PlainTextResponse("Subdomain: " + request.path_params["subdomain"])
 
 
-def runtime_error(request):
+def runtime_error(request: Request) -> None:
     raise RuntimeError()
 
 
-async def websocket_endpoint(session):
+async def websocket_endpoint(session: WebSocket) -> None:
     await session.accept()
     await session.send_text("Hello, world!")
     await session.close()
 
 
-async def websocket_raise_websocket(websocket: WebSocket):
+async def websocket_raise_websocket(websocket: WebSocket) -> None:
     await websocket.accept()
     raise WebSocketException(code=status.WS_1003_UNSUPPORTED_DATA)
 
@@ -76,12 +81,12 @@ class CustomWSException(Exception):
     pass
 
 
-async def websocket_raise_custom(websocket: WebSocket):
+async def websocket_raise_custom(websocket: WebSocket) -> None:
     await websocket.accept()
     raise CustomWSException()
 
 
-def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException):
+def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException) -> None:
     anyio.from_thread.run(websocket.close, status.WS_1013_TRY_AGAIN_LATER)
 
 
@@ -121,22 +126,22 @@ app = Starlette(
         Mount("/users", app=users),
         Host("{subdomain}.example.org", app=subdomain),
     ],
-    exception_handlers=exception_handlers,
+    exception_handlers=exception_handlers,  # type: ignore
     middleware=middleware,
 )
 
 
 @pytest.fixture
-def client(test_client_factory):
+def client(test_client_factory: TestClientFactory) -> Generator[TestClient, None, None]:
     with test_client_factory(app) as client:
         yield client
 
 
-def test_url_path_for():
+def test_url_path_for() -> None:
     assert app.url_path_for("func_homepage") == "/func"
 
 
-def test_func_route(client):
+def test_func_route(client: TestClient) -> None:
     response = client.get("/func")
     assert response.status_code == 200
     assert response.text == "Hello, world!"
@@ -146,31 +151,31 @@ def test_func_route(client):
     assert response.text == ""
 
 
-def test_async_route(client):
+def test_async_route(client: TestClient) -> None:
     response = client.get("/async")
     assert response.status_code == 200
     assert response.text == "Hello, world!"
 
 
-def test_class_route(client):
+def test_class_route(client: TestClient) -> None:
     response = client.get("/class")
     assert response.status_code == 200
     assert response.text == "Hello, world!"
 
 
-def test_mounted_route(client):
+def test_mounted_route(client: TestClient) -> None:
     response = client.get("/users/")
     assert response.status_code == 200
     assert response.text == "Hello, everyone!"
 
 
-def test_mounted_route_path_params(client):
+def test_mounted_route_path_params(client: TestClient) -> None:
     response = client.get("/users/tomchristie")
     assert response.status_code == 200
     assert response.text == "Hello, tomchristie!"
 
 
-def test_subdomain_route(test_client_factory):
+def test_subdomain_route(test_client_factory: TestClientFactory) -> None:
     client = test_client_factory(app, base_url="https://foo.example.org/")
 
     response = client.get("/")
@@ -178,19 +183,19 @@ def test_subdomain_route(test_client_factory):
     assert response.text == "Subdomain: foo"
 
 
-def test_websocket_route(client):
+def test_websocket_route(client: TestClient) -> None:
     with client.websocket_connect("/ws") as session:
         text = session.receive_text()
         assert text == "Hello, world!"
 
 
-def test_400(client):
+def test_400(client: TestClient) -> None:
     response = client.get("/404")
     assert response.status_code == 404
     assert response.json() == {"detail": "Not Found"}
 
 
-def test_405(client):
+def test_405(client: TestClient) -> None:
     response = client.post("/func")
     assert response.status_code == 405
     assert response.json() == {"detail": "Custom message"}
@@ -200,14 +205,14 @@ def test_405(client):
     assert response.json() == {"detail": "Custom message"}
 
 
-def test_500(test_client_factory):
+def test_500(test_client_factory: TestClientFactory) -> None:
     client = test_client_factory(app, raise_server_exceptions=False)
     response = client.get("/500")
     assert response.status_code == 500
     assert response.json() == {"detail": "Server Error"}
 
 
-def test_websocket_raise_websocket_exception(client):
+def test_websocket_raise_websocket_exception(client: TestClient) -> None:
     with client.websocket_connect("/ws-raise-websocket") as session:
         response = session.receive()
         assert response == {
@@ -217,7 +222,7 @@ def test_websocket_raise_websocket_exception(client):
         }
 
 
-def test_websocket_raise_custom_exception(client):
+def test_websocket_raise_custom_exception(client: TestClient) -> None:
     with client.websocket_connect("/ws-raise-custom") as session:
         response = session.receive()
         assert response == {
@@ -227,14 +232,14 @@ def test_websocket_raise_custom_exception(client):
         }
 
 
-def test_middleware(test_client_factory):
+def test_middleware(test_client_factory: TestClientFactory) -> None:
     client = test_client_factory(app, base_url="http://incorrecthost")
     response = client.get("/func")
     assert response.status_code == 400
     assert response.text == "Invalid host header"
 
 
-def test_routes():
+def test_routes() -> None:
     assert app.routes == [
         Route("/func", endpoint=func_homepage, methods=["GET"]),
         Route("/async", endpoint=async_homepage, methods=["GET"]),
@@ -259,7 +264,7 @@ def test_routes():
     ]
 
 
-def test_app_mount(tmpdir, test_client_factory):
+def test_app_mount(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     path = os.path.join(tmpdir, "example.txt")
     with open(path, "w") as file:
         file.write("<file content>")
@@ -281,8 +286,8 @@ def test_app_mount(tmpdir, test_client_factory):
     assert response.text == "Method Not Allowed"
 
 
-def test_app_debug(test_client_factory):
-    async def homepage(request):
+def test_app_debug(test_client_factory: TestClientFactory) -> None:
+    async def homepage(request: Request) -> None:
         raise RuntimeError()
 
     app = Starlette(
@@ -299,8 +304,8 @@ def test_app_debug(test_client_factory):
     assert app.debug
 
 
-def test_app_add_route(test_client_factory):
-    async def homepage(request):
+def test_app_add_route(test_client_factory: TestClientFactory) -> None:
+    async def homepage(request: Request) -> PlainTextResponse:
         return PlainTextResponse("Hello, World!")
 
     app = Starlette(
@@ -315,8 +320,8 @@ def test_app_add_route(test_client_factory):
     assert response.text == "Hello, World!"
 
 
-def test_app_add_websocket_route(test_client_factory):
-    async def websocket_endpoint(session):
+def test_app_add_websocket_route(test_client_factory: TestClientFactory) -> None:
+    async def websocket_endpoint(session: WebSocket) -> None:
         await session.accept()
         await session.send_text("Hello, world!")
         await session.close()
@@ -333,15 +338,15 @@ def test_app_add_websocket_route(test_client_factory):
         assert text == "Hello, world!"
 
 
-def test_app_add_event_handler(test_client_factory):
+def test_app_add_event_handler(test_client_factory: TestClientFactory) -> None:
     startup_complete = False
     cleanup_complete = False
 
-    def run_startup():
+    def run_startup() -> None:
         nonlocal startup_complete
         startup_complete = True
 
-    def run_cleanup():
+    def run_cleanup() -> None:
         nonlocal cleanup_complete
         cleanup_complete = True
 
@@ -362,12 +367,12 @@ def test_app_add_event_handler(test_client_factory):
     assert cleanup_complete
 
 
-def test_app_async_cm_lifespan(test_client_factory):
+def test_app_async_cm_lifespan(test_client_factory: TestClientFactory) -> None:
     startup_complete = False
     cleanup_complete = False
 
     @asynccontextmanager
-    async def lifespan(app):
+    async def lifespan(app: ASGIApp) -> AsyncGenerator[None, None]:
         nonlocal startup_complete, cleanup_complete
         startup_complete = True
         yield
@@ -394,17 +399,17 @@ deprecated_lifespan = pytest.mark.filterwarnings(
 
 
 @deprecated_lifespan
-def test_app_async_gen_lifespan(test_client_factory):
+def test_app_async_gen_lifespan(test_client_factory: TestClientFactory) -> None:
     startup_complete = False
     cleanup_complete = False
 
-    async def lifespan(app):
+    async def lifespan(app: ASGIApp) -> AsyncGenerator[None, None]:
         nonlocal startup_complete, cleanup_complete
         startup_complete = True
         yield
         cleanup_complete = True
 
-    app = Starlette(lifespan=lifespan)
+    app = Starlette(lifespan=lifespan)  # type: ignore
 
     assert not startup_complete
     assert not cleanup_complete
@@ -416,17 +421,17 @@ def test_app_async_gen_lifespan(test_client_factory):
 
 
 @deprecated_lifespan
-def test_app_sync_gen_lifespan(test_client_factory):
+def test_app_sync_gen_lifespan(test_client_factory: TestClientFactory) -> None:
     startup_complete = False
     cleanup_complete = False
 
-    def lifespan(app):
+    def lifespan(app: ASGIApp) -> Generator[None, None, None]:
         nonlocal startup_complete, cleanup_complete
         startup_complete = True
         yield
         cleanup_complete = True
 
-    app = Starlette(lifespan=lifespan)
+    app = Starlette(lifespan=lifespan)  # type: ignore
 
     assert not startup_complete
     assert not cleanup_complete
@@ -456,7 +461,9 @@ def test_decorator_deprecations() -> None:
         )
     ) as record:
 
-        async def middleware(request, call_next):
+        async def middleware(
+            request: Request, call_next: RequestResponseEndpoint
+        ) -> None:
             ...  # pragma: no cover
 
         app.middleware("http")(middleware)
@@ -487,14 +494,14 @@ def test_decorator_deprecations() -> None:
         )
     ) as record:
 
-        async def startup():
+        async def startup() -> None:
             ...  # pragma: no cover
 
         app.on_event("startup")(startup)
         assert len(record) == 1
 
 
-def test_middleware_stack_init(test_client_factory: Callable[[ASGIApp], httpx.Client]):
+def test_middleware_stack_init(test_client_factory: TestClientFactory) -> None:
     class NoOpMiddleware:
         def __init__(self, app: ASGIApp):
             self.app = app
@@ -536,7 +543,7 @@ def test_middleware_stack_init(test_client_factory: Callable[[ASGIApp], httpx.Cl
     assert SimpleInitializableMiddleware.counter == 2
 
 
-def test_lifespan_app_subclass():
+def test_lifespan_app_subclass() -> None:
     # This test exists to make sure that subclasses of Starlette
     # (like FastAPI) are compatible with the types hints for Lifespan