response_param_name: Optional[str] = None
background_tasks_param_name: Optional[str] = None
security_scopes_param_name: Optional[str] = None
- security_scopes: Optional[List[str]] = None
+ own_oauth_scopes: Optional[List[str]] = None
+ parent_oauth_scopes: Optional[List[str]] = None
use_cache: bool = True
path: Optional[str] = None
scope: Union[Literal["function", "request"], None] = None
+ @cached_property
+ def oauth_scopes(self) -> List[str]:
+ scopes = self.parent_oauth_scopes.copy() if self.parent_oauth_scopes else []
+ # This doesn't use a set to preserve order, just in case
+ for scope in self.own_oauth_scopes or []:
+ if scope not in scopes:
+ scopes.append(scope)
+ return scopes
+
@cached_property
def cache_key(self) -> DependencyCacheKey:
+ scopes_for_cache = (
+ tuple(sorted(set(self.oauth_scopes or []))) if self._uses_scopes else ()
+ )
return (
self.call,
- tuple(sorted(set(self.security_scopes or []))),
+ scopes_for_cache,
self.computed_scope or "",
)
+ @cached_property
+ def _uses_scopes(self) -> bool:
+ if self.own_oauth_scopes:
+ return True
+ if self.security_scopes_param_name is not None:
+ return True
+ for sub_dep in self.dependencies:
+ if sub_dep._uses_scopes:
+ return True
+ return False
+
@cached_property
def is_gen_callable(self) -> bool:
if inspect.isgeneratorfunction(self.call):
from fastapi.exceptions import DependencyScopeError
from fastapi.logger import logger
from fastapi.security.base import SecurityBase
-from fastapi.security.oauth2 import OAuth2, SecurityScopes
-from fastapi.security.open_id_connect_url import OpenIdConnect
+from fastapi.security.oauth2 import SecurityScopes
from fastapi.types import DependencyCacheKey
from fastapi.utils import create_model_field, get_path_param_names
from pydantic import BaseModel
assert callable(depends.dependency), (
"A parameter-less dependency must have a callable dependency"
)
- use_security_scopes: List[str] = []
+ own_oauth_scopes: List[str] = []
if isinstance(depends, params.Security) and depends.scopes:
- use_security_scopes.extend(depends.scopes)
+ own_oauth_scopes.extend(depends.scopes)
return get_dependant(
path=path,
call=depends.dependency,
scope=depends.scope,
- security_scopes=use_security_scopes,
+ own_oauth_scopes=own_oauth_scopes,
)
path: str,
call: Callable[..., Any],
name: Optional[str] = None,
- security_scopes: Optional[List[str]] = None,
+ own_oauth_scopes: Optional[List[str]] = None,
+ parent_oauth_scopes: Optional[List[str]] = None,
use_cache: bool = True,
scope: Union[Literal["function", "request"], None] = None,
) -> Dependant:
call=call,
name=name,
path=path,
- security_scopes=security_scopes,
use_cache=use_cache,
scope=scope,
+ own_oauth_scopes=own_oauth_scopes,
+ parent_oauth_scopes=parent_oauth_scopes,
)
+ current_scopes = (parent_oauth_scopes or []) + (own_oauth_scopes or [])
path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call)
signature_params = endpoint_signature.parameters
if isinstance(call, SecurityBase):
- use_scopes: List[str] = []
- if isinstance(call, (OAuth2, OpenIdConnect)):
- use_scopes = security_scopes or use_scopes
security_requirement = SecurityRequirement(
- security_scheme=call, scopes=use_scopes
+ security_scheme=call, scopes=current_scopes
)
dependant.security_requirements.append(security_requirement)
for param_name, param in signature_params.items():
f'The dependency "{dependant.call.__name__}" has a scope of '
'"request", it cannot depend on dependencies with scope "function".'
)
- use_security_scopes = security_scopes or []
+ sub_own_oauth_scopes: List[str] = []
if isinstance(param_details.depends, params.Security):
if param_details.depends.scopes:
- use_security_scopes = use_security_scopes + list(
- param_details.depends.scopes
- )
+ sub_own_oauth_scopes = list(param_details.depends.scopes)
sub_dependant = get_dependant(
path=path,
call=param_details.depends.dependency,
name=param_name,
- security_scopes=use_security_scopes,
+ own_oauth_scopes=sub_own_oauth_scopes,
+ parent_oauth_scopes=current_scopes,
use_cache=param_details.depends.use_cache,
scope=param_details.depends.scope,
)
path=use_path,
call=call,
name=sub_dependant.name,
- security_scopes=sub_dependant.security_scopes,
+ parent_oauth_scopes=sub_dependant.oauth_scopes,
scope=sub_dependant.scope,
)
values[dependant.response_param_name] = response
if dependant.security_scopes_param_name:
values[dependant.security_scopes_param_name] = SecurityScopes(
- scopes=dependant.security_scopes
+ scopes=dependant.oauth_scopes
)
return SolvedDependency(
values=values,
--- /dev/null
+from typing import Dict
+
+import pytest
+from fastapi import Depends, FastAPI, Security
+from fastapi.testclient import TestClient
+from typing_extensions import Annotated
+
+
+@pytest.fixture(name="call_counter")
+def call_counter_fixture():
+ return {"count": 0}
+
+
+@pytest.fixture(name="app")
+def app_fixture(call_counter: Dict[str, int]):
+ def get_db():
+ call_counter["count"] += 1
+ return f"db_{call_counter['count']}"
+
+ def get_user(db: Annotated[str, Depends(get_db)]):
+ return "user"
+
+ app = FastAPI()
+
+ @app.get("/")
+ def endpoint(
+ db: Annotated[str, Depends(get_db)],
+ user: Annotated[str, Security(get_user, scopes=["read"])],
+ ):
+ return {"db": db}
+
+ return app
+
+
+@pytest.fixture(name="client")
+def client_fixture(app: FastAPI):
+ return TestClient(app)
+
+
+def test_security_scopes_dependency_called_once(
+ client: TestClient, call_counter: Dict[str, int]
+):
+ response = client.get("/")
+
+ assert response.status_code == 200
+ assert call_counter["count"] == 1
--- /dev/null
+# Ref: https://github.com/fastapi/fastapi/discussions/6024#discussioncomment-8541913
+
+from typing import Dict
+
+import pytest
+from fastapi import Depends, FastAPI, Security
+from fastapi.security import SecurityScopes
+from fastapi.testclient import TestClient
+from typing_extensions import Annotated
+
+
+@pytest.fixture(name="call_counts")
+def call_counts_fixture():
+ return {
+ "get_db_session": 0,
+ "get_current_user": 0,
+ "get_user_me": 0,
+ "get_user_items": 0,
+ }
+
+
+@pytest.fixture(name="app")
+def app_fixture(call_counts: Dict[str, int]):
+ def get_db_session():
+ call_counts["get_db_session"] += 1
+ return f"db_session_{call_counts['get_db_session']}"
+
+ def get_current_user(
+ security_scopes: SecurityScopes,
+ db_session: Annotated[str, Depends(get_db_session)],
+ ):
+ call_counts["get_current_user"] += 1
+ return {
+ "user": f"user_{call_counts['get_current_user']}",
+ "scopes": security_scopes.scopes,
+ "db_session": db_session,
+ }
+
+ def get_user_me(
+ current_user: Annotated[dict, Security(get_current_user, scopes=["me"])],
+ ):
+ call_counts["get_user_me"] += 1
+ return {
+ "user_me": f"user_me_{call_counts['get_user_me']}",
+ "current_user": current_user,
+ }
+
+ def get_user_items(
+ user_me: Annotated[dict, Depends(get_user_me)],
+ ):
+ call_counts["get_user_items"] += 1
+ return {
+ "user_items": f"user_items_{call_counts['get_user_items']}",
+ "user_me": user_me,
+ }
+
+ app = FastAPI()
+
+ @app.get("/")
+ def path_operation(
+ user_me: Annotated[dict, Depends(get_user_me)],
+ user_items: Annotated[dict, Security(get_user_items, scopes=["items"])],
+ ):
+ return {
+ "user_me": user_me,
+ "user_items": user_items,
+ }
+
+ return app
+
+
+@pytest.fixture(name="client")
+def client_fixture(app: FastAPI):
+ return TestClient(app)
+
+
+def test_security_scopes_sub_dependency_caching(
+ client: TestClient, call_counts: Dict[str, int]
+):
+ response = client.get("/")
+
+ assert response.status_code == 200
+ assert call_counts["get_db_session"] == 1
+ assert call_counts["get_current_user"] == 2
+ assert call_counts["get_user_me"] == 2
+ assert call_counts["get_user_items"] == 1
+ assert response.json() == {
+ "user_me": {
+ "user_me": "user_me_1",
+ "current_user": {
+ "user": "user_1",
+ "scopes": ["me"],
+ "db_session": "db_session_1",
+ },
+ },
+ "user_items": {
+ "user_items": "user_items_1",
+ "user_me": {
+ "user_me": "user_me_2",
+ "current_user": {
+ "user": "user_2",
+ "scopes": ["items", "me"],
+ "db_session": "db_session_1",
+ },
+ },
+ },
+ }