]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Added type annotations to test_exceptions.py (#2479)
authorScirlat Danut <danut.scirlat@gmail.com>
Sun, 4 Feb 2024 18:21:41 +0000 (20:21 +0200)
committerGitHub <noreply@github.com>
Sun, 4 Feb 2024 18:21:41 +0000 (18:21 +0000)
Co-authored-by: Scirlat Danut <scirlatdanut@scirlats-mini.lan>
tests/test_exceptions.py

index 735f2401ed41fe7a783232b34b0744f5c82d4f9d..401ad82126dc7b2ace18a4f67717b1ae536305aa 100644 (file)
@@ -1,4 +1,5 @@
 import warnings
+from typing import Callable, Generator
 
 import pytest
 
@@ -7,25 +8,29 @@ from starlette.middleware.exceptions import ExceptionMiddleware
 from starlette.requests import Request
 from starlette.responses import JSONResponse, PlainTextResponse
 from starlette.routing import Route, Router, WebSocketRoute
+from starlette.testclient import TestClient
+from starlette.types import Receive, Scope, Send
 
+TestClientFactory = Callable[..., TestClient]
 
-def raise_runtime_error(request):
+
+def raise_runtime_error(request: Request) -> None:
     raise RuntimeError("Yikes")
 
 
-def not_acceptable(request):
+def not_acceptable(request: Request) -> None:
     raise HTTPException(status_code=406)
 
 
-def no_content(request):
+def no_content(request: Request) -> None:
     raise HTTPException(status_code=204)
 
 
-def not_modified(request):
+def not_modified(request: Request) -> None:
     raise HTTPException(status_code=304)
 
 
-def with_headers(request):
+def with_headers(request: Request) -> None:
     raise HTTPException(status_code=200, headers={"x-potato": "always"})
 
 
@@ -33,7 +38,7 @@ class BadBodyException(HTTPException):
     pass
 
 
-async def read_body_and_raise_exc(request: Request):
+async def read_body_and_raise_exc(request: Request) -> None:
     await request.body()
     raise BadBodyException(422)
 
@@ -46,7 +51,7 @@ async def handler_that_reads_body(
 
 
 class HandledExcAfterResponse:
-    async def __call__(self, scope, receive, send):
+    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
         response = PlainTextResponse("OK", status_code=200)
         await response(scope, receive, send)
         raise HTTPException(status_code=406)
@@ -77,42 +82,45 @@ app = ExceptionMiddleware(
 
 
 @pytest.fixture
-def client(test_client_factory):
+def client(test_client_factory: TestClientFactory) -> Generator[TestClient, None, None]:
     with test_client_factory(app) as client:
         yield client
 
 
-def test_not_acceptable(client):
+def test_not_acceptable(client: TestClient) -> None:
     response = client.get("/not_acceptable")
     assert response.status_code == 406
     assert response.text == "Not Acceptable"
 
 
-def test_no_content(client):
+def test_no_content(client: TestClient) -> None:
     response = client.get("/no_content")
     assert response.status_code == 204
     assert "content-length" not in response.headers
 
 
-def test_not_modified(client):
+def test_not_modified(client: TestClient) -> None:
     response = client.get("/not_modified")
     assert response.status_code == 304
     assert response.text == ""
 
 
-def test_with_headers(client):
+def test_with_headers(client: TestClient) -> None:
     response = client.get("/with_headers")
     assert response.status_code == 200
     assert response.headers["x-potato"] == "always"
 
 
-def test_websockets_should_raise(client):
+def test_websockets_should_raise(client: TestClient) -> None:
     with pytest.raises(RuntimeError):
         with client.websocket_connect("/runtime_error"):
             pass  # pragma: nocover
 
 
-def test_handled_exc_after_response(test_client_factory, client):
+def test_handled_exc_after_response(
+    test_client_factory: TestClientFactory,
+    client: TestClient,
+) -> None:
     # A 406 HttpException is raised *after* the response has already been sent.
     # The exception middleware should raise a RuntimeError.
     with pytest.raises(RuntimeError):
@@ -126,13 +134,13 @@ def test_handled_exc_after_response(test_client_factory, client):
     assert response.text == "OK"
 
 
-def test_force_500_response(test_client_factory):
+def test_force_500_response(test_client_factory: TestClientFactory) -> None:
     # use a sentinal variable to make sure we actually
     # make it into the endpoint and don't get a 500
     # from an incorrect ASGI app signature or something
     called = False
 
-    async def app(scope, receive, send):
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         nonlocal called
         called = True
         raise RuntimeError()
@@ -144,13 +152,13 @@ def test_force_500_response(test_client_factory):
     assert response.text == ""
 
 
-def test_http_str():
+def test_http_str() -> None:
     assert str(HTTPException(status_code=404)) == "404: Not Found"
     assert str(HTTPException(404, "Not Found: foo")) == "404: Not Found: foo"
     assert str(HTTPException(404, headers={"key": "value"})) == "404: Not Found"
 
 
-def test_http_repr():
+def test_http_repr() -> None:
     assert repr(HTTPException(404)) == (
         "HTTPException(status_code=404, detail='Not Found')"
     )
@@ -166,12 +174,12 @@ def test_http_repr():
     )
 
 
-def test_websocket_str():
+def test_websocket_str() -> None:
     assert str(WebSocketException(1008)) == "1008: "
     assert str(WebSocketException(1008, "Policy Violation")) == "1008: Policy Violation"
 
 
-def test_websocket_repr():
+def test_websocket_repr() -> None:
     assert repr(WebSocketException(1008, reason="Policy Violation")) == (
         "WebSocketException(code=1008, reason='Policy Violation')"
     )
@@ -198,7 +206,7 @@ def test_exception_middleware_deprecation() -> None:
         starlette.exceptions.ExceptionMiddleware
 
 
-def test_request_in_app_and_handler_is_the_same_object(client) -> None:
+def test_request_in_app_and_handler_is_the_same_object(client: TestClient) -> None:
     response = client.post("/consume_body_in_endpoint_and_handler", content=b"Hello!")
     assert response.status_code == 422
     assert response.json() == {"body": "Hello!"}