]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:sparkles: Improve security utilities and add tests
authorSebastián Ramírez <tiangolo@gmail.com>
Fri, 28 Dec 2018 16:35:48 +0000 (20:35 +0400)
committerSebastián Ramírez <tiangolo@gmail.com>
Fri, 28 Dec 2018 16:35:48 +0000 (20:35 +0400)
17 files changed:
fastapi/dependencies/utils.py
fastapi/security/api_key.py
fastapi/security/http.py
fastapi/security/oauth2.py
fastapi/security/open_id_connect_url.py
fastapi/security/utils.py [new file with mode: 0644]
tests/main.py
tests/test_application.py
tests/test_extra_routes.py
tests/test_include_route.py [new file with mode: 0644]
tests/test_query.py
tests/test_security.py [deleted file]
tests/test_security_api_key_cookie.py [new file with mode: 0644]
tests/test_security_api_key_header.py [new file with mode: 0644]
tests/test_security_api_key_query.py [new file with mode: 0644]
tests/test_security_oauth2.py [new file with mode: 0644]
tests/test_security_openid_connect.py [new file with mode: 0644]

index 9cb9f3f14edfe566ff40b44831b6384b72544d86..e72ac8f010fa2dd52b37ca1c7c442c405a3a93a9 100644 (file)
@@ -97,7 +97,7 @@ def get_dependant(*, path: str, call: Callable, name: str = None) -> Dependant:
         elif (
             param.default == param.empty
             or param.default is None
-            or type(param.default) in param_supported_types
+            or isinstance(param.default, param_supported_types)
         ) and (
             param.annotation == param.empty
             or lenient_issubclass(param.annotation, param_supported_types)
@@ -214,7 +214,8 @@ async def solve_dependencies(
             request=request, dependant=sub_dependant, body=body
         )
         if sub_errors:
-            return {}, errors
+            errors.extend(sub_errors)
+            continue
         assert sub_dependant.call is not None, "sub_dependant.call must be a function"
         if is_coroutine_callable(sub_dependant.call):
             solved = await sub_dependant.call(**sub_values)
@@ -238,7 +239,7 @@ async def solve_dependencies(
     values.update(query_values)
     values.update(header_values)
     values.update(cookie_values)
-    errors = path_errors + query_errors + header_errors + cookie_errors
+    errors += path_errors + query_errors + header_errors + cookie_errors
     if dependant.body_params:
         body_values, body_errors = await request_body_to_args(  # type: ignore # body_params checked above
             dependant.body_params, body
@@ -295,7 +296,7 @@ async def request_body_to_args(
             received_body = {}
         for field in required_params:
             value = received_body.get(field.alias)
-            if value is None:
+            if value is None or (isinstance(field.schema, params.Form) and value == ""):
                 if field.required:
                     errors.append(
                         ErrorWrapper(
index 12eba37ee7d2462a1a001f764aba54172f02a133..018e4f99eb516ff61136d9857ba8d81fb50bfb8b 100644 (file)
@@ -1,6 +1,8 @@
 from fastapi.openapi.models import APIKey, APIKeyIn
 from fastapi.security.base import SecurityBase
+from starlette.exceptions import HTTPException
 from starlette.requests import Request
+from starlette.status import HTTP_403_FORBIDDEN
 
 
 class APIKeyBase(SecurityBase):
@@ -9,26 +11,41 @@ class APIKeyBase(SecurityBase):
 
 class APIKeyQuery(APIKeyBase):
     def __init__(self, *, name: str, scheme_name: str = None):
-        self.model = APIKey(in_=APIKeyIn.query, name=name)
+        self.model = APIKey(**{"in": APIKeyIn.query}, name=name)
         self.scheme_name = scheme_name or self.__class__.__name__
 
-    async def __call__(self, requests: Request) -> str:
-        return requests.query_params.get(self.model.name)
+    async def __call__(self, request: Request) -> str:
+        api_key: str = request.query_params.get(self.model.name)
+        if not api_key:
+            raise HTTPException(
+                status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
+            )
+        return api_key
 
 
 class APIKeyHeader(APIKeyBase):
     def __init__(self, *, name: str, scheme_name: str = None):
-        self.model = APIKey(in_=APIKeyIn.header, name=name)
+        self.model = APIKey(**{"in": APIKeyIn.header}, name=name)
         self.scheme_name = scheme_name or self.__class__.__name__
 
-    async def __call__(self, requests: Request) -> str:
-        return requests.headers.get(self.model.name)
+    async def __call__(self, request: Request) -> str:
+        api_key: str = request.headers.get(self.model.name)
+        if not api_key:
+            raise HTTPException(
+                status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
+            )
+        return api_key
 
 
 class APIKeyCookie(APIKeyBase):
     def __init__(self, *, name: str, scheme_name: str = None):
-        self.model = APIKey(in_=APIKeyIn.cookie, name=name)
+        self.model = APIKey(**{"in": APIKeyIn.cookie}, name=name)
         self.scheme_name = scheme_name or self.__class__.__name__
 
-    async def __call__(self, requests: Request) -> str:
-        return requests.cookies.get(self.model.name)
+    async def __call__(self, request: Request) -> str:
+        api_key: str = request.cookies.get(self.model.name)
+        if not api_key:
+            raise HTTPException(
+                status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
+            )
+        return api_key
index b1cba1921fd24b013ae815b107926ea096fea9d3..287beee58e9a25d699c839ceb8c5ee34570dcea1 100644 (file)
@@ -1,9 +1,26 @@
+import binascii
+from base64 import b64decode
+
 from fastapi.openapi.models import (
     HTTPBase as HTTPBaseModel,
     HTTPBearer as HTTPBearerModel,
 )
 from fastapi.security.base import SecurityBase
+from fastapi.security.utils import get_authorization_scheme_param
+from pydantic import BaseModel
+from starlette.exceptions import HTTPException
 from starlette.requests import Request
+from starlette.status import HTTP_403_FORBIDDEN
+
+
+class HTTPBasicCredentials(BaseModel):
+    username: str
+    password: str
+
+
+class HTTPAuthorizationCredentials(BaseModel):
+    scheme: str
+    credentials: str
 
 
 class HTTPBase(SecurityBase):
@@ -12,16 +29,41 @@ class HTTPBase(SecurityBase):
         self.scheme_name = scheme_name or self.__class__.__name__
 
     async def __call__(self, request: Request) -> str:
-        return request.headers.get("Authorization")
+        authorization: str = request.headers.get("Authorization")
+        scheme, credentials = get_authorization_scheme_param(authorization)
+        if not (authorization and scheme and credentials):
+            raise HTTPException(
+                status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
+            )
+        return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
 
 
 class HTTPBasic(HTTPBase):
-    def __init__(self, *, scheme_name: str = None):
+    def __init__(self, *, scheme_name: str = None, realm: str = None):
         self.model = HTTPBaseModel(scheme="basic")
         self.scheme_name = scheme_name or self.__class__.__name__
+        self.realm = realm
 
     async def __call__(self, request: Request) -> str:
-        return request.headers.get("Authorization")
+        authorization: str = request.headers.get("Authorization")
+        scheme, param = get_authorization_scheme_param(authorization)
+        # before implementing headers with 401 errors, wait for: https://github.com/encode/starlette/issues/295
+        # unauthorized_headers = {"WWW-Authenticate": "Basic"}
+        invalid_user_credentials_exc = HTTPException(
+            status_code=HTTP_403_FORBIDDEN, detail="Invalid authentication credentials"
+        )
+        if not authorization or scheme.lower() != "basic":
+            raise HTTPException(
+                status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
+            )
+        try:
+            data = b64decode(param).decode("ascii")
+        except (ValueError, UnicodeDecodeError, binascii.Error):
+            raise invalid_user_credentials_exc
+        username, separator, password = data.partition(":")
+        if not (separator):
+            raise invalid_user_credentials_exc
+        return HTTPBasicCredentials(username=username, password=password)
 
 
 class HTTPBearer(HTTPBase):
@@ -30,7 +72,13 @@ class HTTPBearer(HTTPBase):
         self.scheme_name = scheme_name or self.__class__.__name__
 
     async def __call__(self, request: Request) -> str:
-        return request.headers.get("Authorization")
+        authorization: str = request.headers.get("Authorization")
+        scheme, credentials = get_authorization_scheme_param(authorization)
+        if not (authorization and scheme and credentials):
+            raise HTTPException(
+                status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
+            )
+        return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
 
 
 class HTTPDigest(HTTPBase):
@@ -39,4 +87,10 @@ class HTTPDigest(HTTPBase):
         self.scheme_name = scheme_name or self.__class__.__name__
 
     async def __call__(self, request: Request) -> str:
-        return request.headers.get("Authorization")
+        authorization: str = request.headers.get("Authorization")
+        scheme, credentials = get_authorization_scheme_param(authorization)
+        if not (authorization and scheme and credentials):
+            raise HTTPException(
+                status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
+            )
+        return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
index 4fd767ec69dd170edc58b3e5a37c7bdcbda187e3..b1132fef133a5f11c679fedf7593f08306253d96 100644 (file)
@@ -3,6 +3,7 @@ from typing import Optional
 from fastapi.openapi.models import OAuth2 as OAuth2Model, OAuthFlows as OAuthFlowsModel
 from fastapi.params import Form
 from fastapi.security.base import SecurityBase
+from fastapi.security.utils import get_authorization_scheme_param
 from starlette.exceptions import HTTPException
 from starlette.requests import Request
 from starlette.status import HTTP_403_FORBIDDEN
@@ -118,7 +119,12 @@ class OAuth2(SecurityBase):
         self.scheme_name = scheme_name or self.__class__.__name__
 
     async def __call__(self, request: Request) -> str:
-        return request.headers.get("Authorization")
+        authorization: str = request.headers.get("Authorization")
+        if not authorization:
+            raise HTTPException(
+                status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
+            )
+        return authorization
 
 
 class OAuth2PasswordBearer(OAuth2):
@@ -130,9 +136,9 @@ class OAuth2PasswordBearer(OAuth2):
 
     async def __call__(self, request: Request) -> str:
         authorization: str = request.headers.get("Authorization")
-        if not authorization or "Bearer " not in authorization:
+        scheme, param = get_authorization_scheme_param(authorization)
+        if not authorization or scheme.lower() != "bearer":
             raise HTTPException(
                 status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
             )
-        token = authorization.replace("Bearer ", "")
-        return token
+        return param
index 7d73ed81f388ca620af2c768d7f6fc85281e299c..e10f4a5106d56c5ca3dfce999bb33fb82fd091d4 100644 (file)
@@ -1,6 +1,8 @@
 from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
 from fastapi.security.base import SecurityBase
+from starlette.exceptions import HTTPException
 from starlette.requests import Request
+from starlette.status import HTTP_403_FORBIDDEN
 
 
 class OpenIdConnect(SecurityBase):
@@ -9,4 +11,9 @@ class OpenIdConnect(SecurityBase):
         self.scheme_name = scheme_name or self.__class__.__name__
 
     async def __call__(self, request: Request) -> str:
-        return request.headers.get("Authorization")
+        authorization: str = request.headers.get("Authorization")
+        if not authorization:
+            raise HTTPException(
+                status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
+            )
+        return authorization
diff --git a/fastapi/security/utils.py b/fastapi/security/utils.py
new file mode 100644 (file)
index 0000000..3ddd83a
--- /dev/null
@@ -0,0 +1,5 @@
+def get_authorization_scheme_param(authorization_header_value: str):
+    if not authorization_header_value:
+        return "", ""
+    scheme, _, param = authorization_header_value.partition(" ")
+    return scheme, param
index c384bbb759d4b9b994ff663e1a643f3e22cf449c..ab0b186072933eb2bb4c770b4e78183c7b4d3b29 100644 (file)
@@ -1,6 +1,4 @@
-from fastapi import Depends, FastAPI, Path, Query, Security
-from fastapi.security import OAuth2PasswordBearer
-from pydantic import BaseModel
+from fastapi import FastAPI, Path, Query
 
 app = FastAPI()
 
@@ -144,8 +142,6 @@ def get_path_param_le_ge_int(item_id: int = Path(..., le=3, ge=1)):
 
 @app.get("/query")
 def get_query(query):
-    if query is None:
-        return "foo bar"
     return f"foo bar {query}"
 
 
@@ -158,8 +154,6 @@ def get_query_optional(query=None):
 
 @app.get("/query/int")
 def get_query_type(query: int):
-    if query is None:
-        return "foo bar"
     return f"foo bar {query}"
 
 
@@ -184,30 +178,9 @@ def get_query_param(query=Query(None)):
 
 @app.get("/query/param-required")
 def get_query_param_required(query=Query(...)):
-    if query is None:
-        return "foo bar"
     return f"foo bar {query}"
 
 
 @app.get("/query/param-required/int")
 def get_query_param_required_type(query: int = Query(...)):
-    if query is None:
-        return "foo bar"
     return f"foo bar {query}"
-
-
-reusable_oauth2b = OAuth2PasswordBearer(tokenUrl="/token")
-
-
-class User(BaseModel):
-    username: str
-
-
-def get_current_user(oauth_header: str = Security(reusable_oauth2b)):
-    user = User(username=oauth_header)
-    return user
-
-
-@app.get("/security/oauth2b")
-def read_current_user(current_user: User = Depends(get_current_user)):
-    return current_user
index 55394c19dd3dc2c7a4b9ba739d7ab7f339316482..fb0f539751f15361f16b33e83806f81b3c1fa9cf 100644 (file)
@@ -1078,19 +1078,6 @@ openapi_schema = {
                 ],
             }
         },
-        "/security/oauth2b": {
-            "get": {
-                "responses": {
-                    "200": {
-                        "description": "Successful Response",
-                        "content": {"application/json": {"schema": {}}},
-                    }
-                },
-                "summary": "Read Current User Get",
-                "operationId": "read_current_user_security_oauth2b_get",
-                "security": [{"OAuth2PasswordBearer": []}],
-            }
-        },
     },
     "components": {
         "schemas": {
@@ -1119,13 +1106,7 @@ openapi_schema = {
                     }
                 },
             },
-        },
-        "securitySchemes": {
-            "OAuth2PasswordBearer": {
-                "type": "oauth2",
-                "flows": {"password": {"scopes": {}, "tokenUrl": "/token"}},
-            }
-        },
+        }
     },
 }
 
@@ -1134,6 +1115,7 @@ openapi_schema = {
     "path,expected_status,expected_response",
     [
         ("/api_route", 200, {"message": "Hello World"}),
+        ("/non_decorated_route", 200, {"message": "Hello World"}),
         ("/nonexistent", 404, {"detail": "Not Found"}),
         ("/openapi.json", 200, openapi_schema),
     ],
index d07b90d3f23bae6dff0d7a8695693b903507ae3d..6147c3414b13efb4d5843bb2a5107f4d843197ac 100644 (file)
@@ -343,7 +343,7 @@ def test_head():
 
 
 def test_options():
-    response = client.head("/items/foo")
+    response = client.options("/items/foo")
     assert response.status_code == 200
     assert response.headers["x-fastapi-item-id"] == "foo"
 
diff --git a/tests/test_include_route.py b/tests/test_include_route.py
new file mode 100644 (file)
index 0000000..c194d20
--- /dev/null
@@ -0,0 +1,23 @@
+from fastapi import APIRouter, FastAPI
+from starlette.requests import Request
+from starlette.responses import JSONResponse
+from starlette.testclient import TestClient
+
+app = FastAPI()
+router = APIRouter()
+
+
+@router.route("/items/")
+def read_items(request: Request):
+    return JSONResponse({"hello": "world"})
+
+
+app.include_router(router)
+
+client = TestClient(app)
+
+
+def test_sub_router():
+    response = client.get("/items/")
+    assert response.status_code == 200
+    assert response.json() == {"hello": "world"}
index 17d120287fbf5dd5de5f49455aa9facfba38e145..92cff2bb5a2487f648e61814ef5987ff8d6c8628 100644 (file)
@@ -40,9 +40,19 @@ response_not_valid_int = {
         ("/query/int?query=42.5", 422, response_not_valid_int),
         ("/query/int?query=baz", 422, response_not_valid_int),
         ("/query/int?not_declared=baz", 422, response_missing),
+        ("/query/int/optional", 200, "foo bar"),
+        ("/query/int/optional?query=50", 200, "foo bar 50"),
+        ("/query/int/optional?query=foo", 422, response_not_valid_int),
         ("/query/int/default", 200, "foo bar 10"),
         ("/query/int/default?query=50", 200, "foo bar 50"),
         ("/query/int/default?query=foo", 422, response_not_valid_int),
+        ("/query/param", 200, "foo bar"),
+        ("/query/param?query=50", 200, "foo bar 50"),
+        ("/query/param-required", 422, response_missing),
+        ("/query/param-required?query=50", 200, "foo bar 50"),
+        ("/query/param-required/int", 422, response_missing),
+        ("/query/param-required/int?query=50", 200, "foo bar 50"),
+        ("/query/param-required/int?query=foo", 422, response_not_valid_int),
     ],
 )
 def test_get_path(path, expected_status, expected_response):
diff --git a/tests/test_security.py b/tests/test_security.py
deleted file mode 100644 (file)
index 672a846..0000000
+++ /dev/null
@@ -1,25 +0,0 @@
-from starlette.testclient import TestClient
-
-from .main import app
-
-client = TestClient(app)
-
-
-def test_security_oauth2_password_bearer():
-    response = client.get(
-        "/security/oauth2b", headers={"Authorization": "Bearer footokenbar"}
-    )
-    assert response.status_code == 200
-    assert response.json() == {"username": "footokenbar"}
-
-
-def test_security_oauth2_password_bearer_wrong_header():
-    response = client.get("/security/oauth2b", headers={"Authorization": "footokenbar"})
-    assert response.status_code == 403
-    assert response.json() == {"detail": "Not authenticated"}
-
-
-def test_security_oauth2_password_bearer_no_header():
-    response = client.get("/security/oauth2b")
-    assert response.status_code == 403
-    assert response.json() == {"detail": "Not authenticated"}
diff --git a/tests/test_security_api_key_cookie.py b/tests/test_security_api_key_cookie.py
new file mode 100644 (file)
index 0000000..88b3eef
--- /dev/null
@@ -0,0 +1,68 @@
+from fastapi import Depends, FastAPI, Security
+from fastapi.security import APIKeyCookie
+from pydantic import BaseModel
+from starlette.testclient import TestClient
+
+app = FastAPI()
+
+api_key = APIKeyCookie(name="key")
+
+
+class User(BaseModel):
+    username: str
+
+
+def get_current_user(oauth_header: str = Security(api_key)):
+    user = User(username=oauth_header)
+    return user
+
+
+@app.get("/users/me")
+def read_current_user(current_user: User = Depends(get_current_user)):
+    return current_user
+
+
+client = TestClient(app)
+
+openapi_schema = {
+    "openapi": "3.0.2",
+    "info": {"title": "Fast API", "version": "0.1.0"},
+    "paths": {
+        "/users/me": {
+            "get": {
+                "responses": {
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/json": {"schema": {}}},
+                    }
+                },
+                "summary": "Read Current User Get",
+                "operationId": "read_current_user_users_me_get",
+                "security": [{"APIKeyCookie": []}],
+            }
+        }
+    },
+    "components": {
+        "securitySchemes": {
+            "APIKeyCookie": {"type": "apiKey", "name": "key", "in": "cookie"}
+        }
+    },
+}
+
+
+def test_openapi_schema():
+    response = client.get("/openapi.json")
+    assert response.status_code == 200
+    assert response.json() == openapi_schema
+
+
+def test_security_api_key():
+    response = client.get("/users/me", cookies={"key": "secret"})
+    assert response.status_code == 200
+    assert response.json() == {"username": "secret"}
+
+
+def test_security_api_key_no_key():
+    response = client.get("/users/me")
+    assert response.status_code == 403
+    assert response.json() == {"detail": "Not authenticated"}
diff --git a/tests/test_security_api_key_header.py b/tests/test_security_api_key_header.py
new file mode 100644 (file)
index 0000000..2d6114d
--- /dev/null
@@ -0,0 +1,68 @@
+from fastapi import Depends, FastAPI, Security
+from fastapi.security import APIKeyHeader
+from pydantic import BaseModel
+from starlette.testclient import TestClient
+
+app = FastAPI()
+
+api_key = APIKeyHeader(name="key")
+
+
+class User(BaseModel):
+    username: str
+
+
+def get_current_user(oauth_header: str = Security(api_key)):
+    user = User(username=oauth_header)
+    return user
+
+
+@app.get("/users/me")
+def read_current_user(current_user: User = Depends(get_current_user)):
+    return current_user
+
+
+client = TestClient(app)
+
+openapi_schema = {
+    "openapi": "3.0.2",
+    "info": {"title": "Fast API", "version": "0.1.0"},
+    "paths": {
+        "/users/me": {
+            "get": {
+                "responses": {
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/json": {"schema": {}}},
+                    }
+                },
+                "summary": "Read Current User Get",
+                "operationId": "read_current_user_users_me_get",
+                "security": [{"APIKeyHeader": []}],
+            }
+        }
+    },
+    "components": {
+        "securitySchemes": {
+            "APIKeyHeader": {"type": "apiKey", "name": "key", "in": "header"}
+        }
+    },
+}
+
+
+def test_openapi_schema():
+    response = client.get("/openapi.json")
+    assert response.status_code == 200
+    assert response.json() == openapi_schema
+
+
+def test_security_api_key():
+    response = client.get("/users/me", headers={"key": "secret"})
+    assert response.status_code == 200
+    assert response.json() == {"username": "secret"}
+
+
+def test_security_api_key_no_key():
+    response = client.get("/users/me")
+    assert response.status_code == 403
+    assert response.json() == {"detail": "Not authenticated"}
diff --git a/tests/test_security_api_key_query.py b/tests/test_security_api_key_query.py
new file mode 100644 (file)
index 0000000..599b254
--- /dev/null
@@ -0,0 +1,68 @@
+from fastapi import Depends, FastAPI, Security
+from fastapi.security import APIKeyQuery
+from pydantic import BaseModel
+from starlette.testclient import TestClient
+
+app = FastAPI()
+
+api_key = APIKeyQuery(name="key")
+
+
+class User(BaseModel):
+    username: str
+
+
+def get_current_user(oauth_header: str = Security(api_key)):
+    user = User(username=oauth_header)
+    return user
+
+
+@app.get("/users/me")
+def read_current_user(current_user: User = Depends(get_current_user)):
+    return current_user
+
+
+client = TestClient(app)
+
+openapi_schema = {
+    "openapi": "3.0.2",
+    "info": {"title": "Fast API", "version": "0.1.0"},
+    "paths": {
+        "/users/me": {
+            "get": {
+                "responses": {
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/json": {"schema": {}}},
+                    }
+                },
+                "summary": "Read Current User Get",
+                "operationId": "read_current_user_users_me_get",
+                "security": [{"APIKeyQuery": []}],
+            }
+        }
+    },
+    "components": {
+        "securitySchemes": {
+            "APIKeyQuery": {"type": "apiKey", "name": "key", "in": "query"}
+        }
+    },
+}
+
+
+def test_openapi_schema():
+    response = client.get("/openapi.json")
+    assert response.status_code == 200
+    assert response.json() == openapi_schema
+
+
+def test_security_api_key():
+    response = client.get("/users/me?key=secret")
+    assert response.status_code == 200
+    assert response.json() == {"username": "secret"}
+
+
+def test_security_api_key_no_key():
+    response = client.get("/users/me")
+    assert response.status_code == 403
+    assert response.json() == {"detail": "Not authenticated"}
diff --git a/tests/test_security_oauth2.py b/tests/test_security_oauth2.py
new file mode 100644 (file)
index 0000000..050f17d
--- /dev/null
@@ -0,0 +1,247 @@
+import pytest
+from fastapi import Depends, FastAPI, Security
+from fastapi.security import OAuth2
+from fastapi.security.oauth2 import OAuth2PasswordRequestFormStrict
+from pydantic import BaseModel
+from starlette.testclient import TestClient
+
+app = FastAPI()
+
+reusable_oauth2 = OAuth2(
+    flows={
+        "password": {
+            "tokenUrl": "/token",
+            "scopes": {"read:users": "Read the users", "write:users": "Create users"},
+        }
+    }
+)
+
+
+class User(BaseModel):
+    username: str
+
+
+def get_current_user(oauth_header: str = Security(reusable_oauth2)):
+    user = User(username=oauth_header)
+    return user
+
+
+@app.post("/login")
+def read_current_user(form_data: OAuth2PasswordRequestFormStrict = Depends()):
+    return form_data
+
+
+@app.get("/users/me")
+def read_current_user(current_user: User = Depends(get_current_user)):
+    return current_user
+
+
+client = TestClient(app)
+
+openapi_schema = {
+    "openapi": "3.0.2",
+    "info": {"title": "Fast API", "version": "0.1.0"},
+    "paths": {
+        "/login": {
+            "post": {
+                "responses": {
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/json": {"schema": {}}},
+                    },
+                    "422": {
+                        "description": "Validation Error",
+                        "content": {
+                            "application/json": {
+                                "schema": {
+                                    "$ref": "#/components/schemas/HTTPValidationError"
+                                }
+                            }
+                        },
+                    },
+                },
+                "summary": "Read Current User Post",
+                "operationId": "read_current_user_login_post",
+                "requestBody": {
+                    "content": {
+                        "application/x-www-form-urlencoded": {
+                            "schema": {
+                                "$ref": "#/components/schemas/Body_read_current_user"
+                            }
+                        }
+                    },
+                    "required": True,
+                },
+            }
+        },
+        "/users/me": {
+            "get": {
+                "responses": {
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/json": {"schema": {}}},
+                    }
+                },
+                "summary": "Read Current User Get",
+                "operationId": "read_current_user_users_me_get",
+                "security": [{"OAuth2": []}],
+            }
+        },
+    },
+    "components": {
+        "schemas": {
+            "Body_read_current_user": {
+                "title": "Body_read_current_user",
+                "required": ["grant_type", "username", "password"],
+                "type": "object",
+                "properties": {
+                    "grant_type": {
+                        "title": "Grant_Type",
+                        "pattern": "password",
+                        "type": "string",
+                    },
+                    "username": {"title": "Username", "type": "string"},
+                    "password": {"title": "Password", "type": "string"},
+                    "scope": {"title": "Scope", "type": "string", "default": ""},
+                    "client_id": {"title": "Client_Id", "type": "string"},
+                    "client_secret": {"title": "Client_Secret", "type": "string"},
+                },
+            },
+            "ValidationError": {
+                "title": "ValidationError",
+                "required": ["loc", "msg", "type"],
+                "type": "object",
+                "properties": {
+                    "loc": {
+                        "title": "Location",
+                        "type": "array",
+                        "items": {"type": "string"},
+                    },
+                    "msg": {"title": "Message", "type": "string"},
+                    "type": {"title": "Error Type", "type": "string"},
+                },
+            },
+            "HTTPValidationError": {
+                "title": "HTTPValidationError",
+                "type": "object",
+                "properties": {
+                    "detail": {
+                        "title": "Detail",
+                        "type": "array",
+                        "items": {"$ref": "#/components/schemas/ValidationError"},
+                    }
+                },
+            },
+        },
+        "securitySchemes": {
+            "OAuth2": {
+                "type": "oauth2",
+                "flows": {
+                    "password": {
+                        "scopes": {
+                            "read:users": "Read the users",
+                            "write:users": "Create users",
+                        },
+                        "tokenUrl": "/token",
+                    }
+                },
+            }
+        },
+    },
+}
+
+
+def test_openapi_schema():
+    response = client.get("/openapi.json")
+    assert response.status_code == 200
+    assert response.json() == openapi_schema
+
+
+def test_security_oauth2():
+    response = client.get("/users/me", headers={"Authorization": "Bearer footokenbar"})
+    assert response.status_code == 200
+    assert response.json() == {"username": "Bearer footokenbar"}
+
+
+def test_security_oauth2_password_other_header():
+    response = client.get("/users/me", headers={"Authorization": "Other footokenbar"})
+    assert response.status_code == 200
+    assert response.json() == {"username": "Other footokenbar"}
+
+
+def test_security_oauth2_password_bearer_no_header():
+    response = client.get("/users/me")
+    assert response.status_code == 403
+    assert response.json() == {"detail": "Not authenticated"}
+
+
+required_params = {
+    "detail": [
+        {
+            "loc": ["body", "grant_type"],
+            "msg": "field required",
+            "type": "value_error.missing",
+        },
+        {
+            "loc": ["body", "username"],
+            "msg": "field required",
+            "type": "value_error.missing",
+        },
+        {
+            "loc": ["body", "password"],
+            "msg": "field required",
+            "type": "value_error.missing",
+        },
+    ]
+}
+
+grant_type_required = {
+    "detail": [
+        {
+            "loc": ["body", "grant_type"],
+            "msg": "field required",
+            "type": "value_error.missing",
+        }
+    ]
+}
+
+grant_type_incorrect = {
+    "detail": [
+        {
+            "loc": ["body", "grant_type"],
+            "msg": 'string does not match regex "password"',
+            "type": "value_error.str.regex",
+            "ctx": {"pattern": "password"},
+        }
+    ]
+}
+
+
+@pytest.mark.parametrize(
+    "data,expected_status,expected_response",
+    [
+        (None, 422, required_params),
+        ({"username": "johndoe", "password": "secret"}, 422, grant_type_required),
+        (
+            {"username": "johndoe", "password": "secret", "grant_type": "incorrect"},
+            422,
+            grant_type_incorrect,
+        ),
+        (
+            {"username": "johndoe", "password": "secret", "grant_type": "password"},
+            200,
+            {
+                "grant_type": "password",
+                "username": "johndoe",
+                "password": "secret",
+                "scopes": [],
+                "client_id": None,
+                "client_secret": None,
+            },
+        ),
+    ],
+)
+def test_strict_login(data, expected_status, expected_response):
+    response = client.post("/login", data=data)
+    assert response.status_code == expected_status
+    assert response.json() == expected_response
diff --git a/tests/test_security_openid_connect.py b/tests/test_security_openid_connect.py
new file mode 100644 (file)
index 0000000..ce19dd9
--- /dev/null
@@ -0,0 +1,74 @@
+from fastapi import Depends, FastAPI, Security
+from fastapi.security.open_id_connect_url import OpenIdConnect
+from pydantic import BaseModel
+from starlette.testclient import TestClient
+
+app = FastAPI()
+
+oid = OpenIdConnect(openIdConnectUrl="/openid")
+
+
+class User(BaseModel):
+    username: str
+
+
+def get_current_user(oauth_header: str = Security(oid)):
+    user = User(username=oauth_header)
+    return user
+
+
+@app.get("/users/me")
+def read_current_user(current_user: User = Depends(get_current_user)):
+    return current_user
+
+
+client = TestClient(app)
+
+openapi_schema = {
+    "openapi": "3.0.2",
+    "info": {"title": "Fast API", "version": "0.1.0"},
+    "paths": {
+        "/users/me": {
+            "get": {
+                "responses": {
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/json": {"schema": {}}},
+                    }
+                },
+                "summary": "Read Current User Get",
+                "operationId": "read_current_user_users_me_get",
+                "security": [{"OpenIdConnect": []}],
+            }
+        }
+    },
+    "components": {
+        "securitySchemes": {
+            "OpenIdConnect": {"type": "openIdConnect", "openIdConnectUrl": "/openid"}
+        }
+    },
+}
+
+
+def test_openapi_schema():
+    response = client.get("/openapi.json")
+    assert response.status_code == 200
+    assert response.json() == openapi_schema
+
+
+def test_security_oauth2():
+    response = client.get("/users/me", headers={"Authorization": "Bearer footokenbar"})
+    assert response.status_code == 200
+    assert response.json() == {"username": "Bearer footokenbar"}
+
+
+def test_security_oauth2_password_other_header():
+    response = client.get("/users/me", headers={"Authorization": "Other footokenbar"})
+    assert response.status_code == 200
+    assert response.json() == {"username": "Other footokenbar"}
+
+
+def test_security_oauth2_password_bearer_no_header():
+    response = client.get("/users/me")
+    assert response.status_code == 403
+    assert response.json() == {"detail": "Not authenticated"}