]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Added type annotations to test_endpoints.py (#2478)
authorScirlat Danut <danut.scirlat@gmail.com>
Sun, 4 Feb 2024 21:06:00 +0000 (23:06 +0200)
committerGitHub <noreply@github.com>
Sun, 4 Feb 2024 21:06:00 +0000 (14:06 -0700)
* added type annotations to test_endpoints.py

* Apply suggestions from code review

---------

Co-authored-by: Scirlat Danut <scirlatdanut@scirlats-mini.lan>
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
tests/test_endpoints.py

index 9895a455904e6c093457e0c2a4567c25bd14bb8e..eeb0f2322f86d7c2159161e1adc39accdda2c622 100644 (file)
@@ -1,12 +1,19 @@
+from typing import Callable, Iterator
+
 import pytest
 
 from starlette.endpoints import HTTPEndpoint, WebSocketEndpoint
+from starlette.requests import Request
 from starlette.responses import PlainTextResponse
 from starlette.routing import Route, Router
+from starlette.testclient import TestClient
+from starlette.websockets import WebSocket
+
+TestClientFactory = Callable[..., TestClient]
 
 
 class Homepage(HTTPEndpoint):
-    async def get(self, request):
+    async def get(self, request: Request) -> PlainTextResponse:
         username = request.path_params.get("username")
         if username is None:
             return PlainTextResponse("Hello, world!")
@@ -19,33 +26,33 @@ app = Router(
 
 
 @pytest.fixture
-def client(test_client_factory):
+def client(test_client_factory: TestClientFactory) -> Iterator[TestClient]:
     with test_client_factory(app) as client:
         yield client
 
 
-def test_http_endpoint_route(client):
+def test_http_endpoint_route(client: TestClient) -> None:
     response = client.get("/")
     assert response.status_code == 200
     assert response.text == "Hello, world!"
 
 
-def test_http_endpoint_route_path_params(client):
+def test_http_endpoint_route_path_params(client: TestClient) -> None:
     response = client.get("/tomchristie")
     assert response.status_code == 200
     assert response.text == "Hello, tomchristie!"
 
 
-def test_http_endpoint_route_method(client):
+def test_http_endpoint_route_method(client: TestClient) -> None:
     response = client.post("/")
     assert response.status_code == 405
     assert response.text == "Method Not Allowed"
     assert response.headers["allow"] == "GET"
 
 
-def test_websocket_endpoint_on_connect(test_client_factory):
+def test_websocket_endpoint_on_connect(test_client_factory: TestClientFactory) -> None:
     class WebSocketApp(WebSocketEndpoint):
-        async def on_connect(self, websocket):
+        async def on_connect(self, websocket: WebSocket) -> None:
             assert websocket["subprotocols"] == ["soap", "wamp"]
             await websocket.accept(subprotocol="wamp")
 
@@ -54,11 +61,13 @@ def test_websocket_endpoint_on_connect(test_client_factory):
         assert websocket.accepted_subprotocol == "wamp"
 
 
-def test_websocket_endpoint_on_receive_bytes(test_client_factory):
+def test_websocket_endpoint_on_receive_bytes(
+    test_client_factory: TestClientFactory,
+) -> None:
     class WebSocketApp(WebSocketEndpoint):
         encoding = "bytes"
 
-        async def on_receive(self, websocket, data):
+        async def on_receive(self, websocket: WebSocket, data: bytes) -> None:
             await websocket.send_bytes(b"Message bytes was: " + data)
 
     client = test_client_factory(WebSocketApp)
@@ -72,11 +81,13 @@ def test_websocket_endpoint_on_receive_bytes(test_client_factory):
             websocket.send_text("Hello world")
 
 
-def test_websocket_endpoint_on_receive_json(test_client_factory):
+def test_websocket_endpoint_on_receive_json(
+    test_client_factory: TestClientFactory,
+) -> None:
     class WebSocketApp(WebSocketEndpoint):
         encoding = "json"
 
-        async def on_receive(self, websocket, data):
+        async def on_receive(self, websocket: WebSocket, data: str) -> None:
             await websocket.send_json({"message": data})
 
     client = test_client_factory(WebSocketApp)
@@ -90,11 +101,13 @@ def test_websocket_endpoint_on_receive_json(test_client_factory):
             websocket.send_text("Hello world")
 
 
-def test_websocket_endpoint_on_receive_json_binary(test_client_factory):
+def test_websocket_endpoint_on_receive_json_binary(
+    test_client_factory: TestClientFactory,
+) -> None:
     class WebSocketApp(WebSocketEndpoint):
         encoding = "json"
 
-        async def on_receive(self, websocket, data):
+        async def on_receive(self, websocket: WebSocket, data: str) -> None:
             await websocket.send_json({"message": data}, mode="binary")
 
     client = test_client_factory(WebSocketApp)
@@ -104,11 +117,13 @@ def test_websocket_endpoint_on_receive_json_binary(test_client_factory):
         assert data == {"message": {"hello": "world"}}
 
 
-def test_websocket_endpoint_on_receive_text(test_client_factory):
+def test_websocket_endpoint_on_receive_text(
+    test_client_factory: TestClientFactory,
+) -> None:
     class WebSocketApp(WebSocketEndpoint):
         encoding = "text"
 
-        async def on_receive(self, websocket, data):
+        async def on_receive(self, websocket: WebSocket, data: str) -> None:
             await websocket.send_text(f"Message text was: {data}")
 
     client = test_client_factory(WebSocketApp)
@@ -122,11 +137,11 @@ def test_websocket_endpoint_on_receive_text(test_client_factory):
             websocket.send_bytes(b"Hello world")
 
 
-def test_websocket_endpoint_on_default(test_client_factory):
+def test_websocket_endpoint_on_default(test_client_factory: TestClientFactory) -> None:
     class WebSocketApp(WebSocketEndpoint):
         encoding = None
 
-        async def on_receive(self, websocket, data):
+        async def on_receive(self, websocket: WebSocket, data: str) -> None:
             await websocket.send_text(f"Message text was: {data}")
 
     client = test_client_factory(WebSocketApp)
@@ -136,9 +151,11 @@ def test_websocket_endpoint_on_default(test_client_factory):
         assert _text == "Message text was: Hello, world!"
 
 
-def test_websocket_endpoint_on_disconnect(test_client_factory):
+def test_websocket_endpoint_on_disconnect(
+    test_client_factory: TestClientFactory,
+) -> None:
     class WebSocketApp(WebSocketEndpoint):
-        async def on_disconnect(self, websocket, close_code):
+        async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None:
             assert close_code == 1001
             await websocket.close(code=close_code)