]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
🐛 Fix OpenAPI generation when using callbacks with routers including Pydantic models...
authorNik <sidnev.nick@gmail.com>
Fri, 12 Jun 2020 20:35:59 +0000 (23:35 +0300)
committerGitHub <noreply@github.com>
Fri, 12 Jun 2020 20:35:59 +0000 (22:35 +0200)
* drop model class from additional responses when generating openapi

* ♻️ Copy response to be mutated early in get_openapi_path

Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
fastapi/openapi/utils.py
tests/test_additional_responses_custom_model_in_callback.py [new file with mode: 0644]

index b5778327bbad182db019f684dd5ecb0d67d4fd57..bb2e7dff74475e9180fab4914b2562f3ec19ee3c 100644 (file)
@@ -203,27 +203,31 @@ def get_openapi_path(
                 operation["callbacks"] = callbacks
             if route.responses:
                 for (additional_status_code, response) in route.responses.items():
+                    process_response = response.copy()
                     assert isinstance(
-                        response, dict
+                        process_response, dict
                     ), "An additional response must be a dict"
                     field = route.response_fields.get(additional_status_code)
                     if field:
                         response_schema, _, _ = field_schema(
                             field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
                         )
-                        response.setdefault("content", {}).setdefault(
+                        process_response.setdefault("content", {}).setdefault(
                             route_response_media_type or "application/json", {}
                         )["schema"] = response_schema
                     status_text: Optional[str] = status_code_ranges.get(
                         str(additional_status_code).upper()
                     ) or http.client.responses.get(int(additional_status_code))
-                    response.setdefault(
+                    process_response.setdefault(
                         "description", status_text or "Additional Response"
                     )
                     status_code_key = str(additional_status_code).upper()
                     if status_code_key == "DEFAULT":
                         status_code_key = "default"
-                    operation.setdefault("responses", {})[status_code_key] = response
+                    process_response.pop("model", None)
+                    operation.setdefault("responses", {})[
+                        status_code_key
+                    ] = process_response
             status_code = str(route.status_code)
             operation.setdefault("responses", {}).setdefault(status_code, {})[
                 "description"
diff --git a/tests/test_additional_responses_custom_model_in_callback.py b/tests/test_additional_responses_custom_model_in_callback.py
new file mode 100644 (file)
index 0000000..36dd0d6
--- /dev/null
@@ -0,0 +1,138 @@
+from fastapi import APIRouter, FastAPI
+from fastapi.testclient import TestClient
+from pydantic import BaseModel, HttpUrl
+from starlette.responses import JSONResponse
+
+
+class CustomModel(BaseModel):
+    a: int
+
+
+app = FastAPI()
+
+callback_router = APIRouter(default_response_class=JSONResponse)
+
+
+@callback_router.get(
+    "{$callback_url}/callback/", responses={400: {"model": CustomModel}}
+)
+def callback_route():
+    pass  # pragma: no cover
+
+
+@app.post("/", callbacks=callback_router.routes)
+def main_route(callback_url: HttpUrl):
+    pass  # pragma: no cover
+
+
+openapi_schema = {
+    "openapi": "3.0.2",
+    "info": {"title": "FastAPI", "version": "0.1.0"},
+    "paths": {
+        "/": {
+            "post": {
+                "summary": "Main Route",
+                "operationId": "main_route__post",
+                "parameters": [
+                    {
+                        "required": True,
+                        "schema": {
+                            "title": "Callback Url",
+                            "maxLength": 2083,
+                            "minLength": 1,
+                            "type": "string",
+                            "format": "uri",
+                        },
+                        "name": "callback_url",
+                        "in": "query",
+                    }
+                ],
+                "responses": {
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/json": {"schema": {}}},
+                    },
+                    "422": {
+                        "description": "Validation Error",
+                        "content": {
+                            "application/json": {
+                                "schema": {
+                                    "$ref": "#/components/schemas/HTTPValidationError"
+                                }
+                            }
+                        },
+                    },
+                },
+                "callbacks": {
+                    "callback_route": {
+                        "{$callback_url}/callback/": {
+                            "get": {
+                                "summary": "Callback Route",
+                                "operationId": "callback_route__callback_url__callback__get",
+                                "responses": {
+                                    "400": {
+                                        "content": {
+                                            "application/json": {
+                                                "schema": {
+                                                    "$ref": "#/components/schemas/CustomModel"
+                                                }
+                                            }
+                                        },
+                                        "description": "Bad Request",
+                                    },
+                                    "200": {
+                                        "description": "Successful Response",
+                                        "content": {"application/json": {"schema": {}}},
+                                    },
+                                },
+                            }
+                        }
+                    }
+                },
+            }
+        }
+    },
+    "components": {
+        "schemas": {
+            "CustomModel": {
+                "title": "CustomModel",
+                "required": ["a"],
+                "type": "object",
+                "properties": {"a": {"title": "A", "type": "integer"}},
+            },
+            "HTTPValidationError": {
+                "title": "HTTPValidationError",
+                "type": "object",
+                "properties": {
+                    "detail": {
+                        "title": "Detail",
+                        "type": "array",
+                        "items": {"$ref": "#/components/schemas/ValidationError"},
+                    }
+                },
+            },
+            "ValidationError": {
+                "title": "ValidationError",
+                "required": ["loc", "msg", "type"],
+                "type": "object",
+                "properties": {
+                    "loc": {
+                        "title": "Location",
+                        "type": "array",
+                        "items": {"type": "string"},
+                    },
+                    "msg": {"title": "Message", "type": "string"},
+                    "type": {"title": "Error Type", "type": "string"},
+                },
+            },
+        }
+    },
+}
+
+client = TestClient(app)
+
+
+def test_openapi_schema():
+    response = client.get("/openapi.json")
+    assert response.status_code == 200, response.text
+    assert response.json() == openapi_schema