]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:bug: Admit valid types for Pydantic fields as responses models (#1017)
authorPatrick McKenna <patrick.b.mckenna@gmail.com>
Sat, 29 Feb 2020 13:04:35 +0000 (05:04 -0800)
committerGitHub <noreply@github.com>
Sat, 29 Feb 2020 13:04:35 +0000 (14:04 +0100)
fastapi/dependencies/utils.py
fastapi/exceptions.py
fastapi/routing.py
fastapi/utils.py
tests/test_response_model_invalid.py [new file with mode: 0644]
tests/test_response_model_sub_types.py [new file with mode: 0644]

index 33130a90ef33959872528e1d1dc791dcfa72af5a..543479be8814bbc1d226022f3eef8d31d45e02a0 100644 (file)
@@ -27,7 +27,12 @@ from fastapi.dependencies.models import Dependant, SecurityRequirement
 from fastapi.security.base import SecurityBase
 from fastapi.security.oauth2 import OAuth2, SecurityScopes
 from fastapi.security.open_id_connect_url import OpenIdConnect
-from fastapi.utils import PYDANTIC_1, get_field_info, get_path_param_names
+from fastapi.utils import (
+    PYDANTIC_1,
+    create_response_field,
+    get_field_info,
+    get_path_param_names,
+)
 from pydantic import BaseConfig, BaseModel, create_model
 from pydantic.error_wrappers import ErrorWrapper
 from pydantic.errors import MissingError
@@ -362,31 +367,15 @@ def get_param_field(
         alias = param.name.replace("_", "-")
     else:
         alias = field_info.alias or param.name
-    if PYDANTIC_1:
-        field = ModelField(
-            name=param.name,
-            type_=annotation,
-            default=None if required else default_value,
-            alias=alias,
-            required=required,
-            model_config=BaseConfig,
-            class_validators={},
-            field_info=field_info,
-        )
-        # TODO: remove when removing support for Pydantic < 1.2.0
-        field.required = required
-    else:  # pragma: nocover
-        field = ModelField(  # type: ignore
-            name=param.name,
-            type_=annotation,
-            default=None if required else default_value,
-            alias=alias,
-            required=required,
-            model_config=BaseConfig,
-            class_validators={},
-            schema=field_info,
-        )
-        field.required = required
+    field = create_response_field(
+        name=param.name,
+        type_=annotation,
+        default=None if required else default_value,
+        alias=alias,
+        required=required,
+        field_info=field_info,
+    )
+    field.required = required
     if not had_schema and not is_scalar_field(field=field):
         if PYDANTIC_1:
             field.field_info = params.Body(field_info.default)
@@ -694,28 +683,16 @@ def get_schema_compatible_field(*, field: ModelField) -> ModelField:
         use_type: type = bytes
         if field.shape in sequence_shapes:
             use_type = List[bytes]
-        if PYDANTIC_1:
-            out_field = ModelField(
-                name=field.name,
-                type_=use_type,
-                class_validators=field.class_validators,
-                model_config=field.model_config,
-                default=field.default,
-                required=field.required,
-                alias=field.alias,
-                field_info=field.field_info,
-            )
-        else:  # pragma: nocover
-            out_field = ModelField(  # type: ignore
-                name=field.name,
-                type_=use_type,
-                class_validators=field.class_validators,
-                model_config=field.model_config,
-                default=field.default,
-                required=field.required,
-                alias=field.alias,
-                schema=field.schema,  # type: ignore
-            )
+        out_field = create_response_field(
+            name=field.name,
+            type_=use_type,
+            class_validators=field.class_validators,
+            model_config=field.model_config,
+            default=field.default,
+            required=field.required,
+            alias=field.alias,
+            field_info=field.field_info if PYDANTIC_1 else field.schema,  # type: ignore
+        )
 
     return out_field
 
@@ -754,26 +731,10 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
         ]
         if len(set(body_param_media_types)) == 1:
             BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
-    if PYDANTIC_1:
-        field = ModelField(
-            name="body",
-            type_=BodyModel,
-            default=None,
-            required=required,
-            model_config=BaseConfig,
-            class_validators={},
-            alias="body",
-            field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
-        )
-    else:  # pragma: nocover
-        field = ModelField(  # type: ignore
-            name="body",
-            type_=BodyModel,
-            default=None,
-            required=required,
-            model_config=BaseConfig,
-            class_validators={},
-            alias="body",
-            schema=BodyFieldInfo(**BodyFieldInfo_kwargs),
-        )
-    return field
+    return create_response_field(
+        name="body",
+        type_=BodyModel,
+        required=required,
+        alias="body",
+        field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
+    )
index ac002205a3762a4eafbc9f259ff58642ee262e25..be196d0cb5e0b5e9e96c31ab74f7b3260faa05c7 100644 (file)
@@ -20,6 +20,12 @@ RequestErrorModel = create_model("Request")
 WebSocketErrorModel = create_model("WebSocket")
 
 
+class FastAPIError(RuntimeError):
+    """
+    A generic, FastAPI-specific error.
+    """
+
+
 class RequestValidationError(ValidationError):
     def __init__(self, errors: Sequence[ErrorList], *, body: Any = None) -> None:
         self.body = body
index bbc2b4133599c7f8fd84a327c78c04f0a8e95c10..7c626af410889d8cccda60e08e1bc6682aa9eb62 100644 (file)
@@ -17,13 +17,13 @@ from fastapi.openapi.constants import STATUS_CODES_WITH_NO_BODY
 from fastapi.utils import (
     PYDANTIC_1,
     create_cloned_field,
+    create_response_field,
     generate_operation_id_for_path,
     get_field_info,
     warning_response_model_skip_defaults_deprecated,
 )
-from pydantic import BaseConfig, BaseModel
+from pydantic import BaseModel
 from pydantic.error_wrappers import ErrorWrapper, ValidationError
-from pydantic.utils import lenient_issubclass
 from starlette import routing
 from starlette.concurrency import run_in_threadpool
 from starlette.exceptions import HTTPException
@@ -243,26 +243,9 @@ class APIRoute(routing.Route):
                 status_code not in STATUS_CODES_WITH_NO_BODY
             ), f"Status code {status_code} must not have a response body"
             response_name = "Response_" + self.unique_id
-            if PYDANTIC_1:
-                self.response_field: Optional[ModelField] = ModelField(
-                    name=response_name,
-                    type_=self.response_model,
-                    class_validators={},
-                    default=None,
-                    required=False,
-                    model_config=BaseConfig,
-                    field_info=FieldInfo(None),
-                )
-            else:
-                self.response_field: Optional[ModelField] = ModelField(  # type: ignore  # pragma: nocover
-                    name=response_name,
-                    type_=self.response_model,
-                    class_validators={},
-                    default=None,
-                    required=False,
-                    model_config=BaseConfig,
-                    schema=FieldInfo(None),
-                )
+            self.response_field = create_response_field(
+                name=response_name, type_=self.response_model
+            )
             # Create a clone of the field, so that a Pydantic submodel is not returned
             # as is just because it's an instance of a subclass of a more limited class
             # e.g. UserInDB (containing hashed_password) could be a subclass of User
@@ -274,7 +257,7 @@ class APIRoute(routing.Route):
                 ModelField
             ] = create_cloned_field(self.response_field)
         else:
-            self.response_field = None
+            self.response_field = None  # type: ignore
             self.secure_cloned_response_field = None
         self.status_code = status_code
         self.tags = tags or []
@@ -297,30 +280,8 @@ class APIRoute(routing.Route):
                 assert (
                     additional_status_code not in STATUS_CODES_WITH_NO_BODY
                 ), f"Status code {additional_status_code} must not have a response body"
-                assert lenient_issubclass(
-                    model, BaseModel
-                ), "A response model must be a Pydantic model"
                 response_name = f"Response_{additional_status_code}_{self.unique_id}"
-                if PYDANTIC_1:
-                    response_field = ModelField(
-                        name=response_name,
-                        type_=model,
-                        class_validators=None,
-                        default=None,
-                        required=False,
-                        model_config=BaseConfig,
-                        field_info=FieldInfo(None),
-                    )
-                else:
-                    response_field = ModelField(  # type: ignore  # pragma: nocover
-                        name=response_name,
-                        type_=model,
-                        class_validators=None,
-                        default=None,
-                        required=False,
-                        model_config=BaseConfig,
-                        schema=FieldInfo(None),
-                    )
+                response_field = create_response_field(name=response_name, type_=model)
                 response_fields[additional_status_code] = response_field
         if response_fields:
             self.response_fields: Dict[Union[int, str], ModelField] = response_fields
index e7d3891f4bb70e93ed233164dada1e9223d7765d..f24f280739d19be2a872f04b4a6e32093e511d84 100644 (file)
@@ -1,17 +1,20 @@
+import functools
 import re
 from dataclasses import is_dataclass
-from typing import Any, Dict, List, Sequence, Set, Type, cast
+from typing import Any, Dict, List, Optional, Sequence, Set, Type, Union, cast
 
+import fastapi
 from fastapi import routing
 from fastapi.logger import logger
 from fastapi.openapi.constants import REF_PREFIX
 from pydantic import BaseConfig, BaseModel, create_model
+from pydantic.class_validators import Validator
 from pydantic.schema import get_flat_models_from_fields, model_process_schema
 from pydantic.utils import lenient_issubclass
 from starlette.routing import BaseRoute
 
 try:
-    from pydantic.fields import FieldInfo, ModelField
+    from pydantic.fields import FieldInfo, ModelField, UndefinedType
 
     PYDANTIC_1 = True
 except ImportError:  # pragma: nocover
@@ -19,6 +22,10 @@ except ImportError:  # pragma: nocover
     from pydantic.fields import Field as ModelField  # type: ignore
     from pydantic import Schema as FieldInfo  # type: ignore
 
+    class UndefinedType:  # type: ignore
+        def __repr__(self) -> str:
+            return "PydanticUndefined"
+
     logger.warning(
         "Pydantic versions < 1.0.0 are deprecated in FastAPI and support will be "
         "removed soon."
@@ -86,6 +93,44 @@ def get_path_param_names(path: str) -> Set[str]:
     return {item.strip("{}") for item in re.findall("{[^}]*}", path)}
 
 
+def create_response_field(
+    name: str,
+    type_: Type[Any],
+    class_validators: Optional[Dict[str, Validator]] = None,
+    default: Optional[Any] = None,
+    required: Union[bool, UndefinedType] = False,
+    model_config: Type[BaseConfig] = BaseConfig,
+    field_info: Optional[FieldInfo] = None,
+    alias: Optional[str] = None,
+) -> ModelField:
+    """
+    Create a new response field. Raises if type_ is invalid.
+    """
+    class_validators = class_validators or {}
+    field_info = field_info or FieldInfo(None)
+
+    response_field = functools.partial(
+        ModelField,
+        name=name,
+        type_=type_,
+        class_validators=class_validators,
+        default=default,
+        required=required,
+        model_config=model_config,
+        alias=alias,
+    )
+
+    try:
+        if PYDANTIC_1:
+            return response_field(field_info=field_info)
+        else:  # pragma: nocover
+            return response_field(schema=field_info)
+    except RuntimeError:
+        raise fastapi.exceptions.FastAPIError(
+            f"Invalid args for response field! Hint: check that {type_} is a valid pydantic field type"
+        )
+
+
 def create_cloned_field(field: ModelField) -> ModelField:
     original_type = field.type_
     if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
@@ -96,26 +141,8 @@ def create_cloned_field(field: ModelField) -> ModelField:
         use_type = create_model(original_type.__name__, __base__=original_type)
         for f in original_type.__fields__.values():
             use_type.__fields__[f.name] = create_cloned_field(f)
-    if PYDANTIC_1:
-        new_field = ModelField(
-            name=field.name,
-            type_=use_type,
-            class_validators={},
-            default=None,
-            required=False,
-            model_config=BaseConfig,
-            field_info=FieldInfo(None),
-        )
-    else:  # pragma: nocover
-        new_field = ModelField(  # type: ignore
-            name=field.name,
-            type_=use_type,
-            class_validators={},
-            default=None,
-            required=False,
-            model_config=BaseConfig,
-            schema=FieldInfo(None),
-        )
+
+    new_field = create_response_field(name=field.name, type_=use_type)
     new_field.has_alias = field.has_alias
     new_field.alias = field.alias
     new_field.class_validators = field.class_validators
diff --git a/tests/test_response_model_invalid.py b/tests/test_response_model_invalid.py
new file mode 100644 (file)
index 0000000..88b55a4
--- /dev/null
@@ -0,0 +1,45 @@
+from typing import List
+
+import pytest
+from fastapi import FastAPI
+from fastapi.exceptions import FastAPIError
+
+
+class NonPydanticModel:
+    pass
+
+
+def test_invalid_response_model_raises():
+    with pytest.raises(FastAPIError):
+        app = FastAPI()
+
+        @app.get("/", response_model=NonPydanticModel)
+        def read_root():
+            pass  # pragma: nocover
+
+
+def test_invalid_response_model_sub_type_raises():
+    with pytest.raises(FastAPIError):
+        app = FastAPI()
+
+        @app.get("/", response_model=List[NonPydanticModel])
+        def read_root():
+            pass  # pragma: nocover
+
+
+def test_invalid_response_model_in_responses_raises():
+    with pytest.raises(FastAPIError):
+        app = FastAPI()
+
+        @app.get("/", responses={"500": {"model": NonPydanticModel}})
+        def read_root():
+            pass  # pragma: nocover
+
+
+def test_invalid_response_model_sub_type_in_responses_raises():
+    with pytest.raises(FastAPIError):
+        app = FastAPI()
+
+        @app.get("/", responses={"500": {"model": List[NonPydanticModel]}})
+        def read_root():
+            pass  # pragma: nocover
diff --git a/tests/test_response_model_sub_types.py b/tests/test_response_model_sub_types.py
new file mode 100644 (file)
index 0000000..ac12098
--- /dev/null
@@ -0,0 +1,160 @@
+from typing import List
+
+from fastapi import FastAPI
+from pydantic import BaseModel
+from starlette.testclient import TestClient
+
+
+class Model(BaseModel):
+    name: str
+
+
+app = FastAPI()
+
+
+@app.get("/valid1", responses={"500": {"model": int}})
+def valid1():
+    pass
+
+
+@app.get("/valid2", responses={"500": {"model": List[int]}})
+def valid2():
+    pass
+
+
+@app.get("/valid3", responses={"500": {"model": Model}})
+def valid3():
+    pass
+
+
+@app.get("/valid4", responses={"500": {"model": List[Model]}})
+def valid4():
+    pass
+
+
+openapi_schema = {
+    "openapi": "3.0.2",
+    "info": {"title": "FastAPI", "version": "0.1.0"},
+    "paths": {
+        "/valid1": {
+            "get": {
+                "summary": "Valid1",
+                "operationId": "valid1_valid1_get",
+                "responses": {
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/json": {"schema": {}}},
+                    },
+                    "500": {
+                        "description": "Internal Server Error",
+                        "content": {
+                            "application/json": {
+                                "schema": {
+                                    "title": "Response 500 Valid1 Valid1 Get",
+                                    "type": "integer",
+                                }
+                            }
+                        },
+                    },
+                },
+            }
+        },
+        "/valid2": {
+            "get": {
+                "summary": "Valid2",
+                "operationId": "valid2_valid2_get",
+                "responses": {
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/json": {"schema": {}}},
+                    },
+                    "500": {
+                        "description": "Internal Server Error",
+                        "content": {
+                            "application/json": {
+                                "schema": {
+                                    "title": "Response 500 Valid2 Valid2 Get",
+                                    "type": "array",
+                                    "items": {"type": "integer"},
+                                }
+                            }
+                        },
+                    },
+                },
+            }
+        },
+        "/valid3": {
+            "get": {
+                "summary": "Valid3",
+                "operationId": "valid3_valid3_get",
+                "responses": {
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/json": {"schema": {}}},
+                    },
+                    "500": {
+                        "description": "Internal Server Error",
+                        "content": {
+                            "application/json": {
+                                "schema": {"$ref": "#/components/schemas/Model"}
+                            }
+                        },
+                    },
+                },
+            }
+        },
+        "/valid4": {
+            "get": {
+                "summary": "Valid4",
+                "operationId": "valid4_valid4_get",
+                "responses": {
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/json": {"schema": {}}},
+                    },
+                    "500": {
+                        "description": "Internal Server Error",
+                        "content": {
+                            "application/json": {
+                                "schema": {
+                                    "title": "Response 500 Valid4 Valid4 Get",
+                                    "type": "array",
+                                    "items": {"$ref": "#/components/schemas/Model"},
+                                }
+                            }
+                        },
+                    },
+                },
+            }
+        },
+    },
+    "components": {
+        "schemas": {
+            "Model": {
+                "title": "Model",
+                "required": ["name"],
+                "type": "object",
+                "properties": {"name": {"title": "Name", "type": "string"}},
+            }
+        }
+    },
+}
+
+client = TestClient(app)
+
+
+def test_openapi_schema():
+    response = client.get("/openapi.json")
+    assert response.status_code == 200
+    assert response.json() == openapi_schema
+
+
+def test_path_operations():
+    response = client.get("/valid1")
+    assert response.status_code == 200
+    response = client.get("/valid2")
+    assert response.status_code == 200
+    response = client.get("/valid3")
+    assert response.status_code == 200
+    response = client.get("/valid4")
+    assert response.status_code == 200