]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:sparkles: Add automatic header handling for HTTP Basic Auth (#175)
authorSebastián Ramírez <tiangolo@gmail.com>
Sun, 21 Apr 2019 17:44:25 +0000 (21:44 +0400)
committerGitHub <noreply@github.com>
Sun, 21 Apr 2019 17:44:25 +0000 (21:44 +0400)
* :sparkles: Add automatic header handling for HTTP Basic Auth

* :art: Remove obsolete comment

fastapi/security/http.py
tests/test_security_http_basic.py
tests/test_security_http_basic_optional.py
tests/test_security_http_basic_realm.py [new file with mode: 0644]

index b2da3fcb5460a115317805818ae5922fd880762f..f41d8d9447347889ed1970756c52250a6505d5b0 100644 (file)
@@ -2,6 +2,7 @@ import binascii
 from base64 import b64decode
 from typing import Optional
 
+from fastapi.exceptions import HTTPException
 from fastapi.openapi.models import (
     HTTPBase as HTTPBaseModel,
     HTTPBearer as HTTPBearerModel,
@@ -9,9 +10,8 @@ from fastapi.openapi.models import (
 from fastapi.security.base import SecurityBase
 from fastapi.security.utils import get_authorization_scheme_param
 from pydantic import BaseModel
-from starlette.exceptions import HTTPException
 from starlette.requests import Request
-from starlette.status import HTTP_403_FORBIDDEN
+from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
 
 
 class HTTPBasicCredentials(BaseModel):
@@ -59,15 +59,21 @@ class HTTPBasic(HTTPBase):
     async def __call__(self, request: Request) -> Optional[HTTPBasicCredentials]:
         authorization: str = request.headers.get("Authorization")
         scheme, param = get_authorization_scheme_param(authorization)
-        # before implementing headers with 401 errors, wait for: https://github.com/encode/starlette/issues/295
-        # unauthorized_headers = {"WWW-Authenticate": "Basic"}
+        if self.realm:
+            unauthorized_headers = {"WWW-Authenticate": f'Basic realm="{self.realm}"'}
+        else:
+            unauthorized_headers = {"WWW-Authenticate": "Basic"}
         invalid_user_credentials_exc = HTTPException(
-            status_code=HTTP_403_FORBIDDEN, detail="Invalid authentication credentials"
+            status_code=HTTP_401_UNAUTHORIZED,
+            detail="Invalid authentication credentials",
+            headers=unauthorized_headers,
         )
         if not authorization or scheme.lower() != "basic":
             if self.auto_error:
                 raise HTTPException(
-                    status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
+                    status_code=HTTP_401_UNAUTHORIZED,
+                    detail="Not authenticated",
+                    headers=unauthorized_headers,
                 )
             else:
                 return None
@@ -87,7 +93,7 @@ class HTTPBearer(HTTPBase):
         *,
         bearerFormat: str = None,
         scheme_name: str = None,
-        auto_error: bool = True
+        auto_error: bool = True,
     ):
         self.model = HTTPBearerModel(bearerFormat=bearerFormat)
         self.scheme_name = scheme_name or self.__class__.__name__
index dd289301d4fd3dbfbcb8a617ec4a6072974c1d65..7d380fef0f0b3ffeff6c31c2e589bc9369d7c433 100644 (file)
@@ -56,15 +56,17 @@ def test_security_http_basic():
 
 def test_security_http_basic_no_credentials():
     response = client.get("/users/me")
-    assert response.status_code == 403
     assert response.json() == {"detail": "Not authenticated"}
+    assert response.status_code == 401
+    assert response.headers["WWW-Authenticate"] == "Basic"
 
 
 def test_security_http_basic_invalid_credentials():
     response = client.get(
         "/users/me", headers={"Authorization": "Basic notabase64token"}
     )
-    assert response.status_code == 403
+    assert response.status_code == 401
+    assert response.headers["WWW-Authenticate"] == "Basic"
     assert response.json() == {"detail": "Invalid authentication credentials"}
 
 
@@ -72,5 +74,6 @@ def test_security_http_basic_non_basic_credentials():
     payload = b64encode(b"johnsecret").decode("ascii")
     auth_header = f"Basic {payload}"
     response = client.get("/users/me", headers={"Authorization": auth_header})
-    assert response.status_code == 403
+    assert response.status_code == 401
+    assert response.headers["WWW-Authenticate"] == "Basic"
     assert response.json() == {"detail": "Invalid authentication credentials"}
index 40d64d4124528a3f782dfe7a27c1e822cde76fc5..2a4686bb362819f96465ff2480f518c6b367fe9a 100644 (file)
@@ -67,7 +67,8 @@ def test_security_http_basic_invalid_credentials():
     response = client.get(
         "/users/me", headers={"Authorization": "Basic notabase64token"}
     )
-    assert response.status_code == 403
+    assert response.status_code == 401
+    assert response.headers["WWW-Authenticate"] == "Basic"
     assert response.json() == {"detail": "Invalid authentication credentials"}
 
 
@@ -75,5 +76,6 @@ def test_security_http_basic_non_basic_credentials():
     payload = b64encode(b"johnsecret").decode("ascii")
     auth_header = f"Basic {payload}"
     response = client.get("/users/me", headers={"Authorization": auth_header})
-    assert response.status_code == 403
+    assert response.status_code == 401
+    assert response.headers["WWW-Authenticate"] == "Basic"
     assert response.json() == {"detail": "Invalid authentication credentials"}
diff --git a/tests/test_security_http_basic_realm.py b/tests/test_security_http_basic_realm.py
new file mode 100644 (file)
index 0000000..6b5b4ae
--- /dev/null
@@ -0,0 +1,79 @@
+from base64 import b64encode
+
+from fastapi import FastAPI, Security
+from fastapi.security import HTTPBasic, HTTPBasicCredentials
+from requests.auth import HTTPBasicAuth
+from starlette.testclient import TestClient
+
+app = FastAPI()
+
+security = HTTPBasic(realm="simple")
+
+
+@app.get("/users/me")
+def read_current_user(credentials: HTTPBasicCredentials = Security(security)):
+    return {"username": credentials.username, "password": credentials.password}
+
+
+client = TestClient(app)
+
+openapi_schema = {
+    "openapi": "3.0.2",
+    "info": {"title": "Fast API", "version": "0.1.0"},
+    "paths": {
+        "/users/me": {
+            "get": {
+                "responses": {
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/json": {"schema": {}}},
+                    }
+                },
+                "summary": "Read Current User",
+                "operationId": "read_current_user_users_me_get",
+                "security": [{"HTTPBasic": []}],
+            }
+        }
+    },
+    "components": {
+        "securitySchemes": {"HTTPBasic": {"type": "http", "scheme": "basic"}}
+    },
+}
+
+
+def test_openapi_schema():
+    response = client.get("/openapi.json")
+    assert response.status_code == 200
+    assert response.json() == openapi_schema
+
+
+def test_security_http_basic():
+    auth = HTTPBasicAuth(username="john", password="secret")
+    response = client.get("/users/me", auth=auth)
+    assert response.status_code == 200
+    assert response.json() == {"username": "john", "password": "secret"}
+
+
+def test_security_http_basic_no_credentials():
+    response = client.get("/users/me")
+    assert response.json() == {"detail": "Not authenticated"}
+    assert response.status_code == 401
+    assert response.headers["WWW-Authenticate"] == 'Basic realm="simple"'
+
+
+def test_security_http_basic_invalid_credentials():
+    response = client.get(
+        "/users/me", headers={"Authorization": "Basic notabase64token"}
+    )
+    assert response.status_code == 401
+    assert response.headers["WWW-Authenticate"] == 'Basic realm="simple"'
+    assert response.json() == {"detail": "Invalid authentication credentials"}
+
+
+def test_security_http_basic_non_basic_credentials():
+    payload = b64encode(b"johnsecret").decode("ascii")
+    auth_header = f"Basic {payload}"
+    response = client.get("/users/me", headers={"Authorization": auth_header})
+    assert response.status_code == 401
+    assert response.headers["WWW-Authenticate"] == 'Basic realm="simple"'
+    assert response.json() == {"detail": "Invalid authentication credentials"}