+from typing import Any, Callable
+
import pytest
from starlette.applications import Starlette
from starlette.background import BackgroundTask
from starlette.middleware.errors import ServerErrorMiddleware
+from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.routing import Route
+from starlette.testclient import TestClient
+from starlette.types import Receive, Scope, Send
+
+TestClientFactory = Callable[..., TestClient]
-def test_handler(test_client_factory):
- async def app(scope, receive, send):
+def test_handler(
+ test_client_factory: TestClientFactory,
+) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
raise RuntimeError("Something went wrong")
- def error_500(request, exc):
+ def error_500(request: Request, exc: Exception) -> JSONResponse:
return JSONResponse({"detail": "Server Error"}, status_code=500)
app = ServerErrorMiddleware(app, handler=error_500)
assert response.json() == {"detail": "Server Error"}
-def test_debug_text(test_client_factory):
- async def app(scope, receive, send):
+def test_debug_text(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
raise RuntimeError("Something went wrong")
app = ServerErrorMiddleware(app, debug=True)
assert "RuntimeError: Something went wrong" in response.text
-def test_debug_html(test_client_factory):
- async def app(scope, receive, send):
+def test_debug_html(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
raise RuntimeError("Something went wrong")
app = ServerErrorMiddleware(app, debug=True)
assert "RuntimeError" in response.text
-def test_debug_after_response_sent(test_client_factory):
- async def app(scope, receive, send):
+def test_debug_after_response_sent(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = Response(b"", status_code=204)
await response(scope, receive, send)
raise RuntimeError("Something went wrong")
client.get("/")
-def test_debug_not_http(test_client_factory):
+def test_debug_not_http(test_client_factory: TestClientFactory) -> None:
"""
DebugMiddleware should just pass through any non-http messages as-is.
"""
- async def app(scope, receive, send):
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
raise RuntimeError("Something went wrong")
app = ServerErrorMiddleware(app)
pass # pragma: nocover
-def test_background_task(test_client_factory):
+def test_background_task(test_client_factory: TestClientFactory) -> None:
accessed_error_handler = False
- def error_handler(request, exc):
+ def error_handler(request: Request, exc: Exception) -> Any:
nonlocal accessed_error_handler
accessed_error_handler = True
- def raise_exception():
+ def raise_exception() -> None:
raise Exception("Something went wrong")
- async def endpoint(request):
+ async def endpoint(request: Request) -> Response:
task = BackgroundTask(raise_exception)
return Response(status_code=204, background=task)