from fastapi.security.base import SecurityBase
from fastapi.security.oauth2 import OAuth2, SecurityScopes
from fastapi.security.open_id_connect_url import OpenIdConnect
-from fastapi.utils import PYDANTIC_1, get_field_info, get_path_param_names
+from fastapi.utils import (
+ PYDANTIC_1,
+ create_response_field,
+ get_field_info,
+ get_path_param_names,
+)
from pydantic import BaseConfig, BaseModel, create_model
from pydantic.error_wrappers import ErrorWrapper
from pydantic.errors import MissingError
alias = param.name.replace("_", "-")
else:
alias = field_info.alias or param.name
- if PYDANTIC_1:
- field = ModelField(
- name=param.name,
- type_=annotation,
- default=None if required else default_value,
- alias=alias,
- required=required,
- model_config=BaseConfig,
- class_validators={},
- field_info=field_info,
- )
- # TODO: remove when removing support for Pydantic < 1.2.0
- field.required = required
- else: # pragma: nocover
- field = ModelField( # type: ignore
- name=param.name,
- type_=annotation,
- default=None if required else default_value,
- alias=alias,
- required=required,
- model_config=BaseConfig,
- class_validators={},
- schema=field_info,
- )
- field.required = required
+ field = create_response_field(
+ name=param.name,
+ type_=annotation,
+ default=None if required else default_value,
+ alias=alias,
+ required=required,
+ field_info=field_info,
+ )
+ field.required = required
if not had_schema and not is_scalar_field(field=field):
if PYDANTIC_1:
field.field_info = params.Body(field_info.default)
use_type: type = bytes
if field.shape in sequence_shapes:
use_type = List[bytes]
- if PYDANTIC_1:
- out_field = ModelField(
- name=field.name,
- type_=use_type,
- class_validators=field.class_validators,
- model_config=field.model_config,
- default=field.default,
- required=field.required,
- alias=field.alias,
- field_info=field.field_info,
- )
- else: # pragma: nocover
- out_field = ModelField( # type: ignore
- name=field.name,
- type_=use_type,
- class_validators=field.class_validators,
- model_config=field.model_config,
- default=field.default,
- required=field.required,
- alias=field.alias,
- schema=field.schema, # type: ignore
- )
+ out_field = create_response_field(
+ name=field.name,
+ type_=use_type,
+ class_validators=field.class_validators,
+ model_config=field.model_config,
+ default=field.default,
+ required=field.required,
+ alias=field.alias,
+ field_info=field.field_info if PYDANTIC_1 else field.schema, # type: ignore
+ )
return out_field
]
if len(set(body_param_media_types)) == 1:
BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
- if PYDANTIC_1:
- field = ModelField(
- name="body",
- type_=BodyModel,
- default=None,
- required=required,
- model_config=BaseConfig,
- class_validators={},
- alias="body",
- field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
- )
- else: # pragma: nocover
- field = ModelField( # type: ignore
- name="body",
- type_=BodyModel,
- default=None,
- required=required,
- model_config=BaseConfig,
- class_validators={},
- alias="body",
- schema=BodyFieldInfo(**BodyFieldInfo_kwargs),
- )
- return field
+ return create_response_field(
+ name="body",
+ type_=BodyModel,
+ required=required,
+ alias="body",
+ field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
+ )
WebSocketErrorModel = create_model("WebSocket")
+class FastAPIError(RuntimeError):
+ """
+ A generic, FastAPI-specific error.
+ """
+
+
class RequestValidationError(ValidationError):
def __init__(self, errors: Sequence[ErrorList], *, body: Any = None) -> None:
self.body = body
from fastapi.utils import (
PYDANTIC_1,
create_cloned_field,
+ create_response_field,
generate_operation_id_for_path,
get_field_info,
warning_response_model_skip_defaults_deprecated,
)
-from pydantic import BaseConfig, BaseModel
+from pydantic import BaseModel
from pydantic.error_wrappers import ErrorWrapper, ValidationError
-from pydantic.utils import lenient_issubclass
from starlette import routing
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
status_code not in STATUS_CODES_WITH_NO_BODY
), f"Status code {status_code} must not have a response body"
response_name = "Response_" + self.unique_id
- if PYDANTIC_1:
- self.response_field: Optional[ModelField] = ModelField(
- name=response_name,
- type_=self.response_model,
- class_validators={},
- default=None,
- required=False,
- model_config=BaseConfig,
- field_info=FieldInfo(None),
- )
- else:
- self.response_field: Optional[ModelField] = ModelField( # type: ignore # pragma: nocover
- name=response_name,
- type_=self.response_model,
- class_validators={},
- default=None,
- required=False,
- model_config=BaseConfig,
- schema=FieldInfo(None),
- )
+ self.response_field = create_response_field(
+ name=response_name, type_=self.response_model
+ )
# Create a clone of the field, so that a Pydantic submodel is not returned
# as is just because it's an instance of a subclass of a more limited class
# e.g. UserInDB (containing hashed_password) could be a subclass of User
ModelField
] = create_cloned_field(self.response_field)
else:
- self.response_field = None
+ self.response_field = None # type: ignore
self.secure_cloned_response_field = None
self.status_code = status_code
self.tags = tags or []
assert (
additional_status_code not in STATUS_CODES_WITH_NO_BODY
), f"Status code {additional_status_code} must not have a response body"
- assert lenient_issubclass(
- model, BaseModel
- ), "A response model must be a Pydantic model"
response_name = f"Response_{additional_status_code}_{self.unique_id}"
- if PYDANTIC_1:
- response_field = ModelField(
- name=response_name,
- type_=model,
- class_validators=None,
- default=None,
- required=False,
- model_config=BaseConfig,
- field_info=FieldInfo(None),
- )
- else:
- response_field = ModelField( # type: ignore # pragma: nocover
- name=response_name,
- type_=model,
- class_validators=None,
- default=None,
- required=False,
- model_config=BaseConfig,
- schema=FieldInfo(None),
- )
+ response_field = create_response_field(name=response_name, type_=model)
response_fields[additional_status_code] = response_field
if response_fields:
self.response_fields: Dict[Union[int, str], ModelField] = response_fields
+import functools
import re
from dataclasses import is_dataclass
-from typing import Any, Dict, List, Sequence, Set, Type, cast
+from typing import Any, Dict, List, Optional, Sequence, Set, Type, Union, cast
+import fastapi
from fastapi import routing
from fastapi.logger import logger
from fastapi.openapi.constants import REF_PREFIX
from pydantic import BaseConfig, BaseModel, create_model
+from pydantic.class_validators import Validator
from pydantic.schema import get_flat_models_from_fields, model_process_schema
from pydantic.utils import lenient_issubclass
from starlette.routing import BaseRoute
try:
- from pydantic.fields import FieldInfo, ModelField
+ from pydantic.fields import FieldInfo, ModelField, UndefinedType
PYDANTIC_1 = True
except ImportError: # pragma: nocover
from pydantic.fields import Field as ModelField # type: ignore
from pydantic import Schema as FieldInfo # type: ignore
+ class UndefinedType: # type: ignore
+ def __repr__(self) -> str:
+ return "PydanticUndefined"
+
logger.warning(
"Pydantic versions < 1.0.0 are deprecated in FastAPI and support will be "
"removed soon."
return {item.strip("{}") for item in re.findall("{[^}]*}", path)}
+def create_response_field(
+ name: str,
+ type_: Type[Any],
+ class_validators: Optional[Dict[str, Validator]] = None,
+ default: Optional[Any] = None,
+ required: Union[bool, UndefinedType] = False,
+ model_config: Type[BaseConfig] = BaseConfig,
+ field_info: Optional[FieldInfo] = None,
+ alias: Optional[str] = None,
+) -> ModelField:
+ """
+ Create a new response field. Raises if type_ is invalid.
+ """
+ class_validators = class_validators or {}
+ field_info = field_info or FieldInfo(None)
+
+ response_field = functools.partial(
+ ModelField,
+ name=name,
+ type_=type_,
+ class_validators=class_validators,
+ default=default,
+ required=required,
+ model_config=model_config,
+ alias=alias,
+ )
+
+ try:
+ if PYDANTIC_1:
+ return response_field(field_info=field_info)
+ else: # pragma: nocover
+ return response_field(schema=field_info)
+ except RuntimeError:
+ raise fastapi.exceptions.FastAPIError(
+ f"Invalid args for response field! Hint: check that {type_} is a valid pydantic field type"
+ )
+
+
def create_cloned_field(field: ModelField) -> ModelField:
original_type = field.type_
if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
use_type = create_model(original_type.__name__, __base__=original_type)
for f in original_type.__fields__.values():
use_type.__fields__[f.name] = create_cloned_field(f)
- if PYDANTIC_1:
- new_field = ModelField(
- name=field.name,
- type_=use_type,
- class_validators={},
- default=None,
- required=False,
- model_config=BaseConfig,
- field_info=FieldInfo(None),
- )
- else: # pragma: nocover
- new_field = ModelField( # type: ignore
- name=field.name,
- type_=use_type,
- class_validators={},
- default=None,
- required=False,
- model_config=BaseConfig,
- schema=FieldInfo(None),
- )
+
+ new_field = create_response_field(name=field.name, type_=use_type)
new_field.has_alias = field.has_alias
new_field.alias = field.alias
new_field.class_validators = field.class_validators
--- /dev/null
+from typing import List
+
+import pytest
+from fastapi import FastAPI
+from fastapi.exceptions import FastAPIError
+
+
+class NonPydanticModel:
+ pass
+
+
+def test_invalid_response_model_raises():
+ with pytest.raises(FastAPIError):
+ app = FastAPI()
+
+ @app.get("/", response_model=NonPydanticModel)
+ def read_root():
+ pass # pragma: nocover
+
+
+def test_invalid_response_model_sub_type_raises():
+ with pytest.raises(FastAPIError):
+ app = FastAPI()
+
+ @app.get("/", response_model=List[NonPydanticModel])
+ def read_root():
+ pass # pragma: nocover
+
+
+def test_invalid_response_model_in_responses_raises():
+ with pytest.raises(FastAPIError):
+ app = FastAPI()
+
+ @app.get("/", responses={"500": {"model": NonPydanticModel}})
+ def read_root():
+ pass # pragma: nocover
+
+
+def test_invalid_response_model_sub_type_in_responses_raises():
+ with pytest.raises(FastAPIError):
+ app = FastAPI()
+
+ @app.get("/", responses={"500": {"model": List[NonPydanticModel]}})
+ def read_root():
+ pass # pragma: nocover
--- /dev/null
+from typing import List
+
+from fastapi import FastAPI
+from pydantic import BaseModel
+from starlette.testclient import TestClient
+
+
+class Model(BaseModel):
+ name: str
+
+
+app = FastAPI()
+
+
+@app.get("/valid1", responses={"500": {"model": int}})
+def valid1():
+ pass
+
+
+@app.get("/valid2", responses={"500": {"model": List[int]}})
+def valid2():
+ pass
+
+
+@app.get("/valid3", responses={"500": {"model": Model}})
+def valid3():
+ pass
+
+
+@app.get("/valid4", responses={"500": {"model": List[Model]}})
+def valid4():
+ pass
+
+
+openapi_schema = {
+ "openapi": "3.0.2",
+ "info": {"title": "FastAPI", "version": "0.1.0"},
+ "paths": {
+ "/valid1": {
+ "get": {
+ "summary": "Valid1",
+ "operationId": "valid1_valid1_get",
+ "responses": {
+ "200": {
+ "description": "Successful Response",
+ "content": {"application/json": {"schema": {}}},
+ },
+ "500": {
+ "description": "Internal Server Error",
+ "content": {
+ "application/json": {
+ "schema": {
+ "title": "Response 500 Valid1 Valid1 Get",
+ "type": "integer",
+ }
+ }
+ },
+ },
+ },
+ }
+ },
+ "/valid2": {
+ "get": {
+ "summary": "Valid2",
+ "operationId": "valid2_valid2_get",
+ "responses": {
+ "200": {
+ "description": "Successful Response",
+ "content": {"application/json": {"schema": {}}},
+ },
+ "500": {
+ "description": "Internal Server Error",
+ "content": {
+ "application/json": {
+ "schema": {
+ "title": "Response 500 Valid2 Valid2 Get",
+ "type": "array",
+ "items": {"type": "integer"},
+ }
+ }
+ },
+ },
+ },
+ }
+ },
+ "/valid3": {
+ "get": {
+ "summary": "Valid3",
+ "operationId": "valid3_valid3_get",
+ "responses": {
+ "200": {
+ "description": "Successful Response",
+ "content": {"application/json": {"schema": {}}},
+ },
+ "500": {
+ "description": "Internal Server Error",
+ "content": {
+ "application/json": {
+ "schema": {"$ref": "#/components/schemas/Model"}
+ }
+ },
+ },
+ },
+ }
+ },
+ "/valid4": {
+ "get": {
+ "summary": "Valid4",
+ "operationId": "valid4_valid4_get",
+ "responses": {
+ "200": {
+ "description": "Successful Response",
+ "content": {"application/json": {"schema": {}}},
+ },
+ "500": {
+ "description": "Internal Server Error",
+ "content": {
+ "application/json": {
+ "schema": {
+ "title": "Response 500 Valid4 Valid4 Get",
+ "type": "array",
+ "items": {"$ref": "#/components/schemas/Model"},
+ }
+ }
+ },
+ },
+ },
+ }
+ },
+ },
+ "components": {
+ "schemas": {
+ "Model": {
+ "title": "Model",
+ "required": ["name"],
+ "type": "object",
+ "properties": {"name": {"title": "Name", "type": "string"}},
+ }
+ }
+ },
+}
+
+client = TestClient(app)
+
+
+def test_openapi_schema():
+ response = client.get("/openapi.json")
+ assert response.status_code == 200
+ assert response.json() == openapi_schema
+
+
+def test_path_operations():
+ response = client.get("/valid1")
+ assert response.status_code == 200
+ response = client.get("/valid2")
+ assert response.status_code == 200
+ response = client.get("/valid3")
+ assert response.status_code == 200
+ response = client.get("/valid4")
+ assert response.status_code == 200