From ac438b99342c859ae0e10f7064021125bd247bf5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 3 Nov 2025 11:12:49 +0100 Subject: [PATCH] =?utf8?q?=E2=9C=A8=20Add=20support=20for=20dependencies?= =?utf8?q?=20with=20scopes,=20support=20`scope=3D"request"`=20for=20depend?= =?utf8?q?encies=20with=20`yield`=20that=20exit=20before=20the=20response?= =?utf8?q?=20is=20sent=20(#14262)?= MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../dependencies/dependencies-with-yield.md | 45 ++++ docs_src/dependencies/tutorial008e.py | 15 ++ docs_src/dependencies/tutorial008e_an.py | 16 ++ docs_src/dependencies/tutorial008e_an_py39.py | 17 ++ fastapi/dependencies/models.py | 54 ++++- fastapi/dependencies/utils.py | 100 +++++---- fastapi/exceptions.py | 7 + fastapi/param_functions.py | 24 ++- fastapi/params.py | 3 +- fastapi/routing.py | 27 ++- fastapi/types.py | 3 +- tests/test_dependency_yield_scope.py | 184 ++++++++++++++++ .../test_dependency_yield_scope_websockets.py | 201 ++++++++++++++++++ .../test_dependencies/test_tutorial008e.py | 27 +++ 14 files changed, 653 insertions(+), 70 deletions(-) create mode 100644 docs_src/dependencies/tutorial008e.py create mode 100644 docs_src/dependencies/tutorial008e_an.py create mode 100644 docs_src/dependencies/tutorial008e_an_py39.py create mode 100644 tests/test_dependency_yield_scope.py create mode 100644 tests/test_dependency_yield_scope_websockets.py create mode 100644 tests/test_tutorial/test_dependencies/test_tutorial008e.py diff --git a/docs/en/docs/tutorial/dependencies/dependencies-with-yield.md b/docs/en/docs/tutorial/dependencies/dependencies-with-yield.md index adc1afa8d..494c40efa 100644 --- a/docs/en/docs/tutorial/dependencies/dependencies-with-yield.md +++ b/docs/en/docs/tutorial/dependencies/dependencies-with-yield.md @@ -184,6 +184,51 @@ If you raise any exception in the code from the *path operation function*, it wi /// +## Early exit and `scope` { #early-exit-and-scope } + +Normally the exit code of dependencies with `yield` is executed **after the response** is sent to the client. + +But if you know that you won't need to use the dependency after returning from the *path operation function*, you can use `Depends(scope="function")` to tell FastAPI that it should close the dependency after the *path operation function* returns, but **before** the **response is sent**. + +{* ../../docs_src/dependencies/tutorial008e_an_py39.py hl[12,16] *} + +`Depends()` receives a `scope` parameter that can be: + +* `"function"`: start the dependency before the *path operation function* that handles the request, end the dependency after the *path operation function* ends, but **before** the response is sent back to the client. So, the dependency function will be executed **around** the *path operation **function***. +* `"request"`: start the dependency before the *path operation function* that handles the request (similar to when using `"function"`), but end **after** the response is sent back to the client. So, the dependency function will be executed **around** the **request** and response cycle. + +If not specified and the dependency has `yield`, it will have a `scope` of `"request"` by default. + +### `scope` for sub-dependencies { #scope-for-sub-dependencies } + +When you declare a dependency with a `scope="request"` (the default), any sub-dependency needs to also have a `scope` of `"request"`. + +But a dependency with `scope` of `"function"` can have dependencies with `scope` of `"function"` and `scope` of `"request"`. + +This is because any dependency needs to be able to run its exit code before the sub-dependencies, as it might need to still use them during its exit code. + +```mermaid +sequenceDiagram + +participant client as Client +participant dep_req as Dep scope="request" +participant dep_func as Dep scope="function" +participant operation as Path Operation + + client ->> dep_req: Start request + Note over dep_req: Run code up to yield + dep_req ->> dep_func: Pass dependency + Note over dep_func: Run code up to yield + dep_func ->> operation: Run path operation with dependency + operation ->> dep_func: Return from path operation + Note over dep_func: Run code after yield + Note over dep_func: ✅ Dependency closed + dep_func ->> client: Send response to client + Note over client: Response sent + Note over dep_req: Run code after yield + Note over dep_req: ✅ Dependency closed +``` + ## Dependencies with `yield`, `HTTPException`, `except` and Background Tasks { #dependencies-with-yield-httpexception-except-and-background-tasks } Dependencies with `yield` have evolved over time to cover different use cases and fix some issues. diff --git a/docs_src/dependencies/tutorial008e.py b/docs_src/dependencies/tutorial008e.py new file mode 100644 index 000000000..1ed056e91 --- /dev/null +++ b/docs_src/dependencies/tutorial008e.py @@ -0,0 +1,15 @@ +from fastapi import Depends, FastAPI + +app = FastAPI() + + +def get_username(): + try: + yield "Rick" + finally: + print("Cleanup up before response is sent") + + +@app.get("/users/me") +def get_user_me(username: str = Depends(get_username, scope="function")): + return username diff --git a/docs_src/dependencies/tutorial008e_an.py b/docs_src/dependencies/tutorial008e_an.py new file mode 100644 index 000000000..c8a0af2b3 --- /dev/null +++ b/docs_src/dependencies/tutorial008e_an.py @@ -0,0 +1,16 @@ +from fastapi import Depends, FastAPI +from typing_extensions import Annotated + +app = FastAPI() + + +def get_username(): + try: + yield "Rick" + finally: + print("Cleanup up before response is sent") + + +@app.get("/users/me") +def get_user_me(username: Annotated[str, Depends(get_username, scope="function")]): + return username diff --git a/docs_src/dependencies/tutorial008e_an_py39.py b/docs_src/dependencies/tutorial008e_an_py39.py new file mode 100644 index 000000000..80a44c7e2 --- /dev/null +++ b/docs_src/dependencies/tutorial008e_an_py39.py @@ -0,0 +1,17 @@ +from typing import Annotated + +from fastapi import Depends, FastAPI + +app = FastAPI() + + +def get_username(): + try: + yield "Rick" + finally: + print("Cleanup up before response is sent") + + +@app.get("/users/me") +def get_user_me(username: Annotated[str, Depends(get_username, scope="function")]): + return username diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py index 418c11725..d6359c0f5 100644 --- a/fastapi/dependencies/models.py +++ b/fastapi/dependencies/models.py @@ -1,8 +1,18 @@ +import inspect +import sys from dataclasses import dataclass, field -from typing import Any, Callable, List, Optional, Sequence, Tuple +from functools import cached_property +from typing import Any, Callable, List, Optional, Sequence, Union from fastapi._compat import ModelField from fastapi.security.base import SecurityBase +from fastapi.types import DependencyCacheKey +from typing_extensions import Literal + +if sys.version_info >= (3, 13): # pragma: no cover + from inspect import iscoroutinefunction +else: # pragma: no cover + from asyncio import iscoroutinefunction @dataclass @@ -31,7 +41,43 @@ class Dependant: security_scopes: Optional[List[str]] = None use_cache: bool = True path: Optional[str] = None - cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False) + scope: Union[Literal["function", "request"], None] = None + + @cached_property + def cache_key(self) -> DependencyCacheKey: + return ( + self.call, + tuple(sorted(set(self.security_scopes or []))), + self.computed_scope or "", + ) + + @cached_property + def is_gen_callable(self) -> bool: + if inspect.isgeneratorfunction(self.call): + return True + dunder_call = getattr(self.call, "__call__", None) # noqa: B004 + return inspect.isgeneratorfunction(dunder_call) + + @cached_property + def is_async_gen_callable(self) -> bool: + if inspect.isasyncgenfunction(self.call): + return True + dunder_call = getattr(self.call, "__call__", None) # noqa: B004 + return inspect.isasyncgenfunction(dunder_call) + + @cached_property + def is_coroutine_callable(self) -> bool: + if inspect.isroutine(self.call): + return iscoroutinefunction(self.call) + if inspect.isclass(self.call): + return False + dunder_call = getattr(self.call, "__call__", None) # noqa: B004 + return iscoroutinefunction(dunder_call) - def __post_init__(self) -> None: - self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or [])))) + @cached_property + def computed_scope(self) -> Union[str, None]: + if self.scope: + return self.scope + if self.is_gen_callable or self.is_async_gen_callable: + return "request" + return None diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 6477a2cba..c5c6b69bb 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -1,5 +1,4 @@ import inspect -import sys from contextlib import AsyncExitStack, contextmanager from copy import copy, deepcopy from dataclasses import dataclass @@ -55,10 +54,12 @@ from fastapi.concurrency import ( contextmanager_in_threadpool, ) from fastapi.dependencies.models import Dependant, SecurityRequirement +from fastapi.exceptions import DependencyScopeError from fastapi.logger import logger from fastapi.security.base import SecurityBase from fastapi.security.oauth2 import OAuth2, SecurityScopes from fastapi.security.open_id_connect_url import OpenIdConnect +from fastapi.types import DependencyCacheKey from fastapi.utils import create_model_field, get_path_param_names from pydantic import BaseModel from pydantic.fields import FieldInfo @@ -74,15 +75,10 @@ from starlette.datastructures import ( from starlette.requests import HTTPConnection, Request from starlette.responses import Response from starlette.websockets import WebSocket -from typing_extensions import Annotated, get_args, get_origin +from typing_extensions import Annotated, Literal, get_args, get_origin from .. import temp_pydantic_v1_params -if sys.version_info >= (3, 13): # pragma: no cover - from inspect import iscoroutinefunction -else: # pragma: no cover - from asyncio import iscoroutinefunction - multipart_not_installed_error = ( 'Form data requires "python-multipart" to be installed. \n' 'You can install "python-multipart" with: \n\n' @@ -137,14 +133,11 @@ def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> De ) -CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] - - def get_flat_dependant( dependant: Dependant, *, skip_repeats: bool = False, - visited: Optional[List[CacheKey]] = None, + visited: Optional[List[DependencyCacheKey]] = None, ) -> Dependant: if visited is None: visited = [] @@ -237,6 +230,7 @@ def get_dependant( name: Optional[str] = None, security_scopes: Optional[List[str]] = None, use_cache: bool = True, + scope: Union[Literal["function", "request"], None] = None, ) -> Dependant: dependant = Dependant( call=call, @@ -244,6 +238,7 @@ def get_dependant( path=path, security_scopes=security_scopes, use_cache=use_cache, + scope=scope, ) path_param_names = get_path_param_names(path) endpoint_signature = get_typed_signature(call) @@ -251,7 +246,7 @@ def get_dependant( if isinstance(call, SecurityBase): use_scopes: List[str] = [] if isinstance(call, (OAuth2, OpenIdConnect)): - use_scopes = security_scopes + use_scopes = security_scopes or use_scopes security_requirement = SecurityRequirement( security_scheme=call, scopes=use_scopes ) @@ -266,6 +261,16 @@ def get_dependant( ) if param_details.depends is not None: assert param_details.depends.dependency + if ( + (dependant.is_gen_callable or dependant.is_async_gen_callable) + and dependant.computed_scope == "request" + and param_details.depends.scope == "function" + ): + assert dependant.call + raise DependencyScopeError( + f'The dependency "{dependant.call.__name__}" has a scope of ' + '"request", it cannot depend on dependencies with scope "function".' + ) use_security_scopes = security_scopes or [] if isinstance(param_details.depends, params.Security): if param_details.depends.scopes: @@ -276,6 +281,7 @@ def get_dependant( name=param_name, security_scopes=use_security_scopes, use_cache=param_details.depends.use_cache, + scope=param_details.depends.scope, ) dependant.dependencies.append(sub_dependant) continue @@ -532,36 +538,14 @@ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None: dependant.cookie_params.append(field) -def is_coroutine_callable(call: Callable[..., Any]) -> bool: - if inspect.isroutine(call): - return iscoroutinefunction(call) - if inspect.isclass(call): - return False - dunder_call = getattr(call, "__call__", None) # noqa: B004 - return iscoroutinefunction(dunder_call) - - -def is_async_gen_callable(call: Callable[..., Any]) -> bool: - if inspect.isasyncgenfunction(call): - return True - dunder_call = getattr(call, "__call__", None) # noqa: B004 - return inspect.isasyncgenfunction(dunder_call) - - -def is_gen_callable(call: Callable[..., Any]) -> bool: - if inspect.isgeneratorfunction(call): - return True - dunder_call = getattr(call, "__call__", None) # noqa: B004 - return inspect.isgeneratorfunction(dunder_call) - - -async def solve_generator( - *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any] +async def _solve_generator( + *, dependant: Dependant, stack: AsyncExitStack, sub_values: Dict[str, Any] ) -> Any: - if is_gen_callable(call): - cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) - elif is_async_gen_callable(call): - cm = asynccontextmanager(call)(**sub_values) + assert dependant.call + if dependant.is_gen_callable: + cm = contextmanager_in_threadpool(contextmanager(dependant.call)(**sub_values)) + elif dependant.is_async_gen_callable: + cm = asynccontextmanager(dependant.call)(**sub_values) return await stack.enter_async_context(cm) @@ -571,7 +555,7 @@ class SolvedDependency: errors: List[Any] background_tasks: Optional[StarletteBackgroundTasks] response: Response - dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any] + dependency_cache: Dict[DependencyCacheKey, Any] async def solve_dependencies( @@ -582,10 +566,20 @@ async def solve_dependencies( background_tasks: Optional[StarletteBackgroundTasks] = None, response: Optional[Response] = None, dependency_overrides_provider: Optional[Any] = None, - dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None, + dependency_cache: Optional[Dict[DependencyCacheKey, Any]] = None, + # TODO: remove this parameter later, no longer used, not removing it yet as some + # people might be monkey patching this function (although that's not supported) async_exit_stack: AsyncExitStack, embed_body_fields: bool, ) -> SolvedDependency: + request_astack = request.scope.get("fastapi_inner_astack") + assert isinstance(request_astack, AsyncExitStack), ( + "fastapi_inner_astack not found in request scope" + ) + function_astack = request.scope.get("fastapi_function_astack") + assert isinstance(function_astack, AsyncExitStack), ( + "fastapi_function_astack not found in request scope" + ) values: Dict[str, Any] = {} errors: List[Any] = [] if response is None: @@ -594,12 +588,8 @@ async def solve_dependencies( response.status_code = None # type: ignore if dependency_cache is None: dependency_cache = {} - sub_dependant: Dependant for sub_dependant in dependant.dependencies: sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) - sub_dependant.cache_key = cast( - Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key - ) call = sub_dependant.call use_sub_dependant = sub_dependant if ( @@ -616,6 +606,7 @@ async def solve_dependencies( call=call, name=sub_dependant.name, security_scopes=sub_dependant.security_scopes, + scope=sub_dependant.scope, ) solved_result = await solve_dependencies( @@ -635,11 +626,18 @@ async def solve_dependencies( continue if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: solved = dependency_cache[sub_dependant.cache_key] - elif is_gen_callable(call) or is_async_gen_callable(call): - solved = await solve_generator( - call=call, stack=async_exit_stack, sub_values=solved_result.values + elif ( + use_sub_dependant.is_gen_callable or use_sub_dependant.is_async_gen_callable + ): + use_astack = request_astack + if sub_dependant.scope == "function": + use_astack = function_astack + solved = await _solve_generator( + dependant=use_sub_dependant, + stack=use_astack, + sub_values=solved_result.values, ) - elif is_coroutine_callable(call): + elif use_sub_dependant.is_coroutine_callable: solved = await call(**solved_result.values) else: solved = await run_in_threadpool(call, **solved_result.values) diff --git a/fastapi/exceptions.py b/fastapi/exceptions.py index bb775fcbf..0620428be 100644 --- a/fastapi/exceptions.py +++ b/fastapi/exceptions.py @@ -147,6 +147,13 @@ class FastAPIError(RuntimeError): """ +class DependencyScopeError(FastAPIError): + """ + A dependency declared that it depends on another dependency with an invalid + (narrower) scope. + """ + + class ValidationException(Exception): def __init__(self, errors: Sequence[Any]) -> None: self._errors = errors diff --git a/fastapi/param_functions.py b/fastapi/param_functions.py index f88937e24..e32f75593 100644 --- a/fastapi/param_functions.py +++ b/fastapi/param_functions.py @@ -4,7 +4,7 @@ from annotated_doc import Doc from fastapi import params from fastapi._compat import Undefined from fastapi.openapi.models import Example -from typing_extensions import Annotated, deprecated +from typing_extensions import Annotated, Literal, deprecated _Unset: Any = Undefined @@ -2245,6 +2245,26 @@ def Depends( # noqa: N802 """ ), ] = True, + scope: Annotated[ + Union[Literal["function", "request"], None], + Doc( + """ + Mainly for dependencies with `yield`, define when the dependency function + should start (the code before `yield`) and when it should end (the code + after `yield`). + + * `"function"`: start the dependency before the *path operation function* + that handles the request, end the dependency after the *path operation + function* ends, but **before** the response is sent back to the client. + So, the dependency function will be executed **around** the *path operation + **function***. + * `"request"`: start the dependency before the *path operation function* + that handles the request (similar to when using `"function"`), but end + **after** the response is sent back to the client. So, the dependency + function will be executed **around** the **request** and response cycle. + """ + ), + ] = None, ) -> Any: """ Declare a FastAPI dependency. @@ -2275,7 +2295,7 @@ def Depends( # noqa: N802 return commons ``` """ - return params.Depends(dependency=dependency, use_cache=use_cache) + return params.Depends(dependency=dependency, use_cache=use_cache, scope=scope) def Security( # noqa: N802 diff --git a/fastapi/params.py b/fastapi/params.py index 2dc04be14..6a58d5808 100644 --- a/fastapi/params.py +++ b/fastapi/params.py @@ -5,7 +5,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union from fastapi.openapi.models import Example from pydantic.fields import FieldInfo -from typing_extensions import Annotated, deprecated +from typing_extensions import Annotated, Literal, deprecated from ._compat import ( PYDANTIC_V2, @@ -766,6 +766,7 @@ class File(Form): # type: ignore[misc] class Depends: dependency: Optional[Callable[..., Any]] = None use_cache: bool = True + scope: Union[Literal["function", "request"], None] = None @dataclass diff --git a/fastapi/routing.py b/fastapi/routing.py index 0b59d250a..a8e12eb60 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -104,10 +104,11 @@ def request_response( async def app(scope: Scope, receive: Receive, send: Send) -> None: # Starts customization response_awaited = False - async with AsyncExitStack() as stack: - scope["fastapi_inner_astack"] = stack - # Same as in Starlette - response = await f(request) + async with AsyncExitStack() as request_stack: + scope["fastapi_inner_astack"] = request_stack + async with AsyncExitStack() as function_stack: + scope["fastapi_function_astack"] = function_stack + response = await f(request) await response(scope, receive, send) # Continues customization response_awaited = True @@ -140,11 +141,11 @@ def websocket_session( session = WebSocket(scope, receive=receive, send=send) async def app(scope: Scope, receive: Receive, send: Send) -> None: - # Starts customization - async with AsyncExitStack() as stack: - scope["fastapi_inner_astack"] = stack - # Same as in Starlette - await func(session) + async with AsyncExitStack() as request_stack: + scope["fastapi_inner_astack"] = request_stack + async with AsyncExitStack() as function_stack: + scope["fastapi_function_astack"] = function_stack + await func(session) # Same as in Starlette await wrap_app_handling_exceptions(app, session)(scope, receive, send) @@ -479,7 +480,9 @@ class APIWebSocketRoute(routing.WebSocketRoute): self.name = get_name(endpoint) if name is None else name self.dependencies = list(dependencies or []) self.path_regex, self.path_format, self.param_convertors = compile_path(path) - self.dependant = get_dependant(path=self.path_format, call=self.endpoint) + self.dependant = get_dependant( + path=self.path_format, call=self.endpoint, scope="function" + ) for depends in self.dependencies[::-1]: self.dependant.dependencies.insert( 0, @@ -630,7 +633,9 @@ class APIRoute(routing.Route): self.response_fields = {} assert callable(endpoint), "An endpoint must be a callable" - self.dependant = get_dependant(path=self.path_format, call=self.endpoint) + self.dependant = get_dependant( + path=self.path_format, call=self.endpoint, scope="function" + ) for depends in self.dependencies[::-1]: self.dependant.dependencies.insert( 0, diff --git a/fastapi/types.py b/fastapi/types.py index 3205654c7..3f4e81a7c 100644 --- a/fastapi/types.py +++ b/fastapi/types.py @@ -1,6 +1,6 @@ import types from enum import Enum -from typing import Any, Callable, Dict, Set, Type, TypeVar, Union +from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, TypeVar, Union from pydantic import BaseModel @@ -8,3 +8,4 @@ DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any]) UnionType = getattr(types, "UnionType", Union) ModelNameMap = Dict[Union[Type[BaseModel], Type[Enum]], str] IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]] +DependencyCacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...], str] diff --git a/tests/test_dependency_yield_scope.py b/tests/test_dependency_yield_scope.py new file mode 100644 index 000000000..a5227dd7a --- /dev/null +++ b/tests/test_dependency_yield_scope.py @@ -0,0 +1,184 @@ +import json +from typing import Any, Tuple + +import pytest +from fastapi import Depends, FastAPI +from fastapi.exceptions import FastAPIError +from fastapi.responses import StreamingResponse +from fastapi.testclient import TestClient +from typing_extensions import Annotated + + +class Session: + def __init__(self) -> None: + self.open = True + + +def dep_session() -> Any: + s = Session() + yield s + s.open = False + + +SessionFuncDep = Annotated[Session, Depends(dep_session, scope="function")] +SessionRequestDep = Annotated[Session, Depends(dep_session, scope="request")] +SessionDefaultDep = Annotated[Session, Depends(dep_session)] + + +class NamedSession: + def __init__(self, name: str = "default") -> None: + self.name = name + self.open = True + + +def get_named_session(session: SessionRequestDep, session_b: SessionDefaultDep) -> Any: + assert session is session_b + named_session = NamedSession(name="named") + yield named_session, session_b + named_session.open = False + + +NamedSessionsDep = Annotated[Tuple[NamedSession, Session], Depends(get_named_session)] + + +def get_named_func_session(session: SessionFuncDep) -> Any: + named_session = NamedSession(name="named") + yield named_session, session + named_session.open = False + + +def get_named_regular_func_session(session: SessionFuncDep) -> Any: + named_session = NamedSession(name="named") + return named_session, session + + +BrokenSessionsDep = Annotated[ + Tuple[NamedSession, Session], Depends(get_named_func_session) +] +NamedSessionsFuncDep = Annotated[ + Tuple[NamedSession, Session], Depends(get_named_func_session, scope="function") +] + +RegularSessionsDep = Annotated[ + Tuple[NamedSession, Session], Depends(get_named_regular_func_session) +] + +app = FastAPI() + + +@app.get("/function-scope") +def function_scope(session: SessionFuncDep) -> Any: + def iter_data(): + yield json.dumps({"is_open": session.open}) + + return StreamingResponse(iter_data()) + + +@app.get("/request-scope") +def request_scope(session: SessionRequestDep) -> Any: + def iter_data(): + yield json.dumps({"is_open": session.open}) + + return StreamingResponse(iter_data()) + + +@app.get("/two-scopes") +def get_stream_session( + function_session: SessionFuncDep, request_session: SessionRequestDep +) -> Any: + def iter_data(): + yield json.dumps( + {"func_is_open": function_session.open, "req_is_open": request_session.open} + ) + + return StreamingResponse(iter_data()) + + +@app.get("/sub") +def get_sub(sessions: NamedSessionsDep) -> Any: + def iter_data(): + yield json.dumps( + {"named_session_open": sessions[0].open, "session_open": sessions[1].open} + ) + + return StreamingResponse(iter_data()) + + +@app.get("/named-function-scope") +def get_named_function_scope(sessions: NamedSessionsFuncDep) -> Any: + def iter_data(): + yield json.dumps( + {"named_session_open": sessions[0].open, "session_open": sessions[1].open} + ) + + return StreamingResponse(iter_data()) + + +@app.get("/regular-function-scope") +def get_regular_function_scope(sessions: RegularSessionsDep) -> Any: + def iter_data(): + yield json.dumps( + {"named_session_open": sessions[0].open, "session_open": sessions[1].open} + ) + + return StreamingResponse(iter_data()) + + +client = TestClient(app) + + +def test_function_scope() -> None: + response = client.get("/function-scope") + assert response.status_code == 200 + data = response.json() + assert data["is_open"] is False + + +def test_request_scope() -> None: + response = client.get("/request-scope") + assert response.status_code == 200 + data = response.json() + assert data["is_open"] is True + + +def test_two_scopes() -> None: + response = client.get("/two-scopes") + assert response.status_code == 200 + data = response.json() + assert data["func_is_open"] is False + assert data["req_is_open"] is True + + +def test_sub() -> None: + response = client.get("/sub") + assert response.status_code == 200 + data = response.json() + assert data["named_session_open"] is True + assert data["session_open"] is True + + +def test_broken_scope() -> None: + with pytest.raises( + FastAPIError, + match='The dependency "get_named_func_session" has a scope of "request", it cannot depend on dependencies with scope "function"', + ): + + @app.get("/broken-scope") + def get_broken(sessions: BrokenSessionsDep) -> Any: # pragma: no cover + pass + + +def test_named_function_scope() -> None: + response = client.get("/named-function-scope") + assert response.status_code == 200 + data = response.json() + assert data["named_session_open"] is False + assert data["session_open"] is False + + +def test_regular_function_scope() -> None: + response = client.get("/regular-function-scope") + assert response.status_code == 200 + data = response.json() + assert data["named_session_open"] is True + assert data["session_open"] is False diff --git a/tests/test_dependency_yield_scope_websockets.py b/tests/test_dependency_yield_scope_websockets.py new file mode 100644 index 000000000..52a30ae7a --- /dev/null +++ b/tests/test_dependency_yield_scope_websockets.py @@ -0,0 +1,201 @@ +from contextvars import ContextVar +from typing import Any, Dict, Tuple + +import pytest +from fastapi import Depends, FastAPI, WebSocket +from fastapi.exceptions import FastAPIError +from fastapi.testclient import TestClient +from typing_extensions import Annotated + +global_context: ContextVar[Dict[str, Any]] = ContextVar("global_context", default={}) # noqa: B039 + + +class Session: + def __init__(self) -> None: + self.open = True + + +async def dep_session() -> Any: + s = Session() + yield s + s.open = False + global_state = global_context.get() + global_state["session_closed"] = True + + +SessionFuncDep = Annotated[Session, Depends(dep_session, scope="function")] +SessionRequestDep = Annotated[Session, Depends(dep_session, scope="request")] +SessionDefaultDep = Annotated[Session, Depends(dep_session)] + + +class NamedSession: + def __init__(self, name: str = "default") -> None: + self.name = name + self.open = True + + +def get_named_session(session: SessionRequestDep, session_b: SessionDefaultDep) -> Any: + assert session is session_b + named_session = NamedSession(name="named") + yield named_session, session_b + named_session.open = False + global_state = global_context.get() + global_state["named_session_closed"] = True + + +NamedSessionsDep = Annotated[Tuple[NamedSession, Session], Depends(get_named_session)] + + +def get_named_func_session(session: SessionFuncDep) -> Any: + named_session = NamedSession(name="named") + yield named_session, session + named_session.open = False + global_state = global_context.get() + global_state["named_func_session_closed"] = True + + +def get_named_regular_func_session(session: SessionFuncDep) -> Any: + named_session = NamedSession(name="named") + return named_session, session + + +BrokenSessionsDep = Annotated[ + Tuple[NamedSession, Session], Depends(get_named_func_session) +] +NamedSessionsFuncDep = Annotated[ + Tuple[NamedSession, Session], Depends(get_named_func_session, scope="function") +] + +RegularSessionsDep = Annotated[ + Tuple[NamedSession, Session], Depends(get_named_regular_func_session) +] + +app = FastAPI() + + +@app.websocket("/function-scope") +async def function_scope(websocket: WebSocket, session: SessionFuncDep) -> Any: + await websocket.accept() + await websocket.send_json({"is_open": session.open}) + + +@app.websocket("/request-scope") +async def request_scope(websocket: WebSocket, session: SessionRequestDep) -> Any: + await websocket.accept() + await websocket.send_json({"is_open": session.open}) + + +@app.websocket("/two-scopes") +async def get_stream_session( + websocket: WebSocket, + function_session: SessionFuncDep, + request_session: SessionRequestDep, +) -> Any: + await websocket.accept() + await websocket.send_json( + {"func_is_open": function_session.open, "req_is_open": request_session.open} + ) + + +@app.websocket("/sub") +async def get_sub(websocket: WebSocket, sessions: NamedSessionsDep) -> Any: + await websocket.accept() + await websocket.send_json( + {"named_session_open": sessions[0].open, "session_open": sessions[1].open} + ) + + +@app.websocket("/named-function-scope") +async def get_named_function_scope( + websocket: WebSocket, sessions: NamedSessionsFuncDep +) -> Any: + await websocket.accept() + await websocket.send_json( + {"named_session_open": sessions[0].open, "session_open": sessions[1].open} + ) + + +@app.websocket("/regular-function-scope") +async def get_regular_function_scope( + websocket: WebSocket, sessions: RegularSessionsDep +) -> Any: + await websocket.accept() + await websocket.send_json( + {"named_session_open": sessions[0].open, "session_open": sessions[1].open} + ) + + +client = TestClient(app) + + +def test_function_scope() -> None: + global_context.set({}) + global_state = global_context.get() + with client.websocket_connect("/function-scope") as websocket: + data = websocket.receive_json() + assert data["is_open"] is True + assert global_state["session_closed"] is True + + +def test_request_scope() -> None: + global_context.set({}) + global_state = global_context.get() + with client.websocket_connect("/request-scope") as websocket: + data = websocket.receive_json() + assert data["is_open"] is True + assert global_state["session_closed"] is True + + +def test_two_scopes() -> None: + global_context.set({}) + global_state = global_context.get() + with client.websocket_connect("/two-scopes") as websocket: + data = websocket.receive_json() + assert data["func_is_open"] is True + assert data["req_is_open"] is True + assert global_state["session_closed"] is True + + +def test_sub() -> None: + global_context.set({}) + global_state = global_context.get() + with client.websocket_connect("/sub") as websocket: + data = websocket.receive_json() + assert data["named_session_open"] is True + assert data["session_open"] is True + assert global_state["session_closed"] is True + assert global_state["named_session_closed"] is True + + +def test_broken_scope() -> None: + with pytest.raises( + FastAPIError, + match='The dependency "get_named_func_session" has a scope of "request", it cannot depend on dependencies with scope "function"', + ): + + @app.websocket("/broken-scope") + async def get_broken( + websocket: WebSocket, sessions: BrokenSessionsDep + ) -> Any: # pragma: no cover + pass + + +def test_named_function_scope() -> None: + global_context.set({}) + global_state = global_context.get() + with client.websocket_connect("/named-function-scope") as websocket: + data = websocket.receive_json() + assert data["named_session_open"] is True + assert data["session_open"] is True + assert global_state["session_closed"] is True + assert global_state["named_func_session_closed"] is True + + +def test_regular_function_scope() -> None: + global_context.set({}) + global_state = global_context.get() + with client.websocket_connect("/regular-function-scope") as websocket: + data = websocket.receive_json() + assert data["named_session_open"] is True + assert data["session_open"] is True + assert global_state["session_closed"] is True diff --git a/tests/test_tutorial/test_dependencies/test_tutorial008e.py b/tests/test_tutorial/test_dependencies/test_tutorial008e.py new file mode 100644 index 000000000..1ae9ab2cd --- /dev/null +++ b/tests/test_tutorial/test_dependencies/test_tutorial008e.py @@ -0,0 +1,27 @@ +import importlib + +import pytest +from fastapi.testclient import TestClient + +from ...utils import needs_py39 + + +@pytest.fixture( + name="client", + params=[ + "tutorial008e", + "tutorial008e_an", + pytest.param("tutorial008e_an_py39", marks=needs_py39), + ], +) +def get_client(request: pytest.FixtureRequest): + mod = importlib.import_module(f"docs_src.dependencies.{request.param}") + + client = TestClient(mod.app) + return client + + +def test_get_users_me(client: TestClient): + response = client.get("/users/me") + assert response.status_code == 200, response.text + assert response.json() == "Rick" -- 2.47.3