From: Scirlat Danut Date: Wed, 7 Feb 2024 20:17:51 +0000 (+0200) Subject: Add type hints to `test_routing.py` (#2489) X-Git-Tag: 0.37.1~6 X-Git-Url: http://git.ipfire.org/gitweb/gitweb.cgi?a=commitdiff_plain;h=9576eac076337da5ca26d407067ff173c9047387;p=thirdparty%2Fstarlette.git Add type hints to `test_routing.py` (#2489) * Add type hints to test_routing.py * Apply suggestions from code review * Fix check errors * Apply suggestions from code review --------- Co-authored-by: Scirlat Danut Co-authored-by: Marcelo Trylesinski --- diff --git a/tests/test_routing.py b/tests/test_routing.py index dd7083e8..8c3f1663 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -16,40 +16,42 @@ from starlette.testclient import TestClient from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketDisconnect +TestClientFactory = typing.Callable[..., TestClient] -def homepage(request): + +def homepage(request: Request) -> Response: return Response("Hello, world", media_type="text/plain") -def users(request): +def users(request: Request) -> Response: return Response("All users", media_type="text/plain") -def user(request): +def user(request: Request) -> Response: content = "User " + request.path_params["username"] return Response(content, media_type="text/plain") -def user_me(request): +def user_me(request: Request) -> Response: content = "User fixed me" return Response(content, media_type="text/plain") -def disable_user(request): +def disable_user(request: Request) -> Response: content = "User " + request.path_params["username"] + " disabled" return Response(content, media_type="text/plain") -def user_no_match(request): # pragma: no cover +def user_no_match(request: Request) -> Response: # pragma: no cover content = "User fixed no match" return Response(content, media_type="text/plain") -async def partial_endpoint(arg, request): +async def partial_endpoint(arg: str, request: Request) -> JSONResponse: return JSONResponse({"arg": arg}) -async def partial_ws_endpoint(websocket: WebSocket): +async def partial_ws_endpoint(websocket: WebSocket) -> None: await websocket.accept() await websocket.send_json({"url": str(websocket.url)}) await websocket.close() @@ -57,56 +59,56 @@ async def partial_ws_endpoint(websocket: WebSocket): class PartialRoutes: @classmethod - async def async_endpoint(cls, arg, request): + async def async_endpoint(cls, arg: str, request: Request) -> JSONResponse: return JSONResponse({"arg": arg}) @classmethod - async def async_ws_endpoint(cls, websocket: WebSocket): + async def async_ws_endpoint(cls, websocket: WebSocket) -> None: await websocket.accept() await websocket.send_json({"url": str(websocket.url)}) await websocket.close() -def func_homepage(request): +def func_homepage(request: Request) -> Response: return Response("Hello, world!", media_type="text/plain") -def contact(request): +def contact(request: Request) -> Response: return Response("Hello, POST!", media_type="text/plain") -def int_convertor(request): +def int_convertor(request: Request) -> JSONResponse: number = request.path_params["param"] return JSONResponse({"int": number}) -def float_convertor(request): +def float_convertor(request: Request) -> JSONResponse: num = request.path_params["param"] return JSONResponse({"float": num}) -def path_convertor(request): +def path_convertor(request: Request) -> JSONResponse: path = request.path_params["param"] return JSONResponse({"path": path}) -def uuid_converter(request): +def uuid_converter(request: Request) -> JSONResponse: uuid_param = request.path_params["param"] return JSONResponse({"uuid": str(uuid_param)}) -def path_with_parentheses(request): +def path_with_parentheses(request: Request) -> JSONResponse: number = request.path_params["param"] return JSONResponse({"int": number}) -async def websocket_endpoint(session: WebSocket): +async def websocket_endpoint(session: WebSocket) -> None: await session.accept() await session.send_text("Hello, world!") await session.close() -async def websocket_params(session: WebSocket): +async def websocket_params(session: WebSocket) -> None: await session.accept() await session.send_text(f"Hello, {session.path_params['room']}!") await session.close() @@ -160,7 +162,9 @@ app = Router( @pytest.fixture -def client(test_client_factory: typing.Callable[..., TestClient]): +def client( + test_client_factory: TestClientFactory, +) -> typing.Generator[TestClient, None, None]: with test_client_factory(app) as client: yield client @@ -171,7 +175,7 @@ def client(test_client_factory: typing.Callable[..., TestClient]): r":UserWarning" r":charset_normalizer.api" ) -def test_router(client: TestClient): +def test_router(client: TestClient) -> None: response = client.get("/") assert response.status_code == 200 assert response.text == "Hello, world" @@ -216,7 +220,7 @@ def test_router(client: TestClient): assert response.text == "xxxxx" -def test_route_converters(client): +def test_route_converters(client: TestClient) -> None: # Test integer conversion response = client.get("/int/5") assert response.status_code == 200 @@ -258,7 +262,7 @@ def test_route_converters(client): ) -def test_url_path_for(): +def test_url_path_for() -> None: assert app.url_path_for("homepage") == "/" assert app.url_path_for("user", username="tomchristie") == "/users/tomchristie" assert app.url_path_for("websocket_endpoint") == "/ws" @@ -276,7 +280,7 @@ def test_url_path_for(): app.url_path_for("user", username="") -def test_url_for(): +def test_url_for() -> None: assert ( app.url_path_for("homepage").make_absolute_url(base_url="https://example.org") == "https://example.org/" @@ -307,19 +311,19 @@ def test_url_for(): ) -def test_router_add_route(client): +def test_router_add_route(client: TestClient) -> None: response = client.get("/func") assert response.status_code == 200 assert response.text == "Hello, world!" -def test_router_duplicate_path(client): +def test_router_duplicate_path(client: TestClient) -> None: response = client.post("/func") assert response.status_code == 200 assert response.text == "Hello, POST!" -def test_router_add_websocket_route(client): +def test_router_add_websocket_route(client: TestClient) -> None: with client.websocket_connect("/ws") as session: text = session.receive_text() assert text == "Hello, world!" @@ -329,12 +333,12 @@ def test_router_add_websocket_route(client): assert text == "Hello, test!" -def test_router_middleware(test_client_factory: typing.Callable[..., TestClient]): +def test_router_middleware(test_client_factory: TestClientFactory) -> None: class CustomMiddleware: def __init__(self, app: ASGIApp) -> None: self.app = app - async def __call__(self, scope: Scope, receive: Receive, send: Send): + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: response = PlainTextResponse("OK") await response(scope, receive, send) @@ -349,13 +353,13 @@ def test_router_middleware(test_client_factory: typing.Callable[..., TestClient] assert response.text == "OK" -def http_endpoint(request): +def http_endpoint(request: Request) -> Response: url = request.url_for("http_endpoint") return Response(f"URL: {url}", media_type="text/plain") class WebSocketEndpoint: - async def __call__(self, scope, receive, send): + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope=scope, receive=receive, send=send) await websocket.accept() await websocket.send_json({"URL": str(websocket.url_for("websocket_endpoint"))}) @@ -370,7 +374,7 @@ mixed_protocol_app = Router( ) -def test_protocol_switch(test_client_factory): +def test_protocol_switch(test_client_factory: TestClientFactory) -> None: client = test_client_factory(mixed_protocol_app) response = client.get("/") @@ -388,7 +392,7 @@ def test_protocol_switch(test_client_factory): ok = PlainTextResponse("OK") -def test_mount_urls(test_client_factory): +def test_mount_urls(test_client_factory: TestClientFactory) -> None: mounted = Router([Mount("/users", ok, name="users")]) client = test_client_factory(mounted) assert client.get("/users").status_code == 200 @@ -398,7 +402,7 @@ def test_mount_urls(test_client_factory): assert client.get("/usersa").status_code == 404 -def test_reverse_mount_urls(): +def test_reverse_mount_urls() -> None: mounted = Router([Mount("/users", ok, name="users")]) assert mounted.url_path_for("users", path="/a") == "/users/a" @@ -413,13 +417,13 @@ def test_reverse_mount_urls(): ) -def test_mount_at_root(test_client_factory): +def test_mount_at_root(test_client_factory: TestClientFactory) -> None: mounted = Router([Mount("/", ok, name="users")]) client = test_client_factory(mounted) assert client.get("/").status_code == 200 -def users_api(request): +def users_api(request: Request) -> JSONResponse: return JSONResponse({"users": [{"username": "tom"}]}) @@ -448,7 +452,7 @@ mixed_hosts_app = Router( ) -def test_host_routing(test_client_factory): +def test_host_routing(test_client_factory: TestClientFactory) -> None: client = test_client_factory(mixed_hosts_app, base_url="https://api.example.org/") response = client.get("/users") @@ -492,7 +496,7 @@ def test_host_routing(test_client_factory): assert response.status_code == 200 -def test_host_reverse_urls(): +def test_host_reverse_urls() -> None: assert ( mixed_hosts_app.url_path_for("homepage").make_absolute_url("https://whatever") == "https://www.example.org/" @@ -513,7 +517,7 @@ def test_host_reverse_urls(): ) -async def subdomain_app(scope, receive, send): +async def subdomain_app(scope: Scope, receive: Receive, send: Send) -> None: response = JSONResponse({"subdomain": scope["path_params"]["subdomain"]}) await response(scope, receive, send) @@ -523,7 +527,7 @@ subdomain_router = Router( ) -def test_subdomain_routing(test_client_factory): +def test_subdomain_routing(test_client_factory: TestClientFactory) -> None: client = test_client_factory(subdomain_router, base_url="https://foo.example.org/") response = client.get("/") @@ -531,7 +535,7 @@ def test_subdomain_routing(test_client_factory): assert response.json() == {"subdomain": "foo"} -def test_subdomain_reverse_urls(): +def test_subdomain_reverse_urls() -> None: assert ( subdomain_router.url_path_for( "subdomains", subdomain="foo", path="/homepage" @@ -540,7 +544,7 @@ def test_subdomain_reverse_urls(): ) -async def echo_urls(request): +async def echo_urls(request: Request) -> JSONResponse: return JSONResponse( { "index": str(request.url_for("index")), @@ -559,7 +563,7 @@ echo_url_routes = [ ] -def test_url_for_with_root_path(test_client_factory): +def test_url_for_with_root_path(test_client_factory: TestClientFactory) -> None: app = Starlette(routes=echo_url_routes) client = test_client_factory( app, base_url="https://www.example.org/", root_path="/sub_path" @@ -576,7 +580,7 @@ def test_url_for_with_root_path(test_client_factory): } -async def stub_app(scope, receive, send): +async def stub_app(scope: Scope, receive: Receive, send: Send) -> None: pass # pragma: no cover @@ -585,13 +589,15 @@ double_mount_routes = [ ] -def test_url_for_with_double_mount(): +def test_url_for_with_double_mount() -> None: app = Starlette(routes=double_mount_routes) url = app.url_path_for("mount:static", path="123") assert url == "/mount/static/123" -def test_standalone_route_matches(test_client_factory): +def test_standalone_route_matches( + test_client_factory: TestClientFactory, +) -> None: app = Route("/", PlainTextResponse("Hello, World!")) client = test_client_factory(app) response = client.get("/") @@ -599,7 +605,9 @@ def test_standalone_route_matches(test_client_factory): assert response.text == "Hello, World!" -def test_standalone_route_does_not_match(test_client_factory): +def test_standalone_route_does_not_match( + test_client_factory: typing.Callable[..., TestClient], +) -> None: app = Route("/", PlainTextResponse("Hello, World!")) client = test_client_factory(app) response = client.get("/invalid") @@ -607,13 +615,15 @@ def test_standalone_route_does_not_match(test_client_factory): assert response.text == "Not Found" -async def ws_helloworld(websocket): +async def ws_helloworld(websocket: WebSocket) -> None: await websocket.accept() await websocket.send_text("Hello, world!") await websocket.close() -def test_standalone_ws_route_matches(test_client_factory): +def test_standalone_ws_route_matches( + test_client_factory: TestClientFactory, +) -> None: app = WebSocketRoute("/", ws_helloworld) client = test_client_factory(app) with client.websocket_connect("/") as websocket: @@ -621,7 +631,9 @@ def test_standalone_ws_route_matches(test_client_factory): assert text == "Hello, world!" -def test_standalone_ws_route_does_not_match(test_client_factory): +def test_standalone_ws_route_does_not_match( + test_client_factory: TestClientFactory, +) -> None: app = WebSocketRoute("/", ws_helloworld) client = test_client_factory(app) with pytest.raises(WebSocketDisconnect): @@ -629,18 +641,18 @@ def test_standalone_ws_route_does_not_match(test_client_factory): pass # pragma: nocover -def test_lifespan_async(test_client_factory): +def test_lifespan_async(test_client_factory: TestClientFactory) -> None: startup_complete = False shutdown_complete = False - async def hello_world(request): + async def hello_world(request: Request) -> PlainTextResponse: return PlainTextResponse("hello, world") - async def run_startup(): + async def run_startup() -> None: nonlocal startup_complete startup_complete = True - async def run_shutdown(): + async def run_shutdown() -> None: nonlocal shutdown_complete shutdown_complete = True @@ -663,24 +675,24 @@ def test_lifespan_async(test_client_factory): assert shutdown_complete -def test_lifespan_with_on_events(test_client_factory: typing.Callable[..., TestClient]): +def test_lifespan_with_on_events(test_client_factory: TestClientFactory) -> None: lifespan_called = False startup_called = False shutdown_called = False @contextlib.asynccontextmanager - async def lifespan(app: Starlette): + async def lifespan(app: Starlette) -> typing.AsyncGenerator[None, None]: nonlocal lifespan_called lifespan_called = True yield # We do not expected, neither of run_startup nor run_shutdown to be called # we thus mark them as #pragma: no cover, to fulfill test coverage - def run_startup(): # pragma: no cover + def run_startup() -> None: # pragma: no cover nonlocal startup_called startup_called = True - def run_shutdown(): # pragma: no cover + def run_shutdown() -> None: # pragma: no cover nonlocal shutdown_called shutdown_called = True @@ -710,18 +722,18 @@ def test_lifespan_with_on_events(test_client_factory: typing.Callable[..., TestC assert not shutdown_called -def test_lifespan_sync(test_client_factory): +def test_lifespan_sync(test_client_factory: TestClientFactory) -> None: startup_complete = False shutdown_complete = False - def hello_world(request): + def hello_world(request: Request) -> PlainTextResponse: return PlainTextResponse("hello, world") - def run_startup(): + def run_startup() -> None: nonlocal startup_complete startup_complete = True - def run_shutdown(): + def run_shutdown() -> None: nonlocal shutdown_complete shutdown_complete = True @@ -744,9 +756,13 @@ def test_lifespan_sync(test_client_factory): assert shutdown_complete -def test_lifespan_state_unsupported(test_client_factory): +def test_lifespan_state_unsupported( + test_client_factory: TestClientFactory, +) -> None: @contextlib.asynccontextmanager - async def lifespan(app): + async def lifespan( + app: ASGIApp, + ) -> typing.AsyncGenerator[typing.Dict[str, str], None]: yield {"foo": "bar"} app = Router( @@ -754,7 +770,7 @@ def test_lifespan_state_unsupported(test_client_factory): routes=[Mount("/", PlainTextResponse("hello, world"))], ) - async def no_state_wrapper(scope, receive, send): + async def no_state_wrapper(scope: Scope, receive: Receive, send: Send) -> None: del scope["state"] await app(scope, receive, send) @@ -765,7 +781,7 @@ def test_lifespan_state_unsupported(test_client_factory): raise AssertionError("Should not be called") # pragma: no cover -def test_lifespan_state_async_cm(test_client_factory): +def test_lifespan_state_async_cm(test_client_factory: TestClientFactory) -> None: startup_complete = False shutdown_complete = False @@ -813,8 +829,8 @@ def test_lifespan_state_async_cm(test_client_factory): assert shutdown_complete -def test_raise_on_startup(test_client_factory): - def run_startup(): +def test_raise_on_startup(test_client_factory: TestClientFactory) -> None: + def run_startup() -> None: raise RuntimeError() with pytest.deprecated_call( @@ -823,8 +839,8 @@ def test_raise_on_startup(test_client_factory): router = Router(on_startup=[run_startup]) startup_failed = False - async def app(scope, receive, send): - async def _send(message): + async def app(scope: Scope, receive: Receive, send: Send) -> None: + async def _send(message: Message) -> None: nonlocal startup_failed if message["type"] == "lifespan.startup.failed": startup_failed = True @@ -838,8 +854,8 @@ def test_raise_on_startup(test_client_factory): assert startup_failed -def test_raise_on_shutdown(test_client_factory): - def run_shutdown(): +def test_raise_on_shutdown(test_client_factory: TestClientFactory) -> None: + def run_shutdown() -> None: raise RuntimeError() with pytest.deprecated_call( @@ -852,7 +868,7 @@ def test_raise_on_shutdown(test_client_factory): pass # pragma: nocover -def test_partial_async_endpoint(test_client_factory): +def test_partial_async_endpoint(test_client_factory: TestClientFactory) -> None: test_client = test_client_factory(app) response = test_client.get("/partial") assert response.status_code == 200 @@ -863,7 +879,9 @@ def test_partial_async_endpoint(test_client_factory): assert cls_method_response.json() == {"arg": "foo"} -def test_partial_async_ws_endpoint(test_client_factory): +def test_partial_async_ws_endpoint( + test_client_factory: TestClientFactory, +) -> None: test_client = test_client_factory(app) with test_client.websocket_connect("/partial/ws") as websocket: data = websocket.receive_json() @@ -874,7 +892,7 @@ def test_partial_async_ws_endpoint(test_client_factory): assert data == {"url": "ws://testserver/partial/ws/cls"} -def test_duplicated_param_names(): +def test_duplicated_param_names() -> None: with pytest.raises( ValueError, match="Duplicated param name id at path /{id}/{id}", @@ -889,18 +907,18 @@ def test_duplicated_param_names(): class Endpoint: - async def my_method(self, request): + async def my_method(self, request: Request) -> None: ... # pragma: no cover @classmethod - async def my_classmethod(cls, request): + async def my_classmethod(cls, request: Request) -> None: ... # pragma: no cover @staticmethod - async def my_staticmethod(request): + async def my_staticmethod(request: Request) -> None: ... # pragma: no cover - def __call__(self, request): + def __call__(self, request: Request) -> None: ... # pragma: no cover @@ -919,7 +937,9 @@ class Endpoint: pytest.param(lambda request: ..., "", id="lambda"), ], ) -def test_route_name(endpoint: typing.Callable[..., typing.Any], expected_name: str): +def test_route_name( + endpoint: typing.Callable[..., Response], expected_name: str +) -> None: assert Route(path="/", endpoint=endpoint).name == expected_name @@ -1000,7 +1020,7 @@ mounted_app_with_middleware = Starlette( ], ) def test_base_route_middleware( - test_client_factory: typing.Callable[..., TestClient], + test_client_factory: TestClientFactory, app: Starlette, ) -> None: test_client = test_client_factory(app) @@ -1045,8 +1065,10 @@ def test_add_route_to_app_after_mount( assert response.status_code == 200 -def test_exception_on_mounted_apps(test_client_factory): - def exc(request): +def test_exception_on_mounted_apps( + test_client_factory: TestClientFactory, +) -> None: + def exc(request: Request) -> None: raise Exception("Exc") sub_app = Starlette(routes=[Route("/", exc)]) @@ -1114,9 +1136,9 @@ def test_mounted_middleware_does_not_catch_exception( def test_websocket_route_middleware( - test_client_factory: typing.Callable[..., TestClient], -): - async def websocket_endpoint(session: WebSocket): + test_client_factory: TestClientFactory, +) -> None: + async def websocket_endpoint(session: WebSocket) -> None: await session.accept() await session.send_text("Hello, world!") await session.close() @@ -1236,7 +1258,7 @@ def test_decorator_deprecations() -> None: router.on_event("startup")(startup) -async def echo_paths(request: Request, name: str): +async def echo_paths(request: Request, name: str) -> JSONResponse: return JSONResponse( { "name": name, @@ -1246,7 +1268,9 @@ async def echo_paths(request: Request, name: str): ) -async def pure_asgi_echo_paths(scope: Scope, receive: Receive, send: Send, name: str): +async def pure_asgi_echo_paths( + scope: Scope, receive: Receive, send: Send, name: str +) -> None: data = {"name": name, "path": scope["path"], "root_path": scope["root_path"]} content = json.dumps(data).encode("utf-8") await send( @@ -1282,7 +1306,7 @@ echo_paths_routes = [ ] -def test_paths_with_root_path(test_client_factory: typing.Callable[..., TestClient]): +def test_paths_with_root_path(test_client_factory: TestClientFactory) -> None: app = Starlette(routes=echo_paths_routes) client = test_client_factory( app, base_url="https://www.example.org/", root_path="/root"