import base64
import binascii
+from typing import Any, Awaitable, Callable, Optional, Tuple
from urllib.parse import urlencode
import pytest
from starlette.endpoints import HTTPEndpoint
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
-from starlette.requests import HTTPConnection
-from starlette.responses import JSONResponse
+from starlette.requests import HTTPConnection, Request
+from starlette.responses import JSONResponse, Response
from starlette.routing import Route, WebSocketRoute
-from starlette.websockets import WebSocketDisconnect
+from starlette.testclient import TestClient
+from starlette.websockets import WebSocket, WebSocketDisconnect
+
+TestClientFactory = Callable[..., TestClient]
+AsyncEndpoint = Callable[..., Awaitable[Response]]
+SyncEndpoint = Callable[..., Response]
class BasicAuth(AuthenticationBackend):
- async def authenticate(self, request):
+ async def authenticate(
+ self,
+ request: HTTPConnection,
+ ) -> Optional[Tuple[AuthCredentials, SimpleUser]]:
if "Authorization" not in request.headers:
return None
return AuthCredentials(["authenticated"]), SimpleUser(username)
-def homepage(request):
+def homepage(request: Request) -> JSONResponse:
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
@requires("authenticated")
-async def dashboard(request):
+async def dashboard(request: Request) -> JSONResponse:
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
@requires("authenticated", redirect="homepage")
-async def admin(request):
+async def admin(request: Request) -> JSONResponse:
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
@requires("authenticated")
-def dashboard_sync(request):
+def dashboard_sync(request: Request) -> JSONResponse:
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
class Dashboard(HTTPEndpoint):
@requires("authenticated")
- def get(self, request):
+ def get(self, request: Request) -> JSONResponse:
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
@requires("authenticated", redirect="homepage")
-def admin_sync(request):
+def admin_sync(request: Request) -> JSONResponse:
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
@requires("authenticated")
-async def websocket_endpoint(websocket):
+async def websocket_endpoint(websocket: WebSocket) -> None:
await websocket.accept()
await websocket.send_json(
{
)
-def async_inject_decorator(**kwargs):
- def wrapper(endpoint):
- async def app(request):
+def async_inject_decorator(
+ **kwargs: Any,
+) -> Callable[[AsyncEndpoint], Callable[..., Awaitable[Response]]]:
+ def wrapper(endpoint: AsyncEndpoint) -> Callable[..., Awaitable[Response]]:
+ async def app(request: Request) -> Response:
return await endpoint(request=request, **kwargs)
return app
@async_inject_decorator(additional="payload")
@requires("authenticated")
-async def decorated_async(request, additional):
+async def decorated_async(request: Request, additional: str) -> JSONResponse:
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
)
-def sync_inject_decorator(**kwargs):
- def wrapper(endpoint):
- def app(request):
+def sync_inject_decorator(
+ **kwargs: Any,
+) -> Callable[[SyncEndpoint], Callable[..., Response]]:
+ def wrapper(endpoint: SyncEndpoint) -> Callable[..., Response]:
+ def app(request: Request) -> Response:
return endpoint(request=request, **kwargs)
return app
@sync_inject_decorator(additional="payload")
@requires("authenticated")
-def decorated_sync(request, additional):
+def decorated_sync(request: Request, additional: str) -> JSONResponse:
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
)
-def ws_inject_decorator(**kwargs):
- def wrapper(endpoint):
- def app(websocket):
+def ws_inject_decorator(**kwargs: Any) -> Callable[..., AsyncEndpoint]:
+ def wrapper(endpoint: AsyncEndpoint) -> AsyncEndpoint:
+ def app(websocket: WebSocket) -> Awaitable[Response]:
return endpoint(websocket=websocket, **kwargs)
return app
@ws_inject_decorator(additional="payload")
@requires("authenticated")
-async def websocket_endpoint_decorated(websocket, additional):
+async def websocket_endpoint_decorated(websocket: WebSocket, additional: str) -> None:
await websocket.accept()
await websocket.send_json(
{
)
-def test_invalid_decorator_usage():
+def test_invalid_decorator_usage() -> None:
with pytest.raises(Exception):
@requires("authenticated")
- def foo():
+ def foo() -> None:
pass # pragma: nocover
-def test_user_interface(test_client_factory):
+def test_user_interface(test_client_factory: TestClientFactory) -> None:
with test_client_factory(app) as client:
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"authenticated": True, "user": "tomchristie"}
-def test_authentication_required(test_client_factory):
+def test_authentication_required(test_client_factory: TestClientFactory) -> None:
with test_client_factory(app) as client:
response = client.get("/dashboard")
assert response.status_code == 403
assert response.text == "Invalid basic auth credentials"
-def test_websocket_authentication_required(test_client_factory):
+def test_websocket_authentication_required(
+ test_client_factory: TestClientFactory,
+) -> None:
with test_client_factory(app) as client:
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect("/ws"):
}
-def test_authentication_redirect(test_client_factory):
+def test_authentication_redirect(test_client_factory: TestClientFactory) -> None:
with test_client_factory(app) as client:
response = client.get("/admin")
assert response.status_code == 200
assert response.json() == {"authenticated": True, "user": "tomchristie"}
-def on_auth_error(request: HTTPConnection, exc: AuthenticationError):
+def on_auth_error(request: HTTPConnection, exc: AuthenticationError) -> JSONResponse:
return JSONResponse({"error": str(exc)}, status_code=401)
@requires("authenticated")
-def control_panel(request):
+def control_panel(request: Request) -> JSONResponse:
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
)
-def test_custom_on_error(test_client_factory):
+def test_custom_on_error(test_client_factory: TestClientFactory) -> None:
with test_client_factory(other_app) as client:
response = client.get("/control-panel", auth=("tomchristie", "example"))
assert response.status_code == 200