]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add type hints to `test_routing.py` (#2489)
authorScirlat Danut <danut.scirlat@gmail.com>
Wed, 7 Feb 2024 20:17:51 +0000 (22:17 +0200)
committerGitHub <noreply@github.com>
Wed, 7 Feb 2024 20:17:51 +0000 (20:17 +0000)
* Add type hints to test_routing.py

* Apply suggestions from code review

* Fix check errors

* Apply suggestions from code review

---------

Co-authored-by: Scirlat Danut <scirlatdanut@scirlats-mini.lan>
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
tests/test_routing.py

index dd7083e818e67a623b38f3d69a11dbb9d4dec4cb..8c3f16639a9e3ebd427cd752ad06881361a4acb1 100644 (file)
@@ -16,40 +16,42 @@ from starlette.testclient import TestClient
 from starlette.types import ASGIApp, Message, Receive, Scope, Send
 from starlette.websockets import WebSocket, WebSocketDisconnect
 
+TestClientFactory = typing.Callable[..., TestClient]
 
-def homepage(request):
+
+def homepage(request: Request) -> Response:
     return Response("Hello, world", media_type="text/plain")
 
 
-def users(request):
+def users(request: Request) -> Response:
     return Response("All users", media_type="text/plain")
 
 
-def user(request):
+def user(request: Request) -> Response:
     content = "User " + request.path_params["username"]
     return Response(content, media_type="text/plain")
 
 
-def user_me(request):
+def user_me(request: Request) -> Response:
     content = "User fixed me"
     return Response(content, media_type="text/plain")
 
 
-def disable_user(request):
+def disable_user(request: Request) -> Response:
     content = "User " + request.path_params["username"] + " disabled"
     return Response(content, media_type="text/plain")
 
 
-def user_no_match(request):  # pragma: no cover
+def user_no_match(request: Request) -> Response:  # pragma: no cover
     content = "User fixed no match"
     return Response(content, media_type="text/plain")
 
 
-async def partial_endpoint(arg, request):
+async def partial_endpoint(arg: str, request: Request) -> JSONResponse:
     return JSONResponse({"arg": arg})
 
 
-async def partial_ws_endpoint(websocket: WebSocket):
+async def partial_ws_endpoint(websocket: WebSocket) -> None:
     await websocket.accept()
     await websocket.send_json({"url": str(websocket.url)})
     await websocket.close()
@@ -57,56 +59,56 @@ async def partial_ws_endpoint(websocket: WebSocket):
 
 class PartialRoutes:
     @classmethod
-    async def async_endpoint(cls, arg, request):
+    async def async_endpoint(cls, arg: str, request: Request) -> JSONResponse:
         return JSONResponse({"arg": arg})
 
     @classmethod
-    async def async_ws_endpoint(cls, websocket: WebSocket):
+    async def async_ws_endpoint(cls, websocket: WebSocket) -> None:
         await websocket.accept()
         await websocket.send_json({"url": str(websocket.url)})
         await websocket.close()
 
 
-def func_homepage(request):
+def func_homepage(request: Request) -> Response:
     return Response("Hello, world!", media_type="text/plain")
 
 
-def contact(request):
+def contact(request: Request) -> Response:
     return Response("Hello, POST!", media_type="text/plain")
 
 
-def int_convertor(request):
+def int_convertor(request: Request) -> JSONResponse:
     number = request.path_params["param"]
     return JSONResponse({"int": number})
 
 
-def float_convertor(request):
+def float_convertor(request: Request) -> JSONResponse:
     num = request.path_params["param"]
     return JSONResponse({"float": num})
 
 
-def path_convertor(request):
+def path_convertor(request: Request) -> JSONResponse:
     path = request.path_params["param"]
     return JSONResponse({"path": path})
 
 
-def uuid_converter(request):
+def uuid_converter(request: Request) -> JSONResponse:
     uuid_param = request.path_params["param"]
     return JSONResponse({"uuid": str(uuid_param)})
 
 
-def path_with_parentheses(request):
+def path_with_parentheses(request: Request) -> JSONResponse:
     number = request.path_params["param"]
     return JSONResponse({"int": number})
 
 
-async def websocket_endpoint(session: WebSocket):
+async def websocket_endpoint(session: WebSocket) -> None:
     await session.accept()
     await session.send_text("Hello, world!")
     await session.close()
 
 
-async def websocket_params(session: WebSocket):
+async def websocket_params(session: WebSocket) -> None:
     await session.accept()
     await session.send_text(f"Hello, {session.path_params['room']}!")
     await session.close()
@@ -160,7 +162,9 @@ app = Router(
 
 
 @pytest.fixture
-def client(test_client_factory: typing.Callable[..., TestClient]):
+def client(
+    test_client_factory: TestClientFactory,
+) -> typing.Generator[TestClient, None, None]:
     with test_client_factory(app) as client:
         yield client
 
@@ -171,7 +175,7 @@ def client(test_client_factory: typing.Callable[..., TestClient]):
     r":UserWarning"
     r":charset_normalizer.api"
 )
-def test_router(client: TestClient):
+def test_router(client: TestClient) -> None:
     response = client.get("/")
     assert response.status_code == 200
     assert response.text == "Hello, world"
@@ -216,7 +220,7 @@ def test_router(client: TestClient):
     assert response.text == "xxxxx"
 
 
-def test_route_converters(client):
+def test_route_converters(client: TestClient) -> None:
     # Test integer conversion
     response = client.get("/int/5")
     assert response.status_code == 200
@@ -258,7 +262,7 @@ def test_route_converters(client):
     )
 
 
-def test_url_path_for():
+def test_url_path_for() -> None:
     assert app.url_path_for("homepage") == "/"
     assert app.url_path_for("user", username="tomchristie") == "/users/tomchristie"
     assert app.url_path_for("websocket_endpoint") == "/ws"
@@ -276,7 +280,7 @@ def test_url_path_for():
         app.url_path_for("user", username="")
 
 
-def test_url_for():
+def test_url_for() -> None:
     assert (
         app.url_path_for("homepage").make_absolute_url(base_url="https://example.org")
         == "https://example.org/"
@@ -307,19 +311,19 @@ def test_url_for():
     )
 
 
-def test_router_add_route(client):
+def test_router_add_route(client: TestClient) -> None:
     response = client.get("/func")
     assert response.status_code == 200
     assert response.text == "Hello, world!"
 
 
-def test_router_duplicate_path(client):
+def test_router_duplicate_path(client: TestClient) -> None:
     response = client.post("/func")
     assert response.status_code == 200
     assert response.text == "Hello, POST!"
 
 
-def test_router_add_websocket_route(client):
+def test_router_add_websocket_route(client: TestClient) -> None:
     with client.websocket_connect("/ws") as session:
         text = session.receive_text()
         assert text == "Hello, world!"
@@ -329,12 +333,12 @@ def test_router_add_websocket_route(client):
         assert text == "Hello, test!"
 
 
-def test_router_middleware(test_client_factory: typing.Callable[..., TestClient]):
+def test_router_middleware(test_client_factory: TestClientFactory) -> None:
     class CustomMiddleware:
         def __init__(self, app: ASGIApp) -> None:
             self.app = app
 
-        async def __call__(self, scope: Scope, receive: Receive, send: Send):
+        async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
             response = PlainTextResponse("OK")
             await response(scope, receive, send)
 
@@ -349,13 +353,13 @@ def test_router_middleware(test_client_factory: typing.Callable[..., TestClient]
     assert response.text == "OK"
 
 
-def http_endpoint(request):
+def http_endpoint(request: Request) -> Response:
     url = request.url_for("http_endpoint")
     return Response(f"URL: {url}", media_type="text/plain")
 
 
 class WebSocketEndpoint:
-    async def __call__(self, scope, receive, send):
+    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope=scope, receive=receive, send=send)
         await websocket.accept()
         await websocket.send_json({"URL": str(websocket.url_for("websocket_endpoint"))})
@@ -370,7 +374,7 @@ mixed_protocol_app = Router(
 )
 
 
-def test_protocol_switch(test_client_factory):
+def test_protocol_switch(test_client_factory: TestClientFactory) -> None:
     client = test_client_factory(mixed_protocol_app)
 
     response = client.get("/")
@@ -388,7 +392,7 @@ def test_protocol_switch(test_client_factory):
 ok = PlainTextResponse("OK")
 
 
-def test_mount_urls(test_client_factory):
+def test_mount_urls(test_client_factory: TestClientFactory) -> None:
     mounted = Router([Mount("/users", ok, name="users")])
     client = test_client_factory(mounted)
     assert client.get("/users").status_code == 200
@@ -398,7 +402,7 @@ def test_mount_urls(test_client_factory):
     assert client.get("/usersa").status_code == 404
 
 
-def test_reverse_mount_urls():
+def test_reverse_mount_urls() -> None:
     mounted = Router([Mount("/users", ok, name="users")])
     assert mounted.url_path_for("users", path="/a") == "/users/a"
 
@@ -413,13 +417,13 @@ def test_reverse_mount_urls():
     )
 
 
-def test_mount_at_root(test_client_factory):
+def test_mount_at_root(test_client_factory: TestClientFactory) -> None:
     mounted = Router([Mount("/", ok, name="users")])
     client = test_client_factory(mounted)
     assert client.get("/").status_code == 200
 
 
-def users_api(request):
+def users_api(request: Request) -> JSONResponse:
     return JSONResponse({"users": [{"username": "tom"}]})
 
 
@@ -448,7 +452,7 @@ mixed_hosts_app = Router(
 )
 
 
-def test_host_routing(test_client_factory):
+def test_host_routing(test_client_factory: TestClientFactory) -> None:
     client = test_client_factory(mixed_hosts_app, base_url="https://api.example.org/")
 
     response = client.get("/users")
@@ -492,7 +496,7 @@ def test_host_routing(test_client_factory):
     assert response.status_code == 200
 
 
-def test_host_reverse_urls():
+def test_host_reverse_urls() -> None:
     assert (
         mixed_hosts_app.url_path_for("homepage").make_absolute_url("https://whatever")
         == "https://www.example.org/"
@@ -513,7 +517,7 @@ def test_host_reverse_urls():
     )
 
 
-async def subdomain_app(scope, receive, send):
+async def subdomain_app(scope: Scope, receive: Receive, send: Send) -> None:
     response = JSONResponse({"subdomain": scope["path_params"]["subdomain"]})
     await response(scope, receive, send)
 
@@ -523,7 +527,7 @@ subdomain_router = Router(
 )
 
 
-def test_subdomain_routing(test_client_factory):
+def test_subdomain_routing(test_client_factory: TestClientFactory) -> None:
     client = test_client_factory(subdomain_router, base_url="https://foo.example.org/")
 
     response = client.get("/")
@@ -531,7 +535,7 @@ def test_subdomain_routing(test_client_factory):
     assert response.json() == {"subdomain": "foo"}
 
 
-def test_subdomain_reverse_urls():
+def test_subdomain_reverse_urls() -> None:
     assert (
         subdomain_router.url_path_for(
             "subdomains", subdomain="foo", path="/homepage"
@@ -540,7 +544,7 @@ def test_subdomain_reverse_urls():
     )
 
 
-async def echo_urls(request):
+async def echo_urls(request: Request) -> JSONResponse:
     return JSONResponse(
         {
             "index": str(request.url_for("index")),
@@ -559,7 +563,7 @@ echo_url_routes = [
 ]
 
 
-def test_url_for_with_root_path(test_client_factory):
+def test_url_for_with_root_path(test_client_factory: TestClientFactory) -> None:
     app = Starlette(routes=echo_url_routes)
     client = test_client_factory(
         app, base_url="https://www.example.org/", root_path="/sub_path"
@@ -576,7 +580,7 @@ def test_url_for_with_root_path(test_client_factory):
     }
 
 
-async def stub_app(scope, receive, send):
+async def stub_app(scope: Scope, receive: Receive, send: Send) -> None:
     pass  # pragma: no cover
 
 
@@ -585,13 +589,15 @@ double_mount_routes = [
 ]
 
 
-def test_url_for_with_double_mount():
+def test_url_for_with_double_mount() -> None:
     app = Starlette(routes=double_mount_routes)
     url = app.url_path_for("mount:static", path="123")
     assert url == "/mount/static/123"
 
 
-def test_standalone_route_matches(test_client_factory):
+def test_standalone_route_matches(
+    test_client_factory: TestClientFactory,
+) -> None:
     app = Route("/", PlainTextResponse("Hello, World!"))
     client = test_client_factory(app)
     response = client.get("/")
@@ -599,7 +605,9 @@ def test_standalone_route_matches(test_client_factory):
     assert response.text == "Hello, World!"
 
 
-def test_standalone_route_does_not_match(test_client_factory):
+def test_standalone_route_does_not_match(
+    test_client_factory: typing.Callable[..., TestClient],
+) -> None:
     app = Route("/", PlainTextResponse("Hello, World!"))
     client = test_client_factory(app)
     response = client.get("/invalid")
@@ -607,13 +615,15 @@ def test_standalone_route_does_not_match(test_client_factory):
     assert response.text == "Not Found"
 
 
-async def ws_helloworld(websocket):
+async def ws_helloworld(websocket: WebSocket) -> None:
     await websocket.accept()
     await websocket.send_text("Hello, world!")
     await websocket.close()
 
 
-def test_standalone_ws_route_matches(test_client_factory):
+def test_standalone_ws_route_matches(
+    test_client_factory: TestClientFactory,
+) -> None:
     app = WebSocketRoute("/", ws_helloworld)
     client = test_client_factory(app)
     with client.websocket_connect("/") as websocket:
@@ -621,7 +631,9 @@ def test_standalone_ws_route_matches(test_client_factory):
         assert text == "Hello, world!"
 
 
-def test_standalone_ws_route_does_not_match(test_client_factory):
+def test_standalone_ws_route_does_not_match(
+    test_client_factory: TestClientFactory,
+) -> None:
     app = WebSocketRoute("/", ws_helloworld)
     client = test_client_factory(app)
     with pytest.raises(WebSocketDisconnect):
@@ -629,18 +641,18 @@ def test_standalone_ws_route_does_not_match(test_client_factory):
             pass  # pragma: nocover
 
 
-def test_lifespan_async(test_client_factory):
+def test_lifespan_async(test_client_factory: TestClientFactory) -> None:
     startup_complete = False
     shutdown_complete = False
 
-    async def hello_world(request):
+    async def hello_world(request: Request) -> PlainTextResponse:
         return PlainTextResponse("hello, world")
 
-    async def run_startup():
+    async def run_startup() -> None:
         nonlocal startup_complete
         startup_complete = True
 
-    async def run_shutdown():
+    async def run_shutdown() -> None:
         nonlocal shutdown_complete
         shutdown_complete = True
 
@@ -663,24 +675,24 @@ def test_lifespan_async(test_client_factory):
     assert shutdown_complete
 
 
-def test_lifespan_with_on_events(test_client_factory: typing.Callable[..., TestClient]):
+def test_lifespan_with_on_events(test_client_factory: TestClientFactory) -> None:
     lifespan_called = False
     startup_called = False
     shutdown_called = False
 
     @contextlib.asynccontextmanager
-    async def lifespan(app: Starlette):
+    async def lifespan(app: Starlette) -> typing.AsyncGenerator[None, None]:
         nonlocal lifespan_called
         lifespan_called = True
         yield
 
     # We do not expected, neither of run_startup nor run_shutdown to be called
     # we thus mark them as #pragma: no cover, to fulfill test coverage
-    def run_startup():  # pragma: no cover
+    def run_startup() -> None:  # pragma: no cover
         nonlocal startup_called
         startup_called = True
 
-    def run_shutdown():  # pragma: no cover
+    def run_shutdown() -> None:  # pragma: no cover
         nonlocal shutdown_called
         shutdown_called = True
 
@@ -710,18 +722,18 @@ def test_lifespan_with_on_events(test_client_factory: typing.Callable[..., TestC
             assert not shutdown_called
 
 
-def test_lifespan_sync(test_client_factory):
+def test_lifespan_sync(test_client_factory: TestClientFactory) -> None:
     startup_complete = False
     shutdown_complete = False
 
-    def hello_world(request):
+    def hello_world(request: Request) -> PlainTextResponse:
         return PlainTextResponse("hello, world")
 
-    def run_startup():
+    def run_startup() -> None:
         nonlocal startup_complete
         startup_complete = True
 
-    def run_shutdown():
+    def run_shutdown() -> None:
         nonlocal shutdown_complete
         shutdown_complete = True
 
@@ -744,9 +756,13 @@ def test_lifespan_sync(test_client_factory):
     assert shutdown_complete
 
 
-def test_lifespan_state_unsupported(test_client_factory):
+def test_lifespan_state_unsupported(
+    test_client_factory: TestClientFactory,
+) -> None:
     @contextlib.asynccontextmanager
-    async def lifespan(app):
+    async def lifespan(
+        app: ASGIApp,
+    ) -> typing.AsyncGenerator[typing.Dict[str, str], None]:
         yield {"foo": "bar"}
 
     app = Router(
@@ -754,7 +770,7 @@ def test_lifespan_state_unsupported(test_client_factory):
         routes=[Mount("/", PlainTextResponse("hello, world"))],
     )
 
-    async def no_state_wrapper(scope, receive, send):
+    async def no_state_wrapper(scope: Scope, receive: Receive, send: Send) -> None:
         del scope["state"]
         await app(scope, receive, send)
 
@@ -765,7 +781,7 @@ def test_lifespan_state_unsupported(test_client_factory):
             raise AssertionError("Should not be called")  # pragma: no cover
 
 
-def test_lifespan_state_async_cm(test_client_factory):
+def test_lifespan_state_async_cm(test_client_factory: TestClientFactory) -> None:
     startup_complete = False
     shutdown_complete = False
 
@@ -813,8 +829,8 @@ def test_lifespan_state_async_cm(test_client_factory):
     assert shutdown_complete
 
 
-def test_raise_on_startup(test_client_factory):
-    def run_startup():
+def test_raise_on_startup(test_client_factory: TestClientFactory) -> None:
+    def run_startup() -> None:
         raise RuntimeError()
 
     with pytest.deprecated_call(
@@ -823,8 +839,8 @@ def test_raise_on_startup(test_client_factory):
         router = Router(on_startup=[run_startup])
     startup_failed = False
 
-    async def app(scope, receive, send):
-        async def _send(message):
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        async def _send(message: Message) -> None:
             nonlocal startup_failed
             if message["type"] == "lifespan.startup.failed":
                 startup_failed = True
@@ -838,8 +854,8 @@ def test_raise_on_startup(test_client_factory):
     assert startup_failed
 
 
-def test_raise_on_shutdown(test_client_factory):
-    def run_shutdown():
+def test_raise_on_shutdown(test_client_factory: TestClientFactory) -> None:
+    def run_shutdown() -> None:
         raise RuntimeError()
 
     with pytest.deprecated_call(
@@ -852,7 +868,7 @@ def test_raise_on_shutdown(test_client_factory):
             pass  # pragma: nocover
 
 
-def test_partial_async_endpoint(test_client_factory):
+def test_partial_async_endpoint(test_client_factory: TestClientFactory) -> None:
     test_client = test_client_factory(app)
     response = test_client.get("/partial")
     assert response.status_code == 200
@@ -863,7 +879,9 @@ def test_partial_async_endpoint(test_client_factory):
     assert cls_method_response.json() == {"arg": "foo"}
 
 
-def test_partial_async_ws_endpoint(test_client_factory):
+def test_partial_async_ws_endpoint(
+    test_client_factory: TestClientFactory,
+) -> None:
     test_client = test_client_factory(app)
     with test_client.websocket_connect("/partial/ws") as websocket:
         data = websocket.receive_json()
@@ -874,7 +892,7 @@ def test_partial_async_ws_endpoint(test_client_factory):
         assert data == {"url": "ws://testserver/partial/ws/cls"}
 
 
-def test_duplicated_param_names():
+def test_duplicated_param_names() -> None:
     with pytest.raises(
         ValueError,
         match="Duplicated param name id at path /{id}/{id}",
@@ -889,18 +907,18 @@ def test_duplicated_param_names():
 
 
 class Endpoint:
-    async def my_method(self, request):
+    async def my_method(self, request: Request) -> None:
         ...  # pragma: no cover
 
     @classmethod
-    async def my_classmethod(cls, request):
+    async def my_classmethod(cls, request: Request) -> None:
         ...  # pragma: no cover
 
     @staticmethod
-    async def my_staticmethod(request):
+    async def my_staticmethod(request: Request) -> None:
         ...  # pragma: no cover
 
-    def __call__(self, request):
+    def __call__(self, request: Request) -> None:
         ...  # pragma: no cover
 
 
@@ -919,7 +937,9 @@ class Endpoint:
         pytest.param(lambda request: ..., "<lambda>", id="lambda"),
     ],
 )
-def test_route_name(endpoint: typing.Callable[..., typing.Any], expected_name: str):
+def test_route_name(
+    endpoint: typing.Callable[..., Response], expected_name: str
+) -> None:
     assert Route(path="/", endpoint=endpoint).name == expected_name
 
 
@@ -1000,7 +1020,7 @@ mounted_app_with_middleware = Starlette(
     ],
 )
 def test_base_route_middleware(
-    test_client_factory: typing.Callable[..., TestClient],
+    test_client_factory: TestClientFactory,
     app: Starlette,
 ) -> None:
     test_client = test_client_factory(app)
@@ -1045,8 +1065,10 @@ def test_add_route_to_app_after_mount(
     assert response.status_code == 200
 
 
-def test_exception_on_mounted_apps(test_client_factory):
-    def exc(request):
+def test_exception_on_mounted_apps(
+    test_client_factory: TestClientFactory,
+) -> None:
+    def exc(request: Request) -> None:
         raise Exception("Exc")
 
     sub_app = Starlette(routes=[Route("/", exc)])
@@ -1114,9 +1136,9 @@ def test_mounted_middleware_does_not_catch_exception(
 
 
 def test_websocket_route_middleware(
-    test_client_factory: typing.Callable[..., TestClient],
-):
-    async def websocket_endpoint(session: WebSocket):
+    test_client_factory: TestClientFactory,
+) -> None:
+    async def websocket_endpoint(session: WebSocket) -> None:
         await session.accept()
         await session.send_text("Hello, world!")
         await session.close()
@@ -1236,7 +1258,7 @@ def test_decorator_deprecations() -> None:
         router.on_event("startup")(startup)
 
 
-async def echo_paths(request: Request, name: str):
+async def echo_paths(request: Request, name: str) -> JSONResponse:
     return JSONResponse(
         {
             "name": name,
@@ -1246,7 +1268,9 @@ async def echo_paths(request: Request, name: str):
     )
 
 
-async def pure_asgi_echo_paths(scope: Scope, receive: Receive, send: Send, name: str):
+async def pure_asgi_echo_paths(
+    scope: Scope, receive: Receive, send: Send, name: str
+) -> None:
     data = {"name": name, "path": scope["path"], "root_path": scope["root_path"]}
     content = json.dumps(data).encode("utf-8")
     await send(
@@ -1282,7 +1306,7 @@ echo_paths_routes = [
 ]
 
 
-def test_paths_with_root_path(test_client_factory: typing.Callable[..., TestClient]):
+def test_paths_with_root_path(test_client_factory: TestClientFactory) -> None:
     app = Starlette(routes=echo_paths_routes)
     client = test_client_factory(
         app, base_url="https://www.example.org/", root_path="/root"