]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
🐛 Cache dependencies that don't use scopes and don't have sub-dependencies with scope...
authorSebastián Ramírez <tiangolo@gmail.com>
Sun, 30 Nov 2025 14:45:49 +0000 (06:45 -0800)
committerGitHub <noreply@github.com>
Sun, 30 Nov 2025 14:45:49 +0000 (15:45 +0100)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
fastapi/dependencies/models.py
fastapi/dependencies/utils.py
tests/test_security_scopes.py [new file with mode: 0644]
tests/test_security_scopes_sub_dependency.py [new file with mode: 0644]

index d6359c0f51aa2955f677d14a3c1625572e1d38ba..fbb666a7daea25ab7689d26dc9ae3c6234a02855 100644 (file)
@@ -38,19 +38,43 @@ class Dependant:
     response_param_name: Optional[str] = None
     background_tasks_param_name: Optional[str] = None
     security_scopes_param_name: Optional[str] = None
-    security_scopes: Optional[List[str]] = None
+    own_oauth_scopes: Optional[List[str]] = None
+    parent_oauth_scopes: Optional[List[str]] = None
     use_cache: bool = True
     path: Optional[str] = None
     scope: Union[Literal["function", "request"], None] = None
 
+    @cached_property
+    def oauth_scopes(self) -> List[str]:
+        scopes = self.parent_oauth_scopes.copy() if self.parent_oauth_scopes else []
+        # This doesn't use a set to preserve order, just in case
+        for scope in self.own_oauth_scopes or []:
+            if scope not in scopes:
+                scopes.append(scope)
+        return scopes
+
     @cached_property
     def cache_key(self) -> DependencyCacheKey:
+        scopes_for_cache = (
+            tuple(sorted(set(self.oauth_scopes or []))) if self._uses_scopes else ()
+        )
         return (
             self.call,
-            tuple(sorted(set(self.security_scopes or []))),
+            scopes_for_cache,
             self.computed_scope or "",
         )
 
+    @cached_property
+    def _uses_scopes(self) -> bool:
+        if self.own_oauth_scopes:
+            return True
+        if self.security_scopes_param_name is not None:
+            return True
+        for sub_dep in self.dependencies:
+            if sub_dep._uses_scopes:
+                return True
+        return False
+
     @cached_property
     def is_gen_callable(self) -> bool:
         if inspect.isgeneratorfunction(self.call):
index 45353835b489b0f67e2d46cb19f16a476e5b8e1f..d43fa8a5163ec16b3c9beb4e8bca448cfb9a8674 100644 (file)
@@ -58,8 +58,7 @@ from fastapi.dependencies.models import Dependant, SecurityRequirement
 from fastapi.exceptions import DependencyScopeError
 from fastapi.logger import logger
 from fastapi.security.base import SecurityBase
-from fastapi.security.oauth2 import OAuth2, SecurityScopes
-from fastapi.security.open_id_connect_url import OpenIdConnect
+from fastapi.security.oauth2 import SecurityScopes
 from fastapi.types import DependencyCacheKey
 from fastapi.utils import create_model_field, get_path_param_names
 from pydantic import BaseModel
@@ -126,14 +125,14 @@ def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> De
     assert callable(depends.dependency), (
         "A parameter-less dependency must have a callable dependency"
     )
-    use_security_scopes: List[str] = []
+    own_oauth_scopes: List[str] = []
     if isinstance(depends, params.Security) and depends.scopes:
-        use_security_scopes.extend(depends.scopes)
+        own_oauth_scopes.extend(depends.scopes)
     return get_dependant(
         path=path,
         call=depends.dependency,
         scope=depends.scope,
-        security_scopes=use_security_scopes,
+        own_oauth_scopes=own_oauth_scopes,
     )
 
 
@@ -232,7 +231,8 @@ def get_dependant(
     path: str,
     call: Callable[..., Any],
     name: Optional[str] = None,
-    security_scopes: Optional[List[str]] = None,
+    own_oauth_scopes: Optional[List[str]] = None,
+    parent_oauth_scopes: Optional[List[str]] = None,
     use_cache: bool = True,
     scope: Union[Literal["function", "request"], None] = None,
 ) -> Dependant:
@@ -240,19 +240,18 @@ def get_dependant(
         call=call,
         name=name,
         path=path,
-        security_scopes=security_scopes,
         use_cache=use_cache,
         scope=scope,
+        own_oauth_scopes=own_oauth_scopes,
+        parent_oauth_scopes=parent_oauth_scopes,
     )
+    current_scopes = (parent_oauth_scopes or []) + (own_oauth_scopes or [])
     path_param_names = get_path_param_names(path)
     endpoint_signature = get_typed_signature(call)
     signature_params = endpoint_signature.parameters
     if isinstance(call, SecurityBase):
-        use_scopes: List[str] = []
-        if isinstance(call, (OAuth2, OpenIdConnect)):
-            use_scopes = security_scopes or use_scopes
         security_requirement = SecurityRequirement(
-            security_scheme=call, scopes=use_scopes
+            security_scheme=call, scopes=current_scopes
         )
         dependant.security_requirements.append(security_requirement)
     for param_name, param in signature_params.items():
@@ -275,17 +274,16 @@ def get_dependant(
                     f'The dependency "{dependant.call.__name__}" has a scope of '
                     '"request", it cannot depend on dependencies with scope "function".'
                 )
-            use_security_scopes = security_scopes or []
+            sub_own_oauth_scopes: List[str] = []
             if isinstance(param_details.depends, params.Security):
                 if param_details.depends.scopes:
-                    use_security_scopes = use_security_scopes + list(
-                        param_details.depends.scopes
-                    )
+                    sub_own_oauth_scopes = list(param_details.depends.scopes)
             sub_dependant = get_dependant(
                 path=path,
                 call=param_details.depends.dependency,
                 name=param_name,
-                security_scopes=use_security_scopes,
+                own_oauth_scopes=sub_own_oauth_scopes,
+                parent_oauth_scopes=current_scopes,
                 use_cache=param_details.depends.use_cache,
                 scope=param_details.depends.scope,
             )
@@ -611,7 +609,7 @@ async def solve_dependencies(
                 path=use_path,
                 call=call,
                 name=sub_dependant.name,
-                security_scopes=sub_dependant.security_scopes,
+                parent_oauth_scopes=sub_dependant.oauth_scopes,
                 scope=sub_dependant.scope,
             )
 
@@ -693,7 +691,7 @@ async def solve_dependencies(
         values[dependant.response_param_name] = response
     if dependant.security_scopes_param_name:
         values[dependant.security_scopes_param_name] = SecurityScopes(
-            scopes=dependant.security_scopes
+            scopes=dependant.oauth_scopes
         )
     return SolvedDependency(
         values=values,
diff --git a/tests/test_security_scopes.py b/tests/test_security_scopes.py
new file mode 100644 (file)
index 0000000..248fd2b
--- /dev/null
@@ -0,0 +1,46 @@
+from typing import Dict
+
+import pytest
+from fastapi import Depends, FastAPI, Security
+from fastapi.testclient import TestClient
+from typing_extensions import Annotated
+
+
+@pytest.fixture(name="call_counter")
+def call_counter_fixture():
+    return {"count": 0}
+
+
+@pytest.fixture(name="app")
+def app_fixture(call_counter: Dict[str, int]):
+    def get_db():
+        call_counter["count"] += 1
+        return f"db_{call_counter['count']}"
+
+    def get_user(db: Annotated[str, Depends(get_db)]):
+        return "user"
+
+    app = FastAPI()
+
+    @app.get("/")
+    def endpoint(
+        db: Annotated[str, Depends(get_db)],
+        user: Annotated[str, Security(get_user, scopes=["read"])],
+    ):
+        return {"db": db}
+
+    return app
+
+
+@pytest.fixture(name="client")
+def client_fixture(app: FastAPI):
+    return TestClient(app)
+
+
+def test_security_scopes_dependency_called_once(
+    client: TestClient, call_counter: Dict[str, int]
+):
+    response = client.get("/")
+
+    assert response.status_code == 200
+    assert call_counter["count"] == 1
diff --git a/tests/test_security_scopes_sub_dependency.py b/tests/test_security_scopes_sub_dependency.py
new file mode 100644 (file)
index 0000000..9cc668d
--- /dev/null
@@ -0,0 +1,107 @@
+# Ref: https://github.com/fastapi/fastapi/discussions/6024#discussioncomment-8541913
+
+from typing import Dict
+
+import pytest
+from fastapi import Depends, FastAPI, Security
+from fastapi.security import SecurityScopes
+from fastapi.testclient import TestClient
+from typing_extensions import Annotated
+
+
+@pytest.fixture(name="call_counts")
+def call_counts_fixture():
+    return {
+        "get_db_session": 0,
+        "get_current_user": 0,
+        "get_user_me": 0,
+        "get_user_items": 0,
+    }
+
+
+@pytest.fixture(name="app")
+def app_fixture(call_counts: Dict[str, int]):
+    def get_db_session():
+        call_counts["get_db_session"] += 1
+        return f"db_session_{call_counts['get_db_session']}"
+
+    def get_current_user(
+        security_scopes: SecurityScopes,
+        db_session: Annotated[str, Depends(get_db_session)],
+    ):
+        call_counts["get_current_user"] += 1
+        return {
+            "user": f"user_{call_counts['get_current_user']}",
+            "scopes": security_scopes.scopes,
+            "db_session": db_session,
+        }
+
+    def get_user_me(
+        current_user: Annotated[dict, Security(get_current_user, scopes=["me"])],
+    ):
+        call_counts["get_user_me"] += 1
+        return {
+            "user_me": f"user_me_{call_counts['get_user_me']}",
+            "current_user": current_user,
+        }
+
+    def get_user_items(
+        user_me: Annotated[dict, Depends(get_user_me)],
+    ):
+        call_counts["get_user_items"] += 1
+        return {
+            "user_items": f"user_items_{call_counts['get_user_items']}",
+            "user_me": user_me,
+        }
+
+    app = FastAPI()
+
+    @app.get("/")
+    def path_operation(
+        user_me: Annotated[dict, Depends(get_user_me)],
+        user_items: Annotated[dict, Security(get_user_items, scopes=["items"])],
+    ):
+        return {
+            "user_me": user_me,
+            "user_items": user_items,
+        }
+
+    return app
+
+
+@pytest.fixture(name="client")
+def client_fixture(app: FastAPI):
+    return TestClient(app)
+
+
+def test_security_scopes_sub_dependency_caching(
+    client: TestClient, call_counts: Dict[str, int]
+):
+    response = client.get("/")
+
+    assert response.status_code == 200
+    assert call_counts["get_db_session"] == 1
+    assert call_counts["get_current_user"] == 2
+    assert call_counts["get_user_me"] == 2
+    assert call_counts["get_user_items"] == 1
+    assert response.json() == {
+        "user_me": {
+            "user_me": "user_me_1",
+            "current_user": {
+                "user": "user_1",
+                "scopes": ["me"],
+                "db_session": "db_session_1",
+            },
+        },
+        "user_items": {
+            "user_items": "user_items_1",
+            "user_me": {
+                "user_me": "user_me_2",
+                "current_user": {
+                    "user": "user_2",
+                    "scopes": ["items", "me"],
+                    "db_session": "db_session_1",
+                },
+            },
+        },
+    }