]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
✨ Add support for dependencies with scopes, support `scope="request"` for dependencie...
authorSebastián Ramírez <tiangolo@gmail.com>
Mon, 3 Nov 2025 10:12:49 +0000 (11:12 +0100)
committerGitHub <noreply@github.com>
Mon, 3 Nov 2025 10:12:49 +0000 (11:12 +0100)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
14 files changed:
docs/en/docs/tutorial/dependencies/dependencies-with-yield.md
docs_src/dependencies/tutorial008e.py [new file with mode: 0644]
docs_src/dependencies/tutorial008e_an.py [new file with mode: 0644]
docs_src/dependencies/tutorial008e_an_py39.py [new file with mode: 0644]
fastapi/dependencies/models.py
fastapi/dependencies/utils.py
fastapi/exceptions.py
fastapi/param_functions.py
fastapi/params.py
fastapi/routing.py
fastapi/types.py
tests/test_dependency_yield_scope.py [new file with mode: 0644]
tests/test_dependency_yield_scope_websockets.py [new file with mode: 0644]
tests/test_tutorial/test_dependencies/test_tutorial008e.py [new file with mode: 0644]

index adc1afa8d9bbd89952e159fb0e53c549ef654472..494c40efab429194b433ca1ac9427bf4784f6748 100644 (file)
@@ -184,6 +184,51 @@ If you raise any exception in the code from the *path operation function*, it wi
 
 ///
 
+## Early exit and `scope` { #early-exit-and-scope }
+
+Normally the exit code of dependencies with `yield` is executed **after the response** is sent to the client.
+
+But if you know that you won't need to use the dependency after returning from the *path operation function*, you can use `Depends(scope="function")` to tell FastAPI that it should close the dependency after the *path operation function* returns, but **before** the **response is sent**.
+
+{* ../../docs_src/dependencies/tutorial008e_an_py39.py hl[12,16] *}
+
+`Depends()` receives a `scope` parameter that can be:
+
+* `"function"`: start the dependency before the *path operation function* that handles the request, end the dependency after the *path operation function* ends, but **before** the response is sent back to the client. So, the dependency function will be executed **around** the *path operation **function***.
+* `"request"`: start the dependency before the *path operation function* that handles the request (similar to when using `"function"`), but end **after** the response is sent back to the client. So, the dependency function will be executed **around** the **request** and response cycle.
+
+If not specified and the dependency has `yield`, it will have a `scope` of `"request"` by default.
+
+### `scope` for sub-dependencies { #scope-for-sub-dependencies }
+
+When you declare a dependency with a `scope="request"` (the default), any sub-dependency needs to also have a `scope` of `"request"`.
+
+But a dependency with `scope` of `"function"` can have dependencies with `scope` of `"function"` and `scope` of `"request"`.
+
+This is because any dependency needs to be able to run its exit code before the sub-dependencies, as it might need to still use them during its exit code.
+
+```mermaid
+sequenceDiagram
+
+participant client as Client
+participant dep_req as Dep scope="request"
+participant dep_func as Dep scope="function"
+participant operation as Path Operation
+
+    client ->> dep_req: Start request
+    Note over dep_req: Run code up to yield
+    dep_req ->> dep_func: Pass dependency
+    Note over dep_func: Run code up to yield
+    dep_func ->> operation: Run path operation with dependency
+    operation ->> dep_func: Return from path operation
+    Note over dep_func: Run code after yield
+    Note over dep_func: ✅ Dependency closed
+    dep_func ->> client: Send response to client
+    Note over client: Response sent
+    Note over dep_req: Run code after yield
+    Note over dep_req: ✅ Dependency closed
+```
+
 ## Dependencies with `yield`, `HTTPException`, `except` and Background Tasks { #dependencies-with-yield-httpexception-except-and-background-tasks }
 
 Dependencies with `yield` have evolved over time to cover different use cases and fix some issues.
diff --git a/docs_src/dependencies/tutorial008e.py b/docs_src/dependencies/tutorial008e.py
new file mode 100644 (file)
index 0000000..1ed056e
--- /dev/null
@@ -0,0 +1,15 @@
+from fastapi import Depends, FastAPI
+
+app = FastAPI()
+
+
+def get_username():
+    try:
+        yield "Rick"
+    finally:
+        print("Cleanup up before response is sent")
+
+
+@app.get("/users/me")
+def get_user_me(username: str = Depends(get_username, scope="function")):
+    return username
diff --git a/docs_src/dependencies/tutorial008e_an.py b/docs_src/dependencies/tutorial008e_an.py
new file mode 100644 (file)
index 0000000..c8a0af2
--- /dev/null
@@ -0,0 +1,16 @@
+from fastapi import Depends, FastAPI
+from typing_extensions import Annotated
+
+app = FastAPI()
+
+
+def get_username():
+    try:
+        yield "Rick"
+    finally:
+        print("Cleanup up before response is sent")
+
+
+@app.get("/users/me")
+def get_user_me(username: Annotated[str, Depends(get_username, scope="function")]):
+    return username
diff --git a/docs_src/dependencies/tutorial008e_an_py39.py b/docs_src/dependencies/tutorial008e_an_py39.py
new file mode 100644 (file)
index 0000000..80a44c7
--- /dev/null
@@ -0,0 +1,17 @@
+from typing import Annotated
+
+from fastapi import Depends, FastAPI
+
+app = FastAPI()
+
+
+def get_username():
+    try:
+        yield "Rick"
+    finally:
+        print("Cleanup up before response is sent")
+
+
+@app.get("/users/me")
+def get_user_me(username: Annotated[str, Depends(get_username, scope="function")]):
+    return username
index 418c117259aa31de5c9d9bf6a4e5cfcbb5dab2d8..d6359c0f51aa2955f677d14a3c1625572e1d38ba 100644 (file)
@@ -1,8 +1,18 @@
+import inspect
+import sys
 from dataclasses import dataclass, field
-from typing import Any, Callable, List, Optional, Sequence, Tuple
+from functools import cached_property
+from typing import Any, Callable, List, Optional, Sequence, Union
 
 from fastapi._compat import ModelField
 from fastapi.security.base import SecurityBase
+from fastapi.types import DependencyCacheKey
+from typing_extensions import Literal
+
+if sys.version_info >= (3, 13):  # pragma: no cover
+    from inspect import iscoroutinefunction
+else:  # pragma: no cover
+    from asyncio import iscoroutinefunction
 
 
 @dataclass
@@ -31,7 +41,43 @@ class Dependant:
     security_scopes: Optional[List[str]] = None
     use_cache: bool = True
     path: Optional[str] = None
-    cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False)
+    scope: Union[Literal["function", "request"], None] = None
+
+    @cached_property
+    def cache_key(self) -> DependencyCacheKey:
+        return (
+            self.call,
+            tuple(sorted(set(self.security_scopes or []))),
+            self.computed_scope or "",
+        )
+
+    @cached_property
+    def is_gen_callable(self) -> bool:
+        if inspect.isgeneratorfunction(self.call):
+            return True
+        dunder_call = getattr(self.call, "__call__", None)  # noqa: B004
+        return inspect.isgeneratorfunction(dunder_call)
+
+    @cached_property
+    def is_async_gen_callable(self) -> bool:
+        if inspect.isasyncgenfunction(self.call):
+            return True
+        dunder_call = getattr(self.call, "__call__", None)  # noqa: B004
+        return inspect.isasyncgenfunction(dunder_call)
+
+    @cached_property
+    def is_coroutine_callable(self) -> bool:
+        if inspect.isroutine(self.call):
+            return iscoroutinefunction(self.call)
+        if inspect.isclass(self.call):
+            return False
+        dunder_call = getattr(self.call, "__call__", None)  # noqa: B004
+        return iscoroutinefunction(dunder_call)
 
-    def __post_init__(self) -> None:
-        self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or []))))
+    @cached_property
+    def computed_scope(self) -> Union[str, None]:
+        if self.scope:
+            return self.scope
+        if self.is_gen_callable or self.is_async_gen_callable:
+            return "request"
+        return None
index 6477a2cba68cdefed696bface342c69baf747c9a..c5c6b69bbcf644af8eb44869955e172e6e22ac78 100644 (file)
@@ -1,5 +1,4 @@
 import inspect
-import sys
 from contextlib import AsyncExitStack, contextmanager
 from copy import copy, deepcopy
 from dataclasses import dataclass
@@ -55,10 +54,12 @@ from fastapi.concurrency import (
     contextmanager_in_threadpool,
 )
 from fastapi.dependencies.models import Dependant, SecurityRequirement
+from fastapi.exceptions import DependencyScopeError
 from fastapi.logger import logger
 from fastapi.security.base import SecurityBase
 from fastapi.security.oauth2 import OAuth2, SecurityScopes
 from fastapi.security.open_id_connect_url import OpenIdConnect
+from fastapi.types import DependencyCacheKey
 from fastapi.utils import create_model_field, get_path_param_names
 from pydantic import BaseModel
 from pydantic.fields import FieldInfo
@@ -74,15 +75,10 @@ from starlette.datastructures import (
 from starlette.requests import HTTPConnection, Request
 from starlette.responses import Response
 from starlette.websockets import WebSocket
-from typing_extensions import Annotated, get_args, get_origin
+from typing_extensions import Annotated, Literal, get_args, get_origin
 
 from .. import temp_pydantic_v1_params
 
-if sys.version_info >= (3, 13):  # pragma: no cover
-    from inspect import iscoroutinefunction
-else:  # pragma: no cover
-    from asyncio import iscoroutinefunction
-
 multipart_not_installed_error = (
     'Form data requires "python-multipart" to be installed. \n'
     'You can install "python-multipart" with: \n\n'
@@ -137,14 +133,11 @@ def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> De
     )
 
 
-CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]]
-
-
 def get_flat_dependant(
     dependant: Dependant,
     *,
     skip_repeats: bool = False,
-    visited: Optional[List[CacheKey]] = None,
+    visited: Optional[List[DependencyCacheKey]] = None,
 ) -> Dependant:
     if visited is None:
         visited = []
@@ -237,6 +230,7 @@ def get_dependant(
     name: Optional[str] = None,
     security_scopes: Optional[List[str]] = None,
     use_cache: bool = True,
+    scope: Union[Literal["function", "request"], None] = None,
 ) -> Dependant:
     dependant = Dependant(
         call=call,
@@ -244,6 +238,7 @@ def get_dependant(
         path=path,
         security_scopes=security_scopes,
         use_cache=use_cache,
+        scope=scope,
     )
     path_param_names = get_path_param_names(path)
     endpoint_signature = get_typed_signature(call)
@@ -251,7 +246,7 @@ def get_dependant(
     if isinstance(call, SecurityBase):
         use_scopes: List[str] = []
         if isinstance(call, (OAuth2, OpenIdConnect)):
-            use_scopes = security_scopes
+            use_scopes = security_scopes or use_scopes
         security_requirement = SecurityRequirement(
             security_scheme=call, scopes=use_scopes
         )
@@ -266,6 +261,16 @@ def get_dependant(
         )
         if param_details.depends is not None:
             assert param_details.depends.dependency
+            if (
+                (dependant.is_gen_callable or dependant.is_async_gen_callable)
+                and dependant.computed_scope == "request"
+                and param_details.depends.scope == "function"
+            ):
+                assert dependant.call
+                raise DependencyScopeError(
+                    f'The dependency "{dependant.call.__name__}" has a scope of '
+                    '"request", it cannot depend on dependencies with scope "function".'
+                )
             use_security_scopes = security_scopes or []
             if isinstance(param_details.depends, params.Security):
                 if param_details.depends.scopes:
@@ -276,6 +281,7 @@ def get_dependant(
                 name=param_name,
                 security_scopes=use_security_scopes,
                 use_cache=param_details.depends.use_cache,
+                scope=param_details.depends.scope,
             )
             dependant.dependencies.append(sub_dependant)
             continue
@@ -532,36 +538,14 @@ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
         dependant.cookie_params.append(field)
 
 
-def is_coroutine_callable(call: Callable[..., Any]) -> bool:
-    if inspect.isroutine(call):
-        return iscoroutinefunction(call)
-    if inspect.isclass(call):
-        return False
-    dunder_call = getattr(call, "__call__", None)  # noqa: B004
-    return iscoroutinefunction(dunder_call)
-
-
-def is_async_gen_callable(call: Callable[..., Any]) -> bool:
-    if inspect.isasyncgenfunction(call):
-        return True
-    dunder_call = getattr(call, "__call__", None)  # noqa: B004
-    return inspect.isasyncgenfunction(dunder_call)
-
-
-def is_gen_callable(call: Callable[..., Any]) -> bool:
-    if inspect.isgeneratorfunction(call):
-        return True
-    dunder_call = getattr(call, "__call__", None)  # noqa: B004
-    return inspect.isgeneratorfunction(dunder_call)
-
-
-async def solve_generator(
-    *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
+async def _solve_generator(
+    *, dependant: Dependant, stack: AsyncExitStack, sub_values: Dict[str, Any]
 ) -> Any:
-    if is_gen_callable(call):
-        cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
-    elif is_async_gen_callable(call):
-        cm = asynccontextmanager(call)(**sub_values)
+    assert dependant.call
+    if dependant.is_gen_callable:
+        cm = contextmanager_in_threadpool(contextmanager(dependant.call)(**sub_values))
+    elif dependant.is_async_gen_callable:
+        cm = asynccontextmanager(dependant.call)(**sub_values)
     return await stack.enter_async_context(cm)
 
 
@@ -571,7 +555,7 @@ class SolvedDependency:
     errors: List[Any]
     background_tasks: Optional[StarletteBackgroundTasks]
     response: Response
-    dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
+    dependency_cache: Dict[DependencyCacheKey, Any]
 
 
 async def solve_dependencies(
@@ -582,10 +566,20 @@ async def solve_dependencies(
     background_tasks: Optional[StarletteBackgroundTasks] = None,
     response: Optional[Response] = None,
     dependency_overrides_provider: Optional[Any] = None,
-    dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
+    dependency_cache: Optional[Dict[DependencyCacheKey, Any]] = None,
+    # TODO: remove this parameter later, no longer used, not removing it yet as some
+    # people might be monkey patching this function (although that's not supported)
     async_exit_stack: AsyncExitStack,
     embed_body_fields: bool,
 ) -> SolvedDependency:
+    request_astack = request.scope.get("fastapi_inner_astack")
+    assert isinstance(request_astack, AsyncExitStack), (
+        "fastapi_inner_astack not found in request scope"
+    )
+    function_astack = request.scope.get("fastapi_function_astack")
+    assert isinstance(function_astack, AsyncExitStack), (
+        "fastapi_function_astack not found in request scope"
+    )
     values: Dict[str, Any] = {}
     errors: List[Any] = []
     if response is None:
@@ -594,12 +588,8 @@ async def solve_dependencies(
         response.status_code = None  # type: ignore
     if dependency_cache is None:
         dependency_cache = {}
-    sub_dependant: Dependant
     for sub_dependant in dependant.dependencies:
         sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
-        sub_dependant.cache_key = cast(
-            Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
-        )
         call = sub_dependant.call
         use_sub_dependant = sub_dependant
         if (
@@ -616,6 +606,7 @@ async def solve_dependencies(
                 call=call,
                 name=sub_dependant.name,
                 security_scopes=sub_dependant.security_scopes,
+                scope=sub_dependant.scope,
             )
 
         solved_result = await solve_dependencies(
@@ -635,11 +626,18 @@ async def solve_dependencies(
             continue
         if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
             solved = dependency_cache[sub_dependant.cache_key]
-        elif is_gen_callable(call) or is_async_gen_callable(call):
-            solved = await solve_generator(
-                call=call, stack=async_exit_stack, sub_values=solved_result.values
+        elif (
+            use_sub_dependant.is_gen_callable or use_sub_dependant.is_async_gen_callable
+        ):
+            use_astack = request_astack
+            if sub_dependant.scope == "function":
+                use_astack = function_astack
+            solved = await _solve_generator(
+                dependant=use_sub_dependant,
+                stack=use_astack,
+                sub_values=solved_result.values,
             )
-        elif is_coroutine_callable(call):
+        elif use_sub_dependant.is_coroutine_callable:
             solved = await call(**solved_result.values)
         else:
             solved = await run_in_threadpool(call, **solved_result.values)
index bb775fcbfa593960bb856259f2986cc8da7ae5d5..0620428be749bf5f5be47ba950d68c3536735fb0 100644 (file)
@@ -147,6 +147,13 @@ class FastAPIError(RuntimeError):
     """
 
 
+class DependencyScopeError(FastAPIError):
+    """
+    A dependency declared that it depends on another dependency with an invalid
+    (narrower) scope.
+    """
+
+
 class ValidationException(Exception):
     def __init__(self, errors: Sequence[Any]) -> None:
         self._errors = errors
index f88937e240a8683aa905202614227ec1bf5abaef..e32f755933ba506530d9278c4bc2a58d615a7c23 100644 (file)
@@ -4,7 +4,7 @@ from annotated_doc import Doc
 from fastapi import params
 from fastapi._compat import Undefined
 from fastapi.openapi.models import Example
-from typing_extensions import Annotated, deprecated
+from typing_extensions import Annotated, Literal, deprecated
 
 _Unset: Any = Undefined
 
@@ -2245,6 +2245,26 @@ def Depends(  # noqa: N802
             """
         ),
     ] = True,
+    scope: Annotated[
+        Union[Literal["function", "request"], None],
+        Doc(
+            """
+            Mainly for dependencies with `yield`, define when the dependency function
+            should start (the code before `yield`) and when it should end (the code
+            after `yield`).
+
+            * `"function"`: start the dependency before the *path operation function*
+                that handles the request, end the dependency after the *path operation
+                function* ends, but **before** the response is sent back to the client.
+                So, the dependency function will be executed **around** the *path operation
+                **function***.
+            * `"request"`: start the dependency before the *path operation function*
+                that handles the request (similar to when using `"function"`), but end
+                **after** the response is sent back to the client. So, the dependency
+                function will be executed **around** the **request** and response cycle.
+            """
+        ),
+    ] = None,
 ) -> Any:
     """
     Declare a FastAPI dependency.
@@ -2275,7 +2295,7 @@ def Depends(  # noqa: N802
         return commons
     ```
     """
-    return params.Depends(dependency=dependency, use_cache=use_cache)
+    return params.Depends(dependency=dependency, use_cache=use_cache, scope=scope)
 
 
 def Security(  # noqa: N802
index 2dc04be14ea879c3398c709c54f004d21740c510..6a58d5808e006c806eaafc61c6f0f7aec4730c40 100644 (file)
@@ -5,7 +5,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union
 
 from fastapi.openapi.models import Example
 from pydantic.fields import FieldInfo
-from typing_extensions import Annotated, deprecated
+from typing_extensions import Annotated, Literal, deprecated
 
 from ._compat import (
     PYDANTIC_V2,
@@ -766,6 +766,7 @@ class File(Form):  # type: ignore[misc]
 class Depends:
     dependency: Optional[Callable[..., Any]] = None
     use_cache: bool = True
+    scope: Union[Literal["function", "request"], None] = None
 
 
 @dataclass
index 0b59d250a4283dd7e1f7d5cd43790c03448d76f9..a8e12eb6073180490be19c5356c623a572144f18 100644 (file)
@@ -104,10 +104,11 @@ def request_response(
         async def app(scope: Scope, receive: Receive, send: Send) -> None:
             # Starts customization
             response_awaited = False
-            async with AsyncExitStack() as stack:
-                scope["fastapi_inner_astack"] = stack
-                # Same as in Starlette
-                response = await f(request)
+            async with AsyncExitStack() as request_stack:
+                scope["fastapi_inner_astack"] = request_stack
+                async with AsyncExitStack() as function_stack:
+                    scope["fastapi_function_astack"] = function_stack
+                    response = await f(request)
                 await response(scope, receive, send)
                 # Continues customization
                 response_awaited = True
@@ -140,11 +141,11 @@ def websocket_session(
         session = WebSocket(scope, receive=receive, send=send)
 
         async def app(scope: Scope, receive: Receive, send: Send) -> None:
-            # Starts customization
-            async with AsyncExitStack() as stack:
-                scope["fastapi_inner_astack"] = stack
-                # Same as in Starlette
-                await func(session)
+            async with AsyncExitStack() as request_stack:
+                scope["fastapi_inner_astack"] = request_stack
+                async with AsyncExitStack() as function_stack:
+                    scope["fastapi_function_astack"] = function_stack
+                    await func(session)
 
         # Same as in Starlette
         await wrap_app_handling_exceptions(app, session)(scope, receive, send)
@@ -479,7 +480,9 @@ class APIWebSocketRoute(routing.WebSocketRoute):
         self.name = get_name(endpoint) if name is None else name
         self.dependencies = list(dependencies or [])
         self.path_regex, self.path_format, self.param_convertors = compile_path(path)
-        self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
+        self.dependant = get_dependant(
+            path=self.path_format, call=self.endpoint, scope="function"
+        )
         for depends in self.dependencies[::-1]:
             self.dependant.dependencies.insert(
                 0,
@@ -630,7 +633,9 @@ class APIRoute(routing.Route):
             self.response_fields = {}
 
         assert callable(endpoint), "An endpoint must be a callable"
-        self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
+        self.dependant = get_dependant(
+            path=self.path_format, call=self.endpoint, scope="function"
+        )
         for depends in self.dependencies[::-1]:
             self.dependant.dependencies.insert(
                 0,
index 3205654c73b6501f1c88e69f5cd6ceb821a889c7..3f4e81a7cca9487527f4ed0ebdc2e1f5a1cc0bdb 100644 (file)
@@ -1,6 +1,6 @@
 import types
 from enum import Enum
-from typing import Any, Callable, Dict, Set, Type, TypeVar, Union
+from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, TypeVar, Union
 
 from pydantic import BaseModel
 
@@ -8,3 +8,4 @@ DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any])
 UnionType = getattr(types, "UnionType", Union)
 ModelNameMap = Dict[Union[Type[BaseModel], Type[Enum]], str]
 IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]]
+DependencyCacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...], str]
diff --git a/tests/test_dependency_yield_scope.py b/tests/test_dependency_yield_scope.py
new file mode 100644 (file)
index 0000000..a5227dd
--- /dev/null
@@ -0,0 +1,184 @@
+import json
+from typing import Any, Tuple
+
+import pytest
+from fastapi import Depends, FastAPI
+from fastapi.exceptions import FastAPIError
+from fastapi.responses import StreamingResponse
+from fastapi.testclient import TestClient
+from typing_extensions import Annotated
+
+
+class Session:
+    def __init__(self) -> None:
+        self.open = True
+
+
+def dep_session() -> Any:
+    s = Session()
+    yield s
+    s.open = False
+
+
+SessionFuncDep = Annotated[Session, Depends(dep_session, scope="function")]
+SessionRequestDep = Annotated[Session, Depends(dep_session, scope="request")]
+SessionDefaultDep = Annotated[Session, Depends(dep_session)]
+
+
+class NamedSession:
+    def __init__(self, name: str = "default") -> None:
+        self.name = name
+        self.open = True
+
+
+def get_named_session(session: SessionRequestDep, session_b: SessionDefaultDep) -> Any:
+    assert session is session_b
+    named_session = NamedSession(name="named")
+    yield named_session, session_b
+    named_session.open = False
+
+
+NamedSessionsDep = Annotated[Tuple[NamedSession, Session], Depends(get_named_session)]
+
+
+def get_named_func_session(session: SessionFuncDep) -> Any:
+    named_session = NamedSession(name="named")
+    yield named_session, session
+    named_session.open = False
+
+
+def get_named_regular_func_session(session: SessionFuncDep) -> Any:
+    named_session = NamedSession(name="named")
+    return named_session, session
+
+
+BrokenSessionsDep = Annotated[
+    Tuple[NamedSession, Session], Depends(get_named_func_session)
+]
+NamedSessionsFuncDep = Annotated[
+    Tuple[NamedSession, Session], Depends(get_named_func_session, scope="function")
+]
+
+RegularSessionsDep = Annotated[
+    Tuple[NamedSession, Session], Depends(get_named_regular_func_session)
+]
+
+app = FastAPI()
+
+
+@app.get("/function-scope")
+def function_scope(session: SessionFuncDep) -> Any:
+    def iter_data():
+        yield json.dumps({"is_open": session.open})
+
+    return StreamingResponse(iter_data())
+
+
+@app.get("/request-scope")
+def request_scope(session: SessionRequestDep) -> Any:
+    def iter_data():
+        yield json.dumps({"is_open": session.open})
+
+    return StreamingResponse(iter_data())
+
+
+@app.get("/two-scopes")
+def get_stream_session(
+    function_session: SessionFuncDep, request_session: SessionRequestDep
+) -> Any:
+    def iter_data():
+        yield json.dumps(
+            {"func_is_open": function_session.open, "req_is_open": request_session.open}
+        )
+
+    return StreamingResponse(iter_data())
+
+
+@app.get("/sub")
+def get_sub(sessions: NamedSessionsDep) -> Any:
+    def iter_data():
+        yield json.dumps(
+            {"named_session_open": sessions[0].open, "session_open": sessions[1].open}
+        )
+
+    return StreamingResponse(iter_data())
+
+
+@app.get("/named-function-scope")
+def get_named_function_scope(sessions: NamedSessionsFuncDep) -> Any:
+    def iter_data():
+        yield json.dumps(
+            {"named_session_open": sessions[0].open, "session_open": sessions[1].open}
+        )
+
+    return StreamingResponse(iter_data())
+
+
+@app.get("/regular-function-scope")
+def get_regular_function_scope(sessions: RegularSessionsDep) -> Any:
+    def iter_data():
+        yield json.dumps(
+            {"named_session_open": sessions[0].open, "session_open": sessions[1].open}
+        )
+
+    return StreamingResponse(iter_data())
+
+
+client = TestClient(app)
+
+
+def test_function_scope() -> None:
+    response = client.get("/function-scope")
+    assert response.status_code == 200
+    data = response.json()
+    assert data["is_open"] is False
+
+
+def test_request_scope() -> None:
+    response = client.get("/request-scope")
+    assert response.status_code == 200
+    data = response.json()
+    assert data["is_open"] is True
+
+
+def test_two_scopes() -> None:
+    response = client.get("/two-scopes")
+    assert response.status_code == 200
+    data = response.json()
+    assert data["func_is_open"] is False
+    assert data["req_is_open"] is True
+
+
+def test_sub() -> None:
+    response = client.get("/sub")
+    assert response.status_code == 200
+    data = response.json()
+    assert data["named_session_open"] is True
+    assert data["session_open"] is True
+
+
+def test_broken_scope() -> None:
+    with pytest.raises(
+        FastAPIError,
+        match='The dependency "get_named_func_session" has a scope of "request", it cannot depend on dependencies with scope "function"',
+    ):
+
+        @app.get("/broken-scope")
+        def get_broken(sessions: BrokenSessionsDep) -> Any:  # pragma: no cover
+            pass
+
+
+def test_named_function_scope() -> None:
+    response = client.get("/named-function-scope")
+    assert response.status_code == 200
+    data = response.json()
+    assert data["named_session_open"] is False
+    assert data["session_open"] is False
+
+
+def test_regular_function_scope() -> None:
+    response = client.get("/regular-function-scope")
+    assert response.status_code == 200
+    data = response.json()
+    assert data["named_session_open"] is True
+    assert data["session_open"] is False
diff --git a/tests/test_dependency_yield_scope_websockets.py b/tests/test_dependency_yield_scope_websockets.py
new file mode 100644 (file)
index 0000000..52a30ae
--- /dev/null
@@ -0,0 +1,201 @@
+from contextvars import ContextVar
+from typing import Any, Dict, Tuple
+
+import pytest
+from fastapi import Depends, FastAPI, WebSocket
+from fastapi.exceptions import FastAPIError
+from fastapi.testclient import TestClient
+from typing_extensions import Annotated
+
+global_context: ContextVar[Dict[str, Any]] = ContextVar("global_context", default={})  # noqa: B039
+
+
+class Session:
+    def __init__(self) -> None:
+        self.open = True
+
+
+async def dep_session() -> Any:
+    s = Session()
+    yield s
+    s.open = False
+    global_state = global_context.get()
+    global_state["session_closed"] = True
+
+
+SessionFuncDep = Annotated[Session, Depends(dep_session, scope="function")]
+SessionRequestDep = Annotated[Session, Depends(dep_session, scope="request")]
+SessionDefaultDep = Annotated[Session, Depends(dep_session)]
+
+
+class NamedSession:
+    def __init__(self, name: str = "default") -> None:
+        self.name = name
+        self.open = True
+
+
+def get_named_session(session: SessionRequestDep, session_b: SessionDefaultDep) -> Any:
+    assert session is session_b
+    named_session = NamedSession(name="named")
+    yield named_session, session_b
+    named_session.open = False
+    global_state = global_context.get()
+    global_state["named_session_closed"] = True
+
+
+NamedSessionsDep = Annotated[Tuple[NamedSession, Session], Depends(get_named_session)]
+
+
+def get_named_func_session(session: SessionFuncDep) -> Any:
+    named_session = NamedSession(name="named")
+    yield named_session, session
+    named_session.open = False
+    global_state = global_context.get()
+    global_state["named_func_session_closed"] = True
+
+
+def get_named_regular_func_session(session: SessionFuncDep) -> Any:
+    named_session = NamedSession(name="named")
+    return named_session, session
+
+
+BrokenSessionsDep = Annotated[
+    Tuple[NamedSession, Session], Depends(get_named_func_session)
+]
+NamedSessionsFuncDep = Annotated[
+    Tuple[NamedSession, Session], Depends(get_named_func_session, scope="function")
+]
+
+RegularSessionsDep = Annotated[
+    Tuple[NamedSession, Session], Depends(get_named_regular_func_session)
+]
+
+app = FastAPI()
+
+
+@app.websocket("/function-scope")
+async def function_scope(websocket: WebSocket, session: SessionFuncDep) -> Any:
+    await websocket.accept()
+    await websocket.send_json({"is_open": session.open})
+
+
+@app.websocket("/request-scope")
+async def request_scope(websocket: WebSocket, session: SessionRequestDep) -> Any:
+    await websocket.accept()
+    await websocket.send_json({"is_open": session.open})
+
+
+@app.websocket("/two-scopes")
+async def get_stream_session(
+    websocket: WebSocket,
+    function_session: SessionFuncDep,
+    request_session: SessionRequestDep,
+) -> Any:
+    await websocket.accept()
+    await websocket.send_json(
+        {"func_is_open": function_session.open, "req_is_open": request_session.open}
+    )
+
+
+@app.websocket("/sub")
+async def get_sub(websocket: WebSocket, sessions: NamedSessionsDep) -> Any:
+    await websocket.accept()
+    await websocket.send_json(
+        {"named_session_open": sessions[0].open, "session_open": sessions[1].open}
+    )
+
+
+@app.websocket("/named-function-scope")
+async def get_named_function_scope(
+    websocket: WebSocket, sessions: NamedSessionsFuncDep
+) -> Any:
+    await websocket.accept()
+    await websocket.send_json(
+        {"named_session_open": sessions[0].open, "session_open": sessions[1].open}
+    )
+
+
+@app.websocket("/regular-function-scope")
+async def get_regular_function_scope(
+    websocket: WebSocket, sessions: RegularSessionsDep
+) -> Any:
+    await websocket.accept()
+    await websocket.send_json(
+        {"named_session_open": sessions[0].open, "session_open": sessions[1].open}
+    )
+
+
+client = TestClient(app)
+
+
+def test_function_scope() -> None:
+    global_context.set({})
+    global_state = global_context.get()
+    with client.websocket_connect("/function-scope") as websocket:
+        data = websocket.receive_json()
+    assert data["is_open"] is True
+    assert global_state["session_closed"] is True
+
+
+def test_request_scope() -> None:
+    global_context.set({})
+    global_state = global_context.get()
+    with client.websocket_connect("/request-scope") as websocket:
+        data = websocket.receive_json()
+    assert data["is_open"] is True
+    assert global_state["session_closed"] is True
+
+
+def test_two_scopes() -> None:
+    global_context.set({})
+    global_state = global_context.get()
+    with client.websocket_connect("/two-scopes") as websocket:
+        data = websocket.receive_json()
+    assert data["func_is_open"] is True
+    assert data["req_is_open"] is True
+    assert global_state["session_closed"] is True
+
+
+def test_sub() -> None:
+    global_context.set({})
+    global_state = global_context.get()
+    with client.websocket_connect("/sub") as websocket:
+        data = websocket.receive_json()
+    assert data["named_session_open"] is True
+    assert data["session_open"] is True
+    assert global_state["session_closed"] is True
+    assert global_state["named_session_closed"] is True
+
+
+def test_broken_scope() -> None:
+    with pytest.raises(
+        FastAPIError,
+        match='The dependency "get_named_func_session" has a scope of "request", it cannot depend on dependencies with scope "function"',
+    ):
+
+        @app.websocket("/broken-scope")
+        async def get_broken(
+            websocket: WebSocket, sessions: BrokenSessionsDep
+        ) -> Any:  # pragma: no cover
+            pass
+
+
+def test_named_function_scope() -> None:
+    global_context.set({})
+    global_state = global_context.get()
+    with client.websocket_connect("/named-function-scope") as websocket:
+        data = websocket.receive_json()
+    assert data["named_session_open"] is True
+    assert data["session_open"] is True
+    assert global_state["session_closed"] is True
+    assert global_state["named_func_session_closed"] is True
+
+
+def test_regular_function_scope() -> None:
+    global_context.set({})
+    global_state = global_context.get()
+    with client.websocket_connect("/regular-function-scope") as websocket:
+        data = websocket.receive_json()
+    assert data["named_session_open"] is True
+    assert data["session_open"] is True
+    assert global_state["session_closed"] is True
diff --git a/tests/test_tutorial/test_dependencies/test_tutorial008e.py b/tests/test_tutorial/test_dependencies/test_tutorial008e.py
new file mode 100644 (file)
index 0000000..1ae9ab2
--- /dev/null
@@ -0,0 +1,27 @@
+import importlib
+
+import pytest
+from fastapi.testclient import TestClient
+
+from ...utils import needs_py39
+
+
+@pytest.fixture(
+    name="client",
+    params=[
+        "tutorial008e",
+        "tutorial008e_an",
+        pytest.param("tutorial008e_an_py39", marks=needs_py39),
+    ],
+)
+def get_client(request: pytest.FixtureRequest):
+    mod = importlib.import_module(f"docs_src.dependencies.{request.param}")
+
+    client = TestClient(mod.app)
+    return client
+
+
+def test_get_users_me(client: TestClient):
+    response = client.get("/users/me")
+    assert response.status_code == 200, response.text
+    assert response.json() == "Rick"