from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketDisconnect
+TestClientFactory = typing.Callable[..., TestClient]
-def homepage(request):
+
+def homepage(request: Request) -> Response:
return Response("Hello, world", media_type="text/plain")
-def users(request):
+def users(request: Request) -> Response:
return Response("All users", media_type="text/plain")
-def user(request):
+def user(request: Request) -> Response:
content = "User " + request.path_params["username"]
return Response(content, media_type="text/plain")
-def user_me(request):
+def user_me(request: Request) -> Response:
content = "User fixed me"
return Response(content, media_type="text/plain")
-def disable_user(request):
+def disable_user(request: Request) -> Response:
content = "User " + request.path_params["username"] + " disabled"
return Response(content, media_type="text/plain")
-def user_no_match(request): # pragma: no cover
+def user_no_match(request: Request) -> Response: # pragma: no cover
content = "User fixed no match"
return Response(content, media_type="text/plain")
-async def partial_endpoint(arg, request):
+async def partial_endpoint(arg: str, request: Request) -> JSONResponse:
return JSONResponse({"arg": arg})
-async def partial_ws_endpoint(websocket: WebSocket):
+async def partial_ws_endpoint(websocket: WebSocket) -> None:
await websocket.accept()
await websocket.send_json({"url": str(websocket.url)})
await websocket.close()
class PartialRoutes:
@classmethod
- async def async_endpoint(cls, arg, request):
+ async def async_endpoint(cls, arg: str, request: Request) -> JSONResponse:
return JSONResponse({"arg": arg})
@classmethod
- async def async_ws_endpoint(cls, websocket: WebSocket):
+ async def async_ws_endpoint(cls, websocket: WebSocket) -> None:
await websocket.accept()
await websocket.send_json({"url": str(websocket.url)})
await websocket.close()
-def func_homepage(request):
+def func_homepage(request: Request) -> Response:
return Response("Hello, world!", media_type="text/plain")
-def contact(request):
+def contact(request: Request) -> Response:
return Response("Hello, POST!", media_type="text/plain")
-def int_convertor(request):
+def int_convertor(request: Request) -> JSONResponse:
number = request.path_params["param"]
return JSONResponse({"int": number})
-def float_convertor(request):
+def float_convertor(request: Request) -> JSONResponse:
num = request.path_params["param"]
return JSONResponse({"float": num})
-def path_convertor(request):
+def path_convertor(request: Request) -> JSONResponse:
path = request.path_params["param"]
return JSONResponse({"path": path})
-def uuid_converter(request):
+def uuid_converter(request: Request) -> JSONResponse:
uuid_param = request.path_params["param"]
return JSONResponse({"uuid": str(uuid_param)})
-def path_with_parentheses(request):
+def path_with_parentheses(request: Request) -> JSONResponse:
number = request.path_params["param"]
return JSONResponse({"int": number})
-async def websocket_endpoint(session: WebSocket):
+async def websocket_endpoint(session: WebSocket) -> None:
await session.accept()
await session.send_text("Hello, world!")
await session.close()
-async def websocket_params(session: WebSocket):
+async def websocket_params(session: WebSocket) -> None:
await session.accept()
await session.send_text(f"Hello, {session.path_params['room']}!")
await session.close()
@pytest.fixture
-def client(test_client_factory: typing.Callable[..., TestClient]):
+def client(
+ test_client_factory: TestClientFactory,
+) -> typing.Generator[TestClient, None, None]:
with test_client_factory(app) as client:
yield client
r":UserWarning"
r":charset_normalizer.api"
)
-def test_router(client: TestClient):
+def test_router(client: TestClient) -> None:
response = client.get("/")
assert response.status_code == 200
assert response.text == "Hello, world"
assert response.text == "xxxxx"
-def test_route_converters(client):
+def test_route_converters(client: TestClient) -> None:
# Test integer conversion
response = client.get("/int/5")
assert response.status_code == 200
)
-def test_url_path_for():
+def test_url_path_for() -> None:
assert app.url_path_for("homepage") == "/"
assert app.url_path_for("user", username="tomchristie") == "/users/tomchristie"
assert app.url_path_for("websocket_endpoint") == "/ws"
app.url_path_for("user", username="")
-def test_url_for():
+def test_url_for() -> None:
assert (
app.url_path_for("homepage").make_absolute_url(base_url="https://example.org")
== "https://example.org/"
)
-def test_router_add_route(client):
+def test_router_add_route(client: TestClient) -> None:
response = client.get("/func")
assert response.status_code == 200
assert response.text == "Hello, world!"
-def test_router_duplicate_path(client):
+def test_router_duplicate_path(client: TestClient) -> None:
response = client.post("/func")
assert response.status_code == 200
assert response.text == "Hello, POST!"
-def test_router_add_websocket_route(client):
+def test_router_add_websocket_route(client: TestClient) -> None:
with client.websocket_connect("/ws") as session:
text = session.receive_text()
assert text == "Hello, world!"
assert text == "Hello, test!"
-def test_router_middleware(test_client_factory: typing.Callable[..., TestClient]):
+def test_router_middleware(test_client_factory: TestClientFactory) -> None:
class CustomMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app
- async def __call__(self, scope: Scope, receive: Receive, send: Send):
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
response = PlainTextResponse("OK")
await response(scope, receive, send)
assert response.text == "OK"
-def http_endpoint(request):
+def http_endpoint(request: Request) -> Response:
url = request.url_for("http_endpoint")
return Response(f"URL: {url}", media_type="text/plain")
class WebSocketEndpoint:
- async def __call__(self, scope, receive, send):
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope=scope, receive=receive, send=send)
await websocket.accept()
await websocket.send_json({"URL": str(websocket.url_for("websocket_endpoint"))})
)
-def test_protocol_switch(test_client_factory):
+def test_protocol_switch(test_client_factory: TestClientFactory) -> None:
client = test_client_factory(mixed_protocol_app)
response = client.get("/")
ok = PlainTextResponse("OK")
-def test_mount_urls(test_client_factory):
+def test_mount_urls(test_client_factory: TestClientFactory) -> None:
mounted = Router([Mount("/users", ok, name="users")])
client = test_client_factory(mounted)
assert client.get("/users").status_code == 200
assert client.get("/usersa").status_code == 404
-def test_reverse_mount_urls():
+def test_reverse_mount_urls() -> None:
mounted = Router([Mount("/users", ok, name="users")])
assert mounted.url_path_for("users", path="/a") == "/users/a"
)
-def test_mount_at_root(test_client_factory):
+def test_mount_at_root(test_client_factory: TestClientFactory) -> None:
mounted = Router([Mount("/", ok, name="users")])
client = test_client_factory(mounted)
assert client.get("/").status_code == 200
-def users_api(request):
+def users_api(request: Request) -> JSONResponse:
return JSONResponse({"users": [{"username": "tom"}]})
)
-def test_host_routing(test_client_factory):
+def test_host_routing(test_client_factory: TestClientFactory) -> None:
client = test_client_factory(mixed_hosts_app, base_url="https://api.example.org/")
response = client.get("/users")
assert response.status_code == 200
-def test_host_reverse_urls():
+def test_host_reverse_urls() -> None:
assert (
mixed_hosts_app.url_path_for("homepage").make_absolute_url("https://whatever")
== "https://www.example.org/"
)
-async def subdomain_app(scope, receive, send):
+async def subdomain_app(scope: Scope, receive: Receive, send: Send) -> None:
response = JSONResponse({"subdomain": scope["path_params"]["subdomain"]})
await response(scope, receive, send)
)
-def test_subdomain_routing(test_client_factory):
+def test_subdomain_routing(test_client_factory: TestClientFactory) -> None:
client = test_client_factory(subdomain_router, base_url="https://foo.example.org/")
response = client.get("/")
assert response.json() == {"subdomain": "foo"}
-def test_subdomain_reverse_urls():
+def test_subdomain_reverse_urls() -> None:
assert (
subdomain_router.url_path_for(
"subdomains", subdomain="foo", path="/homepage"
)
-async def echo_urls(request):
+async def echo_urls(request: Request) -> JSONResponse:
return JSONResponse(
{
"index": str(request.url_for("index")),
]
-def test_url_for_with_root_path(test_client_factory):
+def test_url_for_with_root_path(test_client_factory: TestClientFactory) -> None:
app = Starlette(routes=echo_url_routes)
client = test_client_factory(
app, base_url="https://www.example.org/", root_path="/sub_path"
}
-async def stub_app(scope, receive, send):
+async def stub_app(scope: Scope, receive: Receive, send: Send) -> None:
pass # pragma: no cover
]
-def test_url_for_with_double_mount():
+def test_url_for_with_double_mount() -> None:
app = Starlette(routes=double_mount_routes)
url = app.url_path_for("mount:static", path="123")
assert url == "/mount/static/123"
-def test_standalone_route_matches(test_client_factory):
+def test_standalone_route_matches(
+ test_client_factory: TestClientFactory,
+) -> None:
app = Route("/", PlainTextResponse("Hello, World!"))
client = test_client_factory(app)
response = client.get("/")
assert response.text == "Hello, World!"
-def test_standalone_route_does_not_match(test_client_factory):
+def test_standalone_route_does_not_match(
+ test_client_factory: typing.Callable[..., TestClient],
+) -> None:
app = Route("/", PlainTextResponse("Hello, World!"))
client = test_client_factory(app)
response = client.get("/invalid")
assert response.text == "Not Found"
-async def ws_helloworld(websocket):
+async def ws_helloworld(websocket: WebSocket) -> None:
await websocket.accept()
await websocket.send_text("Hello, world!")
await websocket.close()
-def test_standalone_ws_route_matches(test_client_factory):
+def test_standalone_ws_route_matches(
+ test_client_factory: TestClientFactory,
+) -> None:
app = WebSocketRoute("/", ws_helloworld)
client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
assert text == "Hello, world!"
-def test_standalone_ws_route_does_not_match(test_client_factory):
+def test_standalone_ws_route_does_not_match(
+ test_client_factory: TestClientFactory,
+) -> None:
app = WebSocketRoute("/", ws_helloworld)
client = test_client_factory(app)
with pytest.raises(WebSocketDisconnect):
pass # pragma: nocover
-def test_lifespan_async(test_client_factory):
+def test_lifespan_async(test_client_factory: TestClientFactory) -> None:
startup_complete = False
shutdown_complete = False
- async def hello_world(request):
+ async def hello_world(request: Request) -> PlainTextResponse:
return PlainTextResponse("hello, world")
- async def run_startup():
+ async def run_startup() -> None:
nonlocal startup_complete
startup_complete = True
- async def run_shutdown():
+ async def run_shutdown() -> None:
nonlocal shutdown_complete
shutdown_complete = True
assert shutdown_complete
-def test_lifespan_with_on_events(test_client_factory: typing.Callable[..., TestClient]):
+def test_lifespan_with_on_events(test_client_factory: TestClientFactory) -> None:
lifespan_called = False
startup_called = False
shutdown_called = False
@contextlib.asynccontextmanager
- async def lifespan(app: Starlette):
+ async def lifespan(app: Starlette) -> typing.AsyncGenerator[None, None]:
nonlocal lifespan_called
lifespan_called = True
yield
# We do not expected, neither of run_startup nor run_shutdown to be called
# we thus mark them as #pragma: no cover, to fulfill test coverage
- def run_startup(): # pragma: no cover
+ def run_startup() -> None: # pragma: no cover
nonlocal startup_called
startup_called = True
- def run_shutdown(): # pragma: no cover
+ def run_shutdown() -> None: # pragma: no cover
nonlocal shutdown_called
shutdown_called = True
assert not shutdown_called
-def test_lifespan_sync(test_client_factory):
+def test_lifespan_sync(test_client_factory: TestClientFactory) -> None:
startup_complete = False
shutdown_complete = False
- def hello_world(request):
+ def hello_world(request: Request) -> PlainTextResponse:
return PlainTextResponse("hello, world")
- def run_startup():
+ def run_startup() -> None:
nonlocal startup_complete
startup_complete = True
- def run_shutdown():
+ def run_shutdown() -> None:
nonlocal shutdown_complete
shutdown_complete = True
assert shutdown_complete
-def test_lifespan_state_unsupported(test_client_factory):
+def test_lifespan_state_unsupported(
+ test_client_factory: TestClientFactory,
+) -> None:
@contextlib.asynccontextmanager
- async def lifespan(app):
+ async def lifespan(
+ app: ASGIApp,
+ ) -> typing.AsyncGenerator[typing.Dict[str, str], None]:
yield {"foo": "bar"}
app = Router(
routes=[Mount("/", PlainTextResponse("hello, world"))],
)
- async def no_state_wrapper(scope, receive, send):
+ async def no_state_wrapper(scope: Scope, receive: Receive, send: Send) -> None:
del scope["state"]
await app(scope, receive, send)
raise AssertionError("Should not be called") # pragma: no cover
-def test_lifespan_state_async_cm(test_client_factory):
+def test_lifespan_state_async_cm(test_client_factory: TestClientFactory) -> None:
startup_complete = False
shutdown_complete = False
assert shutdown_complete
-def test_raise_on_startup(test_client_factory):
- def run_startup():
+def test_raise_on_startup(test_client_factory: TestClientFactory) -> None:
+ def run_startup() -> None:
raise RuntimeError()
with pytest.deprecated_call(
router = Router(on_startup=[run_startup])
startup_failed = False
- async def app(scope, receive, send):
- async def _send(message):
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ async def _send(message: Message) -> None:
nonlocal startup_failed
if message["type"] == "lifespan.startup.failed":
startup_failed = True
assert startup_failed
-def test_raise_on_shutdown(test_client_factory):
- def run_shutdown():
+def test_raise_on_shutdown(test_client_factory: TestClientFactory) -> None:
+ def run_shutdown() -> None:
raise RuntimeError()
with pytest.deprecated_call(
pass # pragma: nocover
-def test_partial_async_endpoint(test_client_factory):
+def test_partial_async_endpoint(test_client_factory: TestClientFactory) -> None:
test_client = test_client_factory(app)
response = test_client.get("/partial")
assert response.status_code == 200
assert cls_method_response.json() == {"arg": "foo"}
-def test_partial_async_ws_endpoint(test_client_factory):
+def test_partial_async_ws_endpoint(
+ test_client_factory: TestClientFactory,
+) -> None:
test_client = test_client_factory(app)
with test_client.websocket_connect("/partial/ws") as websocket:
data = websocket.receive_json()
assert data == {"url": "ws://testserver/partial/ws/cls"}
-def test_duplicated_param_names():
+def test_duplicated_param_names() -> None:
with pytest.raises(
ValueError,
match="Duplicated param name id at path /{id}/{id}",
class Endpoint:
- async def my_method(self, request):
+ async def my_method(self, request: Request) -> None:
... # pragma: no cover
@classmethod
- async def my_classmethod(cls, request):
+ async def my_classmethod(cls, request: Request) -> None:
... # pragma: no cover
@staticmethod
- async def my_staticmethod(request):
+ async def my_staticmethod(request: Request) -> None:
... # pragma: no cover
- def __call__(self, request):
+ def __call__(self, request: Request) -> None:
... # pragma: no cover
pytest.param(lambda request: ..., "<lambda>", id="lambda"),
],
)
-def test_route_name(endpoint: typing.Callable[..., typing.Any], expected_name: str):
+def test_route_name(
+ endpoint: typing.Callable[..., Response], expected_name: str
+) -> None:
assert Route(path="/", endpoint=endpoint).name == expected_name
],
)
def test_base_route_middleware(
- test_client_factory: typing.Callable[..., TestClient],
+ test_client_factory: TestClientFactory,
app: Starlette,
) -> None:
test_client = test_client_factory(app)
assert response.status_code == 200
-def test_exception_on_mounted_apps(test_client_factory):
- def exc(request):
+def test_exception_on_mounted_apps(
+ test_client_factory: TestClientFactory,
+) -> None:
+ def exc(request: Request) -> None:
raise Exception("Exc")
sub_app = Starlette(routes=[Route("/", exc)])
def test_websocket_route_middleware(
- test_client_factory: typing.Callable[..., TestClient],
-):
- async def websocket_endpoint(session: WebSocket):
+ test_client_factory: TestClientFactory,
+) -> None:
+ async def websocket_endpoint(session: WebSocket) -> None:
await session.accept()
await session.send_text("Hello, world!")
await session.close()
router.on_event("startup")(startup)
-async def echo_paths(request: Request, name: str):
+async def echo_paths(request: Request, name: str) -> JSONResponse:
return JSONResponse(
{
"name": name,
)
-async def pure_asgi_echo_paths(scope: Scope, receive: Receive, send: Send, name: str):
+async def pure_asgi_echo_paths(
+ scope: Scope, receive: Receive, send: Send, name: str
+) -> None:
data = {"name": name, "path": scope["path"], "root_path": scope["root_path"]}
content = json.dumps(data).encode("utf-8")
await send(
]
-def test_paths_with_root_path(test_client_factory: typing.Callable[..., TestClient]):
+def test_paths_with_root_path(test_client_factory: TestClientFactory) -> None:
app = Starlette(routes=echo_paths_routes)
client = test_client_factory(
app, base_url="https://www.example.org/", root_path="/root"