]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Added type annotations to test_error.py (#2462)
authorScirlat Danut <danut.scirlat@gmail.com>
Sun, 4 Feb 2024 20:54:16 +0000 (22:54 +0200)
committerGitHub <noreply@github.com>
Sun, 4 Feb 2024 20:54:16 +0000 (20:54 +0000)
* added type annotations to test_error.py

* Apply suggestions from code review

---------

Co-authored-by: Scirlat Danut <scirlatdanut@scirlats-mini.lan>
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
tests/middleware/test_errors.py

index 392c2ba16a6d8074314aa4267c6f4071047b1366..a2dbabd8a15ec5f984075728bab28f717141a48a 100644 (file)
@@ -1,17 +1,26 @@
+from typing import Any, Callable
+
 import pytest
 
 from starlette.applications import Starlette
 from starlette.background import BackgroundTask
 from starlette.middleware.errors import ServerErrorMiddleware
+from starlette.requests import Request
 from starlette.responses import JSONResponse, Response
 from starlette.routing import Route
+from starlette.testclient import TestClient
+from starlette.types import Receive, Scope, Send
+
+TestClientFactory = Callable[..., TestClient]
 
 
-def test_handler(test_client_factory):
-    async def app(scope, receive, send):
+def test_handler(
+    test_client_factory: TestClientFactory,
+) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         raise RuntimeError("Something went wrong")
 
-    def error_500(request, exc):
+    def error_500(request: Request, exc: Exception) -> JSONResponse:
         return JSONResponse({"detail": "Server Error"}, status_code=500)
 
     app = ServerErrorMiddleware(app, handler=error_500)
@@ -21,8 +30,8 @@ def test_handler(test_client_factory):
     assert response.json() == {"detail": "Server Error"}
 
 
-def test_debug_text(test_client_factory):
-    async def app(scope, receive, send):
+def test_debug_text(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         raise RuntimeError("Something went wrong")
 
     app = ServerErrorMiddleware(app, debug=True)
@@ -33,8 +42,8 @@ def test_debug_text(test_client_factory):
     assert "RuntimeError: Something went wrong" in response.text
 
 
-def test_debug_html(test_client_factory):
-    async def app(scope, receive, send):
+def test_debug_html(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         raise RuntimeError("Something went wrong")
 
     app = ServerErrorMiddleware(app, debug=True)
@@ -45,8 +54,8 @@ def test_debug_html(test_client_factory):
     assert "RuntimeError" in response.text
 
 
-def test_debug_after_response_sent(test_client_factory):
-    async def app(scope, receive, send):
+def test_debug_after_response_sent(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         response = Response(b"", status_code=204)
         await response(scope, receive, send)
         raise RuntimeError("Something went wrong")
@@ -57,12 +66,12 @@ def test_debug_after_response_sent(test_client_factory):
         client.get("/")
 
 
-def test_debug_not_http(test_client_factory):
+def test_debug_not_http(test_client_factory: TestClientFactory) -> None:
     """
     DebugMiddleware should just pass through any non-http messages as-is.
     """
 
-    async def app(scope, receive, send):
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         raise RuntimeError("Something went wrong")
 
     app = ServerErrorMiddleware(app)
@@ -73,17 +82,17 @@ def test_debug_not_http(test_client_factory):
             pass  # pragma: nocover
 
 
-def test_background_task(test_client_factory):
+def test_background_task(test_client_factory: TestClientFactory) -> None:
     accessed_error_handler = False
 
-    def error_handler(request, exc):
+    def error_handler(request: Request, exc: Exception) -> Any:
         nonlocal accessed_error_handler
         accessed_error_handler = True
 
-    def raise_exception():
+    def raise_exception() -> None:
         raise Exception("Something went wrong")
 
-    async def endpoint(request):
+    async def endpoint(request: Request) -> Response:
         task = BackgroundTask(raise_exception)
         return Response(status_code=204, background=task)