]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
♻️ Refactor deciding if `embed` body fields, do not overwrite fields, compute once...
authorSebastián Ramírez <tiangolo@gmail.com>
Thu, 5 Sep 2024 11:24:36 +0000 (13:24 +0200)
committerGitHub <noreply@github.com>
Thu, 5 Sep 2024 11:24:36 +0000 (13:24 +0200)
fastapi/_compat.py
fastapi/dependencies/utils.py
fastapi/param_functions.py
fastapi/params.py
fastapi/routing.py
tests/test_compat.py
tests/test_forms_single_param.py [new file with mode: 0644]

index 06b847b4f3818b5a1ae7a183209877bffaa5c1b5..f940d6597322325270a9a32944e312fd86a4c8f2 100644 (file)
@@ -279,6 +279,12 @@ if PYDANTIC_V2:
         BodyModel: Type[BaseModel] = create_model(model_name, **field_params)  # type: ignore[call-overload]
         return BodyModel
 
+    def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
+        return [
+            ModelField(field_info=field_info, name=name)
+            for name, field_info in model.model_fields.items()
+        ]
+
 else:
     from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
     from pydantic import AnyUrl as Url  # noqa: F401
@@ -513,6 +519,9 @@ else:
             BodyModel.__fields__[f.name] = f  # type: ignore[index]
         return BodyModel
 
+    def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
+        return list(model.__fields__.values())  # type: ignore[attr-defined]
+
 
 def _regenerate_error_with_loc(
     *, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
@@ -532,6 +541,12 @@ def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
 
 
 def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
+    origin = get_origin(annotation)
+    if origin is Union or origin is UnionType:
+        for arg in get_args(annotation):
+            if field_annotation_is_sequence(arg):
+                return True
+        return False
     return _annotation_is_sequence(annotation) or _annotation_is_sequence(
         get_origin(annotation)
     )
index 0dcba62f130a9f153dceaebc3b8fc4d31a5236e4..7ac18d941cd8524425cd266f2536d108c1fc02d9 100644 (file)
@@ -59,7 +59,13 @@ from fastapi.utils import create_model_field, get_path_param_names
 from pydantic.fields import FieldInfo
 from starlette.background import BackgroundTasks as StarletteBackgroundTasks
 from starlette.concurrency import run_in_threadpool
-from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
+from starlette.datastructures import (
+    FormData,
+    Headers,
+    ImmutableMultiDict,
+    QueryParams,
+    UploadFile,
+)
 from starlette.requests import HTTPConnection, Request
 from starlette.responses import Response
 from starlette.websockets import WebSocket
@@ -282,7 +288,7 @@ def get_dependant(
             ), f"Cannot specify multiple FastAPI annotations for {param_name!r}"
             continue
         assert param_details.field is not None
-        if is_body_param(param_field=param_details.field, is_path_param=is_path_param):
+        if isinstance(param_details.field.field_info, params.Body):
             dependant.body_params.append(param_details.field)
         else:
             add_param_to_fields(field=param_details.field, dependant=dependant)
@@ -466,29 +472,16 @@ def analyze_param(
             required=field_info.default in (Required, Undefined),
             field_info=field_info,
         )
+        if is_path_param:
+            assert is_scalar_field(
+                field=field
+            ), "Path params must be of one of the supported types"
+        elif isinstance(field_info, params.Query):
+            assert is_scalar_field(field) or is_scalar_sequence_field(field)
 
     return ParamDetails(type_annotation=type_annotation, depends=depends, field=field)
 
 
-def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
-    if is_path_param:
-        assert is_scalar_field(
-            field=param_field
-        ), "Path params must be of one of the supported types"
-        return False
-    elif is_scalar_field(field=param_field):
-        return False
-    elif isinstance(
-        param_field.field_info, (params.Query, params.Header)
-    ) and is_scalar_sequence_field(param_field):
-        return False
-    else:
-        assert isinstance(
-            param_field.field_info, params.Body
-        ), f"Param: {param_field.name} can only be a request body, using Body()"
-        return True
-
-
 def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
     field_info = field.field_info
     field_info_in = getattr(field_info, "in_", None)
@@ -557,6 +550,7 @@ async def solve_dependencies(
     dependency_overrides_provider: Optional[Any] = None,
     dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
     async_exit_stack: AsyncExitStack,
+    embed_body_fields: bool,
 ) -> SolvedDependency:
     values: Dict[str, Any] = {}
     errors: List[Any] = []
@@ -598,6 +592,7 @@ async def solve_dependencies(
             dependency_overrides_provider=dependency_overrides_provider,
             dependency_cache=dependency_cache,
             async_exit_stack=async_exit_stack,
+            embed_body_fields=embed_body_fields,
         )
         background_tasks = solved_result.background_tasks
         dependency_cache.update(solved_result.dependency_cache)
@@ -640,7 +635,9 @@ async def solve_dependencies(
             body_values,
             body_errors,
         ) = await request_body_to_args(  # body_params checked above
-            required_params=dependant.body_params, received_body=body
+            body_fields=dependant.body_params,
+            received_body=body,
+            embed_body_fields=embed_body_fields,
         )
         values.update(body_values)
         errors.extend(body_errors)
@@ -669,138 +666,185 @@ async def solve_dependencies(
     )
 
 
+def _validate_value_with_model_field(
+    *, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
+) -> Tuple[Any, List[Any]]:
+    if value is None:
+        if field.required:
+            return None, [get_missing_field_error(loc=loc)]
+        else:
+            return deepcopy(field.default), []
+    v_, errors_ = field.validate(value, values, loc=loc)
+    if isinstance(errors_, ErrorWrapper):
+        return None, [errors_]
+    elif isinstance(errors_, list):
+        new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
+        return None, new_errors
+    else:
+        return v_, []
+
+
+def _get_multidict_value(field: ModelField, values: Mapping[str, Any]) -> Any:
+    if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
+        value = values.getlist(field.alias)
+    else:
+        value = values.get(field.alias, None)
+    if (
+        value is None
+        or (
+            isinstance(field.field_info, params.Form)
+            and isinstance(value, str)  # For type checks
+            and value == ""
+        )
+        or (is_sequence_field(field) and len(value) == 0)
+    ):
+        if field.required:
+            return
+        else:
+            return deepcopy(field.default)
+    return value
+
+
 def request_params_to_args(
-    required_params: Sequence[ModelField],
+    fields: Sequence[ModelField],
     received_params: Union[Mapping[str, Any], QueryParams, Headers],
 ) -> Tuple[Dict[str, Any], List[Any]]:
-    values = {}
+    values: Dict[str, Any] = {}
     errors = []
-    for field in required_params:
-        if is_scalar_sequence_field(field) and isinstance(
-            received_params, (QueryParams, Headers)
-        ):
-            value = received_params.getlist(field.alias) or field.default
-        else:
-            value = received_params.get(field.alias)
+    for field in fields:
+        value = _get_multidict_value(field, received_params)
         field_info = field.field_info
         assert isinstance(
             field_info, params.Param
         ), "Params must be subclasses of Param"
         loc = (field_info.in_.value, field.alias)
-        if value is None:
-            if field.required:
-                errors.append(get_missing_field_error(loc=loc))
-            else:
-                values[field.name] = deepcopy(field.default)
-            continue
-        v_, errors_ = field.validate(value, values, loc=loc)
-        if isinstance(errors_, ErrorWrapper):
-            errors.append(errors_)
-        elif isinstance(errors_, list):
-            new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
-            errors.extend(new_errors)
+        v_, errors_ = _validate_value_with_model_field(
+            field=field, value=value, values=values, loc=loc
+        )
+        if errors_:
+            errors.extend(errors_)
         else:
             values[field.name] = v_
     return values, errors
 
 
+def _should_embed_body_fields(fields: List[ModelField]) -> bool:
+    if not fields:
+        return False
+    # More than one dependency could have the same field, it would show up as multiple
+    # fields but it's the same one, so count them by name
+    body_param_names_set = {field.name for field in fields}
+    # A top level field has to be a single field, not multiple
+    if len(body_param_names_set) > 1:
+        return True
+    first_field = fields[0]
+    # If it explicitly specifies it is embedded, it has to be embedded
+    if getattr(first_field.field_info, "embed", None):
+        return True
+    # If it's a Form (or File) field, it has to be a BaseModel to be top level
+    # otherwise it has to be embedded, so that the key value pair can be extracted
+    if isinstance(first_field.field_info, params.Form):
+        return True
+    return False
+
+
+async def _extract_form_body(
+    body_fields: List[ModelField],
+    received_body: FormData,
+) -> Dict[str, Any]:
+    values = {}
+    first_field = body_fields[0]
+    first_field_info = first_field.field_info
+
+    for field in body_fields:
+        value = _get_multidict_value(field, received_body)
+        if (
+            isinstance(first_field_info, params.File)
+            and is_bytes_field(field)
+            and isinstance(value, UploadFile)
+        ):
+            value = await value.read()
+        elif (
+            is_bytes_sequence_field(field)
+            and isinstance(first_field_info, params.File)
+            and value_is_sequence(value)
+        ):
+            # For types
+            assert isinstance(value, sequence_types)  # type: ignore[arg-type]
+            results: List[Union[bytes, str]] = []
+
+            async def process_fn(
+                fn: Callable[[], Coroutine[Any, Any, Any]],
+            ) -> None:
+                result = await fn()
+                results.append(result)  # noqa: B023
+
+            async with anyio.create_task_group() as tg:
+                for sub_value in value:
+                    tg.start_soon(process_fn, sub_value.read)
+            value = serialize_sequence_value(field=field, value=results)
+        values[field.name] = value
+    return values
+
+
 async def request_body_to_args(
-    required_params: List[ModelField],
+    body_fields: List[ModelField],
     received_body: Optional[Union[Dict[str, Any], FormData]],
+    embed_body_fields: bool,
 ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
-    values = {}
+    values: Dict[str, Any] = {}
     errors: List[Dict[str, Any]] = []
-    if required_params:
-        field = required_params[0]
-        field_info = field.field_info
-        embed = getattr(field_info, "embed", None)
-        field_alias_omitted = len(required_params) == 1 and not embed
-        if field_alias_omitted:
-            received_body = {field.alias: received_body}
-
-        for field in required_params:
-            loc: Tuple[str, ...]
-            if field_alias_omitted:
-                loc = ("body",)
-            else:
-                loc = ("body", field.alias)
-
-            value: Optional[Any] = None
-            if received_body is not None:
-                if (is_sequence_field(field)) and isinstance(received_body, FormData):
-                    value = received_body.getlist(field.alias)
-                else:
-                    try:
-                        value = received_body.get(field.alias)
-                    except AttributeError:
-                        errors.append(get_missing_field_error(loc))
-                        continue
-            if (
-                value is None
-                or (isinstance(field_info, params.Form) and value == "")
-                or (
-                    isinstance(field_info, params.Form)
-                    and is_sequence_field(field)
-                    and len(value) == 0
-                )
-            ):
-                if field.required:
-                    errors.append(get_missing_field_error(loc))
-                else:
-                    values[field.name] = deepcopy(field.default)
+    assert body_fields, "request_body_to_args() should be called with fields"
+    single_not_embedded_field = len(body_fields) == 1 and not embed_body_fields
+    first_field = body_fields[0]
+    body_to_process = received_body
+    if isinstance(received_body, FormData):
+        body_to_process = await _extract_form_body(body_fields, received_body)
+
+    if single_not_embedded_field:
+        loc: Tuple[str, ...] = ("body",)
+        v_, errors_ = _validate_value_with_model_field(
+            field=first_field, value=body_to_process, values=values, loc=loc
+        )
+        return {first_field.name: v_}, errors_
+    for field in body_fields:
+        loc = ("body", field.alias)
+        value: Optional[Any] = None
+        if body_to_process is not None:
+            try:
+                value = body_to_process.get(field.alias)
+            # If the received body is a list, not a dict
+            except AttributeError:
+                errors.append(get_missing_field_error(loc))
                 continue
-            if (
-                isinstance(field_info, params.File)
-                and is_bytes_field(field)
-                and isinstance(value, UploadFile)
-            ):
-                value = await value.read()
-            elif (
-                is_bytes_sequence_field(field)
-                and isinstance(field_info, params.File)
-                and value_is_sequence(value)
-            ):
-                # For types
-                assert isinstance(value, sequence_types)  # type: ignore[arg-type]
-                results: List[Union[bytes, str]] = []
-
-                async def process_fn(
-                    fn: Callable[[], Coroutine[Any, Any, Any]],
-                ) -> None:
-                    result = await fn()
-                    results.append(result)  # noqa: B023
-
-                async with anyio.create_task_group() as tg:
-                    for sub_value in value:
-                        tg.start_soon(process_fn, sub_value.read)
-                value = serialize_sequence_value(field=field, value=results)
-
-            v_, errors_ = field.validate(value, values, loc=loc)
-
-            if isinstance(errors_, list):
-                errors.extend(errors_)
-            elif errors_:
-                errors.append(errors_)
-            else:
-                values[field.name] = v_
+        v_, errors_ = _validate_value_with_model_field(
+            field=field, value=value, values=values, loc=loc
+        )
+        if errors_:
+            errors.extend(errors_)
+        else:
+            values[field.name] = v_
     return values, errors
 
 
-def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
-    flat_dependant = get_flat_dependant(dependant)
+def get_body_field(
+    *, flat_dependant: Dependant, name: str, embed_body_fields: bool
+) -> Optional[ModelField]:
+    """
+    Get a ModelField representing the request body for a path operation, combining
+    all body parameters into a single field if necessary.
+
+    Used to check if it's form data (with `isinstance(body_field, params.Form)`)
+    or JSON and to generate the JSON Schema for a request body.
+
+    This is **not** used to validate/parse the request body, that's done with each
+    individual body parameter.
+    """
     if not flat_dependant.body_params:
         return None
     first_param = flat_dependant.body_params[0]
-    field_info = first_param.field_info
-    embed = getattr(field_info, "embed", None)
-    body_param_names_set = {param.name for param in flat_dependant.body_params}
-    if len(body_param_names_set) == 1 and not embed:
+    if not embed_body_fields:
         return first_param
-    # If one field requires to embed, all have to be embedded
-    # in case a sub-dependency is evaluated with a single unique body field
-    # That is combined (embedded) with other body fields
-    for param in flat_dependant.body_params:
-        setattr(param.field_info, "embed", True)  # noqa: B010
     model_name = "Body_" + name
     BodyModel = create_body_model(
         fields=flat_dependant.body_params, model_name=model_name
index 0d5f27af48a74613b2b7c19244767c47d49fbccb..7ddaace25a936dcb30a82dbf28e8a0d70b72e5df 100644 (file)
@@ -1282,7 +1282,7 @@ def Body(  # noqa: N802
         ),
     ] = _Unset,
     embed: Annotated[
-        bool,
+        Union[bool, None],
         Doc(
             """
             When `embed` is `True`, the parameter will be expected in a JSON body as a
@@ -1294,7 +1294,7 @@ def Body(  # noqa: N802
             [FastAPI docs for Body - Multiple Parameters](https://fastapi.tiangolo.com/tutorial/body-multiple-params/#embed-a-single-body-parameter).
             """
         ),
-    ] = False,
+    ] = None,
     media_type: Annotated[
         str,
         Doc(
index cc2a5c13c02618105be450f05b34bca9df2c41aa..3dfa5a1a381e551f3b70874f6ce5ba81d9b66218 100644 (file)
@@ -479,7 +479,7 @@ class Body(FieldInfo):
         *,
         default_factory: Union[Callable[[], Any], None] = _Unset,
         annotation: Optional[Any] = None,
-        embed: bool = False,
+        embed: Union[bool, None] = None,
         media_type: str = "application/json",
         alias: Optional[str] = None,
         alias_priority: Union[int, None] = _Unset,
@@ -642,7 +642,6 @@ class Form(Body):
             default=default,
             default_factory=default_factory,
             annotation=annotation,
-            embed=True,
             media_type=media_type,
             alias=alias,
             alias_priority=alias_priority,
index 61a112fc47ea69236942708861ca1cb92bb8c481..86e30360216cb050b3debc81cf73537327e4b4fd 100644 (file)
@@ -33,8 +33,10 @@ from fastapi._compat import (
 from fastapi.datastructures import Default, DefaultPlaceholder
 from fastapi.dependencies.models import Dependant
 from fastapi.dependencies.utils import (
+    _should_embed_body_fields,
     get_body_field,
     get_dependant,
+    get_flat_dependant,
     get_parameterless_sub_dependant,
     get_typed_return_annotation,
     solve_dependencies,
@@ -225,6 +227,7 @@ def get_request_handler(
     response_model_exclude_defaults: bool = False,
     response_model_exclude_none: bool = False,
     dependency_overrides_provider: Optional[Any] = None,
+    embed_body_fields: bool = False,
 ) -> Callable[[Request], Coroutine[Any, Any, Response]]:
     assert dependant.call is not None, "dependant.call must be a function"
     is_coroutine = asyncio.iscoroutinefunction(dependant.call)
@@ -291,6 +294,7 @@ def get_request_handler(
                     body=body,
                     dependency_overrides_provider=dependency_overrides_provider,
                     async_exit_stack=async_exit_stack,
+                    embed_body_fields=embed_body_fields,
                 )
                 errors = solved_result.errors
                 if not errors:
@@ -354,7 +358,9 @@ def get_request_handler(
 
 
 def get_websocket_app(
-    dependant: Dependant, dependency_overrides_provider: Optional[Any] = None
+    dependant: Dependant,
+    dependency_overrides_provider: Optional[Any] = None,
+    embed_body_fields: bool = False,
 ) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
     async def app(websocket: WebSocket) -> None:
         async with AsyncExitStack() as async_exit_stack:
@@ -367,6 +373,7 @@ def get_websocket_app(
                 dependant=dependant,
                 dependency_overrides_provider=dependency_overrides_provider,
                 async_exit_stack=async_exit_stack,
+                embed_body_fields=embed_body_fields,
             )
             if solved_result.errors:
                 raise WebSocketRequestValidationError(
@@ -399,11 +406,15 @@ class APIWebSocketRoute(routing.WebSocketRoute):
                 0,
                 get_parameterless_sub_dependant(depends=depends, path=self.path_format),
             )
-
+        self._flat_dependant = get_flat_dependant(self.dependant)
+        self._embed_body_fields = _should_embed_body_fields(
+            self._flat_dependant.body_params
+        )
         self.app = websocket_session(
             get_websocket_app(
                 dependant=self.dependant,
                 dependency_overrides_provider=dependency_overrides_provider,
+                embed_body_fields=self._embed_body_fields,
             )
         )
 
@@ -544,7 +555,15 @@ class APIRoute(routing.Route):
                 0,
                 get_parameterless_sub_dependant(depends=depends, path=self.path_format),
             )
-        self.body_field = get_body_field(dependant=self.dependant, name=self.unique_id)
+        self._flat_dependant = get_flat_dependant(self.dependant)
+        self._embed_body_fields = _should_embed_body_fields(
+            self._flat_dependant.body_params
+        )
+        self.body_field = get_body_field(
+            flat_dependant=self._flat_dependant,
+            name=self.unique_id,
+            embed_body_fields=self._embed_body_fields,
+        )
         self.app = request_response(self.get_route_handler())
 
     def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
@@ -561,6 +580,7 @@ class APIRoute(routing.Route):
             response_model_exclude_defaults=self.response_model_exclude_defaults,
             response_model_exclude_none=self.response_model_exclude_none,
             dependency_overrides_provider=self.dependency_overrides_provider,
+            embed_body_fields=self._embed_body_fields,
         )
 
     def matches(self, scope: Scope) -> Tuple[Match, Scope]:
index bf268b860b1dcdee16afe92db9cbc683c0969236..270475bf3a42a74f4de7fce6bfd1d9a4d554c9bb 100644 (file)
@@ -1,11 +1,13 @@
-from typing import List, Union
+from typing import Any, Dict, List, Union
 
 from fastapi import FastAPI, UploadFile
 from fastapi._compat import (
     ModelField,
     Undefined,
     _get_model_config,
+    get_model_fields,
     is_bytes_sequence_annotation,
+    is_scalar_field,
     is_uploadfile_sequence_annotation,
 )
 from fastapi.testclient import TestClient
@@ -91,3 +93,12 @@ def test_is_uploadfile_sequence_annotation():
     # and other types, but I'm not even sure it's a good idea to support it as a first
     # class "feature"
     assert is_uploadfile_sequence_annotation(Union[List[str], List[UploadFile]])
+
+
+def test_is_pv1_scalar_field():
+    # For coverage
+    class Model(BaseModel):
+        foo: Union[str, Dict[str, Any]]
+
+    fields = get_model_fields(Model)
+    assert not is_scalar_field(fields[0])
diff --git a/tests/test_forms_single_param.py b/tests/test_forms_single_param.py
new file mode 100644 (file)
index 0000000..3bb9514
--- /dev/null
@@ -0,0 +1,99 @@
+from fastapi import FastAPI, Form
+from fastapi.testclient import TestClient
+from typing_extensions import Annotated
+
+app = FastAPI()
+
+
+@app.post("/form/")
+def post_form(username: Annotated[str, Form()]):
+    return username
+
+
+client = TestClient(app)
+
+
+def test_single_form_field():
+    response = client.post("/form/", data={"username": "Rick"})
+    assert response.status_code == 200, response.text
+    assert response.json() == "Rick"
+
+
+def test_openapi_schema():
+    response = client.get("/openapi.json")
+    assert response.status_code == 200, response.text
+    assert response.json() == {
+        "openapi": "3.1.0",
+        "info": {"title": "FastAPI", "version": "0.1.0"},
+        "paths": {
+            "/form/": {
+                "post": {
+                    "summary": "Post Form",
+                    "operationId": "post_form_form__post",
+                    "requestBody": {
+                        "content": {
+                            "application/x-www-form-urlencoded": {
+                                "schema": {
+                                    "$ref": "#/components/schemas/Body_post_form_form__post"
+                                }
+                            }
+                        },
+                        "required": True,
+                    },
+                    "responses": {
+                        "200": {
+                            "description": "Successful Response",
+                            "content": {"application/json": {"schema": {}}},
+                        },
+                        "422": {
+                            "description": "Validation Error",
+                            "content": {
+                                "application/json": {
+                                    "schema": {
+                                        "$ref": "#/components/schemas/HTTPValidationError"
+                                    }
+                                }
+                            },
+                        },
+                    },
+                }
+            }
+        },
+        "components": {
+            "schemas": {
+                "Body_post_form_form__post": {
+                    "properties": {"username": {"type": "string", "title": "Username"}},
+                    "type": "object",
+                    "required": ["username"],
+                    "title": "Body_post_form_form__post",
+                },
+                "HTTPValidationError": {
+                    "properties": {
+                        "detail": {
+                            "items": {"$ref": "#/components/schemas/ValidationError"},
+                            "type": "array",
+                            "title": "Detail",
+                        }
+                    },
+                    "type": "object",
+                    "title": "HTTPValidationError",
+                },
+                "ValidationError": {
+                    "properties": {
+                        "loc": {
+                            "items": {
+                                "anyOf": [{"type": "string"}, {"type": "integer"}]
+                            },
+                            "type": "array",
+                            "title": "Location",
+                        },
+                        "msg": {"type": "string", "title": "Message"},
+                        "type": {"type": "string", "title": "Error Type"},
+                    },
+                    "type": "object",
+                    "required": ["loc", "msg", "type"],
+                    "title": "ValidationError",
+                },
+            }
+        },
+    }