]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:bug: use media_type from Body params for OpenAPI requestBody (Fixes: #431) (#439)
authorZoltan Papp <divums@users.noreply.github.com>
Fri, 30 Aug 2019 22:32:39 +0000 (01:32 +0300)
committerSebastián Ramírez <tiangolo@gmail.com>
Fri, 30 Aug 2019 22:32:39 +0000 (17:32 -0500)
fastapi/dependencies/utils.py
tests/test_request_body_parameters_media_type.py [new file with mode: 0644]

index f9e42d0a82f922c6c0c397a1386d9dc7a21228b4..7f0f59092233ef7a890f01bed4aad8abb889bbfb 100644 (file)
@@ -559,6 +559,8 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[Field]:
     for f in flat_dependant.body_params:
         BodyModel.__fields__[f.name] = get_schema_compatible_field(field=f)
     required = any(True for f in flat_dependant.body_params if f.required)
+
+    BodySchema_kwargs: Dict[str, Any] = dict(default=None)
     if any(isinstance(f.schema, params.File) for f in flat_dependant.body_params):
         BodySchema: Type[params.Body] = params.File
     elif any(isinstance(f.schema, params.Form) for f in flat_dependant.body_params):
@@ -566,6 +568,14 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[Field]:
     else:
         BodySchema = params.Body
 
+        body_param_media_types = [
+            getattr(f.schema, "media_type")
+            for f in flat_dependant.body_params
+            if isinstance(f.schema, params.Body)
+        ]
+        if len(set(body_param_media_types)) == 1:
+            BodySchema_kwargs["media_type"] = body_param_media_types[0]
+
     field = Field(
         name="body",
         type_=BodyModel,
@@ -574,6 +584,6 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[Field]:
         model_config=BaseConfig,
         class_validators={},
         alias="body",
-        schema=BodySchema(None),
+        schema=BodySchema(**BodySchema_kwargs),
     )
     return field
diff --git a/tests/test_request_body_parameters_media_type.py b/tests/test_request_body_parameters_media_type.py
new file mode 100644 (file)
index 0000000..89b98b2
--- /dev/null
@@ -0,0 +1,67 @@
+import typing
+
+from fastapi import Body, FastAPI
+from pydantic import BaseModel
+from starlette.testclient import TestClient
+
+app = FastAPI()
+
+media_type = "application/vnd.api+json"
+
+# NOTE: These are not valid JSON:API resources
+# but they are fine for testing requestBody with custom media_type
+class Product(BaseModel):
+    name: str
+    price: float
+
+
+class Shop(BaseModel):
+    name: str
+
+
+@app.post("/products")
+async def create_product(data: Product = Body(..., media_type=media_type, embed=True)):
+    pass  # pragma: no cover
+
+
+@app.post("/shops")
+async def create_shop(
+    data: Shop = Body(..., media_type=media_type),
+    included: typing.List[Product] = Body([], media_type=media_type),
+):
+    pass  # pragma: no cover
+
+
+create_product_request_body = {
+    "content": {
+        "application/vnd.api+json": {
+            "schema": {"$ref": "#/components/schemas/Body_create_product_products_post"}
+        }
+    },
+    "required": True,
+}
+
+create_shop_request_body = {
+    "content": {
+        "application/vnd.api+json": {
+            "schema": {"$ref": "#/components/schemas/Body_create_shop_shops_post"}
+        }
+    },
+    "required": True,
+}
+
+client = TestClient(app)
+
+
+def test_openapi_schema():
+    response = client.get("/openapi.json")
+    assert response.status_code == 200
+    openapi_schema = response.json()
+    assert (
+        openapi_schema["paths"]["/products"]["post"]["requestBody"]
+        == create_product_request_body
+    )
+    assert (
+        openapi_schema["paths"]["/shops"]["post"]["requestBody"]
+        == create_shop_request_body
+    )