for f in flat_dependant.body_params:
BodyModel.__fields__[f.name] = get_schema_compatible_field(field=f)
required = any(True for f in flat_dependant.body_params if f.required)
+
+ BodySchema_kwargs: Dict[str, Any] = dict(default=None)
if any(isinstance(f.schema, params.File) for f in flat_dependant.body_params):
BodySchema: Type[params.Body] = params.File
elif any(isinstance(f.schema, params.Form) for f in flat_dependant.body_params):
else:
BodySchema = params.Body
+ body_param_media_types = [
+ getattr(f.schema, "media_type")
+ for f in flat_dependant.body_params
+ if isinstance(f.schema, params.Body)
+ ]
+ if len(set(body_param_media_types)) == 1:
+ BodySchema_kwargs["media_type"] = body_param_media_types[0]
+
field = Field(
name="body",
type_=BodyModel,
model_config=BaseConfig,
class_validators={},
alias="body",
- schema=BodySchema(None),
+ schema=BodySchema(**BodySchema_kwargs),
)
return field
--- /dev/null
+import typing
+
+from fastapi import Body, FastAPI
+from pydantic import BaseModel
+from starlette.testclient import TestClient
+
+app = FastAPI()
+
+media_type = "application/vnd.api+json"
+
+# NOTE: These are not valid JSON:API resources
+# but they are fine for testing requestBody with custom media_type
+class Product(BaseModel):
+ name: str
+ price: float
+
+
+class Shop(BaseModel):
+ name: str
+
+
+@app.post("/products")
+async def create_product(data: Product = Body(..., media_type=media_type, embed=True)):
+ pass # pragma: no cover
+
+
+@app.post("/shops")
+async def create_shop(
+ data: Shop = Body(..., media_type=media_type),
+ included: typing.List[Product] = Body([], media_type=media_type),
+):
+ pass # pragma: no cover
+
+
+create_product_request_body = {
+ "content": {
+ "application/vnd.api+json": {
+ "schema": {"$ref": "#/components/schemas/Body_create_product_products_post"}
+ }
+ },
+ "required": True,
+}
+
+create_shop_request_body = {
+ "content": {
+ "application/vnd.api+json": {
+ "schema": {"$ref": "#/components/schemas/Body_create_shop_shops_post"}
+ }
+ },
+ "required": True,
+}
+
+client = TestClient(app)
+
+
+def test_openapi_schema():
+ response = client.get("/openapi.json")
+ assert response.status_code == 200
+ openapi_schema = response.json()
+ assert (
+ openapi_schema["paths"]["/products"]["post"]["requestBody"]
+ == create_product_request_body
+ )
+ assert (
+ openapi_schema["paths"]["/shops"]["post"]["requestBody"]
+ == create_shop_request_body
+ )