]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:sparkles: Implement OAuth2 authorization_code integration (#797)
authorJesse P. Johnson <kuwv@users.noreply.github.com>
Wed, 8 Jan 2020 21:47:19 +0000 (16:47 -0500)
committerSebastián Ramírez <tiangolo@gmail.com>
Wed, 8 Jan 2020 21:47:19 +0000 (22:47 +0100)
fastapi/security/__init__.py
fastapi/security/oauth2.py
tests/test_security_oauth2_authorization_code_bearer.py [new file with mode: 0644]

index de88d8f151447279ed9725c3e74b4c4bf6ef20de..37bf213c161347af7ff4b94dbcadbcd58cff25e3 100644 (file)
@@ -8,6 +8,7 @@ from .http import (
 )
 from .oauth2 import (
     OAuth2,
+    OAuth2AuthorizationCodeBearer,
     OAuth2PasswordBearer,
     OAuth2PasswordRequestForm,
     SecurityScopes,
index c7451cfafcd66dab04cececb19e6bc4d544cf6f8..781293bb9570de05170a0ba516142ba4daf4ec0f 100644 (file)
@@ -163,6 +163,43 @@ class OAuth2PasswordBearer(OAuth2):
         return param
 
 
+class OAuth2AuthorizationCodeBearer(OAuth2):
+    def __init__(
+        self,
+        authorizationUrl: str,
+        tokenUrl: str,
+        refreshUrl: str = None,
+        scheme_name: str = None,
+        scopes: dict = None,
+        auto_error: bool = True,
+    ):
+        if not scopes:
+            scopes = {}
+        flows = OAuthFlowsModel(
+            authorizationCode={
+                "authorizationUrl": authorizationUrl,
+                "tokenUrl": tokenUrl,
+                "refreshUrl": refreshUrl,
+                "scopes": scopes,
+            }
+        )
+        super().__init__(flows=flows, scheme_name=scheme_name, auto_error=auto_error)
+
+    async def __call__(self, request: Request) -> Optional[str]:
+        authorization: str = request.headers.get("Authorization")
+        scheme, param = get_authorization_scheme_param(authorization)
+        if not authorization or scheme.lower() != "bearer":
+            if self.auto_error:
+                raise HTTPException(
+                    status_code=HTTP_401_UNAUTHORIZED,
+                    detail="Not authenticated",
+                    headers={"WWW-Authenticate": "Bearer"},
+                )
+            else:
+                return None  # pragma: nocover
+        return param
+
+
 class SecurityScopes:
     def __init__(self, scopes: List[str] = None):
         self.scopes = scopes or []
diff --git a/tests/test_security_oauth2_authorization_code_bearer.py b/tests/test_security_oauth2_authorization_code_bearer.py
new file mode 100644 (file)
index 0000000..f39fcd0
--- /dev/null
@@ -0,0 +1,77 @@
+from typing import Optional
+
+from fastapi import FastAPI, Security
+from fastapi.security import OAuth2AuthorizationCodeBearer
+from starlette.testclient import TestClient
+
+app = FastAPI()
+
+oauth2_scheme = OAuth2AuthorizationCodeBearer(
+    authorizationUrl="/authorize", tokenUrl="/token", auto_error=True
+)
+
+
+@app.get("/items/")
+async def read_items(token: Optional[str] = Security(oauth2_scheme)):
+    return {"token": token}
+
+
+client = TestClient(app)
+
+openapi_schema = {
+    "openapi": "3.0.2",
+    "info": {"title": "Fast API", "version": "0.1.0"},
+    "paths": {
+        "/items/": {
+            "get": {
+                "responses": {
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/json": {"schema": {}}},
+                    }
+                },
+                "summary": "Read Items",
+                "operationId": "read_items_items__get",
+                "security": [{"OAuth2AuthorizationCodeBearer": []}],
+            }
+        }
+    },
+    "components": {
+        "securitySchemes": {
+            "OAuth2AuthorizationCodeBearer": {
+                "type": "oauth2",
+                "flows": {
+                    "authorizationCode": {
+                        "authorizationUrl": "/authorize",
+                        "tokenUrl": "/token",
+                        "scopes": {},
+                    }
+                },
+            }
+        }
+    },
+}
+
+
+def test_openapi_schema():
+    response = client.get("/openapi.json")
+    assert response.status_code == 200
+    assert response.json() == openapi_schema
+
+
+def test_no_token():
+    response = client.get("/items")
+    assert response.status_code == 401
+    assert response.json() == {"detail": "Not authenticated"}
+
+
+def test_incorrect_token():
+    response = client.get("/items", headers={"Authorization": "Non-existent testtoken"})
+    assert response.status_code == 401
+    assert response.json() == {"detail": "Not authenticated"}
+
+
+def test_token():
+    response = client.get("/items", headers={"Authorization": "Bearer testtoken"})
+    assert response.status_code == 200
+    assert response.json() == {"token": "testtoken"}