From c566fc6c819f0d565f8cff432351fe009e83d866 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Hannes=20K=C3=BCttner?= Date: Sun, 16 Aug 2020 16:24:23 +0200 Subject: [PATCH] Be more lenient with route arguments in AuthencationMiddleware 'requires' decorator (#942) --- starlette/authentication.py | 6 +- tests/test_authentication.py | 112 +++++++++++++++++++++++++++++++++++ 2 files changed, 115 insertions(+), 3 deletions(-) diff --git a/starlette/authentication.py b/starlette/authentication.py index db0b74fe..1d8a38d8 100644 --- a/starlette/authentication.py +++ b/starlette/authentication.py @@ -41,7 +41,7 @@ def requires( async def websocket_wrapper( *args: typing.Any, **kwargs: typing.Any ) -> None: - websocket = kwargs.get("websocket", args[idx]) + websocket = kwargs.get("websocket", args[idx] if args else None) assert isinstance(websocket, WebSocket) if not has_required_scope(websocket, scopes_list): @@ -57,7 +57,7 @@ def requires( async def async_wrapper( *args: typing.Any, **kwargs: typing.Any ) -> Response: - request = kwargs.get("request", args[idx]) + request = kwargs.get("request", args[idx] if args else None) assert isinstance(request, Request) if not has_required_scope(request, scopes_list): @@ -74,7 +74,7 @@ def requires( # Handle sync request/response functions. @functools.wraps(func) def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response: - request = kwargs.get("request", args[idx]) + request = kwargs.get("request", args[idx] if args else None) assert isinstance(request, Request) if not has_required_scope(request, scopes_list): diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 372ea81d..4c0f57ea 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -117,6 +117,76 @@ async def websocket_endpoint(websocket): ) +def async_inject_decorator(**kwargs): + def wrapper(endpoint): + async def app(request): + return await endpoint(request=request, **kwargs) + + return app + + return wrapper + + +@app.route("/dashboard/decorated") +@async_inject_decorator(additional="payload") +@requires("authenticated") +async def decorated_sync(request, additional): + return JSONResponse( + { + "authenticated": request.user.is_authenticated, + "user": request.user.display_name, + "additional": additional, + } + ) + + +def sync_inject_decorator(**kwargs): + def wrapper(endpoint): + def app(request): + return endpoint(request=request, **kwargs) + + return app + + return wrapper + + +@app.route("/dashboard/decorated/sync") +@sync_inject_decorator(additional="payload") +@requires("authenticated") +def decorated_sync(request, additional): + return JSONResponse( + { + "authenticated": request.user.is_authenticated, + "user": request.user.display_name, + "additional": additional, + } + ) + + +def ws_inject_decorator(**kwargs): + def wrapper(endpoint): + def app(websocket): + return endpoint(websocket=websocket, **kwargs) + + return app + + return wrapper + + +@app.websocket_route("/ws/decorated") +@ws_inject_decorator(additional="payload") +@requires("authenticated") +async def websocket_endpoint(websocket, additional): + await websocket.accept() + await websocket.send_json( + { + "authenticated": websocket.user.is_authenticated, + "user": websocket.user.display_name, + "additional": additional, + } + ) + + def test_invalid_decorator_usage(): with pytest.raises(Exception): @@ -159,6 +229,30 @@ def test_authentication_required(): assert response.status_code == 200 assert response.json() == {"authenticated": True, "user": "tomchristie"} + response = client.get("/dashboard/decorated", auth=("tomchristie", "example")) + assert response.status_code == 200 + assert response.json() == { + "authenticated": True, + "user": "tomchristie", + "additional": "payload", + } + + response = client.get("/dashboard/decorated") + assert response.status_code == 403 + + response = client.get( + "/dashboard/decorated/sync", auth=("tomchristie", "example") + ) + assert response.status_code == 200 + assert response.json() == { + "authenticated": True, + "user": "tomchristie", + "additional": "payload", + } + + response = client.get("/dashboard/decorated/sync") + assert response.status_code == 403 + response = client.get("/dashboard", headers={"Authorization": "basic foobar"}) assert response.status_code == 400 assert response.text == "Invalid basic auth credentials" @@ -178,6 +272,24 @@ def test_websocket_authentication_required(): data = websocket.receive_json() assert data == {"authenticated": True, "user": "tomchristie"} + with pytest.raises(WebSocketDisconnect): + client.websocket_connect("/ws/decorated") + + with pytest.raises(WebSocketDisconnect): + client.websocket_connect( + "/ws/decorated", headers={"Authorization": "basic foobar"} + ) + + with client.websocket_connect( + "/ws/decorated", auth=("tomchristie", "example") + ) as websocket: + data = websocket.receive_json() + assert data == { + "authenticated": True, + "user": "tomchristie", + "additional": "payload", + } + def test_authentication_redirect(): with TestClient(app) as client: -- 2.47.3