]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
Additional Responses implementation
authorMohammed <barsintod@gmail.com>
Fri, 22 Mar 2019 19:40:07 +0000 (22:40 +0300)
committerMohammed <barsintod@gmail.com>
Fri, 22 Mar 2019 19:40:07 +0000 (22:40 +0300)
fastapi/applications.py
fastapi/openapi/models.py
fastapi/openapi/utils.py
fastapi/routing.py
fastapi/utils.py

index 6f54df70671ad8e90537ee5ac89918c5aba6aaa0..7c29141ec082bf4fe0a522de14a25a6d810c8aeb 100644 (file)
@@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, List, Optional, Type
 from fastapi import routing
 from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
 from fastapi.openapi.utils import get_openapi
+from fastapi.openapi.models import AdditionalResponse, AdditionalResponseDescription
 from pydantic import BaseModel
 from starlette.applications import Starlette
 from starlette.exceptions import ExceptionMiddleware, HTTPException
@@ -104,22 +105,23 @@ class FastAPI(Starlette):
         self.add_exception_handler(HTTPException, http_exception)
 
     def add_api_route(
-        self,
-        path: str,
-        endpoint: Callable,
-        *,
-        response_model: Type[BaseModel] = None,
-        status_code: int = 200,
-        tags: List[str] = None,
-        summary: str = None,
-        description: str = None,
-        response_description: str = "Successful Response",
-        deprecated: bool = None,
-        methods: List[str] = None,
-        operation_id: str = None,
-        include_in_schema: bool = True,
-        content_type: Type[Response] = JSONResponse,
-        name: str = None,
+            self,
+            path: str,
+            endpoint: Callable,
+            *,
+            response_model: Type[BaseModel] = None,
+            status_code: int = 200,
+            tags: List[str] = None,
+            summary: str = None,
+            description: str = None,
+            response_description: str = "Successful Response",
+            additional_responses: AdditionalResponse = [],
+            deprecated: bool = None,
+            methods: List[str] = None,
+            operation_id: str = None,
+            include_in_schema: bool = True,
+            content_type: Type[Response] = JSONResponse,
+            name: str = None,
     ) -> None:
         self.router.add_api_route(
             path,
@@ -130,6 +132,7 @@ class FastAPI(Starlette):
             summary=summary,
             description=description,
             response_description=response_description,
+            additional_responses=additional_responses,
             deprecated=deprecated,
             methods=methods,
             operation_id=operation_id,
@@ -139,21 +142,22 @@ class FastAPI(Starlette):
         )
 
     def api_route(
-        self,
-        path: str,
-        *,
-        response_model: Type[BaseModel] = None,
-        status_code: int = 200,
-        tags: List[str] = None,
-        summary: str = None,
-        description: str = None,
-        response_description: str = "Successful Response",
-        deprecated: bool = None,
-        methods: List[str] = None,
-        operation_id: str = None,
-        include_in_schema: bool = True,
-        content_type: Type[Response] = JSONResponse,
-        name: str = None,
+            self,
+            path: str,
+            *,
+            response_model: Type[BaseModel] = None,
+            status_code: int = 200,
+            tags: List[str] = None,
+            summary: str = None,
+            description: str = None,
+            response_description: str = "Successful Response",
+            additional_responses: AdditionalResponse = [],
+            deprecated: bool = None,
+            methods: List[str] = None,
+            operation_id: str = None,
+            include_in_schema: bool = True,
+            content_type: Type[Response] = JSONResponse,
+            name: str = None,
     ) -> Callable:
         def decorator(func: Callable) -> Callable:
             self.router.add_api_route(
@@ -165,6 +169,7 @@ class FastAPI(Starlette):
                 summary=summary,
                 description=description,
                 response_description=response_description,
+                additional_responses=additional_responses,
                 deprecated=deprecated,
                 methods=methods,
                 operation_id=operation_id,
@@ -177,25 +182,31 @@ class FastAPI(Starlette):
         return decorator
 
     def include_router(
-        self, router: routing.APIRouter, *, prefix: str = "", tags: List[str] = None
+            self,
+            router: routing.APIRouter,
+            *,
+            prefix: str = "",
+            tags: List[str] = None,
+            additional_responses: AdditionalResponse = [],
     ) -> None:
-        self.router.include_router(router, prefix=prefix, tags=tags)
+        self.router.include_router(router, prefix=prefix, tags=tags, additional_responses=additional_responses,)
 
     def get(
-        self,
-        path: str,
-        *,
-        response_model: Type[BaseModel] = None,
-        status_code: int = 200,
-        tags: List[str] = None,
-        summary: str = None,
-        description: str = None,
-        response_description: str = "Successful Response",
-        deprecated: bool = None,
-        operation_id: str = None,
-        include_in_schema: bool = True,
-        content_type: Type[Response] = JSONResponse,
-        name: str = None,
+            self,
+            path: str,
+            *,
+            response_model: Type[BaseModel] = None,
+            status_code: int = 200,
+            tags: List[str] = None,
+            summary: str = None,
+            description: str = None,
+            response_description: str = "Successful Response",
+            additional_responses: AdditionalResponse = [],
+            deprecated: bool = None,
+            operation_id: str = None,
+            include_in_schema: bool = True,
+            content_type: Type[Response] = JSONResponse,
+            name: str = None,
     ) -> Callable:
         return self.router.get(
             path,
@@ -205,6 +216,7 @@ class FastAPI(Starlette):
             summary=summary,
             description=description,
             response_description=response_description,
+            additional_responses=additional_responses,
             deprecated=deprecated,
             operation_id=operation_id,
             include_in_schema=include_in_schema,
@@ -213,20 +225,21 @@ class FastAPI(Starlette):
         )
 
     def put(
-        self,
-        path: str,
-        *,
-        response_model: Type[BaseModel] = None,
-        status_code: int = 200,
-        tags: List[str] = None,
-        summary: str = None,
-        description: str = None,
-        response_description: str = "Successful Response",
-        deprecated: bool = None,
-        operation_id: str = None,
-        include_in_schema: bool = True,
-        content_type: Type[Response] = JSONResponse,
-        name: str = None,
+            self,
+            path: str,
+            *,
+            response_model: Type[BaseModel] = None,
+            status_code: int = 200,
+            tags: List[str] = None,
+            summary: str = None,
+            description: str = None,
+            response_description: str = "Successful Response",
+            additional_responses: AdditionalResponse = [],
+            deprecated: bool = None,
+            operation_id: str = None,
+            include_in_schema: bool = True,
+            content_type: Type[Response] = JSONResponse,
+            name: str = None,
     ) -> Callable:
         return self.router.put(
             path,
@@ -236,6 +249,7 @@ class FastAPI(Starlette):
             summary=summary,
             description=description,
             response_description=response_description,
+            additional_responses=additional_responses,
             deprecated=deprecated,
             operation_id=operation_id,
             include_in_schema=include_in_schema,
@@ -244,20 +258,21 @@ class FastAPI(Starlette):
         )
 
     def post(
-        self,
-        path: str,
-        *,
-        response_model: Type[BaseModel] = None,
-        status_code: int = 200,
-        tags: List[str] = None,
-        summary: str = None,
-        description: str = None,
-        response_description: str = "Successful Response",
-        deprecated: bool = None,
-        operation_id: str = None,
-        include_in_schema: bool = True,
-        content_type: Type[Response] = JSONResponse,
-        name: str = None,
+            self,
+            path: str,
+            *,
+            response_model: Type[BaseModel] = None,
+            status_code: int = 200,
+            tags: List[str] = None,
+            summary: str = None,
+            description: str = None,
+            response_description: str = "Successful Response",
+            additional_responses: AdditionalResponse = [],
+            deprecated: bool = None,
+            operation_id: str = None,
+            include_in_schema: bool = True,
+            content_type: Type[Response] = JSONResponse,
+            name: str = None,
     ) -> Callable:
         return self.router.post(
             path,
@@ -267,6 +282,7 @@ class FastAPI(Starlette):
             summary=summary,
             description=description,
             response_description=response_description,
+            additional_responses=additional_responses,
             deprecated=deprecated,
             operation_id=operation_id,
             include_in_schema=include_in_schema,
@@ -275,20 +291,21 @@ class FastAPI(Starlette):
         )
 
     def delete(
-        self,
-        path: str,
-        *,
-        response_model: Type[BaseModel] = None,
-        status_code: int = 200,
-        tags: List[str] = None,
-        summary: str = None,
-        description: str = None,
-        response_description: str = "Successful Response",
-        deprecated: bool = None,
-        operation_id: str = None,
-        include_in_schema: bool = True,
-        content_type: Type[Response] = JSONResponse,
-        name: str = None,
+            self,
+            path: str,
+            *,
+            response_model: Type[BaseModel] = None,
+            status_code: int = 200,
+            tags: List[str] = None,
+            summary: str = None,
+            description: str = None,
+            response_description: str = "Successful Response",
+            additional_responses: AdditionalResponse = [],
+            deprecated: bool = None,
+            operation_id: str = None,
+            include_in_schema: bool = True,
+            content_type: Type[Response] = JSONResponse,
+            name: str = None,
     ) -> Callable:
         return self.router.delete(
             path,
@@ -298,6 +315,7 @@ class FastAPI(Starlette):
             summary=summary,
             description=description,
             response_description=response_description,
+            additional_responses=additional_responses,
             deprecated=deprecated,
             operation_id=operation_id,
             include_in_schema=include_in_schema,
@@ -306,20 +324,21 @@ class FastAPI(Starlette):
         )
 
     def options(
-        self,
-        path: str,
-        *,
-        response_model: Type[BaseModel] = None,
-        status_code: int = 200,
-        tags: List[str] = None,
-        summary: str = None,
-        description: str = None,
-        response_description: str = "Successful Response",
-        deprecated: bool = None,
-        operation_id: str = None,
-        include_in_schema: bool = True,
-        content_type: Type[Response] = JSONResponse,
-        name: str = None,
+            self,
+            path: str,
+            *,
+            response_model: Type[BaseModel] = None,
+            status_code: int = 200,
+            tags: List[str] = None,
+            summary: str = None,
+            description: str = None,
+            response_description: str = "Successful Response",
+            additional_responses: AdditionalResponse = [],
+            deprecated: bool = None,
+            operation_id: str = None,
+            include_in_schema: bool = True,
+            content_type: Type[Response] = JSONResponse,
+            name: str = None,
     ) -> Callable:
         return self.router.options(
             path,
@@ -329,6 +348,7 @@ class FastAPI(Starlette):
             summary=summary,
             description=description,
             response_description=response_description,
+            additional_responses=additional_responses,
             deprecated=deprecated,
             operation_id=operation_id,
             include_in_schema=include_in_schema,
@@ -337,20 +357,21 @@ class FastAPI(Starlette):
         )
 
     def head(
-        self,
-        path: str,
-        *,
-        response_model: Type[BaseModel] = None,
-        status_code: int = 200,
-        tags: List[str] = None,
-        summary: str = None,
-        description: str = None,
-        response_description: str = "Successful Response",
-        deprecated: bool = None,
-        operation_id: str = None,
-        include_in_schema: bool = True,
-        content_type: Type[Response] = JSONResponse,
-        name: str = None,
+            self,
+            path: str,
+            *,
+            response_model: Type[BaseModel] = None,
+            status_code: int = 200,
+            tags: List[str] = None,
+            summary: str = None,
+            description: str = None,
+            response_description: str = "Successful Response",
+            additional_responses: AdditionalResponse = [],
+            deprecated: bool = None,
+            operation_id: str = None,
+            include_in_schema: bool = True,
+            content_type: Type[Response] = JSONResponse,
+            name: str = None,
     ) -> Callable:
         return self.router.head(
             path,
@@ -360,6 +381,7 @@ class FastAPI(Starlette):
             summary=summary,
             description=description,
             response_description=response_description,
+            additional_responses=additional_responses,
             deprecated=deprecated,
             operation_id=operation_id,
             include_in_schema=include_in_schema,
@@ -368,20 +390,21 @@ class FastAPI(Starlette):
         )
 
     def patch(
-        self,
-        path: str,
-        *,
-        response_model: Type[BaseModel] = None,
-        status_code: int = 200,
-        tags: List[str] = None,
-        summary: str = None,
-        description: str = None,
-        response_description: str = "Successful Response",
-        deprecated: bool = None,
-        operation_id: str = None,
-        include_in_schema: bool = True,
-        content_type: Type[Response] = JSONResponse,
-        name: str = None,
+            self,
+            path: str,
+            *,
+            response_model: Type[BaseModel] = None,
+            status_code: int = 200,
+            tags: List[str] = None,
+            summary: str = None,
+            description: str = None,
+            response_description: str = "Successful Response",
+            additional_responses: AdditionalResponse = [],
+            deprecated: bool = None,
+            operation_id: str = None,
+            include_in_schema: bool = True,
+            content_type: Type[Response] = JSONResponse,
+            name: str = None,
     ) -> Callable:
         return self.router.patch(
             path,
@@ -391,6 +414,7 @@ class FastAPI(Starlette):
             summary=summary,
             description=description,
             response_description=response_description,
+            additional_responses=additional_responses,
             deprecated=deprecated,
             operation_id=operation_id,
             include_in_schema=include_in_schema,
@@ -399,20 +423,21 @@ class FastAPI(Starlette):
         )
 
     def trace(
-        self,
-        path: str,
-        *,
-        response_model: Type[BaseModel] = None,
-        status_code: int = 200,
-        tags: List[str] = None,
-        summary: str = None,
-        description: str = None,
-        response_description: str = "Successful Response",
-        deprecated: bool = None,
-        operation_id: str = None,
-        include_in_schema: bool = True,
-        content_type: Type[Response] = JSONResponse,
-        name: str = None,
+            self,
+            path: str,
+            *,
+            response_model: Type[BaseModel] = None,
+            status_code: int = 200,
+            tags: List[str] = None,
+            summary: str = None,
+            description: str = None,
+            response_description: str = "Successful Response",
+            additional_responses: AdditionalResponse = [],
+            deprecated: bool = None,
+            operation_id: str = None,
+            include_in_schema: bool = True,
+            content_type: Type[Response] = JSONResponse,
+            name: str = None,
     ) -> Callable:
         return self.router.trace(
             path,
@@ -422,6 +447,7 @@ class FastAPI(Starlette):
             summary=summary,
             description=description,
             response_description=response_description,
+            additional_responses=additional_responses,
             deprecated=deprecated,
             operation_id=operation_id,
             include_in_schema=include_in_schema,
index 6572c7c07239939e414561ff6e1859410d3f123f..e94e087fa7ee62f5fd05bbc97133ae5d4964e38f 100644 (file)
@@ -1,9 +1,10 @@
 import logging
 from enum import Enum
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Dict, List, Optional, Union, Type, ClassVar, Callable
 
 from pydantic import BaseModel, Schema as PSchema
 from pydantic.types import UrlStr
+from pydantic.fields import Field
 
 try:
     import email_validator
@@ -343,6 +344,35 @@ class Tag(BaseModel):
     externalDocs: Optional[ExternalDocumentation] = None
 
 
+class BaseAdditionalResponse(BaseModel):
+    description: str
+    content_type: str = None
+
+
+class AdditionalResponse(BaseAdditionalResponse):
+    status_code: int = PSchema(
+        ...,
+        ge=100,
+        le=540,
+        title='Status Code',
+        description='HTTP status code',
+    )
+    # NOTE: waiting for pydantic to allow `typing.Type[BasicModel]` type
+    # so, going for `Any` and extra validation on
+    # routing methods
+    models: Optional[List[Any]] = PSchema(
+        [],
+        title='Additional Response Models',
+    )
+
+
+class AdditionalResponseDescription(BaseAdditionalResponse):
+    schema_field: Optional[Field] = None
+
+    class Config:
+        arbitrary_types_allowed = True
+
+
 class OpenAPI(BaseModel):
     openapi: str
     info: Info
index 4a603aa857731bc7b85f9d57dae1e8252c331f48..024fb50c8dee648dcf7d2d6b6b9c5bb589e3a327 100644 (file)
@@ -205,6 +205,26 @@ def get_openapi_path(
                         }
                     },
                 }
+            for add_response_code, add_response in route.additional_responses.items():
+                add_response_schema = {}
+                if (add_response.content_type or route.content_type.media_type
+                    ) == 'application/json' and add_response.schema_field is not None:
+                    add_response_schema, _ = field_schema(
+                        add_response.schema_field,
+                        model_name_map=model_name_map,
+                        ref_prefix=REF_PREFIX,
+                    )
+                add_content = {
+                    add_response.content_type or
+                    route.content_type.media_type: {
+                        "schema": add_response_schema,
+                    },
+                }
+                operation["responses"][str(add_response_code)] = \
+                    {
+                        "description": add_response.description,
+                        "content": add_content,
+                    }
             path[method.lower()] = operation
     return path, security_schemes, definitions
 
index 6d252d817cdb14a2d98da1aa979caeed7df089a3..b59d1eb95eb5847aab2c4a97cc0807855bf3b0b5 100644 (file)
@@ -1,13 +1,14 @@
 import asyncio
 import inspect
 import logging
-from typing import Any, Callable, List, Optional, Type
+from typing import Any, Callable, List, Optional, Type, Dict, Union
 
 from fastapi import params
 from fastapi.dependencies.models import Dependant
 from fastapi.dependencies.utils import get_body_field, get_dependant, solve_dependencies
 from fastapi.encoders import jsonable_encoder
 from fastapi.utils import UnconstrainedConfig
+from fastapi.openapi.models import AdditionalResponse, AdditionalResponseDescription
 from pydantic import BaseModel, Schema
 from pydantic.error_wrappers import ErrorWrapper, ValidationError
 from pydantic.fields import Field
@@ -104,6 +105,7 @@ class APIRoute(routing.Route):
         summary: str = None,
         description: str = None,
         response_description: str = "Successful Response",
+        additional_responses: AdditionalResponse = [],
         deprecated: bool = None,
         name: str = None,
         methods: List[str] = None,
@@ -137,6 +139,56 @@ class APIRoute(routing.Route):
         self.summary = summary
         self.description = description or self.endpoint.__doc__
         self.response_description = response_description
+        self.additional_responses: Dict[int, AdditionalResponseDescription] = {}
+        existed_codes = [self.status_code, 422]
+        if isinstance(additional_responses, dict):
+            self.additional_responses = additional_responses.copy()
+        for add_response in additional_responses:
+            if isinstance(add_response, int):
+                continue
+            assert add_response.status_code not in existed_codes, f"(Duplicated Status Code): Response with status code [{add_response.status_code}] already defined!"
+            existed_codes.append(add_response.status_code)
+            response_models = [
+                m for m in\
+                    add_response.models
+            ]
+            valid_response_models = True
+            try:
+                valid_response_models = all([
+                    issubclass(m, BaseModel)
+                        for m in response_models
+                ])
+            except TypeError as te:
+                valid_response_models = False
+            if not valid_response_models:
+                raise ValueError(
+                    "All response models must be "
+                    "a subclass of `pydantic.BaseModel` "
+                    "model.",
+                )
+            if (add_response.content_type == 'application/json' or lenient_issubclass(
+                    content_type, JSONResponse)):
+                if len(response_models):
+                    schema_field = Field(
+                        name=f'Additional_response_{add_response.status_code}',
+                        type_=Union[tuple(response_models)],
+                        class_validators=[],
+                        default=None,
+                        required=False,
+                        model_config=UnconstrainedConfig,
+                        schema=Schema(None),
+                    )
+                else:
+                    schema_field = None
+            else:
+                schema_field = None
+            add_resp_description = AdditionalResponseDescription(
+                description=add_response.description,
+                content_type=add_response.content_type,
+                schema_field=schema_field,
+                )
+            self.additional_responses[add_response.status_code] = \
+                add_resp_description
         self.deprecated = deprecated
         if methods is None:
             methods = ["GET"]
@@ -164,22 +216,23 @@ class APIRoute(routing.Route):
 
 class APIRouter(routing.Router):
     def add_api_route(
-        self,
-        path: str,
-        endpoint: Callable,
-        *,
-        response_model: Type[BaseModel] = None,
-        status_code: int = 200,
-        tags: List[str] = None,
-        summary: str = None,
-        description: str = None,
-        response_description: str = "Successful Response",
-        deprecated: bool = None,
-        methods: List[str] = None,
-        operation_id: str = None,
-        include_in_schema: bool = True,
-        content_type: Type[Response] = JSONResponse,
-        name: str = None,
+            self,
+            path: str,
+            endpoint: Callable,
+            *,
+            response_model: Type[BaseModel] = None,
+            status_code: int = 200,
+            tags: List[str] = None,
+            summary: str = None,
+            description: str = None,
+            response_description: str = "Successful Response",
+            additional_responses: AdditionalResponse = [],
+            deprecated: bool = None,
+            methods: List[str] = None,
+            operation_id: str = None,
+            include_in_schema: bool = True,
+            content_type: Type[Response] = JSONResponse,
+            name: str = None,
     ) -> None:
         route = APIRoute(
             path,
@@ -190,6 +243,7 @@ class APIRouter(routing.Router):
             summary=summary,
             description=description,
             response_description=response_description,
+            additional_responses=additional_responses,
             deprecated=deprecated,
             methods=methods,
             operation_id=operation_id,
@@ -200,21 +254,22 @@ class APIRouter(routing.Router):
         self.routes.append(route)
 
     def api_route(
-        self,
-        path: str,
-        *,
-        response_model: Type[BaseModel] = None,
-        status_code: int = 200,
-        tags: List[str] = None,
-        summary: str = None,
-        description: str = None,
-        response_description: str = "Successful Response",
-        deprecated: bool = None,
-        methods: List[str] = None,
-        operation_id: str = None,
-        include_in_schema: bool = True,
-        content_type: Type[Response] = JSONResponse,
-        name: str = None,
+            self,
+            path: str,
+            *,
+            response_model: Type[BaseModel] = None,
+            status_code: int = 200,
+            tags: List[str] = None,
+            summary: str = None,
+            description: str = None,
+            response_description: str = "Successful Response",
+            additional_responses: AdditionalResponse = [],
+            deprecated: bool = None,
+            methods: List[str] = None,
+            operation_id: str = None,
+            include_in_schema: bool = True,
+            content_type: Type[Response] = JSONResponse,
+            name: str = None,
     ) -> Callable:
         def decorator(func: Callable) -> Callable:
             self.add_api_route(
@@ -226,6 +281,7 @@ class APIRouter(routing.Router):
                 summary=summary,
                 description=description,
                 response_description=response_description,
+                additional_responses=additional_responses,
                 deprecated=deprecated,
                 methods=methods,
                 operation_id=operation_id,
@@ -238,7 +294,12 @@ class APIRouter(routing.Router):
         return decorator
 
     def include_router(
-        self, router: "APIRouter", *, prefix: str = "", tags: List[str] = None
+            self,
+            router: "APIRouter",
+            *,
+            prefix: str = "",
+            tags: List[str] = None,
+            additional_responses: AdditionalResponse = [],
     ) -> None:
         if prefix:
             assert prefix.startswith("/"), "A path prefix must start with '/'"
@@ -247,6 +308,53 @@ class APIRouter(routing.Router):
             ), "A path prefix must not end with '/', as the routes will start with '/'"
         for route in router.routes:
             if isinstance(route, APIRoute):
+                # really ugly hack and repitition
+                prev_add_resp = route.additional_responses
+                existed_codes = [422, route.status_code
+                                 ] + [int(c) for c in prev_add_resp.keys()]
+                for add_response in additional_responses:
+                    assert add_response.status_code not in existed_codes, f"(Duplicated Status Code): Response with status code [{add_response.status_code}] already defined!"
+                    existed_codes.append(add_response.status_code)
+                    response_models = [
+                        m for m in\
+                            add_response.models
+                    ]
+                    valid_response_models = True
+                    try:
+                        valid_response_models = all([
+                            issubclass(m, BaseModel) for m in response_models
+                        ])
+                    except AttributeError as ae:
+                        valid_response_models = False
+                    if not valid_response_models:
+                        raise ValueError(
+                            "All response models must be"
+                            "a subclass of `pydantic.BaseModel`"
+                            "model."
+                        )
+                    if (add_response.content_type == 'application/json' or lenient_issubclass(
+                    route.content_type, JSONResponse)):
+                        if len(response_models):
+                            schema_field = Field(
+                                name=f'Additional_response_{add_response.status_code}',
+                                type_=Union[tuple(response_models)],
+                                class_validators=[],
+                                default=None,
+                                required=False,
+                                model_config=UnconstrainedConfig,
+                                schema=Schema(None),
+                            )
+                        else:
+                            schema_field = None
+                    else:
+                        schema_field = None
+                    add_resp_description = AdditionalResponseDescription(
+                        description=add_response.description,
+                        content_type=add_response.content_type,
+                        schema_field=schema_field,
+                        )
+                    route.additional_responses[add_response.status_code] = \
+                        add_resp_description
                 self.add_api_route(
                     prefix + route.path,
                     route.endpoint,
@@ -256,6 +364,7 @@ class APIRouter(routing.Router):
                     summary=route.summary,
                     description=route.description,
                     response_description=route.response_description,
+                    additional_responses=route.additional_responses,
                     deprecated=route.deprecated,
                     methods=route.methods,
                     operation_id=route.operation_id,
@@ -273,20 +382,21 @@ class APIRouter(routing.Router):
                 )
 
     def get(
-        self,
-        path: str,
-        *,
-        response_model: Type[BaseModel] = None,
-        status_code: int = 200,
-        tags: List[str] = None,
-        summary: str = None,
-        description: str = None,
-        response_description: str = "Successful Response",
-        deprecated: bool = None,
-        operation_id: str = None,
-        include_in_schema: bool = True,
-        content_type: Type[Response] = JSONResponse,
-        name: str = None,
+            self,
+            path: str,
+            *,
+            response_model: Type[BaseModel] = None,
+            status_code: int = 200,
+            tags: List[str] = None,
+            summary: str = None,
+            description: str = None,
+            response_description: str = "Successful Response",
+            additional_responses: AdditionalResponse = [],
+            deprecated: bool = None,
+            operation_id: str = None,
+            include_in_schema: bool = True,
+            content_type: Type[Response] = JSONResponse,
+            name: str = None,
     ) -> Callable:
         return self.api_route(
             path=path,
@@ -296,6 +406,7 @@ class APIRouter(routing.Router):
             summary=summary,
             description=description,
             response_description=response_description,
+            additional_responses=additional_responses,
             deprecated=deprecated,
             methods=["GET"],
             operation_id=operation_id,
@@ -305,20 +416,21 @@ class APIRouter(routing.Router):
         )
 
     def put(
-        self,
-        path: str,
-        *,
-        response_model: Type[BaseModel] = None,
-        status_code: int = 200,
-        tags: List[str] = None,
-        summary: str = None,
-        description: str = None,
-        response_description: str = "Successful Response",
-        deprecated: bool = None,
-        operation_id: str = None,
-        include_in_schema: bool = True,
-        content_type: Type[Response] = JSONResponse,
-        name: str = None,
+            self,
+            path: str,
+            *,
+            response_model: Type[BaseModel] = None,
+            status_code: int = 200,
+            tags: List[str] = None,
+            summary: str = None,
+            description: str = None,
+            response_description: str = "Successful Response",
+            additional_responses: AdditionalResponse = [],
+            deprecated: bool = None,
+            operation_id: str = None,
+            include_in_schema: bool = True,
+            content_type: Type[Response] = JSONResponse,
+            name: str = None,
     ) -> Callable:
         return self.api_route(
             path=path,
@@ -328,6 +440,7 @@ class APIRouter(routing.Router):
             summary=summary,
             description=description,
             response_description=response_description,
+            additional_responses=additional_responses,
             deprecated=deprecated,
             methods=["PUT"],
             operation_id=operation_id,
@@ -337,20 +450,21 @@ class APIRouter(routing.Router):
         )
 
     def post(
-        self,
-        path: str,
-        *,
-        response_model: Type[BaseModel] = None,
-        status_code: int = 200,
-        tags: List[str] = None,
-        summary: str = None,
-        description: str = None,
-        response_description: str = "Successful Response",
-        deprecated: bool = None,
-        operation_id: str = None,
-        include_in_schema: bool = True,
-        content_type: Type[Response] = JSONResponse,
-        name: str = None,
+            self,
+            path: str,
+            *,
+            response_model: Type[BaseModel] = None,
+            status_code: int = 200,
+            tags: List[str] = None,
+            summary: str = None,
+            description: str = None,
+            response_description: str = "Successful Response",
+            additional_responses: AdditionalResponse = [],
+            deprecated: bool = None,
+            operation_id: str = None,
+            include_in_schema: bool = True,
+            content_type: Type[Response] = JSONResponse,
+            name: str = None,
     ) -> Callable:
         return self.api_route(
             path=path,
@@ -360,6 +474,7 @@ class APIRouter(routing.Router):
             summary=summary,
             description=description,
             response_description=response_description,
+            additional_responses=additional_responses,
             deprecated=deprecated,
             methods=["POST"],
             operation_id=operation_id,
@@ -369,20 +484,21 @@ class APIRouter(routing.Router):
         )
 
     def delete(
-        self,
-        path: str,
-        *,
-        response_model: Type[BaseModel] = None,
-        status_code: int = 200,
-        tags: List[str] = None,
-        summary: str = None,
-        description: str = None,
-        response_description: str = "Successful Response",
-        deprecated: bool = None,
-        operation_id: str = None,
-        include_in_schema: bool = True,
-        content_type: Type[Response] = JSONResponse,
-        name: str = None,
+            self,
+            path: str,
+            *,
+            response_model: Type[BaseModel] = None,
+            status_code: int = 200,
+            tags: List[str] = None,
+            summary: str = None,
+            description: str = None,
+            response_description: str = "Successful Response",
+            additional_responses: AdditionalResponse = [],
+            deprecated: bool = None,
+            operation_id: str = None,
+            include_in_schema: bool = True,
+            content_type: Type[Response] = JSONResponse,
+            name: str = None,
     ) -> Callable:
         return self.api_route(
             path=path,
@@ -392,6 +508,7 @@ class APIRouter(routing.Router):
             summary=summary,
             description=description,
             response_description=response_description,
+            additional_responses=additional_responses,
             deprecated=deprecated,
             methods=["DELETE"],
             operation_id=operation_id,
@@ -401,20 +518,21 @@ class APIRouter(routing.Router):
         )
 
     def options(
-        self,
-        path: str,
-        *,
-        response_model: Type[BaseModel] = None,
-        status_code: int = 200,
-        tags: List[str] = None,
-        summary: str = None,
-        description: str = None,
-        response_description: str = "Successful Response",
-        deprecated: bool = None,
-        operation_id: str = None,
-        include_in_schema: bool = True,
-        content_type: Type[Response] = JSONResponse,
-        name: str = None,
+            self,
+            path: str,
+            *,
+            response_model: Type[BaseModel] = None,
+            status_code: int = 200,
+            tags: List[str] = None,
+            summary: str = None,
+            description: str = None,
+            response_description: str = "Successful Response",
+            additional_responses: AdditionalResponse = [],
+            deprecated: bool = None,
+            operation_id: str = None,
+            include_in_schema: bool = True,
+            content_type: Type[Response] = JSONResponse,
+            name: str = None,
     ) -> Callable:
         return self.api_route(
             path=path,
@@ -424,6 +542,7 @@ class APIRouter(routing.Router):
             summary=summary,
             description=description,
             response_description=response_description,
+            additional_responses=additional_responses,
             deprecated=deprecated,
             methods=["OPTIONS"],
             operation_id=operation_id,
@@ -433,20 +552,21 @@ class APIRouter(routing.Router):
         )
 
     def head(
-        self,
-        path: str,
-        *,
-        response_model: Type[BaseModel] = None,
-        status_code: int = 200,
-        tags: List[str] = None,
-        summary: str = None,
-        description: str = None,
-        response_description: str = "Successful Response",
-        deprecated: bool = None,
-        operation_id: str = None,
-        include_in_schema: bool = True,
-        content_type: Type[Response] = JSONResponse,
-        name: str = None,
+            self,
+            path: str,
+            *,
+            response_model: Type[BaseModel] = None,
+            status_code: int = 200,
+            tags: List[str] = None,
+            summary: str = None,
+            description: str = None,
+            response_description: str = "Successful Response",
+            additional_responses: AdditionalResponse = [],
+            deprecated: bool = None,
+            operation_id: str = None,
+            include_in_schema: bool = True,
+            content_type: Type[Response] = JSONResponse,
+            name: str = None,
     ) -> Callable:
         return self.api_route(
             path=path,
@@ -456,6 +576,7 @@ class APIRouter(routing.Router):
             summary=summary,
             description=description,
             response_description=response_description,
+            additional_responses=additional_responses,
             deprecated=deprecated,
             methods=["HEAD"],
             operation_id=operation_id,
@@ -465,20 +586,21 @@ class APIRouter(routing.Router):
         )
 
     def patch(
-        self,
-        path: str,
-        *,
-        response_model: Type[BaseModel] = None,
-        status_code: int = 200,
-        tags: List[str] = None,
-        summary: str = None,
-        description: str = None,
-        response_description: str = "Successful Response",
-        deprecated: bool = None,
-        operation_id: str = None,
-        include_in_schema: bool = True,
-        content_type: Type[Response] = JSONResponse,
-        name: str = None,
+            self,
+            path: str,
+            *,
+            response_model: Type[BaseModel] = None,
+            status_code: int = 200,
+            tags: List[str] = None,
+            summary: str = None,
+            description: str = None,
+            response_description: str = "Successful Response",
+            additional_responses: AdditionalResponse = [],
+            deprecated: bool = None,
+            operation_id: str = None,
+            include_in_schema: bool = True,
+            content_type: Type[Response] = JSONResponse,
+            name: str = None,
     ) -> Callable:
         return self.api_route(
             path=path,
@@ -488,6 +610,7 @@ class APIRouter(routing.Router):
             summary=summary,
             description=description,
             response_description=response_description,
+            additional_responses=additional_responses,
             deprecated=deprecated,
             methods=["PATCH"],
             operation_id=operation_id,
@@ -497,20 +620,21 @@ class APIRouter(routing.Router):
         )
 
     def trace(
-        self,
-        path: str,
-        *,
-        response_model: Type[BaseModel] = None,
-        status_code: int = 200,
-        tags: List[str] = None,
-        summary: str = None,
-        description: str = None,
-        response_description: str = "Successful Response",
-        deprecated: bool = None,
-        operation_id: str = None,
-        include_in_schema: bool = True,
-        content_type: Type[Response] = JSONResponse,
-        name: str = None,
+            self,
+            path: str,
+            *,
+            response_model: Type[BaseModel] = None,
+            status_code: int = 200,
+            tags: List[str] = None,
+            summary: str = None,
+            description: str = None,
+            response_description: str = "Successful Response",
+            additional_responses: AdditionalResponse = [],
+            deprecated: bool = None,
+            operation_id: str = None,
+            include_in_schema: bool = True,
+            content_type: Type[Response] = JSONResponse,
+            name: str = None,
     ) -> Callable:
         return self.api_route(
             path=path,
@@ -520,6 +644,7 @@ class APIRouter(routing.Router):
             summary=summary,
             description=description,
             response_description=response_description,
+            additional_responses=additional_responses,
             deprecated=deprecated,
             methods=["TRACE"],
             operation_id=operation_id,
index 395ec9f5bcf907b47abe32c50e09c33534856aaf..aa25cfabe4f3c216dc85de604406b7d92df5b49c 100644 (file)
@@ -30,6 +30,12 @@ def get_flat_models_from_routes(
                 body_fields_from_routes.append(route.body_field)
             if route.response_field:
                 responses_from_routes.append(route.response_field)
+            if route.additional_responses:
+                for _, add_response in route.additional_responses.items():
+                    if add_response.schema_field is not None:
+                        responses_from_routes.append(
+                            add_response.schema_field,
+                        )
     flat_models = get_flat_models_from_fields(
         body_fields_from_routes + responses_from_routes
     )