]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
🐛 Fix OAuth2 scopes in OpenAPI in extra corner cases, parent dependency with scopes...
authorSebastián Ramírez <tiangolo@gmail.com>
Thu, 4 Dec 2025 22:22:01 +0000 (14:22 -0800)
committerGitHub <noreply@github.com>
Thu, 4 Dec 2025 22:22:01 +0000 (23:22 +0100)
fastapi/dependencies/models.py
fastapi/dependencies/utils.py
fastapi/openapi/utils.py
tests/test_security_oauth2_authorization_code_bearer_scopes_openapi.py
tests/test_security_oauth2_authorization_code_bearer_scopes_openapi_simple.py [new file with mode: 0644]

index 9b545e4e5cb43f6d4645de57bd4930a0335b119b..af168a177a6f7fd4e7c5fd2b5c4e9614982daf2e 100644 (file)
@@ -2,7 +2,7 @@ import inspect
 import sys
 from dataclasses import dataclass, field
 from functools import cached_property, partial
-from typing import Any, Callable, List, Optional, Sequence, Union
+from typing import Any, Callable, List, Optional, Union
 
 from fastapi._compat import ModelField
 from fastapi.security.base import SecurityBase
@@ -28,12 +28,6 @@ def _impartial(func: Callable[..., Any]) -> Callable[..., Any]:
     return func
 
 
-@dataclass
-class SecurityRequirement:
-    security_scheme: SecurityBase
-    scopes: Optional[Sequence[str]] = None
-
-
 @dataclass
 class Dependant:
     path_params: List[ModelField] = field(default_factory=list)
@@ -42,7 +36,6 @@ class Dependant:
     cookie_params: List[ModelField] = field(default_factory=list)
     body_params: List[ModelField] = field(default_factory=list)
     dependencies: List["Dependant"] = field(default_factory=list)
-    security_requirements: List[SecurityRequirement] = field(default_factory=list)
     name: Optional[str] = None
     call: Optional[Callable[..., Any]] = None
     request_param_name: Optional[str] = None
@@ -83,11 +76,32 @@ class Dependant:
             return True
         if self.security_scopes_param_name is not None:
             return True
+        if self._is_security_scheme:
+            return True
         for sub_dep in self.dependencies:
             if sub_dep._uses_scopes:
                 return True
         return False
 
+    @cached_property
+    def _is_security_scheme(self) -> bool:
+        if self.call is None:
+            return False  # pragma: no cover
+        unwrapped = _unwrapped_call(self.call)
+        return isinstance(unwrapped, SecurityBase)
+
+    # Mainly to get the type of SecurityBase, but it's the same self.call
+    @cached_property
+    def _security_scheme(self) -> SecurityBase:
+        unwrapped = _unwrapped_call(self.call)
+        assert isinstance(unwrapped, SecurityBase)
+        return unwrapped
+
+    @cached_property
+    def _security_dependencies(self) -> List["Dependant"]:
+        security_deps = [dep for dep in self.dependencies if dep._is_security_scheme]
+        return security_deps
+
     @cached_property
     def is_gen_callable(self) -> bool:
         if self.call is None:
index 1ff35f64831dd7585ddf197fcb32d8107fd3f6ad..23bca6f2a1a5e6698ade064cc2567ed2b06e976d 100644 (file)
@@ -55,10 +55,9 @@ from fastapi.concurrency import (
     asynccontextmanager,
     contextmanager_in_threadpool,
 )
-from fastapi.dependencies.models import Dependant, SecurityRequirement
+from fastapi.dependencies.models import Dependant
 from fastapi.exceptions import DependencyScopeError
 from fastapi.logger import logger
-from fastapi.security.base import SecurityBase
 from fastapi.security.oauth2 import SecurityScopes
 from fastapi.types import DependencyCacheKey
 from fastapi.utils import create_model_field, get_path_param_names
@@ -142,10 +141,14 @@ def get_flat_dependant(
     *,
     skip_repeats: bool = False,
     visited: Optional[List[DependencyCacheKey]] = None,
+    parent_oauth_scopes: Optional[List[str]] = None,
 ) -> Dependant:
     if visited is None:
         visited = []
     visited.append(dependant.cache_key)
+    use_parent_oauth_scopes = (parent_oauth_scopes or []) + (
+        dependant.oauth_scopes or []
+    )
 
     flat_dependant = Dependant(
         path_params=dependant.path_params.copy(),
@@ -153,22 +156,37 @@ def get_flat_dependant(
         header_params=dependant.header_params.copy(),
         cookie_params=dependant.cookie_params.copy(),
         body_params=dependant.body_params.copy(),
-        security_requirements=dependant.security_requirements.copy(),
+        name=dependant.name,
+        call=dependant.call,
+        request_param_name=dependant.request_param_name,
+        websocket_param_name=dependant.websocket_param_name,
+        http_connection_param_name=dependant.http_connection_param_name,
+        response_param_name=dependant.response_param_name,
+        background_tasks_param_name=dependant.background_tasks_param_name,
+        security_scopes_param_name=dependant.security_scopes_param_name,
+        own_oauth_scopes=dependant.own_oauth_scopes,
+        parent_oauth_scopes=use_parent_oauth_scopes,
         use_cache=dependant.use_cache,
         path=dependant.path,
+        scope=dependant.scope,
     )
     for sub_dependant in dependant.dependencies:
         if skip_repeats and sub_dependant.cache_key in visited:
             continue
         flat_sub = get_flat_dependant(
-            sub_dependant, skip_repeats=skip_repeats, visited=visited
+            sub_dependant,
+            skip_repeats=skip_repeats,
+            visited=visited,
+            parent_oauth_scopes=flat_dependant.oauth_scopes,
         )
+        flat_dependant.dependencies.append(flat_sub)
         flat_dependant.path_params.extend(flat_sub.path_params)
         flat_dependant.query_params.extend(flat_sub.query_params)
         flat_dependant.header_params.extend(flat_sub.header_params)
         flat_dependant.cookie_params.extend(flat_sub.cookie_params)
         flat_dependant.body_params.extend(flat_sub.body_params)
-        flat_dependant.security_requirements.extend(flat_sub.security_requirements)
+        flat_dependant.dependencies.extend(flat_sub.dependencies)
+
     return flat_dependant
 
 
@@ -258,11 +276,6 @@ def get_dependant(
     path_param_names = get_path_param_names(path)
     endpoint_signature = get_typed_signature(call)
     signature_params = endpoint_signature.parameters
-    if isinstance(call, SecurityBase):
-        security_requirement = SecurityRequirement(
-            security_scheme=call, scopes=current_scopes
-        )
-        dependant.security_requirements.append(security_requirement)
     for param_name, param in signature_params.items():
         is_path_param = param_name in path_param_names
         param_details = analyze_param(
index e7e6da2f764572e5d9e8e3b605ba03b18cafb7b5..06c14861a338cfef3a67e292e9ccd628ab03efaf 100644 (file)
@@ -81,18 +81,18 @@ def get_openapi_security_definitions(
     security_definitions = {}
     # Use a dict to merge scopes for same security scheme
     operation_security_dict: Dict[str, List[str]] = {}
-    for security_requirement in flat_dependant.security_requirements:
+    for security_dependency in flat_dependant._security_dependencies:
         security_definition = jsonable_encoder(
-            security_requirement.security_scheme.model,
+            security_dependency._security_scheme.model,
             by_alias=True,
             exclude_none=True,
         )
-        security_name = security_requirement.security_scheme.scheme_name
+        security_name = security_dependency._security_scheme.scheme_name
         security_definitions[security_name] = security_definition
         # Merge scopes for the same security scheme
         if security_name not in operation_security_dict:
             operation_security_dict[security_name] = []
-        for scope in security_requirement.scopes or []:
+        for scope in security_dependency.oauth_scopes or []:
             if scope not in operation_security_dict[security_name]:
                 operation_security_dict[security_name].append(scope)
     operation_security = [
index 644df8de6cc8ef1f502460a9c0526072ddade16c..d41f1dc1f038cb10a1c98fca222decd076f45301 100644 (file)
@@ -2,10 +2,11 @@
 
 from typing import Optional
 
-from fastapi import APIRouter, FastAPI, Security
+from fastapi import APIRouter, Depends, FastAPI, Security
 from fastapi.security import OAuth2AuthorizationCodeBearer
 from fastapi.testclient import TestClient
 from inline_snapshot import snapshot
+from typing_extensions import Annotated
 
 oauth2_scheme = OAuth2AuthorizationCodeBearer(
     authorizationUrl="authorize",
@@ -14,7 +15,12 @@ oauth2_scheme = OAuth2AuthorizationCodeBearer(
     scopes={"read": "Read access", "write": "Write access"},
 )
 
-app = FastAPI(dependencies=[Security(oauth2_scheme)])
+
+async def get_token(token: Annotated[str, Depends(oauth2_scheme)]) -> str:
+    return token
+
+
+app = FastAPI(dependencies=[Depends(get_token)])
 
 
 @app.get("/")
@@ -22,11 +28,26 @@ async def root():
     return {"message": "Hello World"}
 
 
+@app.get(
+    "/with-oauth2-scheme",
+    dependencies=[Security(oauth2_scheme, scopes=["read", "write"])],
+)
+async def read_with_oauth2_scheme():
+    return {"message": "Admin Access"}
+
+
+@app.get(
+    "/with-get-token", dependencies=[Security(get_token, scopes=["read", "write"])]
+)
+async def read_with_get_token():
+    return {"message": "Admin Access"}
+
+
 router = APIRouter(dependencies=[Security(oauth2_scheme, scopes=["read"])])
 
 
 @router.get("/items/")
-async def read_items(token: Optional[str] = Security(oauth2_scheme)):
+async def read_items(token: Optional[str] = Depends(oauth2_scheme)):
     return {"token": token}
 
 
@@ -48,6 +69,22 @@ def test_root():
     assert response.json() == {"message": "Hello World"}
 
 
+def test_read_with_oauth2_scheme():
+    response = client.get(
+        "/with-oauth2-scheme", headers={"Authorization": "Bearer testtoken"}
+    )
+    assert response.status_code == 200, response.text
+    assert response.json() == {"message": "Admin Access"}
+
+
+def test_read_with_get_token():
+    response = client.get(
+        "/with-get-token", headers={"Authorization": "Bearer testtoken"}
+    )
+    assert response.status_code == 200, response.text
+    assert response.json() == {"message": "Admin Access"}
+
+
 def test_read_token():
     response = client.get("/items/", headers={"Authorization": "Bearer testtoken"})
     assert response.status_code == 200, response.text
@@ -81,6 +118,36 @@ def test_openapi_schema():
                         "security": [{"OAuth2AuthorizationCodeBearer": []}],
                     }
                 },
+                "/with-oauth2-scheme": {
+                    "get": {
+                        "summary": "Read With Oauth2 Scheme",
+                        "operationId": "read_with_oauth2_scheme_with_oauth2_scheme_get",
+                        "responses": {
+                            "200": {
+                                "description": "Successful Response",
+                                "content": {"application/json": {"schema": {}}},
+                            }
+                        },
+                        "security": [
+                            {"OAuth2AuthorizationCodeBearer": ["read", "write"]}
+                        ],
+                    }
+                },
+                "/with-get-token": {
+                    "get": {
+                        "summary": "Read With Get Token",
+                        "operationId": "read_with_get_token_with_get_token_get",
+                        "responses": {
+                            "200": {
+                                "description": "Successful Response",
+                                "content": {"application/json": {"schema": {}}},
+                            }
+                        },
+                        "security": [
+                            {"OAuth2AuthorizationCodeBearer": ["read", "write"]}
+                        ],
+                    }
+                },
                 "/items/": {
                     "get": {
                         "summary": "Read Items",
diff --git a/tests/test_security_oauth2_authorization_code_bearer_scopes_openapi_simple.py b/tests/test_security_oauth2_authorization_code_bearer_scopes_openapi_simple.py
new file mode 100644 (file)
index 0000000..ff866d4
--- /dev/null
@@ -0,0 +1,79 @@
+# Ref: https://github.com/fastapi/fastapi/issues/14454
+
+from fastapi import Depends, FastAPI, Security
+from fastapi.security import OAuth2AuthorizationCodeBearer
+from fastapi.testclient import TestClient
+from inline_snapshot import snapshot
+from typing_extensions import Annotated
+
+oauth2_scheme = OAuth2AuthorizationCodeBearer(
+    authorizationUrl="api/oauth/authorize",
+    tokenUrl="/api/oauth/token",
+    scopes={"read": "Read access", "write": "Write access"},
+)
+
+
+async def get_token(token: Annotated[str, Depends(oauth2_scheme)]) -> str:
+    return token
+
+
+app = FastAPI(dependencies=[Depends(get_token)])
+
+
+@app.get("/admin", dependencies=[Security(get_token, scopes=["read", "write"])])
+async def read_admin():
+    return {"message": "Admin Access"}
+
+
+client = TestClient(app)
+
+
+def test_read_admin():
+    response = client.get("/admin", headers={"Authorization": "Bearer faketoken"})
+    assert response.status_code == 200, response.text
+    assert response.json() == {"message": "Admin Access"}
+
+
+def test_openapi_schema():
+    response = client.get("/openapi.json")
+    assert response.status_code == 200, response.text
+    assert response.json() == snapshot(
+        {
+            "openapi": "3.1.0",
+            "info": {"title": "FastAPI", "version": "0.1.0"},
+            "paths": {
+                "/admin": {
+                    "get": {
+                        "summary": "Read Admin",
+                        "operationId": "read_admin_admin_get",
+                        "responses": {
+                            "200": {
+                                "description": "Successful Response",
+                                "content": {"application/json": {"schema": {}}},
+                            }
+                        },
+                        "security": [
+                            {"OAuth2AuthorizationCodeBearer": ["read", "write"]}
+                        ],
+                    }
+                }
+            },
+            "components": {
+                "securitySchemes": {
+                    "OAuth2AuthorizationCodeBearer": {
+                        "type": "oauth2",
+                        "flows": {
+                            "authorizationCode": {
+                                "scopes": {
+                                    "read": "Read access",
+                                    "write": "Write access",
+                                },
+                                "authorizationUrl": "api/oauth/authorize",
+                                "tokenUrl": "/api/oauth/token",
+                            }
+                        },
+                    }
+                }
+            },
+        }
+    )