from pydantic.fields import FieldInfo
from starlette.background import BackgroundTasks as StarletteBackgroundTasks
from starlette.concurrency import run_in_threadpool
-from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
+from starlette.datastructures import (
+ FormData,
+ Headers,
+ ImmutableMultiDict,
+ QueryParams,
+ UploadFile,
+)
from starlette.requests import HTTPConnection, Request
from starlette.responses import Response
from starlette.websockets import WebSocket
), f"Cannot specify multiple FastAPI annotations for {param_name!r}"
continue
assert param_details.field is not None
- if is_body_param(param_field=param_details.field, is_path_param=is_path_param):
+ if isinstance(param_details.field.field_info, params.Body):
dependant.body_params.append(param_details.field)
else:
add_param_to_fields(field=param_details.field, dependant=dependant)
required=field_info.default in (Required, Undefined),
field_info=field_info,
)
+ if is_path_param:
+ assert is_scalar_field(
+ field=field
+ ), "Path params must be of one of the supported types"
+ elif isinstance(field_info, params.Query):
+ assert is_scalar_field(field) or is_scalar_sequence_field(field)
return ParamDetails(type_annotation=type_annotation, depends=depends, field=field)
-def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
- if is_path_param:
- assert is_scalar_field(
- field=param_field
- ), "Path params must be of one of the supported types"
- return False
- elif is_scalar_field(field=param_field):
- return False
- elif isinstance(
- param_field.field_info, (params.Query, params.Header)
- ) and is_scalar_sequence_field(param_field):
- return False
- else:
- assert isinstance(
- param_field.field_info, params.Body
- ), f"Param: {param_field.name} can only be a request body, using Body()"
- return True
-
-
def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
field_info = field.field_info
field_info_in = getattr(field_info, "in_", None)
dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
async_exit_stack: AsyncExitStack,
+ embed_body_fields: bool,
) -> SolvedDependency:
values: Dict[str, Any] = {}
errors: List[Any] = []
dependency_overrides_provider=dependency_overrides_provider,
dependency_cache=dependency_cache,
async_exit_stack=async_exit_stack,
+ embed_body_fields=embed_body_fields,
)
background_tasks = solved_result.background_tasks
dependency_cache.update(solved_result.dependency_cache)
body_values,
body_errors,
) = await request_body_to_args( # body_params checked above
- required_params=dependant.body_params, received_body=body
+ body_fields=dependant.body_params,
+ received_body=body,
+ embed_body_fields=embed_body_fields,
)
values.update(body_values)
errors.extend(body_errors)
)
+def _validate_value_with_model_field(
+ *, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
+) -> Tuple[Any, List[Any]]:
+ if value is None:
+ if field.required:
+ return None, [get_missing_field_error(loc=loc)]
+ else:
+ return deepcopy(field.default), []
+ v_, errors_ = field.validate(value, values, loc=loc)
+ if isinstance(errors_, ErrorWrapper):
+ return None, [errors_]
+ elif isinstance(errors_, list):
+ new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
+ return None, new_errors
+ else:
+ return v_, []
+
+
+def _get_multidict_value(field: ModelField, values: Mapping[str, Any]) -> Any:
+ if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
+ value = values.getlist(field.alias)
+ else:
+ value = values.get(field.alias, None)
+ if (
+ value is None
+ or (
+ isinstance(field.field_info, params.Form)
+ and isinstance(value, str) # For type checks
+ and value == ""
+ )
+ or (is_sequence_field(field) and len(value) == 0)
+ ):
+ if field.required:
+ return
+ else:
+ return deepcopy(field.default)
+ return value
+
+
def request_params_to_args(
- required_params: Sequence[ModelField],
+ fields: Sequence[ModelField],
received_params: Union[Mapping[str, Any], QueryParams, Headers],
) -> Tuple[Dict[str, Any], List[Any]]:
- values = {}
+ values: Dict[str, Any] = {}
errors = []
- for field in required_params:
- if is_scalar_sequence_field(field) and isinstance(
- received_params, (QueryParams, Headers)
- ):
- value = received_params.getlist(field.alias) or field.default
- else:
- value = received_params.get(field.alias)
+ for field in fields:
+ value = _get_multidict_value(field, received_params)
field_info = field.field_info
assert isinstance(
field_info, params.Param
), "Params must be subclasses of Param"
loc = (field_info.in_.value, field.alias)
- if value is None:
- if field.required:
- errors.append(get_missing_field_error(loc=loc))
- else:
- values[field.name] = deepcopy(field.default)
- continue
- v_, errors_ = field.validate(value, values, loc=loc)
- if isinstance(errors_, ErrorWrapper):
- errors.append(errors_)
- elif isinstance(errors_, list):
- new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
- errors.extend(new_errors)
+ v_, errors_ = _validate_value_with_model_field(
+ field=field, value=value, values=values, loc=loc
+ )
+ if errors_:
+ errors.extend(errors_)
else:
values[field.name] = v_
return values, errors
+def _should_embed_body_fields(fields: List[ModelField]) -> bool:
+ if not fields:
+ return False
+ # More than one dependency could have the same field, it would show up as multiple
+ # fields but it's the same one, so count them by name
+ body_param_names_set = {field.name for field in fields}
+ # A top level field has to be a single field, not multiple
+ if len(body_param_names_set) > 1:
+ return True
+ first_field = fields[0]
+ # If it explicitly specifies it is embedded, it has to be embedded
+ if getattr(first_field.field_info, "embed", None):
+ return True
+ # If it's a Form (or File) field, it has to be a BaseModel to be top level
+ # otherwise it has to be embedded, so that the key value pair can be extracted
+ if isinstance(first_field.field_info, params.Form):
+ return True
+ return False
+
+
+async def _extract_form_body(
+ body_fields: List[ModelField],
+ received_body: FormData,
+) -> Dict[str, Any]:
+ values = {}
+ first_field = body_fields[0]
+ first_field_info = first_field.field_info
+
+ for field in body_fields:
+ value = _get_multidict_value(field, received_body)
+ if (
+ isinstance(first_field_info, params.File)
+ and is_bytes_field(field)
+ and isinstance(value, UploadFile)
+ ):
+ value = await value.read()
+ elif (
+ is_bytes_sequence_field(field)
+ and isinstance(first_field_info, params.File)
+ and value_is_sequence(value)
+ ):
+ # For types
+ assert isinstance(value, sequence_types) # type: ignore[arg-type]
+ results: List[Union[bytes, str]] = []
+
+ async def process_fn(
+ fn: Callable[[], Coroutine[Any, Any, Any]],
+ ) -> None:
+ result = await fn()
+ results.append(result) # noqa: B023
+
+ async with anyio.create_task_group() as tg:
+ for sub_value in value:
+ tg.start_soon(process_fn, sub_value.read)
+ value = serialize_sequence_value(field=field, value=results)
+ values[field.name] = value
+ return values
+
+
async def request_body_to_args(
- required_params: List[ModelField],
+ body_fields: List[ModelField],
received_body: Optional[Union[Dict[str, Any], FormData]],
+ embed_body_fields: bool,
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
- values = {}
+ values: Dict[str, Any] = {}
errors: List[Dict[str, Any]] = []
- if required_params:
- field = required_params[0]
- field_info = field.field_info
- embed = getattr(field_info, "embed", None)
- field_alias_omitted = len(required_params) == 1 and not embed
- if field_alias_omitted:
- received_body = {field.alias: received_body}
-
- for field in required_params:
- loc: Tuple[str, ...]
- if field_alias_omitted:
- loc = ("body",)
- else:
- loc = ("body", field.alias)
-
- value: Optional[Any] = None
- if received_body is not None:
- if (is_sequence_field(field)) and isinstance(received_body, FormData):
- value = received_body.getlist(field.alias)
- else:
- try:
- value = received_body.get(field.alias)
- except AttributeError:
- errors.append(get_missing_field_error(loc))
- continue
- if (
- value is None
- or (isinstance(field_info, params.Form) and value == "")
- or (
- isinstance(field_info, params.Form)
- and is_sequence_field(field)
- and len(value) == 0
- )
- ):
- if field.required:
- errors.append(get_missing_field_error(loc))
- else:
- values[field.name] = deepcopy(field.default)
+ assert body_fields, "request_body_to_args() should be called with fields"
+ single_not_embedded_field = len(body_fields) == 1 and not embed_body_fields
+ first_field = body_fields[0]
+ body_to_process = received_body
+ if isinstance(received_body, FormData):
+ body_to_process = await _extract_form_body(body_fields, received_body)
+
+ if single_not_embedded_field:
+ loc: Tuple[str, ...] = ("body",)
+ v_, errors_ = _validate_value_with_model_field(
+ field=first_field, value=body_to_process, values=values, loc=loc
+ )
+ return {first_field.name: v_}, errors_
+ for field in body_fields:
+ loc = ("body", field.alias)
+ value: Optional[Any] = None
+ if body_to_process is not None:
+ try:
+ value = body_to_process.get(field.alias)
+ # If the received body is a list, not a dict
+ except AttributeError:
+ errors.append(get_missing_field_error(loc))
continue
- if (
- isinstance(field_info, params.File)
- and is_bytes_field(field)
- and isinstance(value, UploadFile)
- ):
- value = await value.read()
- elif (
- is_bytes_sequence_field(field)
- and isinstance(field_info, params.File)
- and value_is_sequence(value)
- ):
- # For types
- assert isinstance(value, sequence_types) # type: ignore[arg-type]
- results: List[Union[bytes, str]] = []
-
- async def process_fn(
- fn: Callable[[], Coroutine[Any, Any, Any]],
- ) -> None:
- result = await fn()
- results.append(result) # noqa: B023
-
- async with anyio.create_task_group() as tg:
- for sub_value in value:
- tg.start_soon(process_fn, sub_value.read)
- value = serialize_sequence_value(field=field, value=results)
-
- v_, errors_ = field.validate(value, values, loc=loc)
-
- if isinstance(errors_, list):
- errors.extend(errors_)
- elif errors_:
- errors.append(errors_)
- else:
- values[field.name] = v_
+ v_, errors_ = _validate_value_with_model_field(
+ field=field, value=value, values=values, loc=loc
+ )
+ if errors_:
+ errors.extend(errors_)
+ else:
+ values[field.name] = v_
return values, errors
-def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
- flat_dependant = get_flat_dependant(dependant)
+def get_body_field(
+ *, flat_dependant: Dependant, name: str, embed_body_fields: bool
+) -> Optional[ModelField]:
+ """
+ Get a ModelField representing the request body for a path operation, combining
+ all body parameters into a single field if necessary.
+
+ Used to check if it's form data (with `isinstance(body_field, params.Form)`)
+ or JSON and to generate the JSON Schema for a request body.
+
+ This is **not** used to validate/parse the request body, that's done with each
+ individual body parameter.
+ """
if not flat_dependant.body_params:
return None
first_param = flat_dependant.body_params[0]
- field_info = first_param.field_info
- embed = getattr(field_info, "embed", None)
- body_param_names_set = {param.name for param in flat_dependant.body_params}
- if len(body_param_names_set) == 1 and not embed:
+ if not embed_body_fields:
return first_param
- # If one field requires to embed, all have to be embedded
- # in case a sub-dependency is evaluated with a single unique body field
- # That is combined (embedded) with other body fields
- for param in flat_dependant.body_params:
- setattr(param.field_info, "embed", True) # noqa: B010
model_name = "Body_" + name
BodyModel = create_body_model(
fields=flat_dependant.body_params, model_name=model_name
from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import (
+ _should_embed_body_fields,
get_body_field,
get_dependant,
+ get_flat_dependant,
get_parameterless_sub_dependant,
get_typed_return_annotation,
solve_dependencies,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
dependency_overrides_provider: Optional[Any] = None,
+ embed_body_fields: bool = False,
) -> Callable[[Request], Coroutine[Any, Any, Response]]:
assert dependant.call is not None, "dependant.call must be a function"
is_coroutine = asyncio.iscoroutinefunction(dependant.call)
body=body,
dependency_overrides_provider=dependency_overrides_provider,
async_exit_stack=async_exit_stack,
+ embed_body_fields=embed_body_fields,
)
errors = solved_result.errors
if not errors:
def get_websocket_app(
- dependant: Dependant, dependency_overrides_provider: Optional[Any] = None
+ dependant: Dependant,
+ dependency_overrides_provider: Optional[Any] = None,
+ embed_body_fields: bool = False,
) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
async def app(websocket: WebSocket) -> None:
async with AsyncExitStack() as async_exit_stack:
dependant=dependant,
dependency_overrides_provider=dependency_overrides_provider,
async_exit_stack=async_exit_stack,
+ embed_body_fields=embed_body_fields,
)
if solved_result.errors:
raise WebSocketRequestValidationError(
0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
)
-
+ self._flat_dependant = get_flat_dependant(self.dependant)
+ self._embed_body_fields = _should_embed_body_fields(
+ self._flat_dependant.body_params
+ )
self.app = websocket_session(
get_websocket_app(
dependant=self.dependant,
dependency_overrides_provider=dependency_overrides_provider,
+ embed_body_fields=self._embed_body_fields,
)
)
0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
)
- self.body_field = get_body_field(dependant=self.dependant, name=self.unique_id)
+ self._flat_dependant = get_flat_dependant(self.dependant)
+ self._embed_body_fields = _should_embed_body_fields(
+ self._flat_dependant.body_params
+ )
+ self.body_field = get_body_field(
+ flat_dependant=self._flat_dependant,
+ name=self.unique_id,
+ embed_body_fields=self._embed_body_fields,
+ )
self.app = request_response(self.get_route_handler())
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
response_model_exclude_defaults=self.response_model_exclude_defaults,
response_model_exclude_none=self.response_model_exclude_none,
dependency_overrides_provider=self.dependency_overrides_provider,
+ embed_body_fields=self._embed_body_fields,
)
def matches(self, scope: Scope) -> Tuple[Match, Scope]: