import sys
from dataclasses import dataclass, field
from functools import cached_property, partial
-from typing import Any, Callable, List, Optional, Sequence, Union
+from typing import Any, Callable, List, Optional, Union
from fastapi._compat import ModelField
from fastapi.security.base import SecurityBase
return func
-@dataclass
-class SecurityRequirement:
- security_scheme: SecurityBase
- scopes: Optional[Sequence[str]] = None
-
-
@dataclass
class Dependant:
path_params: List[ModelField] = field(default_factory=list)
cookie_params: List[ModelField] = field(default_factory=list)
body_params: List[ModelField] = field(default_factory=list)
dependencies: List["Dependant"] = field(default_factory=list)
- security_requirements: List[SecurityRequirement] = field(default_factory=list)
name: Optional[str] = None
call: Optional[Callable[..., Any]] = None
request_param_name: Optional[str] = None
return True
if self.security_scopes_param_name is not None:
return True
+ if self._is_security_scheme:
+ return True
for sub_dep in self.dependencies:
if sub_dep._uses_scopes:
return True
return False
+ @cached_property
+ def _is_security_scheme(self) -> bool:
+ if self.call is None:
+ return False # pragma: no cover
+ unwrapped = _unwrapped_call(self.call)
+ return isinstance(unwrapped, SecurityBase)
+
+ # Mainly to get the type of SecurityBase, but it's the same self.call
+ @cached_property
+ def _security_scheme(self) -> SecurityBase:
+ unwrapped = _unwrapped_call(self.call)
+ assert isinstance(unwrapped, SecurityBase)
+ return unwrapped
+
+ @cached_property
+ def _security_dependencies(self) -> List["Dependant"]:
+ security_deps = [dep for dep in self.dependencies if dep._is_security_scheme]
+ return security_deps
+
@cached_property
def is_gen_callable(self) -> bool:
if self.call is None:
asynccontextmanager,
contextmanager_in_threadpool,
)
-from fastapi.dependencies.models import Dependant, SecurityRequirement
+from fastapi.dependencies.models import Dependant
from fastapi.exceptions import DependencyScopeError
from fastapi.logger import logger
-from fastapi.security.base import SecurityBase
from fastapi.security.oauth2 import SecurityScopes
from fastapi.types import DependencyCacheKey
from fastapi.utils import create_model_field, get_path_param_names
*,
skip_repeats: bool = False,
visited: Optional[List[DependencyCacheKey]] = None,
+ parent_oauth_scopes: Optional[List[str]] = None,
) -> Dependant:
if visited is None:
visited = []
visited.append(dependant.cache_key)
+ use_parent_oauth_scopes = (parent_oauth_scopes or []) + (
+ dependant.oauth_scopes or []
+ )
flat_dependant = Dependant(
path_params=dependant.path_params.copy(),
header_params=dependant.header_params.copy(),
cookie_params=dependant.cookie_params.copy(),
body_params=dependant.body_params.copy(),
- security_requirements=dependant.security_requirements.copy(),
+ name=dependant.name,
+ call=dependant.call,
+ request_param_name=dependant.request_param_name,
+ websocket_param_name=dependant.websocket_param_name,
+ http_connection_param_name=dependant.http_connection_param_name,
+ response_param_name=dependant.response_param_name,
+ background_tasks_param_name=dependant.background_tasks_param_name,
+ security_scopes_param_name=dependant.security_scopes_param_name,
+ own_oauth_scopes=dependant.own_oauth_scopes,
+ parent_oauth_scopes=use_parent_oauth_scopes,
use_cache=dependant.use_cache,
path=dependant.path,
+ scope=dependant.scope,
)
for sub_dependant in dependant.dependencies:
if skip_repeats and sub_dependant.cache_key in visited:
continue
flat_sub = get_flat_dependant(
- sub_dependant, skip_repeats=skip_repeats, visited=visited
+ sub_dependant,
+ skip_repeats=skip_repeats,
+ visited=visited,
+ parent_oauth_scopes=flat_dependant.oauth_scopes,
)
+ flat_dependant.dependencies.append(flat_sub)
flat_dependant.path_params.extend(flat_sub.path_params)
flat_dependant.query_params.extend(flat_sub.query_params)
flat_dependant.header_params.extend(flat_sub.header_params)
flat_dependant.cookie_params.extend(flat_sub.cookie_params)
flat_dependant.body_params.extend(flat_sub.body_params)
- flat_dependant.security_requirements.extend(flat_sub.security_requirements)
+ flat_dependant.dependencies.extend(flat_sub.dependencies)
+
return flat_dependant
path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call)
signature_params = endpoint_signature.parameters
- if isinstance(call, SecurityBase):
- security_requirement = SecurityRequirement(
- security_scheme=call, scopes=current_scopes
- )
- dependant.security_requirements.append(security_requirement)
for param_name, param in signature_params.items():
is_path_param = param_name in path_param_names
param_details = analyze_param(
security_definitions = {}
# Use a dict to merge scopes for same security scheme
operation_security_dict: Dict[str, List[str]] = {}
- for security_requirement in flat_dependant.security_requirements:
+ for security_dependency in flat_dependant._security_dependencies:
security_definition = jsonable_encoder(
- security_requirement.security_scheme.model,
+ security_dependency._security_scheme.model,
by_alias=True,
exclude_none=True,
)
- security_name = security_requirement.security_scheme.scheme_name
+ security_name = security_dependency._security_scheme.scheme_name
security_definitions[security_name] = security_definition
# Merge scopes for the same security scheme
if security_name not in operation_security_dict:
operation_security_dict[security_name] = []
- for scope in security_requirement.scopes or []:
+ for scope in security_dependency.oauth_scopes or []:
if scope not in operation_security_dict[security_name]:
operation_security_dict[security_name].append(scope)
operation_security = [
from typing import Optional
-from fastapi import APIRouter, FastAPI, Security
+from fastapi import APIRouter, Depends, FastAPI, Security
from fastapi.security import OAuth2AuthorizationCodeBearer
from fastapi.testclient import TestClient
from inline_snapshot import snapshot
+from typing_extensions import Annotated
oauth2_scheme = OAuth2AuthorizationCodeBearer(
authorizationUrl="authorize",
scopes={"read": "Read access", "write": "Write access"},
)
-app = FastAPI(dependencies=[Security(oauth2_scheme)])
+
+async def get_token(token: Annotated[str, Depends(oauth2_scheme)]) -> str:
+ return token
+
+
+app = FastAPI(dependencies=[Depends(get_token)])
@app.get("/")
return {"message": "Hello World"}
+@app.get(
+ "/with-oauth2-scheme",
+ dependencies=[Security(oauth2_scheme, scopes=["read", "write"])],
+)
+async def read_with_oauth2_scheme():
+ return {"message": "Admin Access"}
+
+
+@app.get(
+ "/with-get-token", dependencies=[Security(get_token, scopes=["read", "write"])]
+)
+async def read_with_get_token():
+ return {"message": "Admin Access"}
+
+
router = APIRouter(dependencies=[Security(oauth2_scheme, scopes=["read"])])
@router.get("/items/")
-async def read_items(token: Optional[str] = Security(oauth2_scheme)):
+async def read_items(token: Optional[str] = Depends(oauth2_scheme)):
return {"token": token}
assert response.json() == {"message": "Hello World"}
+def test_read_with_oauth2_scheme():
+ response = client.get(
+ "/with-oauth2-scheme", headers={"Authorization": "Bearer testtoken"}
+ )
+ assert response.status_code == 200, response.text
+ assert response.json() == {"message": "Admin Access"}
+
+
+def test_read_with_get_token():
+ response = client.get(
+ "/with-get-token", headers={"Authorization": "Bearer testtoken"}
+ )
+ assert response.status_code == 200, response.text
+ assert response.json() == {"message": "Admin Access"}
+
+
def test_read_token():
response = client.get("/items/", headers={"Authorization": "Bearer testtoken"})
assert response.status_code == 200, response.text
"security": [{"OAuth2AuthorizationCodeBearer": []}],
}
},
+ "/with-oauth2-scheme": {
+ "get": {
+ "summary": "Read With Oauth2 Scheme",
+ "operationId": "read_with_oauth2_scheme_with_oauth2_scheme_get",
+ "responses": {
+ "200": {
+ "description": "Successful Response",
+ "content": {"application/json": {"schema": {}}},
+ }
+ },
+ "security": [
+ {"OAuth2AuthorizationCodeBearer": ["read", "write"]}
+ ],
+ }
+ },
+ "/with-get-token": {
+ "get": {
+ "summary": "Read With Get Token",
+ "operationId": "read_with_get_token_with_get_token_get",
+ "responses": {
+ "200": {
+ "description": "Successful Response",
+ "content": {"application/json": {"schema": {}}},
+ }
+ },
+ "security": [
+ {"OAuth2AuthorizationCodeBearer": ["read", "write"]}
+ ],
+ }
+ },
"/items/": {
"get": {
"summary": "Read Items",
--- /dev/null
+# Ref: https://github.com/fastapi/fastapi/issues/14454
+
+from fastapi import Depends, FastAPI, Security
+from fastapi.security import OAuth2AuthorizationCodeBearer
+from fastapi.testclient import TestClient
+from inline_snapshot import snapshot
+from typing_extensions import Annotated
+
+oauth2_scheme = OAuth2AuthorizationCodeBearer(
+ authorizationUrl="api/oauth/authorize",
+ tokenUrl="/api/oauth/token",
+ scopes={"read": "Read access", "write": "Write access"},
+)
+
+
+async def get_token(token: Annotated[str, Depends(oauth2_scheme)]) -> str:
+ return token
+
+
+app = FastAPI(dependencies=[Depends(get_token)])
+
+
+@app.get("/admin", dependencies=[Security(get_token, scopes=["read", "write"])])
+async def read_admin():
+ return {"message": "Admin Access"}
+
+
+client = TestClient(app)
+
+
+def test_read_admin():
+ response = client.get("/admin", headers={"Authorization": "Bearer faketoken"})
+ assert response.status_code == 200, response.text
+ assert response.json() == {"message": "Admin Access"}
+
+
+def test_openapi_schema():
+ response = client.get("/openapi.json")
+ assert response.status_code == 200, response.text
+ assert response.json() == snapshot(
+ {
+ "openapi": "3.1.0",
+ "info": {"title": "FastAPI", "version": "0.1.0"},
+ "paths": {
+ "/admin": {
+ "get": {
+ "summary": "Read Admin",
+ "operationId": "read_admin_admin_get",
+ "responses": {
+ "200": {
+ "description": "Successful Response",
+ "content": {"application/json": {"schema": {}}},
+ }
+ },
+ "security": [
+ {"OAuth2AuthorizationCodeBearer": ["read", "write"]}
+ ],
+ }
+ }
+ },
+ "components": {
+ "securitySchemes": {
+ "OAuth2AuthorizationCodeBearer": {
+ "type": "oauth2",
+ "flows": {
+ "authorizationCode": {
+ "scopes": {
+ "read": "Read access",
+ "write": "Write access",
+ },
+ "authorizationUrl": "api/oauth/authorize",
+ "tokenUrl": "/api/oauth/token",
+ }
+ },
+ }
+ }
+ },
+ }
+ )