+from typing import Callable, Iterator
+
import pytest
from starlette.endpoints import HTTPEndpoint, WebSocketEndpoint
+from starlette.requests import Request
from starlette.responses import PlainTextResponse
from starlette.routing import Route, Router
+from starlette.testclient import TestClient
+from starlette.websockets import WebSocket
+
+TestClientFactory = Callable[..., TestClient]
class Homepage(HTTPEndpoint):
- async def get(self, request):
+ async def get(self, request: Request) -> PlainTextResponse:
username = request.path_params.get("username")
if username is None:
return PlainTextResponse("Hello, world!")
@pytest.fixture
-def client(test_client_factory):
+def client(test_client_factory: TestClientFactory) -> Iterator[TestClient]:
with test_client_factory(app) as client:
yield client
-def test_http_endpoint_route(client):
+def test_http_endpoint_route(client: TestClient) -> None:
response = client.get("/")
assert response.status_code == 200
assert response.text == "Hello, world!"
-def test_http_endpoint_route_path_params(client):
+def test_http_endpoint_route_path_params(client: TestClient) -> None:
response = client.get("/tomchristie")
assert response.status_code == 200
assert response.text == "Hello, tomchristie!"
-def test_http_endpoint_route_method(client):
+def test_http_endpoint_route_method(client: TestClient) -> None:
response = client.post("/")
assert response.status_code == 405
assert response.text == "Method Not Allowed"
assert response.headers["allow"] == "GET"
-def test_websocket_endpoint_on_connect(test_client_factory):
+def test_websocket_endpoint_on_connect(test_client_factory: TestClientFactory) -> None:
class WebSocketApp(WebSocketEndpoint):
- async def on_connect(self, websocket):
+ async def on_connect(self, websocket: WebSocket) -> None:
assert websocket["subprotocols"] == ["soap", "wamp"]
await websocket.accept(subprotocol="wamp")
assert websocket.accepted_subprotocol == "wamp"
-def test_websocket_endpoint_on_receive_bytes(test_client_factory):
+def test_websocket_endpoint_on_receive_bytes(
+ test_client_factory: TestClientFactory,
+) -> None:
class WebSocketApp(WebSocketEndpoint):
encoding = "bytes"
- async def on_receive(self, websocket, data):
+ async def on_receive(self, websocket: WebSocket, data: bytes) -> None:
await websocket.send_bytes(b"Message bytes was: " + data)
client = test_client_factory(WebSocketApp)
websocket.send_text("Hello world")
-def test_websocket_endpoint_on_receive_json(test_client_factory):
+def test_websocket_endpoint_on_receive_json(
+ test_client_factory: TestClientFactory,
+) -> None:
class WebSocketApp(WebSocketEndpoint):
encoding = "json"
- async def on_receive(self, websocket, data):
+ async def on_receive(self, websocket: WebSocket, data: str) -> None:
await websocket.send_json({"message": data})
client = test_client_factory(WebSocketApp)
websocket.send_text("Hello world")
-def test_websocket_endpoint_on_receive_json_binary(test_client_factory):
+def test_websocket_endpoint_on_receive_json_binary(
+ test_client_factory: TestClientFactory,
+) -> None:
class WebSocketApp(WebSocketEndpoint):
encoding = "json"
- async def on_receive(self, websocket, data):
+ async def on_receive(self, websocket: WebSocket, data: str) -> None:
await websocket.send_json({"message": data}, mode="binary")
client = test_client_factory(WebSocketApp)
assert data == {"message": {"hello": "world"}}
-def test_websocket_endpoint_on_receive_text(test_client_factory):
+def test_websocket_endpoint_on_receive_text(
+ test_client_factory: TestClientFactory,
+) -> None:
class WebSocketApp(WebSocketEndpoint):
encoding = "text"
- async def on_receive(self, websocket, data):
+ async def on_receive(self, websocket: WebSocket, data: str) -> None:
await websocket.send_text(f"Message text was: {data}")
client = test_client_factory(WebSocketApp)
websocket.send_bytes(b"Hello world")
-def test_websocket_endpoint_on_default(test_client_factory):
+def test_websocket_endpoint_on_default(test_client_factory: TestClientFactory) -> None:
class WebSocketApp(WebSocketEndpoint):
encoding = None
- async def on_receive(self, websocket, data):
+ async def on_receive(self, websocket: WebSocket, data: str) -> None:
await websocket.send_text(f"Message text was: {data}")
client = test_client_factory(WebSocketApp)
assert _text == "Message text was: Hello, world!"
-def test_websocket_endpoint_on_disconnect(test_client_factory):
+def test_websocket_endpoint_on_disconnect(
+ test_client_factory: TestClientFactory,
+) -> None:
class WebSocketApp(WebSocketEndpoint):
- async def on_disconnect(self, websocket, close_code):
+ async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None:
assert close_code == 1001
await websocket.close(code=close_code)