]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:sparkles: Allow using custom 422 validation error and use media type from response...
authorZoltan Papp <divums@users.noreply.github.com>
Fri, 30 Aug 2019 21:46:05 +0000 (00:46 +0300)
committerSebastián Ramírez <tiangolo@gmail.com>
Fri, 30 Aug 2019 21:46:05 +0000 (16:46 -0500)
* media_type of additional responses from the response_class

* Use HTTPValidationError only if a custom one is not defined (Fixes: #429)

fastapi/openapi/utils.py
tests/test_additional_responses_custom_validationerror.py [new file with mode: 0644]
tests/test_additional_responses_default_validationerror.py [new file with mode: 0644]
tests/test_additional_responses_response_class.py [new file with mode: 0644]

index c3cc120fd6eca2fd453e391285ed4c40995bfc88..6c987a29fa9155e74d9c007a51bc0a4df8d27ab4 100644 (file)
@@ -80,15 +80,11 @@ def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, L
 
 def get_openapi_operation_parameters(
     all_route_params: Sequence[Field]
-) -> Tuple[Dict[str, Dict], List[Dict[str, Any]]]:
-    definitions: Dict[str, Dict] = {}
+) -> List[Dict[str, Any]]:
     parameters = []
     for param in all_route_params:
         schema = param.schema
         schema = cast(Param, schema)
-        if "ValidationError" not in definitions:
-            definitions["ValidationError"] = validation_error_definition
-            definitions["HTTPValidationError"] = validation_error_response_definition
         parameter = {
             "name": param.alias,
             "in": schema.in_.value,
@@ -100,7 +96,7 @@ def get_openapi_operation_parameters(
         if schema.deprecated:
             parameter["deprecated"] = schema.deprecated
         parameters.append(parameter)
-    return definitions, parameters
+    return parameters
 
 
 def get_openapi_operation_request_body(
@@ -168,10 +164,7 @@ def get_openapi_path(
             if security_definitions:
                 security_schemes.update(security_definitions)
             all_route_params = get_openapi_params(route.dependant)
-            validation_definitions, operation_parameters = get_openapi_operation_parameters(
-                all_route_params=all_route_params
-            )
-            definitions.update(validation_definitions)
+            operation_parameters = get_openapi_operation_parameters(all_route_params)
             parameters.extend(operation_parameters)
             if parameters:
                 operation["parameters"] = parameters
@@ -181,11 +174,6 @@ def get_openapi_path(
                 )
                 if request_body_oai:
                     operation["requestBody"] = request_body_oai
-                    if "ValidationError" not in definitions:
-                        definitions["ValidationError"] = validation_error_definition
-                        definitions[
-                            "HTTPValidationError"
-                        ] = validation_error_response_definition
             if route.responses:
                 for (additional_status_code, response) in route.responses.items():
                     assert isinstance(
@@ -197,7 +185,7 @@ def get_openapi_path(
                             field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
                         )
                         response.setdefault("content", {}).setdefault(
-                            "application/json", {}
+                            route.response_class.media_type, {}
                         )["schema"] = response_schema
                     status_text: Optional[str] = status_code_ranges.get(
                         str(additional_status_code).upper()
@@ -228,8 +216,15 @@ def get_openapi_path(
             ).setdefault("content", {}).setdefault(route.response_class.media_type, {})[
                 "schema"
             ] = response_schema
-            if all_route_params or route.body_field:
-                operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = {
+
+            http422 = str(HTTP_422_UNPROCESSABLE_ENTITY)
+            if (all_route_params or route.body_field) and not any(
+                [
+                    status in operation["responses"]
+                    for status in [http422, "4xx", "default"]
+                ]
+            ):
+                operation["responses"][http422] = {
                     "description": "Validation Error",
                     "content": {
                         "application/json": {
@@ -237,6 +232,13 @@ def get_openapi_path(
                         }
                     },
                 }
+                if "ValidationError" not in definitions:
+                    definitions.update(
+                        {
+                            "ValidationError": validation_error_definition,
+                            "HTTPValidationError": validation_error_response_definition,
+                        }
+                    )
             path[method.lower()] = operation
     return path, security_schemes, definitions
 
diff --git a/tests/test_additional_responses_custom_validationerror.py b/tests/test_additional_responses_custom_validationerror.py
new file mode 100644 (file)
index 0000000..37982ee
--- /dev/null
@@ -0,0 +1,100 @@
+import typing
+
+from fastapi import FastAPI
+from pydantic import BaseModel
+from starlette.responses import JSONResponse
+from starlette.testclient import TestClient
+
+app = FastAPI()
+
+
+class JsonApiResponse(JSONResponse):
+    media_type = "application/vnd.api+json"
+
+
+class Error(BaseModel):
+    status: str
+    title: str
+
+
+class JsonApiError(BaseModel):
+    errors: typing.List[Error]
+
+
+@app.get(
+    "/a/{id}",
+    response_class=JsonApiResponse,
+    responses={422: {"description": "Error", "model": JsonApiError}},
+)
+async def a(id):
+    pass  # pragma: no cover
+
+
+openapi_schema = {
+    "openapi": "3.0.2",
+    "info": {"title": "Fast API", "version": "0.1.0"},
+    "paths": {
+        "/a/{id}": {
+            "get": {
+                "responses": {
+                    "422": {
+                        "description": "Error",
+                        "content": {
+                            "application/vnd.api+json": {
+                                "schema": {"$ref": "#/components/schemas/JsonApiError"}
+                            }
+                        },
+                    },
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/vnd.api+json": {"schema": {}}},
+                    },
+                },
+                "summary": "A",
+                "operationId": "a_a__id__get",
+                "parameters": [
+                    {
+                        "required": True,
+                        "schema": {"title": "Id"},
+                        "name": "id",
+                        "in": "path",
+                    }
+                ],
+            }
+        }
+    },
+    "components": {
+        "schemas": {
+            "Error": {
+                "title": "Error",
+                "required": ["status", "title"],
+                "type": "object",
+                "properties": {
+                    "status": {"title": "Status", "type": "string"},
+                    "title": {"title": "Title", "type": "string"},
+                },
+            },
+            "JsonApiError": {
+                "title": "JsonApiError",
+                "required": ["errors"],
+                "type": "object",
+                "properties": {
+                    "errors": {
+                        "title": "Errors",
+                        "type": "array",
+                        "items": {"$ref": "#/components/schemas/Error"},
+                    }
+                },
+            },
+        }
+    },
+}
+
+
+client = TestClient(app)
+
+
+def test_openapi_schema():
+    response = client.get("/openapi.json")
+    assert response.status_code == 200
+    assert response.json() == openapi_schema
diff --git a/tests/test_additional_responses_default_validationerror.py b/tests/test_additional_responses_default_validationerror.py
new file mode 100644 (file)
index 0000000..ac22bf5
--- /dev/null
@@ -0,0 +1,85 @@
+from fastapi import FastAPI
+from starlette.testclient import TestClient
+
+app = FastAPI()
+
+
+@app.get("/a/{id}")
+async def a(id):
+    pass  # pragma: no cover
+
+
+openapi_schema = {
+    "openapi": "3.0.2",
+    "info": {"title": "Fast API", "version": "0.1.0"},
+    "paths": {
+        "/a/{id}": {
+            "get": {
+                "responses": {
+                    "422": {
+                        "description": "Validation Error",
+                        "content": {
+                            "application/json": {
+                                "schema": {
+                                    "$ref": "#/components/schemas/HTTPValidationError"
+                                }
+                            }
+                        },
+                    },
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/json": {"schema": {}}},
+                    },
+                },
+                "summary": "A",
+                "operationId": "a_a__id__get",
+                "parameters": [
+                    {
+                        "required": True,
+                        "schema": {"title": "Id"},
+                        "name": "id",
+                        "in": "path",
+                    }
+                ],
+            }
+        }
+    },
+    "components": {
+        "schemas": {
+            "ValidationError": {
+                "title": "ValidationError",
+                "required": ["loc", "msg", "type"],
+                "type": "object",
+                "properties": {
+                    "loc": {
+                        "title": "Location",
+                        "type": "array",
+                        "items": {"type": "string"},
+                    },
+                    "msg": {"title": "Message", "type": "string"},
+                    "type": {"title": "Error Type", "type": "string"},
+                },
+            },
+            "HTTPValidationError": {
+                "title": "HTTPValidationError",
+                "type": "object",
+                "properties": {
+                    "detail": {
+                        "title": "Detail",
+                        "type": "array",
+                        "items": {"$ref": "#/components/schemas/ValidationError"},
+                    }
+                },
+            },
+        }
+    },
+}
+
+
+client = TestClient(app)
+
+
+def test_openapi_schema():
+    response = client.get("/openapi.json")
+    assert response.status_code == 200
+    assert response.json() == openapi_schema
diff --git a/tests/test_additional_responses_response_class.py b/tests/test_additional_responses_response_class.py
new file mode 100644 (file)
index 0000000..81c28e3
--- /dev/null
@@ -0,0 +1,117 @@
+import typing
+
+from fastapi import FastAPI
+from pydantic import BaseModel
+from starlette.responses import JSONResponse
+from starlette.testclient import TestClient
+
+app = FastAPI()
+
+
+class JsonApiResponse(JSONResponse):
+    media_type = "application/vnd.api+json"
+
+
+class Error(BaseModel):
+    status: str
+    title: str
+
+
+class JsonApiError(BaseModel):
+    errors: typing.List[Error]
+
+
+@app.get(
+    "/a",
+    response_class=JsonApiResponse,
+    responses={500: {"description": "Error", "model": JsonApiError}},
+)
+async def a():
+    pass  # pragma: no cover
+
+
+@app.get("/b", responses={500: {"description": "Error", "model": Error}})
+async def b():
+    pass  # pragma: no cover
+
+
+openapi_schema = {
+    "openapi": "3.0.2",
+    "info": {"title": "Fast API", "version": "0.1.0"},
+    "paths": {
+        "/a": {
+            "get": {
+                "responses": {
+                    "500": {
+                        "description": "Error",
+                        "content": {
+                            "application/vnd.api+json": {
+                                "schema": {"$ref": "#/components/schemas/JsonApiError"}
+                            }
+                        },
+                    },
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/vnd.api+json": {"schema": {}}},
+                    },
+                },
+                "summary": "A",
+                "operationId": "a_a_get",
+            }
+        },
+        "/b": {
+            "get": {
+                "responses": {
+                    "500": {
+                        "description": "Error",
+                        "content": {
+                            "application/json": {
+                                "schema": {"$ref": "#/components/schemas/Error"}
+                            }
+                        },
+                    },
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/json": {"schema": {}}},
+                    },
+                },
+                "summary": "B",
+                "operationId": "b_b_get",
+            }
+        },
+    },
+    "components": {
+        "schemas": {
+            "Error": {
+                "title": "Error",
+                "required": ["status", "title"],
+                "type": "object",
+                "properties": {
+                    "status": {"title": "Status", "type": "string"},
+                    "title": {"title": "Title", "type": "string"},
+                },
+            },
+            "JsonApiError": {
+                "title": "JsonApiError",
+                "required": ["errors"],
+                "type": "object",
+                "properties": {
+                    "errors": {
+                        "title": "Errors",
+                        "type": "array",
+                        "items": {"$ref": "#/components/schemas/Error"},
+                    }
+                },
+            },
+        }
+    },
+}
+
+
+client = TestClient(app)
+
+
+def test_openapi_schema():
+    response = client.get("/openapi.json")
+    assert response.status_code == 200
+    assert response.json() == openapi_schema