import http.client
-from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, cast
+from enum import Enum
+from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast
from fastapi import routing
from fastapi.dependencies.models import Dependant
-from fastapi.dependencies.utils import get_flat_dependant
+from fastapi.dependencies.utils import get_flat_dependant, get_flat_params
from fastapi.encoders import jsonable_encoder
from fastapi.openapi.constants import (
METHODS_WITH_BODY,
from fastapi.utils import (
generate_operation_id_for_path,
get_field_info,
- get_flat_models_from_routes,
get_model_definitions,
)
from pydantic import BaseModel
-from pydantic.schema import field_schema, get_model_name_map
+from pydantic.schema import (
+ field_schema,
+ get_flat_models_from_fields,
+ get_model_name_map,
+)
from pydantic.utils import lenient_issubclass
from starlette.responses import JSONResponse
from starlette.routing import BaseRoute
}
-def get_openapi_params(dependant: Dependant) -> List[ModelField]:
- flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
- return (
- flat_dependant.path_params
- + flat_dependant.query_params
- + flat_dependant.header_params
- + flat_dependant.cookie_params
- )
-
-
def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, List]:
security_definitions = {}
operation_security = []
def get_openapi_operation_parameters(
+ *,
all_route_params: Sequence[ModelField],
+ model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str]
) -> List[Dict[str, Any]]:
parameters = []
for param in all_route_params:
field_info = get_field_info(param)
field_info = cast(Param, field_info)
+ # ignore mypy error until enum schemas are released
parameter = {
"name": param.alias,
"in": field_info.in_.value,
"required": param.required,
- "schema": field_schema(param, model_name_map={})[0],
+ "schema": field_schema(
+ param, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore
+ )[0],
}
if field_info.description:
parameter["description"] = field_info.description
def get_openapi_operation_request_body(
- *, body_field: Optional[ModelField], model_name_map: Dict[Type[BaseModel], str]
+ *,
+ body_field: Optional[ModelField],
+ model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str]
) -> Optional[Dict]:
if not body_field:
return None
assert isinstance(body_field, ModelField)
+ # ignore mypy error until enum schemas are released
body_schema, _, _ = field_schema(
- body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
+ body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore
)
field_info = cast(Body, get_field_info(body_field))
request_media_type = field_info.media_type
operation.setdefault("security", []).extend(operation_security)
if security_definitions:
security_schemes.update(security_definitions)
- all_route_params = get_openapi_params(route.dependant)
- operation_parameters = get_openapi_operation_parameters(all_route_params)
+ all_route_params = get_flat_params(route.dependant)
+ operation_parameters = get_openapi_operation_parameters(
+ all_route_params=all_route_params, model_name_map=model_name_map
+ )
parameters.extend(operation_parameters)
if parameters:
operation["parameters"] = list(
return path, security_schemes, definitions
+def get_flat_models_from_routes(
+ routes: Sequence[BaseRoute],
+) -> Set[Union[Type[BaseModel], Type[Enum]]]:
+ body_fields_from_routes: List[ModelField] = []
+ responses_from_routes: List[ModelField] = []
+ request_fields_from_routes: List[ModelField] = []
+ callback_flat_models: Set[Union[Type[BaseModel], Type[Enum]]] = set()
+ for route in routes:
+ if getattr(route, "include_in_schema", None) and isinstance(
+ route, routing.APIRoute
+ ):
+ if route.body_field:
+ assert isinstance(
+ route.body_field, ModelField
+ ), "A request body must be a Pydantic Field"
+ body_fields_from_routes.append(route.body_field)
+ if route.response_field:
+ responses_from_routes.append(route.response_field)
+ if route.response_fields:
+ responses_from_routes.extend(route.response_fields.values())
+ if route.callbacks:
+ callback_flat_models |= get_flat_models_from_routes(route.callbacks)
+ params = get_flat_params(route.dependant)
+ request_fields_from_routes.extend(params)
+
+ flat_models = callback_flat_models | get_flat_models_from_fields(
+ body_fields_from_routes + responses_from_routes + request_fields_from_routes,
+ known_models=set(),
+ )
+ return flat_models
+
+
def get_openapi(
*,
title: str,
components: Dict[str, Dict] = {}
paths: Dict[str, Dict] = {}
flat_models = get_flat_models_from_routes(routes)
- model_name_map = get_model_name_map(flat_models)
+ # ignore mypy error until enum schemas are released
+ model_name_map = get_model_name_map(flat_models) # type: ignore
+ # ignore mypy error until enum schemas are released
definitions = get_model_definitions(
- flat_models=flat_models, model_name_map=model_name_map
+ flat_models=flat_models, model_name_map=model_name_map # type: ignore
)
for route in routes:
if isinstance(route, routing.APIRoute):
import functools
import re
from dataclasses import is_dataclass
-from typing import Any, Dict, List, Optional, Sequence, Set, Type, Union, cast
+from enum import Enum
+from typing import Any, Dict, Optional, Set, Type, Union, cast
import fastapi
-from fastapi import routing
from fastapi.logger import logger
from fastapi.openapi.constants import REF_PREFIX
from pydantic import BaseConfig, BaseModel, create_model
from pydantic.class_validators import Validator
-from pydantic.schema import get_flat_models_from_fields, model_process_schema
+from pydantic.schema import model_process_schema
from pydantic.utils import lenient_issubclass
-from starlette.routing import BaseRoute
try:
from pydantic.fields import FieldInfo, ModelField, UndefinedType
)
-def get_flat_models_from_routes(routes: Sequence[BaseRoute]) -> Set[Type[BaseModel]]:
- body_fields_from_routes: List[ModelField] = []
- responses_from_routes: List[ModelField] = []
- callback_flat_models: Set[Type[BaseModel]] = set()
- for route in routes:
- if getattr(route, "include_in_schema", None) and isinstance(
- route, routing.APIRoute
- ):
- if route.body_field:
- assert isinstance(
- route.body_field, ModelField
- ), "A request body must be a Pydantic Field"
- body_fields_from_routes.append(route.body_field)
- if route.response_field:
- responses_from_routes.append(route.response_field)
- if route.response_fields:
- responses_from_routes.extend(route.response_fields.values())
- if route.callbacks:
- callback_flat_models |= get_flat_models_from_routes(route.callbacks)
- flat_models = callback_flat_models | get_flat_models_from_fields(
- body_fields_from_routes + responses_from_routes, known_models=set()
- )
- return flat_models
-
-
def get_model_definitions(
- *, flat_models: Set[Type[BaseModel]], model_name_map: Dict[Type[BaseModel], str]
+ *,
+ flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
+ model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
) -> Dict[str, Any]:
definitions: Dict[str, Dict] = {}
for model in flat_models:
+ # ignore mypy error until enum schemas are released
m_schema, m_definitions, m_nested_models = model_process_schema(
- model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
+ model, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore
)
definitions.update(m_definitions)
model_name = model_name_map[model]