-from typing import Any, Callable, Dict, List, Type
+from typing import Any, Callable, Dict, List, Optional, Type
+from pydantic import BaseModel
from starlette.applications import Starlette
from starlette.exceptions import ExceptionMiddleware, HTTPException
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.middleware.lifespan import LifespanMiddleware
-from starlette.responses import JSONResponse
-
+from starlette.requests import Request
+from starlette.responses import JSONResponse, Response
from fastapi import routing
-from fastapi.openapi.utils import get_openapi
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
+from fastapi.openapi.utils import get_openapi
-async def http_exception(request, exc: HTTPException):
- print(exc)
+async def http_exception(request: Request, exc: HTTPException) -> JSONResponse:
return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
**extra: Dict[str, Any],
) -> None:
self._debug = debug
- self.router = routing.APIRouter()
+ self.router: routing.APIRouter = routing.APIRouter()
self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
self.error_middleware = ServerErrorMiddleware(
self.exception_middleware, debug=debug
if self.swagger_ui_url or self.redoc_url:
assert self.openapi_url, "The openapi_url is required for the docs"
+ self.openapi_schema: Optional[Dict[str, Any]] = None
self.setup()
- def setup(self):
+ def openapi(self) -> Dict:
+ if not self.openapi_schema:
+ self.openapi_schema = get_openapi(
+ title=self.title,
+ version=self.version,
+ openapi_version=self.openapi_version,
+ description=self.description,
+ routes=self.routes,
+ )
+ return self.openapi_schema
+
+ def setup(self) -> None:
if self.openapi_url:
self.add_route(
self.openapi_url,
- lambda req: JSONResponse(
- get_openapi(
- title=self.title,
- version=self.version,
- openapi_version=self.openapi_version,
- description=self.description,
- routes=self.routes,
- )
- ),
+ lambda req: JSONResponse(self.openapi()),
include_in_schema=False,
)
if self.swagger_ui_url:
self.add_route(
self.swagger_ui_url,
- lambda r: get_swagger_ui_html(openapi_url=self.openapi_url, title=self.title + " - Swagger UI"),
+ lambda r: get_swagger_ui_html(
+ openapi_url=self.openapi_url, title=self.title + " - Swagger UI"
+ ),
include_in_schema=False,
)
if self.redoc_url:
self.add_route(
self.redoc_url,
- lambda r: get_redoc_html(openapi_url=self.openapi_url, title=self.title + " - ReDoc"),
+ lambda r: get_redoc_html(
+ openapi_url=self.openapi_url, title=self.title + " - ReDoc"
+ ),
include_in_schema=False,
)
self.add_exception_handler(HTTPException, http_exception)
self,
path: str,
endpoint: Callable,
- methods: List[str] = None,
- name: str = None,
- include_in_schema: bool = True,
+ *,
+ response_model: Type[BaseModel] = None,
+ status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
- operation_id: str = None,
- deprecated: bool = None,
- response_type: Type = None,
response_description: str = "Successful Response",
- response_code=200,
- response_wrapper=JSONResponse,
+ deprecated: bool = None,
+ name: str = None,
+ methods: List[str] = None,
+ operation_id: str = None,
+ include_in_schema: bool = True,
+ content_type: Type[Response] = JSONResponse,
) -> None:
self.router.add_api_route(
path,
endpoint=endpoint,
- methods=methods,
- name=name,
- include_in_schema=include_in_schema,
+ response_model=response_model,
+ status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
- operation_id=operation_id,
- deprecated=deprecated,
- response_type=response_type,
response_description=response_description,
- response_code=response_code,
- response_wrapper=response_wrapper,
+ deprecated=deprecated,
+ name=name,
+ methods=methods,
+ operation_id=operation_id,
+ include_in_schema=include_in_schema,
+ content_type=content_type,
)
def api_route(
self,
path: str,
- methods: List[str] = None,
- name: str = None,
- include_in_schema: bool = True,
+ *,
+ response_model: Type[BaseModel] = None,
+ status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
- operation_id: str = None,
- deprecated: bool = None,
- response_type: Type = None,
response_description: str = "Successful Response",
- response_code=200,
- response_wrapper=JSONResponse,
+ deprecated: bool = None,
+ name: str = None,
+ methods: List[str] = None,
+ operation_id: str = None,
+ include_in_schema: bool = True,
+ content_type: Type[Response] = JSONResponse,
) -> Callable:
def decorator(func: Callable) -> Callable:
self.router.add_api_route(
path,
func,
- methods=methods,
- name=name,
- include_in_schema=include_in_schema,
+ response_model=response_model,
+ status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
- operation_id=operation_id,
- deprecated=deprecated,
- response_type=response_type,
response_description=response_description,
- response_code=response_code,
- response_wrapper=response_wrapper,
+ deprecated=deprecated,
+ name=name,
+ methods=methods,
+ operation_id=operation_id,
+ include_in_schema=include_in_schema,
+ content_type=content_type,
)
return func
+
return decorator
-
- def include_router(self, router: "APIRouter", *, prefix=""):
+
+ def include_router(self, router: routing.APIRouter, *, prefix: str = "") -> None:
self.router.include_router(router, prefix=prefix)
def get(
self,
path: str,
- name: str = None,
- include_in_schema: bool = True,
+ *,
+ response_model: Type[BaseModel] = None,
+ status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
- operation_id: str = None,
- deprecated: bool = None,
- response_type: Type = None,
response_description: str = "Successful Response",
- response_code=200,
- response_wrapper=JSONResponse,
- ):
+ deprecated: bool = None,
+ name: str = None,
+ operation_id: str = None,
+ include_in_schema: bool = True,
+ content_type: Type[Response] = JSONResponse,
+ ) -> Callable:
return self.router.get(
- path=path,
- name=name,
- include_in_schema=include_in_schema,
+ path,
+ response_model=response_model,
+ status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
- operation_id=operation_id,
- deprecated=deprecated,
- response_type=response_type,
response_description=response_description,
- response_code=response_code,
- response_wrapper=response_wrapper,
+ deprecated=deprecated,
+ name=name,
+ operation_id=operation_id,
+ include_in_schema=include_in_schema,
+ content_type=content_type,
)
def put(
self,
path: str,
- name: str = None,
- include_in_schema: bool = True,
+ *,
+ response_model: Type[BaseModel] = None,
+ status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
- operation_id: str = None,
- deprecated: bool = None,
- response_type: Type = None,
response_description: str = "Successful Response",
- response_code=200,
- response_wrapper=JSONResponse,
- ):
+ deprecated: bool = None,
+ name: str = None,
+ operation_id: str = None,
+ include_in_schema: bool = True,
+ content_type: Type[Response] = JSONResponse,
+ ) -> Callable:
return self.router.put(
- path=path,
- name=name,
- include_in_schema=include_in_schema,
+ path,
+ response_model=response_model,
+ status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
- operation_id=operation_id,
- deprecated=deprecated,
- response_type=response_type,
response_description=response_description,
- response_code=response_code,
- response_wrapper=response_wrapper,
+ deprecated=deprecated,
+ name=name,
+ operation_id=operation_id,
+ include_in_schema=include_in_schema,
+ content_type=content_type,
)
def post(
self,
path: str,
- name: str = None,
- include_in_schema: bool = True,
+ *,
+ response_model: Type[BaseModel] = None,
+ status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
- operation_id: str = None,
- deprecated: bool = None,
- response_type: Type = None,
response_description: str = "Successful Response",
- response_code=200,
- response_wrapper=JSONResponse,
- ):
+ deprecated: bool = None,
+ name: str = None,
+ operation_id: str = None,
+ include_in_schema: bool = True,
+ content_type: Type[Response] = JSONResponse,
+ ) -> Callable:
return self.router.post(
- path=path,
- name=name,
- include_in_schema=include_in_schema,
+ path,
+ response_model=response_model,
+ status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
- operation_id=operation_id,
- deprecated=deprecated,
- response_type=response_type,
response_description=response_description,
- response_code=response_code,
- response_wrapper=response_wrapper,
+ deprecated=deprecated,
+ name=name,
+ operation_id=operation_id,
+ include_in_schema=include_in_schema,
+ content_type=content_type,
)
def delete(
self,
path: str,
- name: str = None,
- include_in_schema: bool = True,
+ *,
+ response_model: Type[BaseModel] = None,
+ status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
- operation_id: str = None,
- deprecated: bool = None,
- response_type: Type = None,
response_description: str = "Successful Response",
- response_code=200,
- response_wrapper=JSONResponse,
- ):
+ deprecated: bool = None,
+ name: str = None,
+ operation_id: str = None,
+ include_in_schema: bool = True,
+ content_type: Type[Response] = JSONResponse,
+ ) -> Callable:
return self.router.delete(
- path=path,
- name=name,
- include_in_schema=include_in_schema,
+ path,
+ response_model=response_model,
+ status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
- operation_id=operation_id,
- deprecated=deprecated,
- response_type=response_type,
response_description=response_description,
- response_code=response_code,
- response_wrapper=response_wrapper,
+ deprecated=deprecated,
+ name=name,
+ operation_id=operation_id,
+ include_in_schema=include_in_schema,
+ content_type=content_type,
)
def options(
self,
path: str,
- name: str = None,
- include_in_schema: bool = True,
+ *,
+ response_model: Type[BaseModel] = None,
+ status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
- operation_id: str = None,
- deprecated: bool = None,
- response_type: Type = None,
response_description: str = "Successful Response",
- response_code=200,
- response_wrapper=JSONResponse,
- ):
+ deprecated: bool = None,
+ name: str = None,
+ operation_id: str = None,
+ include_in_schema: bool = True,
+ content_type: Type[Response] = JSONResponse,
+ ) -> Callable:
return self.router.options(
- path=path,
- name=name,
- include_in_schema=include_in_schema,
+ path,
+ response_model=response_model,
+ status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
- operation_id=operation_id,
- deprecated=deprecated,
- response_type=response_type,
response_description=response_description,
- response_code=response_code,
- response_wrapper=response_wrapper,
+ deprecated=deprecated,
+ name=name,
+ operation_id=operation_id,
+ include_in_schema=include_in_schema,
+ content_type=content_type,
)
def head(
self,
path: str,
- name: str = None,
- include_in_schema: bool = True,
+ *,
+ response_model: Type[BaseModel] = None,
+ status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
- operation_id: str = None,
- deprecated: bool = None,
- response_type: Type = None,
response_description: str = "Successful Response",
- response_code=200,
- response_wrapper=JSONResponse,
- ):
+ deprecated: bool = None,
+ name: str = None,
+ operation_id: str = None,
+ include_in_schema: bool = True,
+ content_type: Type[Response] = JSONResponse,
+ ) -> Callable:
return self.router.head(
- path=path,
- name=name,
- include_in_schema=include_in_schema,
+ path,
+ response_model=response_model,
+ status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
- operation_id=operation_id,
- deprecated=deprecated,
- response_type=response_type,
response_description=response_description,
- response_code=response_code,
- response_wrapper=response_wrapper,
+ deprecated=deprecated,
+ name=name,
+ operation_id=operation_id,
+ include_in_schema=include_in_schema,
+ content_type=content_type,
)
def patch(
self,
path: str,
- name: str = None,
- include_in_schema: bool = True,
+ *,
+ response_model: Type[BaseModel] = None,
+ status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
- operation_id: str = None,
- deprecated: bool = None,
- response_type: Type = None,
response_description: str = "Successful Response",
- response_code=200,
- response_wrapper=JSONResponse,
- ):
+ deprecated: bool = None,
+ name: str = None,
+ operation_id: str = None,
+ include_in_schema: bool = True,
+ content_type: Type[Response] = JSONResponse,
+ ) -> Callable:
return self.router.patch(
- path=path,
- name=name,
- include_in_schema=include_in_schema,
+ path,
+ response_model=response_model,
+ status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
- operation_id=operation_id,
- deprecated=deprecated,
- response_type=response_type,
response_description=response_description,
- response_code=response_code,
- response_wrapper=response_wrapper,
+ deprecated=deprecated,
+ name=name,
+ operation_id=operation_id,
+ include_in_schema=include_in_schema,
+ content_type=content_type,
)
def trace(
self,
path: str,
- name: str = None,
- include_in_schema: bool = True,
+ *,
+ response_model: Type[BaseModel] = None,
+ status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
- operation_id: str = None,
- deprecated: bool = None,
- response_type: Type = None,
response_description: str = "Successful Response",
- response_code=200,
- response_wrapper=JSONResponse,
- ):
+ deprecated: bool = None,
+ name: str = None,
+ operation_id: str = None,
+ include_in_schema: bool = True,
+ content_type: Type[Response] = JSONResponse,
+ ) -> Callable:
return self.router.trace(
- path=path,
- name=name,
- include_in_schema=include_in_schema,
+ path,
+ response_model=response_model,
+ status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
- operation_id=operation_id,
- deprecated=deprecated,
- response_type=response_type,
response_description=response_description,
- response_code=response_code,
- response_wrapper=response_wrapper,
+ deprecated=deprecated,
+ name=name,
+ operation_id=operation_id,
+ include_in_schema=include_in_schema,
+ content_type=content_type,
)
from typing import Any, Callable, Dict, List, Sequence, Tuple
-from starlette.concurrency import run_in_threadpool
-from starlette.requests import Request
-
-from fastapi.security.base import SecurityBase
from pydantic import BaseConfig, Schema
from pydantic.error_wrappers import ErrorWrapper
from pydantic.errors import MissingError
from pydantic.fields import Field, Required
from pydantic.schema import get_annotation_from_schema
+from starlette.concurrency import run_in_threadpool
+from starlette.requests import Request
+
+from fastapi.security.base import SecurityBase
param_supported_types = (str, int, float, bool)
import asyncio
import inspect
from copy import deepcopy
-from typing import Any, Callable, Dict, List, Tuple
+from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Type
+from pydantic import BaseConfig, Schema, create_model
+from pydantic.error_wrappers import ErrorWrapper
+from pydantic.errors import MissingError
+from pydantic.fields import Field, Required
+from pydantic.schema import get_annotation_from_schema
+from pydantic.utils import lenient_issubclass
from starlette.concurrency import run_in_threadpool
from starlette.requests import Request
from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.security.base import SecurityBase
from fastapi.utils import get_path_param_names
-from pydantic import BaseConfig, Schema, create_model
-from pydantic.error_wrappers import ErrorWrapper
-from pydantic.errors import MissingError
-from pydantic.fields import Field, Required
-from pydantic.schema import get_annotation_from_schema
-from pydantic.utils import lenient_issubclass
param_supported_types = (str, int, float, bool)
-def get_sub_dependant(*, param: inspect.Parameter, path: str):
+def get_sub_dependant(*, param: inspect.Parameter, path: str) -> Dependant:
depends: params.Depends = param.default
if depends.dependency:
dependency = depends.dependency
return sub_dependant
-def get_flat_dependant(dependant: Dependant):
+def get_flat_dependant(dependant: Dependant) -> Dependant:
flat_dependant = Dependant(
path_params=dependant.path_params.copy(),
query_params=dependant.query_params.copy(),
return flat_dependant
-def get_dependant(*, path: str, call: Callable, name: str = None):
+def get_dependant(*, path: str, call: Callable, name: str = None) -> Dependant:
path_param_names = get_path_param_names(path)
endpoint_signature = inspect.signature(call)
signature_params = endpoint_signature.parameters
if (
(param.default == param.empty) or isinstance(param.default, params.Path)
) and (param_name in path_param_names):
- assert lenient_issubclass(
- param.annotation, param_supported_types
- ) or param.annotation == param.empty, f"Path params must be of type str, int, float or boot: {param}"
+ assert (
+ lenient_issubclass(param.annotation, param_supported_types)
+ or param.annotation == param.empty
+ ), f"Path params must be of type str, int, float or boot: {param}"
param = signature_params[param_name]
add_param_to_fields(
param=param,
*,
param: inspect.Parameter,
dependant: Dependant,
- default_schema=params.Param,
+ default_schema: Type[Schema] = params.Param,
force_type: params.ParamTypes = None,
-):
+) -> None:
default_value = Required
if not param.default == param.empty:
default_value = param.default
else:
schema = default_schema(default_value)
required = default_value == Required
- annotation = Any
+ annotation: Type = Type[Any]
if not param.annotation == param.empty:
annotation = param.annotation
annotation = get_annotation_from_schema(annotation, schema)
+ if not schema.alias and getattr(schema, "alias_underscore_to_hyphen", None):
+ alias = param.name.replace("_", "-")
+ else:
+ alias = schema.alias or param.name
field = Field(
name=param.name,
type_=annotation,
default=None if required else default_value,
- alias=schema.alias or param.name,
+ alias=alias,
required=required,
model_config=BaseConfig(),
class_validators=[],
dependant.cookie_params.append(field)
-def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant):
+def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant) -> None:
default_value = Required
if not param.default == param.empty:
default_value = param.default
dependant.body_params.append(field)
-def is_coroutine_callable(call: Callable = None):
+def is_coroutine_callable(call: Callable = None) -> bool:
if not call:
return False
if inspect.isfunction(call):
async def solve_dependencies(
*, request: Request, dependant: Dependant, body: Dict[str, Any] = None
-):
+) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
values: Dict[str, Any] = {}
errors: List[ErrorWrapper] = []
for sub_dependant in dependant.dependencies:
)
if sub_errors:
return {}, errors
- if sub_dependant.call and is_coroutine_callable(sub_dependant.call):
+ assert sub_dependant.call is not None, "sub_dependant.call must be a function"
+ if is_coroutine_callable(sub_dependant.call):
solved = await sub_dependant.call(**sub_values)
else:
solved = await run_in_threadpool(sub_dependant.call, **sub_values)
- values[
- sub_dependant.name
- ] = solved # type: ignore # Sub-dependants always have a name
+ assert sub_dependant.name is not None, "Subdependants always have a name"
+ values[sub_dependant.name] = solved
path_values, path_errors = request_params_to_args(
dependant.path_params, request.path_params
)
def request_params_to_args(
- required_params: List[Field], received_params: Dict[str, Any]
+ required_params: Sequence[Field], received_params: Mapping[str, Any]
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
values = {}
errors = []
else:
values[field.name] = deepcopy(field.default)
continue
- v_, errors_ = field.validate(
- value, values, loc=(field.schema.in_.value, field.alias)
- )
+ schema: params.Param = field.schema
+ assert isinstance(schema, params.Param), "Params must be subclasses of Param"
+ v_, errors_ = field.validate(value, values, loc=(schema.in_.value, field.alias))
if isinstance(errors_, ErrorWrapper):
errors.append(errors_)
elif isinstance(errors_, list):
return values, errors
-def get_body_field(*, dependant: Dependant, name: str):
+def get_body_field(*, dependant: Dependant, name: str) -> Field:
flat_dependant = get_flat_dependant(dependant)
if not flat_dependant.body_params:
return None
BodyModel.__fields__[f.name] = f
required = any(True for f in flat_dependant.body_params if f.required)
if any(isinstance(f.schema, params.File) for f in flat_dependant.body_params):
- BodySchema = params.File
+ BodySchema: Type[params.Body] = params.File
elif any(isinstance(f.schema, params.Form) for f in flat_dependant.body_params):
BodySchema = params.Form
else:
from enum import Enum
from types import GeneratorType
-from typing import Set
+from typing import Any, Set
from pydantic import BaseModel
from pydantic.json import pydantic_encoder
def jsonable_encoder(
- obj,
+ obj: Any,
include: Set[str] = None,
exclude: Set[str] = set(),
by_alias: bool = False,
- include_none=True,
-):
+ include_none: bool = True,
+) -> Any:
if isinstance(obj, BaseModel):
return jsonable_encoder(
obj.dict(include=include, exclude=exclude, by_alias=by_alias),
from starlette.responses import HTMLResponse
-def get_swagger_ui_html(*, openapi_url: str, title: str):
+
+def get_swagger_ui_html(*, openapi_url: str, title: str) -> HTMLResponse:
return HTMLResponse(
"""
<! doctype html>
</script>
</body>
</html>
- """,
- media_type="text/html",
+ """
)
-def get_redoc_html(*, openapi_url: str, title: str):
+def get_redoc_html(*, openapi_url: str, title: str) -> HTMLResponse:
return HTMLResponse(
"""
<!DOCTYPE html>
<script src="https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js"> </script>
</body>
</html>
- """,
- media_type="text/html",
+ """
)
try:
import pydantic.types.EmailStr
- from pydantic.types import EmailStr
+ from pydantic.types import EmailStr # type: ignore
except ImportError:
logging.warning(
"email-validator not installed, email fields will be treated as str"
)
- class EmailStr(str):
+ class EmailStr(str): # type: ignore
pass
class Reference(BaseModel):
- ref: str = PSchema(..., alias="$ref")
+ ref: str = PSchema(..., alias="$ref") # type: ignore
class Discriminator(BaseModel):
class SchemaBase(BaseModel):
- ref: Optional[str] = PSchema(None, alias="$ref")
+ ref: Optional[str] = PSchema(None, alias="$ref") # type: ignore
title: Optional[str] = None
multipleOf: Optional[float] = None
maximum: Optional[float] = None
exclusiveMaximum: Optional[float] = None
minimum: Optional[float] = None
exclusiveMinimum: Optional[float] = None
- maxLength: Optional[int] = PSchema(None, gte=0)
- minLength: Optional[int] = PSchema(None, gte=0)
+ maxLength: Optional[int] = PSchema(None, gte=0) # type: ignore
+ minLength: Optional[int] = PSchema(None, gte=0) # type: ignore
pattern: Optional[str] = None
- maxItems: Optional[int] = PSchema(None, gte=0)
- minItems: Optional[int] = PSchema(None, gte=0)
+ maxItems: Optional[int] = PSchema(None, gte=0) # type: ignore
+ minItems: Optional[int] = PSchema(None, gte=0) # type: ignore
uniqueItems: Optional[bool] = None
- maxProperties: Optional[int] = PSchema(None, gte=0)
- minProperties: Optional[int] = PSchema(None, gte=0)
+ maxProperties: Optional[int] = PSchema(None, gte=0) # type: ignore
+ minProperties: Optional[int] = PSchema(None, gte=0) # type: ignore
required: Optional[List[str]] = None
enum: Optional[List[str]] = None
type: Optional[str] = None
allOf: Optional[List[Any]] = None
oneOf: Optional[List[Any]] = None
anyOf: Optional[List[Any]] = None
- not_: Optional[List[Any]] = PSchema(None, alias="not")
+ not_: Optional[List[Any]] = PSchema(None, alias="not") # type: ignore
items: Optional[Any] = None
properties: Optional[Dict[str, Any]] = None
additionalProperties: Optional[Union[bool, Any]] = None
allOf: Optional[List[SchemaBase]] = None
oneOf: Optional[List[SchemaBase]] = None
anyOf: Optional[List[SchemaBase]] = None
- not_: Optional[List[SchemaBase]] = PSchema(None, alias="not")
+ not_: Optional[List[SchemaBase]] = PSchema(None, alias="not") # type: ignore
items: Optional[SchemaBase] = None
properties: Optional[Dict[str, SchemaBase]] = None
additionalProperties: Optional[Union[bool, SchemaBase]] = None
class MediaType(BaseModel):
- schema_: Optional[Union[Schema, Reference]] = PSchema(None, alias="schema")
+ schema_: Optional[Union[Schema, Reference]] = PSchema(
+ None, alias="schema"
+ ) # type: ignore
example: Optional[Any] = None
examples: Optional[Dict[str, Union[Example, Reference]]] = None
encoding: Optional[Dict[str, Encoding]] = None
style: Optional[str] = None
explode: Optional[bool] = None
allowReserved: Optional[bool] = None
- schema_: Optional[Union[Schema, Reference]] = PSchema(None, alias="schema")
+ schema_: Optional[Union[Schema, Reference]] = PSchema(
+ None, alias="schema"
+ ) # type: ignore
example: Optional[Any] = None
examples: Optional[Dict[str, Union[Example, Reference]]] = None
# Serialization rules for more complex scenarios
class Parameter(ParameterBase):
name: str
- in_: ParameterInType = PSchema(..., alias="in")
+ in_: ParameterInType = PSchema(..., alias="in") # type: ignore
class Header(ParameterBase):
class PathItem(BaseModel):
- ref: Optional[str] = PSchema(None, alias="$ref")
+ ref: Optional[str] = PSchema(None, alias="$ref") # type: ignore
summary: Optional[str] = None
description: Optional[str] = None
get: Optional[Operation] = None
class SecurityBase(BaseModel):
- type_: SecuritySchemeType = PSchema(..., alias="type")
+ type_: SecuritySchemeType = PSchema(..., alias="type") # type: ignore
description: Optional[str] = None
class APIKey(SecurityBase):
- type_ = PSchema(SecuritySchemeType.apiKey, alias="type")
- in_: APIKeyIn = PSchema(..., alias="in")
+ type_ = PSchema(SecuritySchemeType.apiKey, alias="type") # type: ignore
+ in_: APIKeyIn = PSchema(..., alias="in") # type: ignore
name: str
class HTTPBase(SecurityBase):
- type_ = PSchema(SecuritySchemeType.http, alias="type")
+ type_ = PSchema(SecuritySchemeType.http, alias="type") # type: ignore
scheme: str
class OAuth2(SecurityBase):
- type_ = PSchema(SecuritySchemeType.oauth2, alias="type")
+ type_ = PSchema(SecuritySchemeType.oauth2, alias="type") # type: ignore
flows: OAuthFlows
class OpenIdConnect(SecurityBase):
- type_ = PSchema(SecuritySchemeType.openIdConnect, alias="type")
+ type_ = PSchema(SecuritySchemeType.openIdConnect, alias="type") # type: ignore
openIdConnectUrl: str
-from typing import Any, Dict, Sequence, Type, List
+from typing import Any, Dict, List, Optional, Sequence, Tuple, Type
from pydantic.fields import Field
-from pydantic.schema import field_schema, get_model_name_map
+from pydantic.schema import Schema, field_schema, get_model_name_map
from pydantic.utils import lenient_issubclass
-
from starlette.responses import HTMLResponse, JSONResponse
-from starlette.routing import BaseRoute
+from starlette.routing import BaseRoute, Route
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
from fastapi import routing
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import get_flat_dependant
from fastapi.encoders import jsonable_encoder
-from fastapi.openapi.constants import REF_PREFIX, METHODS_WITH_BODY
+from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX
from fastapi.openapi.models import OpenAPI
-from fastapi.params import Body
+from fastapi.params import Body, Param
from fastapi.utils import get_flat_models_from_routes, get_model_definitions
-
validation_error_definition = {
"title": "ValidationError",
"type": "object",
}
-def get_openapi_params(dependant: Dependant):
+def get_openapi_params(dependant: Dependant) -> List[Field]:
flat_dependant = get_flat_dependant(dependant)
return (
flat_dependant.path_params
)
-def get_openapi_security_definitions(flat_dependant: Dependant):
+def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, List]:
security_definitions = {}
operation_security = []
for security_requirement in flat_dependant.security_requirements:
by_alias=True,
include_none=False,
)
- security_name = (
- security_requirement.security_scheme.scheme_name
-
- )
+ security_name = security_requirement.security_scheme.scheme_name
security_definitions[security_name] = security_definition
operation_security.append({security_name: security_requirement.scopes})
return security_definitions, operation_security
-def get_openapi_operation_parameters(all_route_params: List[Field]):
+def get_openapi_operation_parameters(
+ all_route_params: Sequence[Field]
+) -> Tuple[Dict[str, Dict], List[Dict[str, Any]]]:
definitions: Dict[str, Dict] = {}
parameters = []
for param in all_route_params:
+ schema: Param = param.schema
if "ValidationError" not in definitions:
definitions["ValidationError"] = validation_error_definition
definitions["HTTPValidationError"] = validation_error_response_definition
parameter = {
"name": param.alias,
- "in": param.schema.in_.value,
+ "in": schema.in_.value,
"required": param.required,
"schema": field_schema(param, model_name_map={})[0],
}
- if param.schema.description:
- parameter["description"] = param.schema.description
- if param.schema.deprecated:
- parameter["deprecated"] = param.schema.deprecated
+ if schema.description:
+ parameter["description"] = schema.description
+ if schema.deprecated:
+ parameter["deprecated"] = schema.deprecated
parameters.append(parameter)
return definitions, parameters
def get_openapi_operation_request_body(
*, body_field: Field, model_name_map: Dict[Type, str]
-):
+) -> Optional[Dict]:
if not body_field:
return None
assert isinstance(body_field, Field)
body_schema, _ = field_schema(
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)
- if isinstance(body_field.schema, Body):
- request_media_type = body_field.schema.media_type
+ schema: Schema = body_field.schema
+ if isinstance(schema, Body):
+ request_media_type = schema.media_type
else:
# Includes not declared media types (Schema)
request_media_type = "application/json"
required = body_field.required
- request_body_oai = {}
+ request_body_oai: Dict[str, Any] = {}
if required:
request_body_oai["required"] = required
request_body_oai["content"] = {request_media_type: {"schema": body_schema}}
return request_body_oai
-def generate_operation_id(*, route: routing.APIRoute, method: str):
+def generate_operation_id(*, route: routing.APIRoute, method: str) -> str:
if route.operation_id:
return route.operation_id
path: str = route.path
return operation_id
-def generate_operation_summary(*, route: routing.APIRoute, method: str):
+def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str:
if route.summary:
return route.summary
return method.title() + " " + route.name.replace("_", " ").title()
-def get_openapi_operation_metadata(*, route: BaseRoute, method: str):
+
+def get_openapi_operation_metadata(*, route: routing.APIRoute, method: str) -> Dict:
operation: Dict[str, Any] = {}
if route.tags:
operation["tags"] = route.tags
return operation
-def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]):
- if not (route.include_in_schema and isinstance(route, routing.APIRoute)):
- return None
+def get_openapi_path(
+ *, route: routing.APIRoute, model_name_map: Dict[Type, str]
+) -> Tuple[Dict, Dict, Dict]:
path = {}
security_schemes: Dict[str, Any] = {}
definitions: Dict[str, Any] = {}
+ assert route.methods is not None, "Methods must be a list"
for method in route.methods:
operation = get_openapi_operation_metadata(route=route, method=method)
parameters: List[Dict] = []
)
if request_body_oai:
operation["requestBody"] = request_body_oai
- response_code = str(route.response_code)
+ status_code = str(route.status_code)
response_schema = {"type": "string"}
- if lenient_issubclass(route.response_wrapper, JSONResponse):
- response_media_type = "application/json"
+ if lenient_issubclass(route.content_type, JSONResponse):
if route.response_field:
response_schema, _ = field_schema(
route.response_field,
)
else:
response_schema = {}
- elif lenient_issubclass(route.response_wrapper, HTMLResponse):
- response_media_type = "text/html"
- else:
- response_media_type = "text/plain"
- content = {response_media_type: {"schema": response_schema}}
+ content = {route.content_type.media_type: {"schema": response_schema}}
operation["responses"] = {
- response_code: {
- "description": route.response_description,
- "content": content,
- }
+ status_code: {"description": route.response_description, "content": content}
}
if all_route_params or route.body_field:
operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = {
openapi_version: str = "3.0.2",
description: str = None,
routes: Sequence[BaseRoute]
-):
+) -> Dict:
info = {"title": title, "version": version}
if description:
info["description"] = description
flat_models=flat_models, model_name_map=model_name_map
)
for route in routes:
- result = get_openapi_path(route=route, model_name_map=model_name_map)
- if result:
- path, security_schemes, path_definitions = result
- if path:
- paths.setdefault(route.path, {}).update(path)
- if security_schemes:
- components.setdefault("securitySchemes", {}).update(security_schemes)
- if path_definitions:
- definitions.update(path_definitions)
+ if isinstance(route, routing.APIRoute):
+ result = get_openapi_path(route=route, model_name_map=model_name_map)
+ if result:
+ path, security_schemes, path_definitions = result
+ if path:
+ paths.setdefault(route.path, {}).update(path)
+ if security_schemes:
+ components.setdefault("securitySchemes", {}).update(
+ security_schemes
+ )
+ if path_definitions:
+ definitions.update(path_definitions)
if definitions:
components.setdefault("schemas", {}).update(definitions)
if components:
from enum import Enum
-from typing import Sequence, Any, Dict
+from typing import Any, Callable, Sequence
from pydantic import Schema
def __init__(
self,
- default,
+ default: Any,
*,
deprecated: bool = None,
alias: str = None,
min_length: int = None,
max_length: int = None,
regex: str = None,
- **extra: Dict[str, Any],
+ **extra: Any,
):
self.deprecated = deprecated
super().__init__(
def __init__(
self,
- default,
+ default: Any,
*,
deprecated: bool = None,
alias: str = None,
min_length: int = None,
max_length: int = None,
regex: str = None,
- **extra: Dict[str, Any],
+ **extra: Any,
):
self.description = description
self.deprecated = deprecated
def __init__(
self,
- default,
+ default: Any,
*,
deprecated: bool = None,
alias: str = None,
min_length: int = None,
max_length: int = None,
regex: str = None,
- **extra: Dict[str, Any],
+ **extra: Any,
):
self.description = description
self.deprecated = deprecated
def __init__(
self,
- default,
+ default: Any,
*,
deprecated: bool = None,
alias: str = None,
+ alias_underscore_to_hyphen: bool = True,
title: str = None,
description: str = None,
gt: float = None,
min_length: int = None,
max_length: int = None,
regex: str = None,
- **extra: Dict[str, Any],
+ **extra: Any,
):
self.description = description
self.deprecated = deprecated
+ self.alias_underscore_to_hyphen = alias_underscore_to_hyphen
super().__init__(
default,
alias=alias,
def __init__(
self,
- default,
+ default: Any,
*,
deprecated: bool = None,
alias: str = None,
min_length: int = None,
max_length: int = None,
regex: str = None,
- **extra: Dict[str, Any],
+ **extra: Any,
):
self.description = description
self.deprecated = deprecated
class Body(Schema):
def __init__(
self,
- default,
+ default: Any,
*,
- embed=False,
+ embed: bool = False,
media_type: str = "application/json",
alias: str = None,
title: str = None,
min_length: int = None,
max_length: int = None,
regex: str = None,
- **extra: Dict[str, Any],
+ **extra: Any,
):
self.embed = embed
self.media_type = media_type
class Form(Body):
def __init__(
self,
- default,
+ default: Any,
*,
- sub_key=False,
+ sub_key: bool = False,
media_type: str = "application/x-www-form-urlencoded",
alias: str = None,
title: str = None,
min_length: int = None,
max_length: int = None,
regex: str = None,
- **extra: Dict[str, Any],
+ **extra: Any,
):
super().__init__(
default,
class File(Form):
def __init__(
self,
- default,
+ default: Any,
*,
- sub_key=False,
+ sub_key: bool = False,
media_type: str = "multipart/form-data",
alias: str = None,
title: str = None,
min_length: int = None,
max_length: int = None,
regex: str = None,
- **extra: Dict[str, Any],
+ **extra: Any,
):
super().__init__(
default,
class Depends:
- def __init__(self, dependency=None):
+ def __init__(self, dependency: Callable = None):
self.dependency = dependency
class Security(Depends):
- def __init__(self, dependency=None, scopes: Sequence[str] = None):
+ def __init__(self, dependency: Callable = None, scopes: Sequence[str] = None):
self.scopes = scopes or []
super().__init__(dependency=dependency)
import asyncio
import inspect
-from typing import Callable, List, Type
+from typing import Any, Callable, List, Optional, Type
from pydantic import BaseConfig, BaseModel, Schema
from pydantic.error_wrappers import ErrorWrapper, ValidationError
from pydantic.fields import Field
from pydantic.utils import lenient_issubclass
-
from starlette import routing
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from fastapi.encoders import jsonable_encoder
-def serialize_response(*, field: Field = None, response):
+def serialize_response(*, field: Field = None, response: Response) -> Any:
if field:
errors = []
value, errors_ = field.validate(response, {}, loc=("response",))
def get_app(
dependant: Dependant,
body_field: Field = None,
- response_code: str = 200,
- response_wrapper: Type[Response] = JSONResponse,
- response_field: Type[Field] = None,
-):
- is_coroutine = dependant.call and asyncio.iscoroutinefunction(dependant.call)
+ status_code: int = 200,
+ content_type: Type[Response] = JSONResponse,
+ response_field: Field = None,
+) -> Callable:
+ assert dependant.call is not None, "dependant.call must me a function"
+ is_coroutine = asyncio.iscoroutinefunction(dependant.call)
is_body_form = body_field and isinstance(body_field.schema, params.Form)
async def app(request: Request) -> Response:
status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=errors_out.errors()
)
else:
+ assert dependant.call is not None, "dependant.call must me a function"
if is_coroutine:
raw_response = await dependant.call(**values)
else:
if isinstance(raw_response, Response):
return raw_response
if isinstance(raw_response, BaseModel):
- return response_wrapper(
- content=jsonable_encoder(raw_response), status_code=response_code
+ return content_type(
+ content=jsonable_encoder(raw_response), status_code=status_code
)
errors = []
try:
- return response_wrapper(
+ return content_type(
content=serialize_response(
field=response_field, response=raw_response
),
- status_code=response_code,
+ status_code=status_code,
)
except Exception as e:
errors.append(e)
try:
response = dict(raw_response)
- return response_wrapper(
+ return content_type(
content=serialize_response(field=response_field, response=response),
- status_code=response_code,
+ status_code=status_code,
)
except Exception as e:
errors.append(e)
try:
response = vars(raw_response)
- return response_wrapper(
+ return content_type(
content=serialize_response(field=response_field, response=response),
- status_code=response_code,
+ status_code=status_code,
)
except Exception as e:
errors.append(e)
path: str,
endpoint: Callable,
*,
- methods: List[str] = None,
- name: str = None,
- include_in_schema: bool = True,
+ response_model: Type[BaseModel] = None,
+ status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
- operation_id: str = None,
- deprecated: bool = None,
- response_type: Type = None,
response_description: str = "Successful Response",
- response_code=200,
- response_wrapper=JSONResponse,
+ deprecated: bool = None,
+ name: str = None,
+ methods: List[str] = None,
+ operation_id: str = None,
+ include_in_schema: bool = True,
+ content_type: Type[Response] = JSONResponse,
) -> None:
assert path.startswith("/"), "Routed paths must always start with '/'"
self.path = path
self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name
- self.include_in_schema = include_in_schema
- self.tags = tags or []
- self.summary = summary
- self.description = description or self.endpoint.__doc__
- self.operation_id = operation_id
- self.deprecated = deprecated
- self.body_field: Field = None
- self.response_description = response_description
- self.response_code = response_code
- self.response_wrapper = response_wrapper
- self.response_field = None
- if response_type:
+ self.response_model = response_model
+ if self.response_model:
assert lenient_issubclass(
- response_wrapper, JSONResponse
+ content_type, JSONResponse
), "To declare a type the response must be a JSON response"
- self.response_type = response_type
response_name = "Response_" + self.name
- self.response_field = Field(
+ self.response_field: Optional[Field] = Field(
name=response_name,
- type_=self.response_type,
+ type_=self.response_model,
class_validators=[],
default=None,
required=False,
schema=Schema(None),
)
else:
- self.response_type = None
+ self.response_field = None
+ self.status_code = status_code
+ self.tags = tags or []
+ self.summary = summary
+ self.description = description or self.endpoint.__doc__
+ self.response_description = response_description
+ self.deprecated = deprecated
if methods is None:
methods = ["GET"]
self.methods = methods
+ self.operation_id = operation_id
+ self.include_in_schema = include_in_schema
+ self.content_type = content_type
+
self.path_regex, self.path_format, self.param_convertors = self.compile_path(
path
)
assert inspect.isfunction(endpoint) or inspect.ismethod(
endpoint
), f"An endpoint must be a function or method"
-
self.dependant = get_dependant(path=path, call=self.endpoint)
self.body_field = get_body_field(dependant=self.dependant, name=self.name)
self.app = request_response(
get_app(
dependant=self.dependant,
body_field=self.body_field,
- response_code=self.response_code,
- response_wrapper=self.response_wrapper,
+ status_code=self.status_code,
+ content_type=self.content_type,
response_field=self.response_field,
)
)
self,
path: str,
endpoint: Callable,
- methods: List[str] = None,
- name: str = None,
- include_in_schema: bool = True,
+ *,
+ response_model: Type[BaseModel] = None,
+ status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
- operation_id: str = None,
- deprecated: bool = None,
- response_type: Type = None,
response_description: str = "Successful Response",
- response_code=200,
- response_wrapper=JSONResponse,
+ deprecated: bool = None,
+ name: str = None,
+ methods: List[str] = None,
+ operation_id: str = None,
+ include_in_schema: bool = True,
+ content_type: Type[Response] = JSONResponse,
) -> None:
route = APIRoute(
path,
endpoint=endpoint,
- methods=methods,
- name=name,
- include_in_schema=include_in_schema,
+ response_model=response_model,
+ status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
- operation_id=operation_id,
- deprecated=deprecated,
- response_type=response_type,
response_description=response_description,
- response_code=response_code,
- response_wrapper=response_wrapper,
+ deprecated=deprecated,
+ name=name,
+ methods=methods,
+ operation_id=operation_id,
+ include_in_schema=include_in_schema,
+ content_type=content_type,
)
self.routes.append(route)
def api_route(
self,
path: str,
- methods: List[str] = None,
- name: str = None,
- include_in_schema: bool = True,
+ *,
+ response_model: Type[BaseModel] = None,
+ status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
- operation_id: str = None,
- deprecated: bool = None,
- response_type: Type = None,
response_description: str = "Successful Response",
- response_code=200,
- response_wrapper=JSONResponse,
+ deprecated: bool = None,
+ name: str = None,
+ methods: List[str] = None,
+ operation_id: str = None,
+ include_in_schema: bool = True,
+ content_type: Type[Response] = JSONResponse,
) -> Callable:
def decorator(func: Callable) -> Callable:
self.add_api_route(
path,
func,
- methods=methods,
- name=name,
- include_in_schema=include_in_schema,
+ response_model=response_model,
+ status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
- operation_id=operation_id,
- deprecated=deprecated,
- response_type=response_type,
response_description=response_description,
- response_code=response_code,
- response_wrapper=response_wrapper,
+ deprecated=deprecated,
+ name=name,
+ methods=methods,
+ operation_id=operation_id,
+ include_in_schema=include_in_schema,
+ content_type=content_type,
)
return func
return decorator
- def include_router(self, router: "APIRouter", *, prefix=""):
+ def include_router(self, router: "APIRouter", *, prefix: str = "") -> None:
if prefix:
assert prefix.startswith("/"), "A path prefix must start with '/'"
assert not prefix.endswith(
self.add_api_route(
prefix + route.path,
route.endpoint,
- methods=route.methods,
- name=route.name,
- include_in_schema=route.include_in_schema,
- tags=route.tags,
+ response_model=route.response_model,
+ status_code=route.status_code,
+ tags=route.tags or [],
summary=route.summary,
description=route.description,
- operation_id=route.operation_id,
- deprecated=route.deprecated,
- response_type=route.response_type,
response_description=route.response_description,
- response_code=route.response_code,
- response_wrapper=route.response_wrapper,
+ deprecated=route.deprecated,
+ name=route.name,
+ methods=route.methods,
+ operation_id=route.operation_id,
+ include_in_schema=route.include_in_schema,
+ content_type=route.content_type,
)
elif isinstance(route, routing.Route):
self.add_route(
def get(
self,
path: str,
- name: str = None,
- include_in_schema: bool = True,
+ *,
+ response_model: Type[BaseModel] = None,
+ status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
- operation_id: str = None,
- deprecated: bool = None,
- response_type: Type = None,
response_description: str = "Successful Response",
- response_code=200,
- response_wrapper=JSONResponse,
- ):
+ deprecated: bool = None,
+ name: str = None,
+ operation_id: str = None,
+ include_in_schema: bool = True,
+ content_type: Type[Response] = JSONResponse,
+ ) -> Callable:
return self.api_route(
path=path,
- methods=["GET"],
- name=name,
- include_in_schema=include_in_schema,
+ response_model=response_model,
+ status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
- operation_id=operation_id,
- deprecated=deprecated,
- response_type=response_type,
response_description=response_description,
- response_code=response_code,
- response_wrapper=response_wrapper,
+ deprecated=deprecated,
+ name=name,
+ methods=["GET"],
+ operation_id=operation_id,
+ include_in_schema=include_in_schema,
+ content_type=content_type,
)
def put(
self,
path: str,
- name: str = None,
- include_in_schema: bool = True,
+ *,
+ response_model: Type[BaseModel] = None,
+ status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
- operation_id: str = None,
- deprecated: bool = None,
- response_type: Type = None,
response_description: str = "Successful Response",
- response_code=200,
- response_wrapper=JSONResponse,
- ):
+ deprecated: bool = None,
+ name: str = None,
+ operation_id: str = None,
+ include_in_schema: bool = True,
+ content_type: Type[Response] = JSONResponse,
+ ) -> Callable:
return self.api_route(
path=path,
- methods=["PUT"],
- name=name,
- include_in_schema=include_in_schema,
+ response_model=response_model,
+ status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
- operation_id=operation_id,
- deprecated=deprecated,
- response_type=response_type,
response_description=response_description,
- response_code=response_code,
- response_wrapper=response_wrapper,
+ deprecated=deprecated,
+ name=name,
+ methods=["PUT"],
+ operation_id=operation_id,
+ include_in_schema=include_in_schema,
+ content_type=content_type,
)
def post(
self,
path: str,
- name: str = None,
- include_in_schema: bool = True,
+ *,
+ response_model: Type[BaseModel] = None,
+ status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
- operation_id: str = None,
- deprecated: bool = None,
- response_type: Type = None,
response_description: str = "Successful Response",
- response_code=200,
- response_wrapper=JSONResponse,
- ):
+ deprecated: bool = None,
+ name: str = None,
+ operation_id: str = None,
+ include_in_schema: bool = True,
+ content_type: Type[Response] = JSONResponse,
+ ) -> Callable:
return self.api_route(
path=path,
- methods=["POST"],
- name=name,
- include_in_schema=include_in_schema,
+ response_model=response_model,
+ status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
- operation_id=operation_id,
- deprecated=deprecated,
- response_type=response_type,
response_description=response_description,
- response_code=response_code,
- response_wrapper=response_wrapper,
+ deprecated=deprecated,
+ name=name,
+ methods=["POST"],
+ operation_id=operation_id,
+ include_in_schema=include_in_schema,
+ content_type=content_type,
)
def delete(
self,
path: str,
- name: str = None,
- include_in_schema: bool = True,
+ *,
+ response_model: Type[BaseModel] = None,
+ status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
- operation_id: str = None,
- deprecated: bool = None,
- response_type: Type = None,
response_description: str = "Successful Response",
- response_code=200,
- response_wrapper=JSONResponse,
- ):
+ deprecated: bool = None,
+ name: str = None,
+ operation_id: str = None,
+ include_in_schema: bool = True,
+ content_type: Type[Response] = JSONResponse,
+ ) -> Callable:
return self.api_route(
path=path,
- methods=["DELETE"],
- name=name,
- include_in_schema=include_in_schema,
+ response_model=response_model,
+ status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
- operation_id=operation_id,
- deprecated=deprecated,
- response_type=response_type,
response_description=response_description,
- response_code=response_code,
- response_wrapper=response_wrapper,
+ deprecated=deprecated,
+ name=name,
+ methods=["DELETE"],
+ operation_id=operation_id,
+ include_in_schema=include_in_schema,
+ content_type=content_type,
)
def options(
self,
path: str,
- name: str = None,
- include_in_schema: bool = True,
+ *,
+ response_model: Type[BaseModel] = None,
+ status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
- operation_id: str = None,
- deprecated: bool = None,
- response_type: Type = None,
response_description: str = "Successful Response",
- response_code=200,
- response_wrapper=JSONResponse,
- ):
+ deprecated: bool = None,
+ name: str = None,
+ operation_id: str = None,
+ include_in_schema: bool = True,
+ content_type: Type[Response] = JSONResponse,
+ ) -> Callable:
return self.api_route(
path=path,
- methods=["OPTIONS"],
- name=name,
- include_in_schema=include_in_schema,
+ response_model=response_model,
+ status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
- operation_id=operation_id,
- deprecated=deprecated,
- response_type=response_type,
response_description=response_description,
- response_code=response_code,
- response_wrapper=response_wrapper,
+ deprecated=deprecated,
+ name=name,
+ methods=["OPTIONS"],
+ operation_id=operation_id,
+ include_in_schema=include_in_schema,
+ content_type=content_type,
)
def head(
self,
path: str,
- name: str = None,
- include_in_schema: bool = True,
+ *,
+ response_model: Type[BaseModel] = None,
+ status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
- operation_id: str = None,
- deprecated: bool = None,
- response_type: Type = None,
response_description: str = "Successful Response",
- response_code=200,
- response_wrapper=JSONResponse,
- ):
+ deprecated: bool = None,
+ name: str = None,
+ operation_id: str = None,
+ include_in_schema: bool = True,
+ content_type: Type[Response] = JSONResponse,
+ ) -> Callable:
return self.api_route(
path=path,
- methods=["HEAD"],
- name=name,
- include_in_schema=include_in_schema,
+ response_model=response_model,
+ status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
- operation_id=operation_id,
- deprecated=deprecated,
- response_type=response_type,
response_description=response_description,
- response_code=response_code,
- response_wrapper=response_wrapper,
+ deprecated=deprecated,
+ name=name,
+ methods=["HEAD"],
+ operation_id=operation_id,
+ include_in_schema=include_in_schema,
+ content_type=content_type,
)
def patch(
self,
path: str,
- name: str = None,
- include_in_schema: bool = True,
+ *,
+ response_model: Type[BaseModel] = None,
+ status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
- operation_id: str = None,
- deprecated: bool = None,
- response_type: Type = None,
response_description: str = "Successful Response",
- response_code=200,
- response_wrapper=JSONResponse,
- ):
+ deprecated: bool = None,
+ name: str = None,
+ operation_id: str = None,
+ include_in_schema: bool = True,
+ content_type: Type[Response] = JSONResponse,
+ ) -> Callable:
return self.api_route(
path=path,
- methods=["PATCH"],
- name=name,
- include_in_schema=include_in_schema,
+ response_model=response_model,
+ status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
- operation_id=operation_id,
- deprecated=deprecated,
- response_type=response_type,
response_description=response_description,
- response_code=response_code,
- response_wrapper=response_wrapper,
+ deprecated=deprecated,
+ name=name,
+ methods=["PATCH"],
+ operation_id=operation_id,
+ include_in_schema=include_in_schema,
+ content_type=content_type,
)
def trace(
self,
path: str,
- name: str = None,
- include_in_schema: bool = True,
+ *,
+ response_model: Type[BaseModel] = None,
+ status_code: int = 200,
tags: List[str] = None,
summary: str = None,
description: str = None,
- operation_id: str = None,
- deprecated: bool = None,
- response_type: Type = None,
response_description: str = "Successful Response",
- response_code=200,
- response_wrapper=JSONResponse,
- ):
+ deprecated: bool = None,
+ name: str = None,
+ operation_id: str = None,
+ include_in_schema: bool = True,
+ content_type: Type[Response] = JSONResponse,
+ ) -> Callable:
return self.api_route(
path=path,
- methods=["TRACE"],
- name=name,
- include_in_schema=include_in_schema,
+ response_model=response_model,
+ status_code=status_code,
tags=tags or [],
summary=summary,
description=description,
- operation_id=operation_id,
- deprecated=deprecated,
- response_type=response_type,
response_description=response_description,
- response_code=response_code,
- response_wrapper=response_wrapper,
+ deprecated=deprecated,
+ name=name,
+ methods=["TRACE"],
+ operation_id=operation_id,
+ include_in_schema=include_in_schema,
+ content_type=content_type,
)
from starlette.requests import Request
-from .base import SecurityBase
-from fastapi.openapi.models import APIKeyIn, APIKey
+from fastapi.openapi.models import APIKey, APIKeyIn
+from fastapi.security.base import SecurityBase
+
class APIKeyBase(SecurityBase):
pass
-class APIKeyQuery(APIKeyBase):
+class APIKeyQuery(APIKeyBase):
def __init__(self, *, name: str, scheme_name: str = None):
self.model = APIKey(in_=APIKeyIn.query, name=name)
self.scheme_name = scheme_name or self.__class__.__name__
- async def __call__(self, requests: Request):
+ async def __call__(self, requests: Request) -> str:
return requests.query_params.get(self.model.name)
self.model = APIKey(in_=APIKeyIn.header, name=name)
self.scheme_name = scheme_name or self.__class__.__name__
- async def __call__(self, requests: Request):
+ async def __call__(self, requests: Request) -> str:
return requests.headers.get(self.model.name)
self.model = APIKey(in_=APIKeyIn.cookie, name=name)
self.scheme_name = scheme_name or self.__class__.__name__
- async def __call__(self, requests: Request):
+ async def __call__(self, requests: Request) -> str:
return requests.cookies.get(self.model.name)
-from pydantic import BaseModel
-
from fastapi.openapi.models import SecurityBase as SecurityBaseModel
+
class SecurityBase:
- pass
+ model: SecurityBaseModel
+ scheme_name: str
from starlette.requests import Request
-from .base import SecurityBase
-from fastapi.openapi.models import HTTPBase as HTTPBaseModel, HTTPBearer as HTTPBearerModel
+from fastapi.openapi.models import (
+ HTTPBase as HTTPBaseModel,
+ HTTPBearer as HTTPBearerModel,
+)
+from fastapi.security.base import SecurityBase
class HTTPBase(SecurityBase):
self.model = HTTPBaseModel(scheme=scheme)
self.scheme_name = scheme_name or self.__class__.__name__
- async def __call__(self, request: Request):
+ async def __call__(self, request: Request) -> str:
return request.headers.get("Authorization")
def __init__(self, *, scheme_name: str = None):
self.model = HTTPBaseModel(scheme="basic")
self.scheme_name = scheme_name or self.__class__.__name__
-
- async def __call__(self, request: Request):
+
+ async def __call__(self, request: Request) -> str:
return request.headers.get("Authorization")
def __init__(self, *, bearerFormat: str = None, scheme_name: str = None):
self.model = HTTPBearerModel(bearerFormat=bearerFormat)
self.scheme_name = scheme_name or self.__class__.__name__
-
- async def __call__(self, request: Request):
+
+ async def __call__(self, request: Request) -> str:
return request.headers.get("Authorization")
def __init__(self, *, scheme_name: str = None):
self.model = HTTPBaseModel(scheme="digest")
self.scheme_name = scheme_name or self.__class__.__name__
-
- async def __call__(self, request: Request):
+
+ async def __call__(self, request: Request) -> str:
return request.headers.get("Authorization")
from starlette.requests import Request
-from .base import SecurityBase
from fastapi.openapi.models import OAuth2 as OAuth2Model, OAuthFlows as OAuthFlowsModel
+from fastapi.security.base import SecurityBase
class OAuth2(SecurityBase):
- def __init__(self, *, flows: OAuthFlowsModel = OAuthFlowsModel(), scheme_name: str = None):
+ def __init__(
+ self, *, flows: OAuthFlowsModel = OAuthFlowsModel(), scheme_name: str = None
+ ):
self.model = OAuth2Model(flows=flows)
self.scheme_name = scheme_name or self.__class__.__name__
-
- async def __call__(self, request: Request):
+
+ async def __call__(self, request: Request) -> str:
return request.headers.get("Authorization")
from starlette.requests import Request
-from .base import SecurityBase
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
+from fastapi.security.base import SecurityBase
class OpenIdConnect(SecurityBase):
def __init__(self, *, openIdConnectUrl: str, scheme_name: str = None):
self.model = OpenIdConnectModel(openIdConnectUrl=openIdConnectUrl)
self.scheme_name = scheme_name or self.__class__.__name__
-
- async def __call__(self, request: Request):
+
+ async def __call__(self, request: Request) -> str:
return request.headers.get("Authorization")
import re
-from typing import Dict, Sequence, Set, Type
+from typing import Any, Dict, List, Sequence, Set, Type
+from pydantic import BaseModel
+from pydantic.fields import Field
+from pydantic.schema import get_flat_models_from_fields, model_process_schema
from starlette.routing import BaseRoute
from fastapi import routing
from fastapi.openapi.constants import REF_PREFIX
-from pydantic import BaseModel
-from pydantic.fields import Field
-from pydantic.schema import get_flat_models_from_fields, model_process_schema
-def get_flat_models_from_routes(routes: Sequence[BaseRoute]):
- body_fields_from_routes = []
- responses_from_routes = []
+def get_flat_models_from_routes(
+ routes: Sequence[Type[BaseRoute]]
+) -> Set[Type[BaseModel]]:
+ body_fields_from_routes: List[Field] = []
+ responses_from_routes: List[Field] = []
for route in routes:
- if route.include_in_schema and isinstance(route, routing.APIRoute):
+ if getattr(route, "include_in_schema", None) and isinstance(
+ route, routing.APIRoute
+ ):
if route.body_field:
assert isinstance(
route.body_field, Field
def get_model_definitions(
*, flat_models: Set[Type[BaseModel]], model_name_map: Dict[Type[BaseModel], str]
-):
+) -> Dict[str, Any]:
definitions: Dict[str, Dict] = {}
for model in flat_models:
m_schema, m_definitions = model_process_schema(
return definitions
-def get_path_param_names(path: str):
+def get_path_param_names(path: str) -> Set[str]:
return {item.strip("{}") for item in re.findall("{[^}]*}", path)}