From: Scirlat Danut Date: Tue, 6 Feb 2024 20:47:45 +0000 (+0200) Subject: Add type hints to `test_authentication.py` (#2472) X-Git-Tag: 0.37.1~11 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=0e4da0aea238f4a75b189c0b7f8dd2f7d310db71;p=thirdparty%2Fstarlette.git Add type hints to `test_authentication.py` (#2472) * added type annotations to test_authentication.py * fixed types * Apply suggestions from code review * Fix linting * Fix linting * Apply suggestions from code review --------- Co-authored-by: Scirlat Danut Co-authored-by: Marcelo Trylesinski --- diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 150482a1..27b03376 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -1,5 +1,6 @@ import base64 import binascii +from typing import Any, Awaitable, Callable, Optional, Tuple from urllib.parse import urlencode import pytest @@ -15,14 +16,22 @@ from starlette.authentication import ( from starlette.endpoints import HTTPEndpoint from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware -from starlette.requests import HTTPConnection -from starlette.responses import JSONResponse +from starlette.requests import HTTPConnection, Request +from starlette.responses import JSONResponse, Response from starlette.routing import Route, WebSocketRoute -from starlette.websockets import WebSocketDisconnect +from starlette.testclient import TestClient +from starlette.websockets import WebSocket, WebSocketDisconnect + +TestClientFactory = Callable[..., TestClient] +AsyncEndpoint = Callable[..., Awaitable[Response]] +SyncEndpoint = Callable[..., Response] class BasicAuth(AuthenticationBackend): - async def authenticate(self, request): + async def authenticate( + self, + request: HTTPConnection, + ) -> Optional[Tuple[AuthCredentials, SimpleUser]]: if "Authorization" not in request.headers: return None @@ -37,7 +46,7 @@ class BasicAuth(AuthenticationBackend): return AuthCredentials(["authenticated"]), SimpleUser(username) -def homepage(request): +def homepage(request: Request) -> JSONResponse: return JSONResponse( { "authenticated": request.user.is_authenticated, @@ -47,7 +56,7 @@ def homepage(request): @requires("authenticated") -async def dashboard(request): +async def dashboard(request: Request) -> JSONResponse: return JSONResponse( { "authenticated": request.user.is_authenticated, @@ -57,7 +66,7 @@ async def dashboard(request): @requires("authenticated", redirect="homepage") -async def admin(request): +async def admin(request: Request) -> JSONResponse: return JSONResponse( { "authenticated": request.user.is_authenticated, @@ -67,7 +76,7 @@ async def admin(request): @requires("authenticated") -def dashboard_sync(request): +def dashboard_sync(request: Request) -> JSONResponse: return JSONResponse( { "authenticated": request.user.is_authenticated, @@ -78,7 +87,7 @@ def dashboard_sync(request): class Dashboard(HTTPEndpoint): @requires("authenticated") - def get(self, request): + def get(self, request: Request) -> JSONResponse: return JSONResponse( { "authenticated": request.user.is_authenticated, @@ -88,7 +97,7 @@ class Dashboard(HTTPEndpoint): @requires("authenticated", redirect="homepage") -def admin_sync(request): +def admin_sync(request: Request) -> JSONResponse: return JSONResponse( { "authenticated": request.user.is_authenticated, @@ -98,7 +107,7 @@ def admin_sync(request): @requires("authenticated") -async def websocket_endpoint(websocket): +async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept() await websocket.send_json( { @@ -108,9 +117,11 @@ async def websocket_endpoint(websocket): ) -def async_inject_decorator(**kwargs): - def wrapper(endpoint): - async def app(request): +def async_inject_decorator( + **kwargs: Any, +) -> Callable[[AsyncEndpoint], Callable[..., Awaitable[Response]]]: + def wrapper(endpoint: AsyncEndpoint) -> Callable[..., Awaitable[Response]]: + async def app(request: Request) -> Response: return await endpoint(request=request, **kwargs) return app @@ -120,7 +131,7 @@ def async_inject_decorator(**kwargs): @async_inject_decorator(additional="payload") @requires("authenticated") -async def decorated_async(request, additional): +async def decorated_async(request: Request, additional: str) -> JSONResponse: return JSONResponse( { "authenticated": request.user.is_authenticated, @@ -130,9 +141,11 @@ async def decorated_async(request, additional): ) -def sync_inject_decorator(**kwargs): - def wrapper(endpoint): - def app(request): +def sync_inject_decorator( + **kwargs: Any, +) -> Callable[[SyncEndpoint], Callable[..., Response]]: + def wrapper(endpoint: SyncEndpoint) -> Callable[..., Response]: + def app(request: Request) -> Response: return endpoint(request=request, **kwargs) return app @@ -142,7 +155,7 @@ def sync_inject_decorator(**kwargs): @sync_inject_decorator(additional="payload") @requires("authenticated") -def decorated_sync(request, additional): +def decorated_sync(request: Request, additional: str) -> JSONResponse: return JSONResponse( { "authenticated": request.user.is_authenticated, @@ -152,9 +165,9 @@ def decorated_sync(request, additional): ) -def ws_inject_decorator(**kwargs): - def wrapper(endpoint): - def app(websocket): +def ws_inject_decorator(**kwargs: Any) -> Callable[..., AsyncEndpoint]: + def wrapper(endpoint: AsyncEndpoint) -> AsyncEndpoint: + def app(websocket: WebSocket) -> Awaitable[Response]: return endpoint(websocket=websocket, **kwargs) return app @@ -164,7 +177,7 @@ def ws_inject_decorator(**kwargs): @ws_inject_decorator(additional="payload") @requires("authenticated") -async def websocket_endpoint_decorated(websocket, additional): +async def websocket_endpoint_decorated(websocket: WebSocket, additional: str) -> None: await websocket.accept() await websocket.send_json( { @@ -192,15 +205,15 @@ app = Starlette( ) -def test_invalid_decorator_usage(): +def test_invalid_decorator_usage() -> None: with pytest.raises(Exception): @requires("authenticated") - def foo(): + def foo() -> None: pass # pragma: nocover -def test_user_interface(test_client_factory): +def test_user_interface(test_client_factory: TestClientFactory) -> None: with test_client_factory(app) as client: response = client.get("/") assert response.status_code == 200 @@ -211,7 +224,7 @@ def test_user_interface(test_client_factory): assert response.json() == {"authenticated": True, "user": "tomchristie"} -def test_authentication_required(test_client_factory): +def test_authentication_required(test_client_factory: TestClientFactory) -> None: with test_client_factory(app) as client: response = client.get("/dashboard") assert response.status_code == 403 @@ -263,7 +276,9 @@ def test_authentication_required(test_client_factory): assert response.text == "Invalid basic auth credentials" -def test_websocket_authentication_required(test_client_factory): +def test_websocket_authentication_required( + test_client_factory: TestClientFactory, +) -> None: with test_client_factory(app) as client: with pytest.raises(WebSocketDisconnect): with client.websocket_connect("/ws"): @@ -302,7 +317,7 @@ def test_websocket_authentication_required(test_client_factory): } -def test_authentication_redirect(test_client_factory): +def test_authentication_redirect(test_client_factory: TestClientFactory) -> None: with test_client_factory(app) as client: response = client.get("/admin") assert response.status_code == 200 @@ -327,12 +342,12 @@ def test_authentication_redirect(test_client_factory): assert response.json() == {"authenticated": True, "user": "tomchristie"} -def on_auth_error(request: HTTPConnection, exc: AuthenticationError): +def on_auth_error(request: HTTPConnection, exc: AuthenticationError) -> JSONResponse: return JSONResponse({"error": str(exc)}, status_code=401) @requires("authenticated") -def control_panel(request): +def control_panel(request: Request) -> JSONResponse: return JSONResponse( { "authenticated": request.user.is_authenticated, @@ -351,7 +366,7 @@ other_app = Starlette( ) -def test_custom_on_error(test_client_factory): +def test_custom_on_error(test_client_factory: TestClientFactory) -> None: with test_client_factory(other_app) as client: response = client.get("/control-panel", auth=("tomchristie", "example")) assert response.status_code == 200