]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
♻️ Refactor and simplify internal data from `solve_dependencies()` using dataclasses...
authorSebastián Ramírez <tiangolo@gmail.com>
Sat, 31 Aug 2024 20:52:06 +0000 (22:52 +0200)
committerGitHub <noreply@github.com>
Sat, 31 Aug 2024 20:52:06 +0000 (22:52 +0200)
fastapi/dependencies/utils.py
fastapi/routing.py

index 5ebdddaf658cedf28d1f50ef01635477f98a5815..ed03df88bd1101d75f4b30a1fff6fe6c7107e9b9 100644 (file)
@@ -529,6 +529,15 @@ async def solve_generator(
     return await stack.enter_async_context(cm)
 
 
+@dataclass
+class SolvedDependency:
+    values: Dict[str, Any]
+    errors: List[Any]
+    background_tasks: Optional[StarletteBackgroundTasks]
+    response: Response
+    dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
+
+
 async def solve_dependencies(
     *,
     request: Union[Request, WebSocket],
@@ -539,13 +548,7 @@ async def solve_dependencies(
     dependency_overrides_provider: Optional[Any] = None,
     dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
     async_exit_stack: AsyncExitStack,
-) -> Tuple[
-    Dict[str, Any],
-    List[Any],
-    Optional[StarletteBackgroundTasks],
-    Response,
-    Dict[Tuple[Callable[..., Any], Tuple[str]], Any],
-]:
+) -> SolvedDependency:
     values: Dict[str, Any] = {}
     errors: List[Any] = []
     if response is None:
@@ -587,27 +590,21 @@ async def solve_dependencies(
             dependency_cache=dependency_cache,
             async_exit_stack=async_exit_stack,
         )
-        (
-            sub_values,
-            sub_errors,
-            background_tasks,
-            _,  # the subdependency returns the same response we have
-            sub_dependency_cache,
-        ) = solved_result
-        dependency_cache.update(sub_dependency_cache)
-        if sub_errors:
-            errors.extend(sub_errors)
+        background_tasks = solved_result.background_tasks
+        dependency_cache.update(solved_result.dependency_cache)
+        if solved_result.errors:
+            errors.extend(solved_result.errors)
             continue
         if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
             solved = dependency_cache[sub_dependant.cache_key]
         elif is_gen_callable(call) or is_async_gen_callable(call):
             solved = await solve_generator(
-                call=call, stack=async_exit_stack, sub_values=sub_values
+                call=call, stack=async_exit_stack, sub_values=solved_result.values
             )
         elif is_coroutine_callable(call):
-            solved = await call(**sub_values)
+            solved = await call(**solved_result.values)
         else:
-            solved = await run_in_threadpool(call, **sub_values)
+            solved = await run_in_threadpool(call, **solved_result.values)
         if sub_dependant.name is not None:
             values[sub_dependant.name] = solved
         if sub_dependant.cache_key not in dependency_cache:
@@ -654,7 +651,13 @@ async def solve_dependencies(
         values[dependant.security_scopes_param_name] = SecurityScopes(
             scopes=dependant.security_scopes
         )
-    return values, errors, background_tasks, response, dependency_cache
+    return SolvedDependency(
+        values=values,
+        errors=errors,
+        background_tasks=background_tasks,
+        response=response,
+        dependency_cache=dependency_cache,
+    )
 
 
 def request_params_to_args(
index 49f1b60138f3de9a36567a155a6d2799e0249f67..c46772017d68490b4a64b78433ceede8b434b91d 100644 (file)
@@ -292,26 +292,34 @@ def get_request_handler(
                     dependency_overrides_provider=dependency_overrides_provider,
                     async_exit_stack=async_exit_stack,
                 )
-                values, errors, background_tasks, sub_response, _ = solved_result
+                errors = solved_result.errors
                 if not errors:
                     raw_response = await run_endpoint_function(
-                        dependant=dependant, values=values, is_coroutine=is_coroutine
+                        dependant=dependant,
+                        values=solved_result.values,
+                        is_coroutine=is_coroutine,
                     )
                     if isinstance(raw_response, Response):
                         if raw_response.background is None:
-                            raw_response.background = background_tasks
+                            raw_response.background = solved_result.background_tasks
                         response = raw_response
                     else:
-                        response_args: Dict[str, Any] = {"background": background_tasks}
+                        response_args: Dict[str, Any] = {
+                            "background": solved_result.background_tasks
+                        }
                         # If status_code was set, use it, otherwise use the default from the
                         # response class, in the case of redirect it's 307
                         current_status_code = (
-                            status_code if status_code else sub_response.status_code
+                            status_code
+                            if status_code
+                            else solved_result.response.status_code
                         )
                         if current_status_code is not None:
                             response_args["status_code"] = current_status_code
-                        if sub_response.status_code:
-                            response_args["status_code"] = sub_response.status_code
+                        if solved_result.response.status_code:
+                            response_args["status_code"] = (
+                                solved_result.response.status_code
+                            )
                         content = await serialize_response(
                             field=response_field,
                             response_content=raw_response,
@@ -326,7 +334,7 @@ def get_request_handler(
                         response = actual_response_class(content, **response_args)
                         if not is_body_allowed_for_status_code(response.status_code):
                             response.body = b""
-                        response.headers.raw.extend(sub_response.headers.raw)
+                        response.headers.raw.extend(solved_result.response.headers.raw)
             if errors:
                 validation_error = RequestValidationError(
                     _normalize_errors(errors), body=body
@@ -360,11 +368,12 @@ def get_websocket_app(
                 dependency_overrides_provider=dependency_overrides_provider,
                 async_exit_stack=async_exit_stack,
             )
-            values, errors, _, _2, _3 = solved_result
-            if errors:
-                raise WebSocketRequestValidationError(_normalize_errors(errors))
+            if solved_result.errors:
+                raise WebSocketRequestValidationError(
+                    _normalize_errors(solved_result.errors)
+                )
             assert dependant.call is not None, "dependant.call must be a function"
-            await dependant.call(**values)
+            await dependant.call(**solved_result.values)
 
     return app