return await stack.enter_async_context(cm)
+@dataclass
+class SolvedDependency:
+ values: Dict[str, Any]
+ errors: List[Any]
+ background_tasks: Optional[StarletteBackgroundTasks]
+ response: Response
+ dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
+
+
async def solve_dependencies(
*,
request: Union[Request, WebSocket],
dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
async_exit_stack: AsyncExitStack,
-) -> Tuple[
- Dict[str, Any],
- List[Any],
- Optional[StarletteBackgroundTasks],
- Response,
- Dict[Tuple[Callable[..., Any], Tuple[str]], Any],
-]:
+) -> SolvedDependency:
values: Dict[str, Any] = {}
errors: List[Any] = []
if response is None:
dependency_cache=dependency_cache,
async_exit_stack=async_exit_stack,
)
- (
- sub_values,
- sub_errors,
- background_tasks,
- _, # the subdependency returns the same response we have
- sub_dependency_cache,
- ) = solved_result
- dependency_cache.update(sub_dependency_cache)
- if sub_errors:
- errors.extend(sub_errors)
+ background_tasks = solved_result.background_tasks
+ dependency_cache.update(solved_result.dependency_cache)
+ if solved_result.errors:
+ errors.extend(solved_result.errors)
continue
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
solved = dependency_cache[sub_dependant.cache_key]
elif is_gen_callable(call) or is_async_gen_callable(call):
solved = await solve_generator(
- call=call, stack=async_exit_stack, sub_values=sub_values
+ call=call, stack=async_exit_stack, sub_values=solved_result.values
)
elif is_coroutine_callable(call):
- solved = await call(**sub_values)
+ solved = await call(**solved_result.values)
else:
- solved = await run_in_threadpool(call, **sub_values)
+ solved = await run_in_threadpool(call, **solved_result.values)
if sub_dependant.name is not None:
values[sub_dependant.name] = solved
if sub_dependant.cache_key not in dependency_cache:
values[dependant.security_scopes_param_name] = SecurityScopes(
scopes=dependant.security_scopes
)
- return values, errors, background_tasks, response, dependency_cache
+ return SolvedDependency(
+ values=values,
+ errors=errors,
+ background_tasks=background_tasks,
+ response=response,
+ dependency_cache=dependency_cache,
+ )
def request_params_to_args(
dependency_overrides_provider=dependency_overrides_provider,
async_exit_stack=async_exit_stack,
)
- values, errors, background_tasks, sub_response, _ = solved_result
+ errors = solved_result.errors
if not errors:
raw_response = await run_endpoint_function(
- dependant=dependant, values=values, is_coroutine=is_coroutine
+ dependant=dependant,
+ values=solved_result.values,
+ is_coroutine=is_coroutine,
)
if isinstance(raw_response, Response):
if raw_response.background is None:
- raw_response.background = background_tasks
+ raw_response.background = solved_result.background_tasks
response = raw_response
else:
- response_args: Dict[str, Any] = {"background": background_tasks}
+ response_args: Dict[str, Any] = {
+ "background": solved_result.background_tasks
+ }
# If status_code was set, use it, otherwise use the default from the
# response class, in the case of redirect it's 307
current_status_code = (
- status_code if status_code else sub_response.status_code
+ status_code
+ if status_code
+ else solved_result.response.status_code
)
if current_status_code is not None:
response_args["status_code"] = current_status_code
- if sub_response.status_code:
- response_args["status_code"] = sub_response.status_code
+ if solved_result.response.status_code:
+ response_args["status_code"] = (
+ solved_result.response.status_code
+ )
content = await serialize_response(
field=response_field,
response_content=raw_response,
response = actual_response_class(content, **response_args)
if not is_body_allowed_for_status_code(response.status_code):
response.body = b""
- response.headers.raw.extend(sub_response.headers.raw)
+ response.headers.raw.extend(solved_result.response.headers.raw)
if errors:
validation_error = RequestValidationError(
_normalize_errors(errors), body=body
dependency_overrides_provider=dependency_overrides_provider,
async_exit_stack=async_exit_stack,
)
- values, errors, _, _2, _3 = solved_result
- if errors:
- raise WebSocketRequestValidationError(_normalize_errors(errors))
+ if solved_result.errors:
+ raise WebSocketRequestValidationError(
+ _normalize_errors(solved_result.errors)
+ )
assert dependant.call is not None, "dependant.call must be a function"
- await dependant.call(**values)
+ await dependant.call(**solved_result.values)
return app