]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add type hints to `test_authentication.py` (#2472)
authorScirlat Danut <danut.scirlat@gmail.com>
Tue, 6 Feb 2024 20:47:45 +0000 (22:47 +0200)
committerGitHub <noreply@github.com>
Tue, 6 Feb 2024 20:47:45 +0000 (20:47 +0000)
* added type annotations to test_authentication.py

* fixed types

* Apply suggestions from code review

* Fix linting

* Fix linting

* Apply suggestions from code review

---------

Co-authored-by: Scirlat Danut <scirlatdanut@scirlats-mini.lan>
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
tests/test_authentication.py

index 150482a1b6a7d187b5892ec56252415483024397..27b0337620d594a4d7ca3345c93ece12bbad0fc1 100644 (file)
@@ -1,5 +1,6 @@
 import base64
 import binascii
+from typing import Any, Awaitable, Callable, Optional, Tuple
 from urllib.parse import urlencode
 
 import pytest
@@ -15,14 +16,22 @@ from starlette.authentication import (
 from starlette.endpoints import HTTPEndpoint
 from starlette.middleware import Middleware
 from starlette.middleware.authentication import AuthenticationMiddleware
-from starlette.requests import HTTPConnection
-from starlette.responses import JSONResponse
+from starlette.requests import HTTPConnection, Request
+from starlette.responses import JSONResponse, Response
 from starlette.routing import Route, WebSocketRoute
-from starlette.websockets import WebSocketDisconnect
+from starlette.testclient import TestClient
+from starlette.websockets import WebSocket, WebSocketDisconnect
+
+TestClientFactory = Callable[..., TestClient]
+AsyncEndpoint = Callable[..., Awaitable[Response]]
+SyncEndpoint = Callable[..., Response]
 
 
 class BasicAuth(AuthenticationBackend):
-    async def authenticate(self, request):
+    async def authenticate(
+        self,
+        request: HTTPConnection,
+    ) -> Optional[Tuple[AuthCredentials, SimpleUser]]:
         if "Authorization" not in request.headers:
             return None
 
@@ -37,7 +46,7 @@ class BasicAuth(AuthenticationBackend):
         return AuthCredentials(["authenticated"]), SimpleUser(username)
 
 
-def homepage(request):
+def homepage(request: Request) -> JSONResponse:
     return JSONResponse(
         {
             "authenticated": request.user.is_authenticated,
@@ -47,7 +56,7 @@ def homepage(request):
 
 
 @requires("authenticated")
-async def dashboard(request):
+async def dashboard(request: Request) -> JSONResponse:
     return JSONResponse(
         {
             "authenticated": request.user.is_authenticated,
@@ -57,7 +66,7 @@ async def dashboard(request):
 
 
 @requires("authenticated", redirect="homepage")
-async def admin(request):
+async def admin(request: Request) -> JSONResponse:
     return JSONResponse(
         {
             "authenticated": request.user.is_authenticated,
@@ -67,7 +76,7 @@ async def admin(request):
 
 
 @requires("authenticated")
-def dashboard_sync(request):
+def dashboard_sync(request: Request) -> JSONResponse:
     return JSONResponse(
         {
             "authenticated": request.user.is_authenticated,
@@ -78,7 +87,7 @@ def dashboard_sync(request):
 
 class Dashboard(HTTPEndpoint):
     @requires("authenticated")
-    def get(self, request):
+    def get(self, request: Request) -> JSONResponse:
         return JSONResponse(
             {
                 "authenticated": request.user.is_authenticated,
@@ -88,7 +97,7 @@ class Dashboard(HTTPEndpoint):
 
 
 @requires("authenticated", redirect="homepage")
-def admin_sync(request):
+def admin_sync(request: Request) -> JSONResponse:
     return JSONResponse(
         {
             "authenticated": request.user.is_authenticated,
@@ -98,7 +107,7 @@ def admin_sync(request):
 
 
 @requires("authenticated")
-async def websocket_endpoint(websocket):
+async def websocket_endpoint(websocket: WebSocket) -> None:
     await websocket.accept()
     await websocket.send_json(
         {
@@ -108,9 +117,11 @@ async def websocket_endpoint(websocket):
     )
 
 
-def async_inject_decorator(**kwargs):
-    def wrapper(endpoint):
-        async def app(request):
+def async_inject_decorator(
+    **kwargs: Any,
+) -> Callable[[AsyncEndpoint], Callable[..., Awaitable[Response]]]:
+    def wrapper(endpoint: AsyncEndpoint) -> Callable[..., Awaitable[Response]]:
+        async def app(request: Request) -> Response:
             return await endpoint(request=request, **kwargs)
 
         return app
@@ -120,7 +131,7 @@ def async_inject_decorator(**kwargs):
 
 @async_inject_decorator(additional="payload")
 @requires("authenticated")
-async def decorated_async(request, additional):
+async def decorated_async(request: Request, additional: str) -> JSONResponse:
     return JSONResponse(
         {
             "authenticated": request.user.is_authenticated,
@@ -130,9 +141,11 @@ async def decorated_async(request, additional):
     )
 
 
-def sync_inject_decorator(**kwargs):
-    def wrapper(endpoint):
-        def app(request):
+def sync_inject_decorator(
+    **kwargs: Any,
+) -> Callable[[SyncEndpoint], Callable[..., Response]]:
+    def wrapper(endpoint: SyncEndpoint) -> Callable[..., Response]:
+        def app(request: Request) -> Response:
             return endpoint(request=request, **kwargs)
 
         return app
@@ -142,7 +155,7 @@ def sync_inject_decorator(**kwargs):
 
 @sync_inject_decorator(additional="payload")
 @requires("authenticated")
-def decorated_sync(request, additional):
+def decorated_sync(request: Request, additional: str) -> JSONResponse:
     return JSONResponse(
         {
             "authenticated": request.user.is_authenticated,
@@ -152,9 +165,9 @@ def decorated_sync(request, additional):
     )
 
 
-def ws_inject_decorator(**kwargs):
-    def wrapper(endpoint):
-        def app(websocket):
+def ws_inject_decorator(**kwargs: Any) -> Callable[..., AsyncEndpoint]:
+    def wrapper(endpoint: AsyncEndpoint) -> AsyncEndpoint:
+        def app(websocket: WebSocket) -> Awaitable[Response]:
             return endpoint(websocket=websocket, **kwargs)
 
         return app
@@ -164,7 +177,7 @@ def ws_inject_decorator(**kwargs):
 
 @ws_inject_decorator(additional="payload")
 @requires("authenticated")
-async def websocket_endpoint_decorated(websocket, additional):
+async def websocket_endpoint_decorated(websocket: WebSocket, additional: str) -> None:
     await websocket.accept()
     await websocket.send_json(
         {
@@ -192,15 +205,15 @@ app = Starlette(
 )
 
 
-def test_invalid_decorator_usage():
+def test_invalid_decorator_usage() -> None:
     with pytest.raises(Exception):
 
         @requires("authenticated")
-        def foo():
+        def foo() -> None:
             pass  # pragma: nocover
 
 
-def test_user_interface(test_client_factory):
+def test_user_interface(test_client_factory: TestClientFactory) -> None:
     with test_client_factory(app) as client:
         response = client.get("/")
         assert response.status_code == 200
@@ -211,7 +224,7 @@ def test_user_interface(test_client_factory):
         assert response.json() == {"authenticated": True, "user": "tomchristie"}
 
 
-def test_authentication_required(test_client_factory):
+def test_authentication_required(test_client_factory: TestClientFactory) -> None:
     with test_client_factory(app) as client:
         response = client.get("/dashboard")
         assert response.status_code == 403
@@ -263,7 +276,9 @@ def test_authentication_required(test_client_factory):
         assert response.text == "Invalid basic auth credentials"
 
 
-def test_websocket_authentication_required(test_client_factory):
+def test_websocket_authentication_required(
+    test_client_factory: TestClientFactory,
+) -> None:
     with test_client_factory(app) as client:
         with pytest.raises(WebSocketDisconnect):
             with client.websocket_connect("/ws"):
@@ -302,7 +317,7 @@ def test_websocket_authentication_required(test_client_factory):
             }
 
 
-def test_authentication_redirect(test_client_factory):
+def test_authentication_redirect(test_client_factory: TestClientFactory) -> None:
     with test_client_factory(app) as client:
         response = client.get("/admin")
         assert response.status_code == 200
@@ -327,12 +342,12 @@ def test_authentication_redirect(test_client_factory):
         assert response.json() == {"authenticated": True, "user": "tomchristie"}
 
 
-def on_auth_error(request: HTTPConnection, exc: AuthenticationError):
+def on_auth_error(request: HTTPConnection, exc: AuthenticationError) -> JSONResponse:
     return JSONResponse({"error": str(exc)}, status_code=401)
 
 
 @requires("authenticated")
-def control_panel(request):
+def control_panel(request: Request) -> JSONResponse:
     return JSONResponse(
         {
             "authenticated": request.user.is_authenticated,
@@ -351,7 +366,7 @@ other_app = Starlette(
 )
 
 
-def test_custom_on_error(test_client_factory):
+def test_custom_on_error(test_client_factory: TestClientFactory) -> None:
     with test_client_factory(other_app) as client:
         response = client.get("/control-panel", auth=("tomchristie", "example"))
         assert response.status_code == 200