import os
from contextlib import asynccontextmanager
-from typing import AsyncIterator, Callable
+from pathlib import Path
+from typing import AsyncGenerator, AsyncIterator, Callable, Generator
import anyio
-import httpx
import pytest
from starlette import status
from starlette.endpoints import HTTPEndpoint
from starlette.exceptions import HTTPException, WebSocketException
from starlette.middleware import Middleware
+from starlette.middleware.base import RequestResponseEndpoint
from starlette.middleware.trustedhost import TrustedHostMiddleware
+from starlette.requests import Request
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.staticfiles import StaticFiles
+from starlette.testclient import TestClient
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebSocket
+TestClientFactory = Callable[..., TestClient]
-async def error_500(request, exc):
+
+async def error_500(request: Request, exc: HTTPException) -> JSONResponse:
return JSONResponse({"detail": "Server Error"}, status_code=500)
-async def method_not_allowed(request, exc):
+async def method_not_allowed(request: Request, exc: HTTPException) -> JSONResponse:
return JSONResponse({"detail": "Custom message"}, status_code=405)
-async def http_exception(request, exc):
+async def http_exception(request: Request, exc: HTTPException) -> JSONResponse:
return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
-def func_homepage(request):
+def func_homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Hello, world!")
-async def async_homepage(request):
+async def async_homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Hello, world!")
class Homepage(HTTPEndpoint):
- def get(self, request):
+ def get(self, request: Request) -> PlainTextResponse:
return PlainTextResponse("Hello, world!")
-def all_users_page(request):
+def all_users_page(request: Request) -> PlainTextResponse:
return PlainTextResponse("Hello, everyone!")
-def user_page(request):
+def user_page(request: Request) -> PlainTextResponse:
username = request.path_params["username"]
return PlainTextResponse(f"Hello, {username}!")
-def custom_subdomain(request):
+def custom_subdomain(request: Request) -> PlainTextResponse:
return PlainTextResponse("Subdomain: " + request.path_params["subdomain"])
-def runtime_error(request):
+def runtime_error(request: Request) -> None:
raise RuntimeError()
-async def websocket_endpoint(session):
+async def websocket_endpoint(session: WebSocket) -> None:
await session.accept()
await session.send_text("Hello, world!")
await session.close()
-async def websocket_raise_websocket(websocket: WebSocket):
+async def websocket_raise_websocket(websocket: WebSocket) -> None:
await websocket.accept()
raise WebSocketException(code=status.WS_1003_UNSUPPORTED_DATA)
pass
-async def websocket_raise_custom(websocket: WebSocket):
+async def websocket_raise_custom(websocket: WebSocket) -> None:
await websocket.accept()
raise CustomWSException()
-def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException):
+def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException) -> None:
anyio.from_thread.run(websocket.close, status.WS_1013_TRY_AGAIN_LATER)
Mount("/users", app=users),
Host("{subdomain}.example.org", app=subdomain),
],
- exception_handlers=exception_handlers,
+ exception_handlers=exception_handlers, # type: ignore
middleware=middleware,
)
@pytest.fixture
-def client(test_client_factory):
+def client(test_client_factory: TestClientFactory) -> Generator[TestClient, None, None]:
with test_client_factory(app) as client:
yield client
-def test_url_path_for():
+def test_url_path_for() -> None:
assert app.url_path_for("func_homepage") == "/func"
-def test_func_route(client):
+def test_func_route(client: TestClient) -> None:
response = client.get("/func")
assert response.status_code == 200
assert response.text == "Hello, world!"
assert response.text == ""
-def test_async_route(client):
+def test_async_route(client: TestClient) -> None:
response = client.get("/async")
assert response.status_code == 200
assert response.text == "Hello, world!"
-def test_class_route(client):
+def test_class_route(client: TestClient) -> None:
response = client.get("/class")
assert response.status_code == 200
assert response.text == "Hello, world!"
-def test_mounted_route(client):
+def test_mounted_route(client: TestClient) -> None:
response = client.get("/users/")
assert response.status_code == 200
assert response.text == "Hello, everyone!"
-def test_mounted_route_path_params(client):
+def test_mounted_route_path_params(client: TestClient) -> None:
response = client.get("/users/tomchristie")
assert response.status_code == 200
assert response.text == "Hello, tomchristie!"
-def test_subdomain_route(test_client_factory):
+def test_subdomain_route(test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app, base_url="https://foo.example.org/")
response = client.get("/")
assert response.text == "Subdomain: foo"
-def test_websocket_route(client):
+def test_websocket_route(client: TestClient) -> None:
with client.websocket_connect("/ws") as session:
text = session.receive_text()
assert text == "Hello, world!"
-def test_400(client):
+def test_400(client: TestClient) -> None:
response = client.get("/404")
assert response.status_code == 404
assert response.json() == {"detail": "Not Found"}
-def test_405(client):
+def test_405(client: TestClient) -> None:
response = client.post("/func")
assert response.status_code == 405
assert response.json() == {"detail": "Custom message"}
assert response.json() == {"detail": "Custom message"}
-def test_500(test_client_factory):
+def test_500(test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app, raise_server_exceptions=False)
response = client.get("/500")
assert response.status_code == 500
assert response.json() == {"detail": "Server Error"}
-def test_websocket_raise_websocket_exception(client):
+def test_websocket_raise_websocket_exception(client: TestClient) -> None:
with client.websocket_connect("/ws-raise-websocket") as session:
response = session.receive()
assert response == {
}
-def test_websocket_raise_custom_exception(client):
+def test_websocket_raise_custom_exception(client: TestClient) -> None:
with client.websocket_connect("/ws-raise-custom") as session:
response = session.receive()
assert response == {
}
-def test_middleware(test_client_factory):
+def test_middleware(test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app, base_url="http://incorrecthost")
response = client.get("/func")
assert response.status_code == 400
assert response.text == "Invalid host header"
-def test_routes():
+def test_routes() -> None:
assert app.routes == [
Route("/func", endpoint=func_homepage, methods=["GET"]),
Route("/async", endpoint=async_homepage, methods=["GET"]),
]
-def test_app_mount(tmpdir, test_client_factory):
+def test_app_mount(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("<file content>")
assert response.text == "Method Not Allowed"
-def test_app_debug(test_client_factory):
- async def homepage(request):
+def test_app_debug(test_client_factory: TestClientFactory) -> None:
+ async def homepage(request: Request) -> None:
raise RuntimeError()
app = Starlette(
assert app.debug
-def test_app_add_route(test_client_factory):
- async def homepage(request):
+def test_app_add_route(test_client_factory: TestClientFactory) -> None:
+ async def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Hello, World!")
app = Starlette(
assert response.text == "Hello, World!"
-def test_app_add_websocket_route(test_client_factory):
- async def websocket_endpoint(session):
+def test_app_add_websocket_route(test_client_factory: TestClientFactory) -> None:
+ async def websocket_endpoint(session: WebSocket) -> None:
await session.accept()
await session.send_text("Hello, world!")
await session.close()
assert text == "Hello, world!"
-def test_app_add_event_handler(test_client_factory):
+def test_app_add_event_handler(test_client_factory: TestClientFactory) -> None:
startup_complete = False
cleanup_complete = False
- def run_startup():
+ def run_startup() -> None:
nonlocal startup_complete
startup_complete = True
- def run_cleanup():
+ def run_cleanup() -> None:
nonlocal cleanup_complete
cleanup_complete = True
assert cleanup_complete
-def test_app_async_cm_lifespan(test_client_factory):
+def test_app_async_cm_lifespan(test_client_factory: TestClientFactory) -> None:
startup_complete = False
cleanup_complete = False
@asynccontextmanager
- async def lifespan(app):
+ async def lifespan(app: ASGIApp) -> AsyncGenerator[None, None]:
nonlocal startup_complete, cleanup_complete
startup_complete = True
yield
@deprecated_lifespan
-def test_app_async_gen_lifespan(test_client_factory):
+def test_app_async_gen_lifespan(test_client_factory: TestClientFactory) -> None:
startup_complete = False
cleanup_complete = False
- async def lifespan(app):
+ async def lifespan(app: ASGIApp) -> AsyncGenerator[None, None]:
nonlocal startup_complete, cleanup_complete
startup_complete = True
yield
cleanup_complete = True
- app = Starlette(lifespan=lifespan)
+ app = Starlette(lifespan=lifespan) # type: ignore
assert not startup_complete
assert not cleanup_complete
@deprecated_lifespan
-def test_app_sync_gen_lifespan(test_client_factory):
+def test_app_sync_gen_lifespan(test_client_factory: TestClientFactory) -> None:
startup_complete = False
cleanup_complete = False
- def lifespan(app):
+ def lifespan(app: ASGIApp) -> Generator[None, None, None]:
nonlocal startup_complete, cleanup_complete
startup_complete = True
yield
cleanup_complete = True
- app = Starlette(lifespan=lifespan)
+ app = Starlette(lifespan=lifespan) # type: ignore
assert not startup_complete
assert not cleanup_complete
)
) as record:
- async def middleware(request, call_next):
+ async def middleware(
+ request: Request, call_next: RequestResponseEndpoint
+ ) -> None:
... # pragma: no cover
app.middleware("http")(middleware)
)
) as record:
- async def startup():
+ async def startup() -> None:
... # pragma: no cover
app.on_event("startup")(startup)
assert len(record) == 1
-def test_middleware_stack_init(test_client_factory: Callable[[ASGIApp], httpx.Client]):
+def test_middleware_stack_init(test_client_factory: TestClientFactory) -> None:
class NoOpMiddleware:
def __init__(self, app: ASGIApp):
self.app = app
assert SimpleInitializableMiddleware.counter == 2
-def test_lifespan_app_subclass():
+def test_lifespan_app_subclass() -> None:
# This test exists to make sure that subclasses of Starlette
# (like FastAPI) are compatible with the types hints for Lifespan