///
+## Early exit and `scope` { #early-exit-and-scope }
+
+Normally the exit code of dependencies with `yield` is executed **after the response** is sent to the client.
+
+But if you know that you won't need to use the dependency after returning from the *path operation function*, you can use `Depends(scope="function")` to tell FastAPI that it should close the dependency after the *path operation function* returns, but **before** the **response is sent**.
+
+{* ../../docs_src/dependencies/tutorial008e_an_py39.py hl[12,16] *}
+
+`Depends()` receives a `scope` parameter that can be:
+
+* `"function"`: start the dependency before the *path operation function* that handles the request, end the dependency after the *path operation function* ends, but **before** the response is sent back to the client. So, the dependency function will be executed **around** the *path operation **function***.
+* `"request"`: start the dependency before the *path operation function* that handles the request (similar to when using `"function"`), but end **after** the response is sent back to the client. So, the dependency function will be executed **around** the **request** and response cycle.
+
+If not specified and the dependency has `yield`, it will have a `scope` of `"request"` by default.
+
+### `scope` for sub-dependencies { #scope-for-sub-dependencies }
+
+When you declare a dependency with a `scope="request"` (the default), any sub-dependency needs to also have a `scope` of `"request"`.
+
+But a dependency with `scope` of `"function"` can have dependencies with `scope` of `"function"` and `scope` of `"request"`.
+
+This is because any dependency needs to be able to run its exit code before the sub-dependencies, as it might need to still use them during its exit code.
+
+```mermaid
+sequenceDiagram
+
+participant client as Client
+participant dep_req as Dep scope="request"
+participant dep_func as Dep scope="function"
+participant operation as Path Operation
+
+ client ->> dep_req: Start request
+ Note over dep_req: Run code up to yield
+ dep_req ->> dep_func: Pass dependency
+ Note over dep_func: Run code up to yield
+ dep_func ->> operation: Run path operation with dependency
+ operation ->> dep_func: Return from path operation
+ Note over dep_func: Run code after yield
+ Note over dep_func: ✅ Dependency closed
+ dep_func ->> client: Send response to client
+ Note over client: Response sent
+ Note over dep_req: Run code after yield
+ Note over dep_req: ✅ Dependency closed
+```
+
## Dependencies with `yield`, `HTTPException`, `except` and Background Tasks { #dependencies-with-yield-httpexception-except-and-background-tasks }
Dependencies with `yield` have evolved over time to cover different use cases and fix some issues.
--- /dev/null
+from fastapi import Depends, FastAPI
+
+app = FastAPI()
+
+
+def get_username():
+ try:
+ yield "Rick"
+ finally:
+ print("Cleanup up before response is sent")
+
+
+@app.get("/users/me")
+def get_user_me(username: str = Depends(get_username, scope="function")):
+ return username
--- /dev/null
+from fastapi import Depends, FastAPI
+from typing_extensions import Annotated
+
+app = FastAPI()
+
+
+def get_username():
+ try:
+ yield "Rick"
+ finally:
+ print("Cleanup up before response is sent")
+
+
+@app.get("/users/me")
+def get_user_me(username: Annotated[str, Depends(get_username, scope="function")]):
+ return username
--- /dev/null
+from typing import Annotated
+
+from fastapi import Depends, FastAPI
+
+app = FastAPI()
+
+
+def get_username():
+ try:
+ yield "Rick"
+ finally:
+ print("Cleanup up before response is sent")
+
+
+@app.get("/users/me")
+def get_user_me(username: Annotated[str, Depends(get_username, scope="function")]):
+ return username
+import inspect
+import sys
from dataclasses import dataclass, field
-from typing import Any, Callable, List, Optional, Sequence, Tuple
+from functools import cached_property
+from typing import Any, Callable, List, Optional, Sequence, Union
from fastapi._compat import ModelField
from fastapi.security.base import SecurityBase
+from fastapi.types import DependencyCacheKey
+from typing_extensions import Literal
+
+if sys.version_info >= (3, 13): # pragma: no cover
+ from inspect import iscoroutinefunction
+else: # pragma: no cover
+ from asyncio import iscoroutinefunction
@dataclass
security_scopes: Optional[List[str]] = None
use_cache: bool = True
path: Optional[str] = None
- cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False)
+ scope: Union[Literal["function", "request"], None] = None
+
+ @cached_property
+ def cache_key(self) -> DependencyCacheKey:
+ return (
+ self.call,
+ tuple(sorted(set(self.security_scopes or []))),
+ self.computed_scope or "",
+ )
+
+ @cached_property
+ def is_gen_callable(self) -> bool:
+ if inspect.isgeneratorfunction(self.call):
+ return True
+ dunder_call = getattr(self.call, "__call__", None) # noqa: B004
+ return inspect.isgeneratorfunction(dunder_call)
+
+ @cached_property
+ def is_async_gen_callable(self) -> bool:
+ if inspect.isasyncgenfunction(self.call):
+ return True
+ dunder_call = getattr(self.call, "__call__", None) # noqa: B004
+ return inspect.isasyncgenfunction(dunder_call)
+
+ @cached_property
+ def is_coroutine_callable(self) -> bool:
+ if inspect.isroutine(self.call):
+ return iscoroutinefunction(self.call)
+ if inspect.isclass(self.call):
+ return False
+ dunder_call = getattr(self.call, "__call__", None) # noqa: B004
+ return iscoroutinefunction(dunder_call)
- def __post_init__(self) -> None:
- self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or []))))
+ @cached_property
+ def computed_scope(self) -> Union[str, None]:
+ if self.scope:
+ return self.scope
+ if self.is_gen_callable or self.is_async_gen_callable:
+ return "request"
+ return None
import inspect
-import sys
from contextlib import AsyncExitStack, contextmanager
from copy import copy, deepcopy
from dataclasses import dataclass
contextmanager_in_threadpool,
)
from fastapi.dependencies.models import Dependant, SecurityRequirement
+from fastapi.exceptions import DependencyScopeError
from fastapi.logger import logger
from fastapi.security.base import SecurityBase
from fastapi.security.oauth2 import OAuth2, SecurityScopes
from fastapi.security.open_id_connect_url import OpenIdConnect
+from fastapi.types import DependencyCacheKey
from fastapi.utils import create_model_field, get_path_param_names
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from starlette.requests import HTTPConnection, Request
from starlette.responses import Response
from starlette.websockets import WebSocket
-from typing_extensions import Annotated, get_args, get_origin
+from typing_extensions import Annotated, Literal, get_args, get_origin
from .. import temp_pydantic_v1_params
-if sys.version_info >= (3, 13): # pragma: no cover
- from inspect import iscoroutinefunction
-else: # pragma: no cover
- from asyncio import iscoroutinefunction
-
multipart_not_installed_error = (
'Form data requires "python-multipart" to be installed. \n'
'You can install "python-multipart" with: \n\n'
)
-CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]]
-
-
def get_flat_dependant(
dependant: Dependant,
*,
skip_repeats: bool = False,
- visited: Optional[List[CacheKey]] = None,
+ visited: Optional[List[DependencyCacheKey]] = None,
) -> Dependant:
if visited is None:
visited = []
name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
use_cache: bool = True,
+ scope: Union[Literal["function", "request"], None] = None,
) -> Dependant:
dependant = Dependant(
call=call,
path=path,
security_scopes=security_scopes,
use_cache=use_cache,
+ scope=scope,
)
path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call)
if isinstance(call, SecurityBase):
use_scopes: List[str] = []
if isinstance(call, (OAuth2, OpenIdConnect)):
- use_scopes = security_scopes
+ use_scopes = security_scopes or use_scopes
security_requirement = SecurityRequirement(
security_scheme=call, scopes=use_scopes
)
)
if param_details.depends is not None:
assert param_details.depends.dependency
+ if (
+ (dependant.is_gen_callable or dependant.is_async_gen_callable)
+ and dependant.computed_scope == "request"
+ and param_details.depends.scope == "function"
+ ):
+ assert dependant.call
+ raise DependencyScopeError(
+ f'The dependency "{dependant.call.__name__}" has a scope of '
+ '"request", it cannot depend on dependencies with scope "function".'
+ )
use_security_scopes = security_scopes or []
if isinstance(param_details.depends, params.Security):
if param_details.depends.scopes:
name=param_name,
security_scopes=use_security_scopes,
use_cache=param_details.depends.use_cache,
+ scope=param_details.depends.scope,
)
dependant.dependencies.append(sub_dependant)
continue
dependant.cookie_params.append(field)
-def is_coroutine_callable(call: Callable[..., Any]) -> bool:
- if inspect.isroutine(call):
- return iscoroutinefunction(call)
- if inspect.isclass(call):
- return False
- dunder_call = getattr(call, "__call__", None) # noqa: B004
- return iscoroutinefunction(dunder_call)
-
-
-def is_async_gen_callable(call: Callable[..., Any]) -> bool:
- if inspect.isasyncgenfunction(call):
- return True
- dunder_call = getattr(call, "__call__", None) # noqa: B004
- return inspect.isasyncgenfunction(dunder_call)
-
-
-def is_gen_callable(call: Callable[..., Any]) -> bool:
- if inspect.isgeneratorfunction(call):
- return True
- dunder_call = getattr(call, "__call__", None) # noqa: B004
- return inspect.isgeneratorfunction(dunder_call)
-
-
-async def solve_generator(
- *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
+async def _solve_generator(
+ *, dependant: Dependant, stack: AsyncExitStack, sub_values: Dict[str, Any]
) -> Any:
- if is_gen_callable(call):
- cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
- elif is_async_gen_callable(call):
- cm = asynccontextmanager(call)(**sub_values)
+ assert dependant.call
+ if dependant.is_gen_callable:
+ cm = contextmanager_in_threadpool(contextmanager(dependant.call)(**sub_values))
+ elif dependant.is_async_gen_callable:
+ cm = asynccontextmanager(dependant.call)(**sub_values)
return await stack.enter_async_context(cm)
errors: List[Any]
background_tasks: Optional[StarletteBackgroundTasks]
response: Response
- dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
+ dependency_cache: Dict[DependencyCacheKey, Any]
async def solve_dependencies(
background_tasks: Optional[StarletteBackgroundTasks] = None,
response: Optional[Response] = None,
dependency_overrides_provider: Optional[Any] = None,
- dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
+ dependency_cache: Optional[Dict[DependencyCacheKey, Any]] = None,
+ # TODO: remove this parameter later, no longer used, not removing it yet as some
+ # people might be monkey patching this function (although that's not supported)
async_exit_stack: AsyncExitStack,
embed_body_fields: bool,
) -> SolvedDependency:
+ request_astack = request.scope.get("fastapi_inner_astack")
+ assert isinstance(request_astack, AsyncExitStack), (
+ "fastapi_inner_astack not found in request scope"
+ )
+ function_astack = request.scope.get("fastapi_function_astack")
+ assert isinstance(function_astack, AsyncExitStack), (
+ "fastapi_function_astack not found in request scope"
+ )
values: Dict[str, Any] = {}
errors: List[Any] = []
if response is None:
response.status_code = None # type: ignore
if dependency_cache is None:
dependency_cache = {}
- sub_dependant: Dependant
for sub_dependant in dependant.dependencies:
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
- sub_dependant.cache_key = cast(
- Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
- )
call = sub_dependant.call
use_sub_dependant = sub_dependant
if (
call=call,
name=sub_dependant.name,
security_scopes=sub_dependant.security_scopes,
+ scope=sub_dependant.scope,
)
solved_result = await solve_dependencies(
continue
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
solved = dependency_cache[sub_dependant.cache_key]
- elif is_gen_callable(call) or is_async_gen_callable(call):
- solved = await solve_generator(
- call=call, stack=async_exit_stack, sub_values=solved_result.values
+ elif (
+ use_sub_dependant.is_gen_callable or use_sub_dependant.is_async_gen_callable
+ ):
+ use_astack = request_astack
+ if sub_dependant.scope == "function":
+ use_astack = function_astack
+ solved = await _solve_generator(
+ dependant=use_sub_dependant,
+ stack=use_astack,
+ sub_values=solved_result.values,
)
- elif is_coroutine_callable(call):
+ elif use_sub_dependant.is_coroutine_callable:
solved = await call(**solved_result.values)
else:
solved = await run_in_threadpool(call, **solved_result.values)
"""
+class DependencyScopeError(FastAPIError):
+ """
+ A dependency declared that it depends on another dependency with an invalid
+ (narrower) scope.
+ """
+
+
class ValidationException(Exception):
def __init__(self, errors: Sequence[Any]) -> None:
self._errors = errors
from fastapi import params
from fastapi._compat import Undefined
from fastapi.openapi.models import Example
-from typing_extensions import Annotated, deprecated
+from typing_extensions import Annotated, Literal, deprecated
_Unset: Any = Undefined
"""
),
] = True,
+ scope: Annotated[
+ Union[Literal["function", "request"], None],
+ Doc(
+ """
+ Mainly for dependencies with `yield`, define when the dependency function
+ should start (the code before `yield`) and when it should end (the code
+ after `yield`).
+
+ * `"function"`: start the dependency before the *path operation function*
+ that handles the request, end the dependency after the *path operation
+ function* ends, but **before** the response is sent back to the client.
+ So, the dependency function will be executed **around** the *path operation
+ **function***.
+ * `"request"`: start the dependency before the *path operation function*
+ that handles the request (similar to when using `"function"`), but end
+ **after** the response is sent back to the client. So, the dependency
+ function will be executed **around** the **request** and response cycle.
+ """
+ ),
+ ] = None,
) -> Any:
"""
Declare a FastAPI dependency.
return commons
```
"""
- return params.Depends(dependency=dependency, use_cache=use_cache)
+ return params.Depends(dependency=dependency, use_cache=use_cache, scope=scope)
def Security( # noqa: N802
from fastapi.openapi.models import Example
from pydantic.fields import FieldInfo
-from typing_extensions import Annotated, deprecated
+from typing_extensions import Annotated, Literal, deprecated
from ._compat import (
PYDANTIC_V2,
class Depends:
dependency: Optional[Callable[..., Any]] = None
use_cache: bool = True
+ scope: Union[Literal["function", "request"], None] = None
@dataclass
async def app(scope: Scope, receive: Receive, send: Send) -> None:
# Starts customization
response_awaited = False
- async with AsyncExitStack() as stack:
- scope["fastapi_inner_astack"] = stack
- # Same as in Starlette
- response = await f(request)
+ async with AsyncExitStack() as request_stack:
+ scope["fastapi_inner_astack"] = request_stack
+ async with AsyncExitStack() as function_stack:
+ scope["fastapi_function_astack"] = function_stack
+ response = await f(request)
await response(scope, receive, send)
# Continues customization
response_awaited = True
session = WebSocket(scope, receive=receive, send=send)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
- # Starts customization
- async with AsyncExitStack() as stack:
- scope["fastapi_inner_astack"] = stack
- # Same as in Starlette
- await func(session)
+ async with AsyncExitStack() as request_stack:
+ scope["fastapi_inner_astack"] = request_stack
+ async with AsyncExitStack() as function_stack:
+ scope["fastapi_function_astack"] = function_stack
+ await func(session)
# Same as in Starlette
await wrap_app_handling_exceptions(app, session)(scope, receive, send)
self.name = get_name(endpoint) if name is None else name
self.dependencies = list(dependencies or [])
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
- self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
+ self.dependant = get_dependant(
+ path=self.path_format, call=self.endpoint, scope="function"
+ )
for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert(
0,
self.response_fields = {}
assert callable(endpoint), "An endpoint must be a callable"
- self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
+ self.dependant = get_dependant(
+ path=self.path_format, call=self.endpoint, scope="function"
+ )
for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert(
0,
import types
from enum import Enum
-from typing import Any, Callable, Dict, Set, Type, TypeVar, Union
+from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, TypeVar, Union
from pydantic import BaseModel
UnionType = getattr(types, "UnionType", Union)
ModelNameMap = Dict[Union[Type[BaseModel], Type[Enum]], str]
IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]]
+DependencyCacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...], str]
--- /dev/null
+import json
+from typing import Any, Tuple
+
+import pytest
+from fastapi import Depends, FastAPI
+from fastapi.exceptions import FastAPIError
+from fastapi.responses import StreamingResponse
+from fastapi.testclient import TestClient
+from typing_extensions import Annotated
+
+
+class Session:
+ def __init__(self) -> None:
+ self.open = True
+
+
+def dep_session() -> Any:
+ s = Session()
+ yield s
+ s.open = False
+
+
+SessionFuncDep = Annotated[Session, Depends(dep_session, scope="function")]
+SessionRequestDep = Annotated[Session, Depends(dep_session, scope="request")]
+SessionDefaultDep = Annotated[Session, Depends(dep_session)]
+
+
+class NamedSession:
+ def __init__(self, name: str = "default") -> None:
+ self.name = name
+ self.open = True
+
+
+def get_named_session(session: SessionRequestDep, session_b: SessionDefaultDep) -> Any:
+ assert session is session_b
+ named_session = NamedSession(name="named")
+ yield named_session, session_b
+ named_session.open = False
+
+
+NamedSessionsDep = Annotated[Tuple[NamedSession, Session], Depends(get_named_session)]
+
+
+def get_named_func_session(session: SessionFuncDep) -> Any:
+ named_session = NamedSession(name="named")
+ yield named_session, session
+ named_session.open = False
+
+
+def get_named_regular_func_session(session: SessionFuncDep) -> Any:
+ named_session = NamedSession(name="named")
+ return named_session, session
+
+
+BrokenSessionsDep = Annotated[
+ Tuple[NamedSession, Session], Depends(get_named_func_session)
+]
+NamedSessionsFuncDep = Annotated[
+ Tuple[NamedSession, Session], Depends(get_named_func_session, scope="function")
+]
+
+RegularSessionsDep = Annotated[
+ Tuple[NamedSession, Session], Depends(get_named_regular_func_session)
+]
+
+app = FastAPI()
+
+
+@app.get("/function-scope")
+def function_scope(session: SessionFuncDep) -> Any:
+ def iter_data():
+ yield json.dumps({"is_open": session.open})
+
+ return StreamingResponse(iter_data())
+
+
+@app.get("/request-scope")
+def request_scope(session: SessionRequestDep) -> Any:
+ def iter_data():
+ yield json.dumps({"is_open": session.open})
+
+ return StreamingResponse(iter_data())
+
+
+@app.get("/two-scopes")
+def get_stream_session(
+ function_session: SessionFuncDep, request_session: SessionRequestDep
+) -> Any:
+ def iter_data():
+ yield json.dumps(
+ {"func_is_open": function_session.open, "req_is_open": request_session.open}
+ )
+
+ return StreamingResponse(iter_data())
+
+
+@app.get("/sub")
+def get_sub(sessions: NamedSessionsDep) -> Any:
+ def iter_data():
+ yield json.dumps(
+ {"named_session_open": sessions[0].open, "session_open": sessions[1].open}
+ )
+
+ return StreamingResponse(iter_data())
+
+
+@app.get("/named-function-scope")
+def get_named_function_scope(sessions: NamedSessionsFuncDep) -> Any:
+ def iter_data():
+ yield json.dumps(
+ {"named_session_open": sessions[0].open, "session_open": sessions[1].open}
+ )
+
+ return StreamingResponse(iter_data())
+
+
+@app.get("/regular-function-scope")
+def get_regular_function_scope(sessions: RegularSessionsDep) -> Any:
+ def iter_data():
+ yield json.dumps(
+ {"named_session_open": sessions[0].open, "session_open": sessions[1].open}
+ )
+
+ return StreamingResponse(iter_data())
+
+
+client = TestClient(app)
+
+
+def test_function_scope() -> None:
+ response = client.get("/function-scope")
+ assert response.status_code == 200
+ data = response.json()
+ assert data["is_open"] is False
+
+
+def test_request_scope() -> None:
+ response = client.get("/request-scope")
+ assert response.status_code == 200
+ data = response.json()
+ assert data["is_open"] is True
+
+
+def test_two_scopes() -> None:
+ response = client.get("/two-scopes")
+ assert response.status_code == 200
+ data = response.json()
+ assert data["func_is_open"] is False
+ assert data["req_is_open"] is True
+
+
+def test_sub() -> None:
+ response = client.get("/sub")
+ assert response.status_code == 200
+ data = response.json()
+ assert data["named_session_open"] is True
+ assert data["session_open"] is True
+
+
+def test_broken_scope() -> None:
+ with pytest.raises(
+ FastAPIError,
+ match='The dependency "get_named_func_session" has a scope of "request", it cannot depend on dependencies with scope "function"',
+ ):
+
+ @app.get("/broken-scope")
+ def get_broken(sessions: BrokenSessionsDep) -> Any: # pragma: no cover
+ pass
+
+
+def test_named_function_scope() -> None:
+ response = client.get("/named-function-scope")
+ assert response.status_code == 200
+ data = response.json()
+ assert data["named_session_open"] is False
+ assert data["session_open"] is False
+
+
+def test_regular_function_scope() -> None:
+ response = client.get("/regular-function-scope")
+ assert response.status_code == 200
+ data = response.json()
+ assert data["named_session_open"] is True
+ assert data["session_open"] is False
--- /dev/null
+from contextvars import ContextVar
+from typing import Any, Dict, Tuple
+
+import pytest
+from fastapi import Depends, FastAPI, WebSocket
+from fastapi.exceptions import FastAPIError
+from fastapi.testclient import TestClient
+from typing_extensions import Annotated
+
+global_context: ContextVar[Dict[str, Any]] = ContextVar("global_context", default={}) # noqa: B039
+
+
+class Session:
+ def __init__(self) -> None:
+ self.open = True
+
+
+async def dep_session() -> Any:
+ s = Session()
+ yield s
+ s.open = False
+ global_state = global_context.get()
+ global_state["session_closed"] = True
+
+
+SessionFuncDep = Annotated[Session, Depends(dep_session, scope="function")]
+SessionRequestDep = Annotated[Session, Depends(dep_session, scope="request")]
+SessionDefaultDep = Annotated[Session, Depends(dep_session)]
+
+
+class NamedSession:
+ def __init__(self, name: str = "default") -> None:
+ self.name = name
+ self.open = True
+
+
+def get_named_session(session: SessionRequestDep, session_b: SessionDefaultDep) -> Any:
+ assert session is session_b
+ named_session = NamedSession(name="named")
+ yield named_session, session_b
+ named_session.open = False
+ global_state = global_context.get()
+ global_state["named_session_closed"] = True
+
+
+NamedSessionsDep = Annotated[Tuple[NamedSession, Session], Depends(get_named_session)]
+
+
+def get_named_func_session(session: SessionFuncDep) -> Any:
+ named_session = NamedSession(name="named")
+ yield named_session, session
+ named_session.open = False
+ global_state = global_context.get()
+ global_state["named_func_session_closed"] = True
+
+
+def get_named_regular_func_session(session: SessionFuncDep) -> Any:
+ named_session = NamedSession(name="named")
+ return named_session, session
+
+
+BrokenSessionsDep = Annotated[
+ Tuple[NamedSession, Session], Depends(get_named_func_session)
+]
+NamedSessionsFuncDep = Annotated[
+ Tuple[NamedSession, Session], Depends(get_named_func_session, scope="function")
+]
+
+RegularSessionsDep = Annotated[
+ Tuple[NamedSession, Session], Depends(get_named_regular_func_session)
+]
+
+app = FastAPI()
+
+
+@app.websocket("/function-scope")
+async def function_scope(websocket: WebSocket, session: SessionFuncDep) -> Any:
+ await websocket.accept()
+ await websocket.send_json({"is_open": session.open})
+
+
+@app.websocket("/request-scope")
+async def request_scope(websocket: WebSocket, session: SessionRequestDep) -> Any:
+ await websocket.accept()
+ await websocket.send_json({"is_open": session.open})
+
+
+@app.websocket("/two-scopes")
+async def get_stream_session(
+ websocket: WebSocket,
+ function_session: SessionFuncDep,
+ request_session: SessionRequestDep,
+) -> Any:
+ await websocket.accept()
+ await websocket.send_json(
+ {"func_is_open": function_session.open, "req_is_open": request_session.open}
+ )
+
+
+@app.websocket("/sub")
+async def get_sub(websocket: WebSocket, sessions: NamedSessionsDep) -> Any:
+ await websocket.accept()
+ await websocket.send_json(
+ {"named_session_open": sessions[0].open, "session_open": sessions[1].open}
+ )
+
+
+@app.websocket("/named-function-scope")
+async def get_named_function_scope(
+ websocket: WebSocket, sessions: NamedSessionsFuncDep
+) -> Any:
+ await websocket.accept()
+ await websocket.send_json(
+ {"named_session_open": sessions[0].open, "session_open": sessions[1].open}
+ )
+
+
+@app.websocket("/regular-function-scope")
+async def get_regular_function_scope(
+ websocket: WebSocket, sessions: RegularSessionsDep
+) -> Any:
+ await websocket.accept()
+ await websocket.send_json(
+ {"named_session_open": sessions[0].open, "session_open": sessions[1].open}
+ )
+
+
+client = TestClient(app)
+
+
+def test_function_scope() -> None:
+ global_context.set({})
+ global_state = global_context.get()
+ with client.websocket_connect("/function-scope") as websocket:
+ data = websocket.receive_json()
+ assert data["is_open"] is True
+ assert global_state["session_closed"] is True
+
+
+def test_request_scope() -> None:
+ global_context.set({})
+ global_state = global_context.get()
+ with client.websocket_connect("/request-scope") as websocket:
+ data = websocket.receive_json()
+ assert data["is_open"] is True
+ assert global_state["session_closed"] is True
+
+
+def test_two_scopes() -> None:
+ global_context.set({})
+ global_state = global_context.get()
+ with client.websocket_connect("/two-scopes") as websocket:
+ data = websocket.receive_json()
+ assert data["func_is_open"] is True
+ assert data["req_is_open"] is True
+ assert global_state["session_closed"] is True
+
+
+def test_sub() -> None:
+ global_context.set({})
+ global_state = global_context.get()
+ with client.websocket_connect("/sub") as websocket:
+ data = websocket.receive_json()
+ assert data["named_session_open"] is True
+ assert data["session_open"] is True
+ assert global_state["session_closed"] is True
+ assert global_state["named_session_closed"] is True
+
+
+def test_broken_scope() -> None:
+ with pytest.raises(
+ FastAPIError,
+ match='The dependency "get_named_func_session" has a scope of "request", it cannot depend on dependencies with scope "function"',
+ ):
+
+ @app.websocket("/broken-scope")
+ async def get_broken(
+ websocket: WebSocket, sessions: BrokenSessionsDep
+ ) -> Any: # pragma: no cover
+ pass
+
+
+def test_named_function_scope() -> None:
+ global_context.set({})
+ global_state = global_context.get()
+ with client.websocket_connect("/named-function-scope") as websocket:
+ data = websocket.receive_json()
+ assert data["named_session_open"] is True
+ assert data["session_open"] is True
+ assert global_state["session_closed"] is True
+ assert global_state["named_func_session_closed"] is True
+
+
+def test_regular_function_scope() -> None:
+ global_context.set({})
+ global_state = global_context.get()
+ with client.websocket_connect("/regular-function-scope") as websocket:
+ data = websocket.receive_json()
+ assert data["named_session_open"] is True
+ assert data["session_open"] is True
+ assert global_state["session_closed"] is True
--- /dev/null
+import importlib
+
+import pytest
+from fastapi.testclient import TestClient
+
+from ...utils import needs_py39
+
+
+@pytest.fixture(
+ name="client",
+ params=[
+ "tutorial008e",
+ "tutorial008e_an",
+ pytest.param("tutorial008e_an_py39", marks=needs_py39),
+ ],
+)
+def get_client(request: pytest.FixtureRequest):
+ mod = importlib.import_module(f"docs_src.dependencies.{request.param}")
+
+ client = TestClient(mod.app)
+ return client
+
+
+def test_get_users_me(client: TestClient):
+ response = client.get("/users/me")
+ assert response.status_code == 200, response.text
+ assert response.json() == "Rick"