]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:sparkles: Update parameter names and order
authorSebastián Ramírez <tiangolo@gmail.com>
Fri, 7 Dec 2018 15:12:16 +0000 (19:12 +0400)
committerSebastián Ramírez <tiangolo@gmail.com>
Fri, 7 Dec 2018 15:12:16 +0000 (19:12 +0400)
fix mypy types, refactor, lint

15 files changed:
fastapi/applications.py
fastapi/dependencies/models.py
fastapi/dependencies/utils.py
fastapi/encoders.py
fastapi/openapi/docs.py
fastapi/openapi/models.py
fastapi/openapi/utils.py
fastapi/params.py
fastapi/routing.py
fastapi/security/api_key.py
fastapi/security/base.py
fastapi/security/http.py
fastapi/security/oauth2.py
fastapi/security/open_id_connect_url.py
fastapi/utils.py

index bb21076dfd91b03bf35f65ed78afbe3e302d7f18..f1d40522186c5fa87f5ad8cbbd70ce98230780c8 100644 (file)
@@ -1,19 +1,19 @@
-from typing import Any, Callable, Dict, List, Type
+from typing import Any, Callable, Dict, List, Optional, Type
 
+from pydantic import BaseModel
 from starlette.applications import Starlette
 from starlette.exceptions import ExceptionMiddleware, HTTPException
 from starlette.middleware.errors import ServerErrorMiddleware
 from starlette.middleware.lifespan import LifespanMiddleware
-from starlette.responses import JSONResponse
-
+from starlette.requests import Request
+from starlette.responses import JSONResponse, Response
 
 from fastapi import routing
-from fastapi.openapi.utils import get_openapi
 from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
+from fastapi.openapi.utils import get_openapi
 
 
-async def http_exception(request, exc: HTTPException):
-    print(exc)
+async def http_exception(request: Request, exc: HTTPException) -> JSONResponse:
     return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
 
 
@@ -31,7 +31,7 @@ class FastAPI(Starlette):
         **extra: Dict[str, Any],
     ) -> None:
         self._debug = debug
-        self.router = routing.APIRouter()
+        self.router: routing.APIRouter = routing.APIRouter()
         self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
         self.error_middleware = ServerErrorMiddleware(
             self.exception_middleware, debug=debug
@@ -56,33 +56,41 @@ class FastAPI(Starlette):
 
         if self.swagger_ui_url or self.redoc_url:
             assert self.openapi_url, "The openapi_url is required for the docs"
+        self.openapi_schema: Optional[Dict[str, Any]] = None
         self.setup()
 
-    def setup(self):
+    def openapi(self) -> Dict:
+        if not self.openapi_schema:
+            self.openapi_schema = get_openapi(
+                title=self.title,
+                version=self.version,
+                openapi_version=self.openapi_version,
+                description=self.description,
+                routes=self.routes,
+            )
+        return self.openapi_schema
+
+    def setup(self) -> None:
         if self.openapi_url:
             self.add_route(
                 self.openapi_url,
-                lambda req: JSONResponse(
-                    get_openapi(
-                        title=self.title,
-                        version=self.version,
-                        openapi_version=self.openapi_version,
-                        description=self.description,
-                        routes=self.routes,
-                    )
-                ),
+                lambda req: JSONResponse(self.openapi()),
                 include_in_schema=False,
             )
         if self.swagger_ui_url:
             self.add_route(
                 self.swagger_ui_url,
-                lambda r: get_swagger_ui_html(openapi_url=self.openapi_url, title=self.title + " - Swagger UI"),
+                lambda r: get_swagger_ui_html(
+                    openapi_url=self.openapi_url, title=self.title + " - Swagger UI"
+                ),
                 include_in_schema=False,
             )
         if self.redoc_url:
             self.add_route(
                 self.redoc_url,
-                lambda r: get_redoc_html(openapi_url=self.openapi_url, title=self.title + " - ReDoc"),
+                lambda r: get_redoc_html(
+                    openapi_url=self.openapi_url, title=self.title + " - ReDoc"
+                ),
                 include_in_schema=False,
             )
         self.add_exception_handler(HTTPException, http_exception)
@@ -91,311 +99,322 @@ class FastAPI(Starlette):
         self,
         path: str,
         endpoint: Callable,
-        methods: List[str] = None,
-        name: str = None,
-        include_in_schema: bool = True,
+        *,
+        response_model: Type[BaseModel] = None,
+        status_code: int = 200,
         tags: List[str] = None,
         summary: str = None,
         description: str = None,
-        operation_id: str = None,
-        deprecated: bool = None,
-        response_type: Type = None,
         response_description: str = "Successful Response",
-        response_code=200,
-        response_wrapper=JSONResponse,
+        deprecated: bool = None,
+        name: str = None,
+        methods: List[str] = None,
+        operation_id: str = None,
+        include_in_schema: bool = True,
+        content_type: Type[Response] = JSONResponse,
     ) -> None:
         self.router.add_api_route(
             path,
             endpoint=endpoint,
-            methods=methods,
-            name=name,
-            include_in_schema=include_in_schema,
+            response_model=response_model,
+            status_code=status_code,
             tags=tags or [],
             summary=summary,
             description=description,
-            operation_id=operation_id,
-            deprecated=deprecated,
-            response_type=response_type,
             response_description=response_description,
-            response_code=response_code,
-            response_wrapper=response_wrapper,
+            deprecated=deprecated,
+            name=name,
+            methods=methods,
+            operation_id=operation_id,
+            include_in_schema=include_in_schema,
+            content_type=content_type,
         )
 
     def api_route(
         self,
         path: str,
-        methods: List[str] = None,
-        name: str = None,
-        include_in_schema: bool = True,
+        *,
+        response_model: Type[BaseModel] = None,
+        status_code: int = 200,
         tags: List[str] = None,
         summary: str = None,
         description: str = None,
-        operation_id: str = None,
-        deprecated: bool = None,
-        response_type: Type = None,
         response_description: str = "Successful Response",
-        response_code=200,
-        response_wrapper=JSONResponse,
+        deprecated: bool = None,
+        name: str = None,
+        methods: List[str] = None,
+        operation_id: str = None,
+        include_in_schema: bool = True,
+        content_type: Type[Response] = JSONResponse,
     ) -> Callable:
         def decorator(func: Callable) -> Callable:
             self.router.add_api_route(
                 path,
                 func,
-                methods=methods,
-                name=name,
-                include_in_schema=include_in_schema,
+                response_model=response_model,
+                status_code=status_code,
                 tags=tags or [],
                 summary=summary,
                 description=description,
-                operation_id=operation_id,
-                deprecated=deprecated,
-                response_type=response_type,
                 response_description=response_description,
-                response_code=response_code,
-                response_wrapper=response_wrapper,
+                deprecated=deprecated,
+                name=name,
+                methods=methods,
+                operation_id=operation_id,
+                include_in_schema=include_in_schema,
+                content_type=content_type,
             )
             return func
+
         return decorator
-    
-    def include_router(self, router: "APIRouter", *, prefix=""):
+
+    def include_router(self, router: routing.APIRouter, *, prefix: str = "") -> None:
         self.router.include_router(router, prefix=prefix)
 
     def get(
         self,
         path: str,
-        name: str = None,
-        include_in_schema: bool = True,
+        *,
+        response_model: Type[BaseModel] = None,
+        status_code: int = 200,
         tags: List[str] = None,
         summary: str = None,
         description: str = None,
-        operation_id: str = None,
-        deprecated: bool = None,
-        response_type: Type = None,
         response_description: str = "Successful Response",
-        response_code=200,
-        response_wrapper=JSONResponse,
-    ):
+        deprecated: bool = None,
+        name: str = None,
+        operation_id: str = None,
+        include_in_schema: bool = True,
+        content_type: Type[Response] = JSONResponse,
+    ) -> Callable:
         return self.router.get(
-            path=path,
-            name=name,
-            include_in_schema=include_in_schema,
+            path,
+            response_model=response_model,
+            status_code=status_code,
             tags=tags or [],
             summary=summary,
             description=description,
-            operation_id=operation_id,
-            deprecated=deprecated,
-            response_type=response_type,
             response_description=response_description,
-            response_code=response_code,
-            response_wrapper=response_wrapper,
+            deprecated=deprecated,
+            name=name,
+            operation_id=operation_id,
+            include_in_schema=include_in_schema,
+            content_type=content_type,
         )
 
     def put(
         self,
         path: str,
-        name: str = None,
-        include_in_schema: bool = True,
+        *,
+        response_model: Type[BaseModel] = None,
+        status_code: int = 200,
         tags: List[str] = None,
         summary: str = None,
         description: str = None,
-        operation_id: str = None,
-        deprecated: bool = None,
-        response_type: Type = None,
         response_description: str = "Successful Response",
-        response_code=200,
-        response_wrapper=JSONResponse,
-    ):
+        deprecated: bool = None,
+        name: str = None,
+        operation_id: str = None,
+        include_in_schema: bool = True,
+        content_type: Type[Response] = JSONResponse,
+    ) -> Callable:
         return self.router.put(
-            path=path,
-            name=name,
-            include_in_schema=include_in_schema,
+            path,
+            response_model=response_model,
+            status_code=status_code,
             tags=tags or [],
             summary=summary,
             description=description,
-            operation_id=operation_id,
-            deprecated=deprecated,
-            response_type=response_type,
             response_description=response_description,
-            response_code=response_code,
-            response_wrapper=response_wrapper,
+            deprecated=deprecated,
+            name=name,
+            operation_id=operation_id,
+            include_in_schema=include_in_schema,
+            content_type=content_type,
         )
 
     def post(
         self,
         path: str,
-        name: str = None,
-        include_in_schema: bool = True,
+        *,
+        response_model: Type[BaseModel] = None,
+        status_code: int = 200,
         tags: List[str] = None,
         summary: str = None,
         description: str = None,
-        operation_id: str = None,
-        deprecated: bool = None,
-        response_type: Type = None,
         response_description: str = "Successful Response",
-        response_code=200,
-        response_wrapper=JSONResponse,
-    ):
+        deprecated: bool = None,
+        name: str = None,
+        operation_id: str = None,
+        include_in_schema: bool = True,
+        content_type: Type[Response] = JSONResponse,
+    ) -> Callable:
         return self.router.post(
-            path=path,
-            name=name,
-            include_in_schema=include_in_schema,
+            path,
+            response_model=response_model,
+            status_code=status_code,
             tags=tags or [],
             summary=summary,
             description=description,
-            operation_id=operation_id,
-            deprecated=deprecated,
-            response_type=response_type,
             response_description=response_description,
-            response_code=response_code,
-            response_wrapper=response_wrapper,
+            deprecated=deprecated,
+            name=name,
+            operation_id=operation_id,
+            include_in_schema=include_in_schema,
+            content_type=content_type,
         )
 
     def delete(
         self,
         path: str,
-        name: str = None,
-        include_in_schema: bool = True,
+        *,
+        response_model: Type[BaseModel] = None,
+        status_code: int = 200,
         tags: List[str] = None,
         summary: str = None,
         description: str = None,
-        operation_id: str = None,
-        deprecated: bool = None,
-        response_type: Type = None,
         response_description: str = "Successful Response",
-        response_code=200,
-        response_wrapper=JSONResponse,
-    ):
+        deprecated: bool = None,
+        name: str = None,
+        operation_id: str = None,
+        include_in_schema: bool = True,
+        content_type: Type[Response] = JSONResponse,
+    ) -> Callable:
         return self.router.delete(
-            path=path,
-            name=name,
-            include_in_schema=include_in_schema,
+            path,
+            response_model=response_model,
+            status_code=status_code,
             tags=tags or [],
             summary=summary,
             description=description,
-            operation_id=operation_id,
-            deprecated=deprecated,
-            response_type=response_type,
             response_description=response_description,
-            response_code=response_code,
-            response_wrapper=response_wrapper,
+            deprecated=deprecated,
+            name=name,
+            operation_id=operation_id,
+            include_in_schema=include_in_schema,
+            content_type=content_type,
         )
 
     def options(
         self,
         path: str,
-        name: str = None,
-        include_in_schema: bool = True,
+        *,
+        response_model: Type[BaseModel] = None,
+        status_code: int = 200,
         tags: List[str] = None,
         summary: str = None,
         description: str = None,
-        operation_id: str = None,
-        deprecated: bool = None,
-        response_type: Type = None,
         response_description: str = "Successful Response",
-        response_code=200,
-        response_wrapper=JSONResponse,
-    ):
+        deprecated: bool = None,
+        name: str = None,
+        operation_id: str = None,
+        include_in_schema: bool = True,
+        content_type: Type[Response] = JSONResponse,
+    ) -> Callable:
         return self.router.options(
-            path=path,
-            name=name,
-            include_in_schema=include_in_schema,
+            path,
+            response_model=response_model,
+            status_code=status_code,
             tags=tags or [],
             summary=summary,
             description=description,
-            operation_id=operation_id,
-            deprecated=deprecated,
-            response_type=response_type,
             response_description=response_description,
-            response_code=response_code,
-            response_wrapper=response_wrapper,
+            deprecated=deprecated,
+            name=name,
+            operation_id=operation_id,
+            include_in_schema=include_in_schema,
+            content_type=content_type,
         )
 
     def head(
         self,
         path: str,
-        name: str = None,
-        include_in_schema: bool = True,
+        *,
+        response_model: Type[BaseModel] = None,
+        status_code: int = 200,
         tags: List[str] = None,
         summary: str = None,
         description: str = None,
-        operation_id: str = None,
-        deprecated: bool = None,
-        response_type: Type = None,
         response_description: str = "Successful Response",
-        response_code=200,
-        response_wrapper=JSONResponse,
-    ):
+        deprecated: bool = None,
+        name: str = None,
+        operation_id: str = None,
+        include_in_schema: bool = True,
+        content_type: Type[Response] = JSONResponse,
+    ) -> Callable:
         return self.router.head(
-            path=path,
-            name=name,
-            include_in_schema=include_in_schema,
+            path,
+            response_model=response_model,
+            status_code=status_code,
             tags=tags or [],
             summary=summary,
             description=description,
-            operation_id=operation_id,
-            deprecated=deprecated,
-            response_type=response_type,
             response_description=response_description,
-            response_code=response_code,
-            response_wrapper=response_wrapper,
+            deprecated=deprecated,
+            name=name,
+            operation_id=operation_id,
+            include_in_schema=include_in_schema,
+            content_type=content_type,
         )
 
     def patch(
         self,
         path: str,
-        name: str = None,
-        include_in_schema: bool = True,
+        *,
+        response_model: Type[BaseModel] = None,
+        status_code: int = 200,
         tags: List[str] = None,
         summary: str = None,
         description: str = None,
-        operation_id: str = None,
-        deprecated: bool = None,
-        response_type: Type = None,
         response_description: str = "Successful Response",
-        response_code=200,
-        response_wrapper=JSONResponse,
-    ):
+        deprecated: bool = None,
+        name: str = None,
+        operation_id: str = None,
+        include_in_schema: bool = True,
+        content_type: Type[Response] = JSONResponse,
+    ) -> Callable:
         return self.router.patch(
-            path=path,
-            name=name,
-            include_in_schema=include_in_schema,
+            path,
+            response_model=response_model,
+            status_code=status_code,
             tags=tags or [],
             summary=summary,
             description=description,
-            operation_id=operation_id,
-            deprecated=deprecated,
-            response_type=response_type,
             response_description=response_description,
-            response_code=response_code,
-            response_wrapper=response_wrapper,
+            deprecated=deprecated,
+            name=name,
+            operation_id=operation_id,
+            include_in_schema=include_in_schema,
+            content_type=content_type,
         )
 
     def trace(
         self,
         path: str,
-        name: str = None,
-        include_in_schema: bool = True,
+        *,
+        response_model: Type[BaseModel] = None,
+        status_code: int = 200,
         tags: List[str] = None,
         summary: str = None,
         description: str = None,
-        operation_id: str = None,
-        deprecated: bool = None,
-        response_type: Type = None,
         response_description: str = "Successful Response",
-        response_code=200,
-        response_wrapper=JSONResponse,
-    ):
+        deprecated: bool = None,
+        name: str = None,
+        operation_id: str = None,
+        include_in_schema: bool = True,
+        content_type: Type[Response] = JSONResponse,
+    ) -> Callable:
         return self.router.trace(
-            path=path,
-            name=name,
-            include_in_schema=include_in_schema,
+            path,
+            response_model=response_model,
+            status_code=status_code,
             tags=tags or [],
             summary=summary,
             description=description,
-            operation_id=operation_id,
-            deprecated=deprecated,
-            response_type=response_type,
             response_description=response_description,
-            response_code=response_code,
-            response_wrapper=response_wrapper,
+            deprecated=deprecated,
+            name=name,
+            operation_id=operation_id,
+            include_in_schema=include_in_schema,
+            content_type=content_type,
         )
index ad9419db5a3d864fc9edfca92752e0ad41dfe6ae..5857f9202aca792137583241f55d2fc51d0519ea 100644 (file)
@@ -1,14 +1,14 @@
 from typing import Any, Callable, Dict, List, Sequence, Tuple
 
-from starlette.concurrency import run_in_threadpool
-from starlette.requests import Request
-
-from fastapi.security.base import SecurityBase
 from pydantic import BaseConfig, Schema
 from pydantic.error_wrappers import ErrorWrapper
 from pydantic.errors import MissingError
 from pydantic.fields import Field, Required
 from pydantic.schema import get_annotation_from_schema
+from starlette.concurrency import run_in_threadpool
+from starlette.requests import Request
+
+from fastapi.security.base import SecurityBase
 
 param_supported_types = (str, int, float, bool)
 
index 6e86de5a5e919f84cddc012de661c81b104ed3e0..834774e1bfe5e799217f2f1f5607f3769ed3c554 100644 (file)
@@ -1,8 +1,14 @@
 import asyncio
 import inspect
 from copy import deepcopy
-from typing import Any, Callable, Dict, List, Tuple
+from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Type
 
+from pydantic import BaseConfig, Schema, create_model
+from pydantic.error_wrappers import ErrorWrapper
+from pydantic.errors import MissingError
+from pydantic.fields import Field, Required
+from pydantic.schema import get_annotation_from_schema
+from pydantic.utils import lenient_issubclass
 from starlette.concurrency import run_in_threadpool
 from starlette.requests import Request
 
@@ -10,17 +16,11 @@ from fastapi import params
 from fastapi.dependencies.models import Dependant, SecurityRequirement
 from fastapi.security.base import SecurityBase
 from fastapi.utils import get_path_param_names
-from pydantic import BaseConfig, Schema, create_model
-from pydantic.error_wrappers import ErrorWrapper
-from pydantic.errors import MissingError
-from pydantic.fields import Field, Required
-from pydantic.schema import get_annotation_from_schema
-from pydantic.utils import lenient_issubclass
 
 param_supported_types = (str, int, float, bool)
 
 
-def get_sub_dependant(*, param: inspect.Parameter, path: str):
+def get_sub_dependant(*, param: inspect.Parameter, path: str) -> Dependant:
     depends: params.Depends = param.default
     if depends.dependency:
         dependency = depends.dependency
@@ -36,7 +36,7 @@ def get_sub_dependant(*, param: inspect.Parameter, path: str):
     return sub_dependant
 
 
-def get_flat_dependant(dependant: Dependant):
+def get_flat_dependant(dependant: Dependant) -> Dependant:
     flat_dependant = Dependant(
         path_params=dependant.path_params.copy(),
         query_params=dependant.query_params.copy(),
@@ -58,7 +58,7 @@ def get_flat_dependant(dependant: Dependant):
     return flat_dependant
 
 
-def get_dependant(*, path: str, call: Callable, name: str = None):
+def get_dependant(*, path: str, call: Callable, name: str = None) -> Dependant:
     path_param_names = get_path_param_names(path)
     endpoint_signature = inspect.signature(call)
     signature_params = endpoint_signature.parameters
@@ -73,9 +73,10 @@ def get_dependant(*, path: str, call: Callable, name: str = None):
         if (
             (param.default == param.empty) or isinstance(param.default, params.Path)
         ) and (param_name in path_param_names):
-            assert lenient_issubclass(
-                param.annotation, param_supported_types
-            ) or param.annotation == param.empty, f"Path params must be of type str, int, float or boot: {param}"
+            assert (
+                lenient_issubclass(param.annotation, param_supported_types)
+                or param.annotation == param.empty
+            ), f"Path params must be of type str, int, float or boot: {param}"
             param = signature_params[param_name]
             add_param_to_fields(
                 param=param,
@@ -109,9 +110,9 @@ def add_param_to_fields(
     *,
     param: inspect.Parameter,
     dependant: Dependant,
-    default_schema=params.Param,
+    default_schema: Type[Schema] = params.Param,
     force_type: params.ParamTypes = None,
-):
+) -> None:
     default_value = Required
     if not param.default == param.empty:
         default_value = param.default
@@ -125,15 +126,19 @@ def add_param_to_fields(
     else:
         schema = default_schema(default_value)
     required = default_value == Required
-    annotation = Any
+    annotation: Type = Type[Any]
     if not param.annotation == param.empty:
         annotation = param.annotation
     annotation = get_annotation_from_schema(annotation, schema)
+    if not schema.alias and getattr(schema, "alias_underscore_to_hyphen", None):
+        alias = param.name.replace("_", "-")
+    else:
+        alias = schema.alias or param.name
     field = Field(
         name=param.name,
         type_=annotation,
         default=None if required else default_value,
-        alias=schema.alias or param.name,
+        alias=alias,
         required=required,
         model_config=BaseConfig(),
         class_validators=[],
@@ -152,7 +157,7 @@ def add_param_to_fields(
         dependant.cookie_params.append(field)
 
 
-def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant):
+def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant) -> None:
     default_value = Required
     if not param.default == param.empty:
         default_value = param.default
@@ -176,7 +181,7 @@ def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant):
     dependant.body_params.append(field)
 
 
-def is_coroutine_callable(call: Callable = None):
+def is_coroutine_callable(call: Callable = None) -> bool:
     if not call:
         return False
     if inspect.isfunction(call):
@@ -191,7 +196,7 @@ def is_coroutine_callable(call: Callable = None):
 
 async def solve_dependencies(
     *, request: Request, dependant: Dependant, body: Dict[str, Any] = None
-):
+) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
     values: Dict[str, Any] = {}
     errors: List[ErrorWrapper] = []
     for sub_dependant in dependant.dependencies:
@@ -200,13 +205,13 @@ async def solve_dependencies(
         )
         if sub_errors:
             return {}, errors
-        if sub_dependant.call and is_coroutine_callable(sub_dependant.call):
+        assert sub_dependant.call is not None, "sub_dependant.call must be a function"
+        if is_coroutine_callable(sub_dependant.call):
             solved = await sub_dependant.call(**sub_values)
         else:
             solved = await run_in_threadpool(sub_dependant.call, **sub_values)
-        values[
-            sub_dependant.name
-        ] = solved  # type: ignore # Sub-dependants always have a name
+        assert sub_dependant.name is not None, "Subdependants always have a name"
+        values[sub_dependant.name] = solved
     path_values, path_errors = request_params_to_args(
         dependant.path_params, request.path_params
     )
@@ -236,7 +241,7 @@ async def solve_dependencies(
 
 
 def request_params_to_args(
-    required_params: List[Field], received_params: Dict[str, Any]
+    required_params: Sequence[Field], received_params: Mapping[str, Any]
 ) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
     values = {}
     errors = []
@@ -250,9 +255,9 @@ def request_params_to_args(
             else:
                 values[field.name] = deepcopy(field.default)
             continue
-        v_, errors_ = field.validate(
-            value, values, loc=(field.schema.in_.value, field.alias)
-        )
+        schema: params.Param = field.schema
+        assert isinstance(schema, params.Param), "Params must be subclasses of Param"
+        v_, errors_ = field.validate(value, values, loc=(schema.in_.value, field.alias))
         if isinstance(errors_, ErrorWrapper):
             errors.append(errors_)
         elif isinstance(errors_, list):
@@ -294,7 +299,7 @@ async def request_body_to_args(
     return values, errors
 
 
-def get_body_field(*, dependant: Dependant, name: str):
+def get_body_field(*, dependant: Dependant, name: str) -> Field:
     flat_dependant = get_flat_dependant(dependant)
     if not flat_dependant.body_params:
         return None
@@ -308,7 +313,7 @@ def get_body_field(*, dependant: Dependant, name: str):
         BodyModel.__fields__[f.name] = f
     required = any(True for f in flat_dependant.body_params if f.required)
     if any(isinstance(f.schema, params.File) for f in flat_dependant.body_params):
-        BodySchema = params.File
+        BodySchema: Type[params.Body] = params.File
     elif any(isinstance(f.schema, params.Form) for f in flat_dependant.body_params):
         BodySchema = params.Form
     else:
index 95ce4479e74b20579d44df680ec2a679737cafbb..3234f8927306febe96463313cdf437e2679d58e1 100644 (file)
@@ -1,18 +1,18 @@
 from enum import Enum
 from types import GeneratorType
-from typing import Set
+from typing import Any, Set
 
 from pydantic import BaseModel
 from pydantic.json import pydantic_encoder
 
 
 def jsonable_encoder(
-    obj,
+    obj: Any,
     include: Set[str] = None,
     exclude: Set[str] = set(),
     by_alias: bool = False,
-    include_none=True,
-):
+    include_none: bool = True,
+) -> Any:
     if isinstance(obj, BaseModel):
         return jsonable_encoder(
             obj.dict(include=include, exclude=exclude, by_alias=by_alias),
index c8a1d6178098d10fff718a31f59bd0386fb2072b..955a99f00229df48087cce905cce337a5352d00c 100644 (file)
@@ -1,6 +1,7 @@
 from starlette.responses import HTMLResponse
 
-def get_swagger_ui_html(*, openapi_url: str, title: str):
+
+def get_swagger_ui_html(*, openapi_url: str, title: str) -> HTMLResponse:
     return HTMLResponse(
         """
     <! doctype html>
@@ -35,12 +36,11 @@ def get_swagger_ui_html(*, openapi_url: str, title: str):
     </script>
     </body>
     </html>
-    """,
-        media_type="text/html",
+    """
     )
 
 
-def get_redoc_html(*, openapi_url: str, title: str):
+def get_redoc_html(*, openapi_url: str, title: str) -> HTMLResponse:
     return HTMLResponse(
         """
     <!DOCTYPE html>
@@ -73,6 +73,5 @@ def get_redoc_html(*, openapi_url: str, title: str):
     <script src="https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js"> </script>
   </body>
 </html>
-    """,
-        media_type="text/html",
+    """
     )
index e3d96bd7f4c73b2157c94fb1ed789635a9e8e980..87eed07be2110c06cb797a393015c26cf4d8050c 100644 (file)
@@ -7,13 +7,13 @@ from pydantic.types import UrlStr
 
 try:
     import pydantic.types.EmailStr
-    from pydantic.types import EmailStr
+    from pydantic.types import EmailStr  # type: ignore
 except ImportError:
     logging.warning(
         "email-validator not installed, email fields will be treated as str"
     )
 
-    class EmailStr(str):
+    class EmailStr(str):  # type: ignore
         pass
 
 
@@ -50,7 +50,7 @@ class Server(BaseModel):
 
 
 class Reference(BaseModel):
-    ref: str = PSchema(..., alias="$ref")
+    ref: str = PSchema(..., alias="$ref")  # type: ignore
 
 
 class Discriminator(BaseModel):
@@ -72,28 +72,28 @@ class ExternalDocumentation(BaseModel):
 
 
 class SchemaBase(BaseModel):
-    ref: Optional[str] = PSchema(None, alias="$ref")
+    ref: Optional[str] = PSchema(None, alias="$ref")  # type: ignore
     title: Optional[str] = None
     multipleOf: Optional[float] = None
     maximum: Optional[float] = None
     exclusiveMaximum: Optional[float] = None
     minimum: Optional[float] = None
     exclusiveMinimum: Optional[float] = None
-    maxLength: Optional[int] = PSchema(None, gte=0)
-    minLength: Optional[int] = PSchema(None, gte=0)
+    maxLength: Optional[int] = PSchema(None, gte=0)  # type: ignore
+    minLength: Optional[int] = PSchema(None, gte=0)  # type: ignore
     pattern: Optional[str] = None
-    maxItems: Optional[int] = PSchema(None, gte=0)
-    minItems: Optional[int] = PSchema(None, gte=0)
+    maxItems: Optional[int] = PSchema(None, gte=0)  # type: ignore
+    minItems: Optional[int] = PSchema(None, gte=0)  # type: ignore
     uniqueItems: Optional[bool] = None
-    maxProperties: Optional[int] = PSchema(None, gte=0)
-    minProperties: Optional[int] = PSchema(None, gte=0)
+    maxProperties: Optional[int] = PSchema(None, gte=0)  # type: ignore
+    minProperties: Optional[int] = PSchema(None, gte=0)  # type: ignore
     required: Optional[List[str]] = None
     enum: Optional[List[str]] = None
     type: Optional[str] = None
     allOf: Optional[List[Any]] = None
     oneOf: Optional[List[Any]] = None
     anyOf: Optional[List[Any]] = None
-    not_: Optional[List[Any]] = PSchema(None, alias="not")
+    not_: Optional[List[Any]] = PSchema(None, alias="not")  # type: ignore
     items: Optional[Any] = None
     properties: Optional[Dict[str, Any]] = None
     additionalProperties: Optional[Union[bool, Any]] = None
@@ -114,7 +114,7 @@ class Schema(SchemaBase):
     allOf: Optional[List[SchemaBase]] = None
     oneOf: Optional[List[SchemaBase]] = None
     anyOf: Optional[List[SchemaBase]] = None
-    not_: Optional[List[SchemaBase]] = PSchema(None, alias="not")
+    not_: Optional[List[SchemaBase]] = PSchema(None, alias="not")  # type: ignore
     items: Optional[SchemaBase] = None
     properties: Optional[Dict[str, SchemaBase]] = None
     additionalProperties: Optional[Union[bool, SchemaBase]] = None
@@ -144,7 +144,9 @@ class Encoding(BaseModel):
 
 
 class MediaType(BaseModel):
-    schema_: Optional[Union[Schema, Reference]] = PSchema(None, alias="schema")
+    schema_: Optional[Union[Schema, Reference]] = PSchema(
+        None, alias="schema"
+    )  # type: ignore
     example: Optional[Any] = None
     examples: Optional[Dict[str, Union[Example, Reference]]] = None
     encoding: Optional[Dict[str, Encoding]] = None
@@ -158,7 +160,9 @@ class ParameterBase(BaseModel):
     style: Optional[str] = None
     explode: Optional[bool] = None
     allowReserved: Optional[bool] = None
-    schema_: Optional[Union[Schema, Reference]] = PSchema(None, alias="schema")
+    schema_: Optional[Union[Schema, Reference]] = PSchema(
+        None, alias="schema"
+    )  # type: ignore
     example: Optional[Any] = None
     examples: Optional[Dict[str, Union[Example, Reference]]] = None
     # Serialization rules for more complex scenarios
@@ -167,7 +171,7 @@ class ParameterBase(BaseModel):
 
 class Parameter(ParameterBase):
     name: str
-    in_: ParameterInType = PSchema(..., alias="in")
+    in_: ParameterInType = PSchema(..., alias="in")  # type: ignore
 
 
 class Header(ParameterBase):
@@ -222,7 +226,7 @@ class Operation(BaseModel):
 
 
 class PathItem(BaseModel):
-    ref: Optional[str] = PSchema(None, alias="$ref")
+    ref: Optional[str] = PSchema(None, alias="$ref")  # type: ignore
     summary: Optional[str] = None
     description: Optional[str] = None
     get: Optional[Operation] = None
@@ -250,7 +254,7 @@ class SecuritySchemeType(Enum):
 
 
 class SecurityBase(BaseModel):
-    type_: SecuritySchemeType = PSchema(..., alias="type")
+    type_: SecuritySchemeType = PSchema(..., alias="type")  # type: ignore
     description: Optional[str] = None
 
 
@@ -261,13 +265,13 @@ class APIKeyIn(Enum):
 
 
 class APIKey(SecurityBase):
-    type_ = PSchema(SecuritySchemeType.apiKey, alias="type")
-    in_: APIKeyIn = PSchema(..., alias="in")
+    type_ = PSchema(SecuritySchemeType.apiKey, alias="type")  # type: ignore
+    in_: APIKeyIn = PSchema(..., alias="in")  # type: ignore
     name: str
 
 
 class HTTPBase(SecurityBase):
-    type_ = PSchema(SecuritySchemeType.http, alias="type")
+    type_ = PSchema(SecuritySchemeType.http, alias="type")  # type: ignore
     scheme: str
 
 
@@ -306,12 +310,12 @@ class OAuthFlows(BaseModel):
 
 
 class OAuth2(SecurityBase):
-    type_ = PSchema(SecuritySchemeType.oauth2, alias="type")
+    type_ = PSchema(SecuritySchemeType.oauth2, alias="type")  # type: ignore
     flows: OAuthFlows
 
 
 class OpenIdConnect(SecurityBase):
-    type_ = PSchema(SecuritySchemeType.openIdConnect, alias="type")
+    type_ = PSchema(SecuritySchemeType.openIdConnect, alias="type")  # type: ignore
     openIdConnectUrl: str
 
 
index 7dbeece737e3e9012bbba01021fa1e6969d03962..1036d201241a92c0ed4b8eb9a87a904bd9896c85 100644 (file)
@@ -1,23 +1,21 @@
-from typing import Any, Dict, Sequence, Type, List
+from typing import Any, Dict, List, Optional, Sequence, Tuple, Type
 
 from pydantic.fields import Field
-from pydantic.schema import field_schema, get_model_name_map
+from pydantic.schema import Schema, field_schema, get_model_name_map
 from pydantic.utils import lenient_issubclass
-
 from starlette.responses import HTMLResponse, JSONResponse
-from starlette.routing import BaseRoute
+from starlette.routing import BaseRoute, Route
 from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
 
 from fastapi import routing
 from fastapi.dependencies.models import Dependant
 from fastapi.dependencies.utils import get_flat_dependant
 from fastapi.encoders import jsonable_encoder
-from fastapi.openapi.constants import REF_PREFIX, METHODS_WITH_BODY
+from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX
 from fastapi.openapi.models import OpenAPI
-from fastapi.params import Body
+from fastapi.params import Body, Param
 from fastapi.utils import get_flat_models_from_routes, get_model_definitions
 
-
 validation_error_definition = {
     "title": "ValidationError",
     "type": "object",
@@ -42,7 +40,7 @@ validation_error_response_definition = {
 }
 
 
-def get_openapi_params(dependant: Dependant):
+def get_openapi_params(dependant: Dependant) -> List[Field]:
     flat_dependant = get_flat_dependant(dependant)
     return (
         flat_dependant.path_params
@@ -52,7 +50,7 @@ def get_openapi_params(dependant: Dependant):
     )
 
 
-def get_openapi_security_definitions(flat_dependant: Dependant):
+def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, List]:
     security_definitions = {}
     operation_security = []
     for security_requirement in flat_dependant.security_requirements:
@@ -61,59 +59,60 @@ def get_openapi_security_definitions(flat_dependant: Dependant):
             by_alias=True,
             include_none=False,
         )
-        security_name = (
-            security_requirement.security_scheme.scheme_name
-            
-        )
+        security_name = security_requirement.security_scheme.scheme_name
         security_definitions[security_name] = security_definition
         operation_security.append({security_name: security_requirement.scopes})
     return security_definitions, operation_security
 
 
-def get_openapi_operation_parameters(all_route_params: List[Field]):
+def get_openapi_operation_parameters(
+    all_route_params: Sequence[Field]
+) -> Tuple[Dict[str, Dict], List[Dict[str, Any]]]:
     definitions: Dict[str, Dict] = {}
     parameters = []
     for param in all_route_params:
+        schema: Param = param.schema
         if "ValidationError" not in definitions:
             definitions["ValidationError"] = validation_error_definition
             definitions["HTTPValidationError"] = validation_error_response_definition
         parameter = {
             "name": param.alias,
-            "in": param.schema.in_.value,
+            "in": schema.in_.value,
             "required": param.required,
             "schema": field_schema(param, model_name_map={})[0],
         }
-        if param.schema.description:
-            parameter["description"] = param.schema.description
-        if param.schema.deprecated:
-            parameter["deprecated"] = param.schema.deprecated
+        if schema.description:
+            parameter["description"] = schema.description
+        if schema.deprecated:
+            parameter["deprecated"] = schema.deprecated
         parameters.append(parameter)
     return definitions, parameters
 
 
 def get_openapi_operation_request_body(
     *, body_field: Field, model_name_map: Dict[Type, str]
-):
+) -> Optional[Dict]:
     if not body_field:
         return None
     assert isinstance(body_field, Field)
     body_schema, _ = field_schema(
         body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
     )
-    if isinstance(body_field.schema, Body):
-        request_media_type = body_field.schema.media_type
+    schema: Schema = body_field.schema
+    if isinstance(schema, Body):
+        request_media_type = schema.media_type
     else:
         # Includes not declared media types (Schema)
         request_media_type = "application/json"
     required = body_field.required
-    request_body_oai = {}
+    request_body_oai: Dict[str, Any] = {}
     if required:
         request_body_oai["required"] = required
     request_body_oai["content"] = {request_media_type: {"schema": body_schema}}
     return request_body_oai
 
 
-def generate_operation_id(*, route: routing.APIRoute, method: str):
+def generate_operation_id(*, route: routing.APIRoute, method: str) -> str:
     if route.operation_id:
         return route.operation_id
     path: str = route.path
@@ -123,12 +122,13 @@ def generate_operation_id(*, route: routing.APIRoute, method: str):
     return operation_id
 
 
-def generate_operation_summary(*, route: routing.APIRoute, method: str):
+def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str:
     if route.summary:
         return route.summary
     return method.title() + " " + route.name.replace("_", " ").title()
 
-def get_openapi_operation_metadata(*, route: BaseRoute, method: str):
+
+def get_openapi_operation_metadata(*, route: routing.APIRoute, method: str) -> Dict:
     operation: Dict[str, Any] = {}
     if route.tags:
         operation["tags"] = route.tags
@@ -141,12 +141,13 @@ def get_openapi_operation_metadata(*, route: BaseRoute, method: str):
     return operation
 
 
-def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]):
-    if not (route.include_in_schema and isinstance(route, routing.APIRoute)):
-        return None
+def get_openapi_path(
+    *, route: routing.APIRoute, model_name_map: Dict[Type, str]
+) -> Tuple[Dict, Dict, Dict]:
     path = {}
     security_schemes: Dict[str, Any] = {}
     definitions: Dict[str, Any] = {}
+    assert route.methods is not None, "Methods must be a list"
     for method in route.methods:
         operation = get_openapi_operation_metadata(route=route, method=method)
         parameters: List[Dict] = []
@@ -172,10 +173,9 @@ def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]):
             )
             if request_body_oai:
                 operation["requestBody"] = request_body_oai
-        response_code = str(route.response_code)
+        status_code = str(route.status_code)
         response_schema = {"type": "string"}
-        if lenient_issubclass(route.response_wrapper, JSONResponse):
-            response_media_type = "application/json"
+        if lenient_issubclass(route.content_type, JSONResponse):
             if route.response_field:
                 response_schema, _ = field_schema(
                     route.response_field,
@@ -184,16 +184,9 @@ def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]):
                 )
             else:
                 response_schema = {}
-        elif lenient_issubclass(route.response_wrapper, HTMLResponse):
-            response_media_type = "text/html"
-        else:
-            response_media_type = "text/plain"
-        content = {response_media_type: {"schema": response_schema}}
+        content = {route.content_type.media_type: {"schema": response_schema}}
         operation["responses"] = {
-            response_code: {
-                "description": route.response_description,
-                "content": content,
-            }
+            status_code: {"description": route.response_description, "content": content}
         }
         if all_route_params or route.body_field:
             operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = {
@@ -215,7 +208,7 @@ def get_openapi(
     openapi_version: str = "3.0.2",
     description: str = None,
     routes: Sequence[BaseRoute]
-):
+) -> Dict:
     info = {"title": title, "version": version}
     if description:
         info["description"] = description
@@ -228,15 +221,18 @@ def get_openapi(
         flat_models=flat_models, model_name_map=model_name_map
     )
     for route in routes:
-        result = get_openapi_path(route=route, model_name_map=model_name_map)
-        if result:
-            path, security_schemes, path_definitions = result
-            if path:
-                paths.setdefault(route.path, {}).update(path)
-            if security_schemes:
-                components.setdefault("securitySchemes", {}).update(security_schemes)
-            if path_definitions:
-                definitions.update(path_definitions)
+        if isinstance(route, routing.APIRoute):
+            result = get_openapi_path(route=route, model_name_map=model_name_map)
+            if result:
+                path, security_schemes, path_definitions = result
+                if path:
+                    paths.setdefault(route.path, {}).update(path)
+                if security_schemes:
+                    components.setdefault("securitySchemes", {}).update(
+                        security_schemes
+                    )
+                if path_definitions:
+                    definitions.update(path_definitions)
     if definitions:
         components.setdefault("schemas", {}).update(definitions)
     if components:
index abbce8aeb9825101d8168fc91654c1d297bd0944..8df0112c8e5cc4d219ef6a232245011a55b14e08 100644 (file)
@@ -1,5 +1,5 @@
 from enum import Enum
-from typing import Sequence, Any, Dict
+from typing import Any, Callable, Sequence
 
 from pydantic import Schema
 
@@ -16,7 +16,7 @@ class Param(Schema):
 
     def __init__(
         self,
-        default,
+        default: Any,
         *,
         deprecated: bool = None,
         alias: str = None,
@@ -29,7 +29,7 @@ class Param(Schema):
         min_length: int = None,
         max_length: int = None,
         regex: str = None,
-        **extra: Dict[str, Any],
+        **extra: Any,
     ):
         self.deprecated = deprecated
         super().__init__(
@@ -53,7 +53,7 @@ class Path(Param):
 
     def __init__(
         self,
-        default,
+        default: Any,
         *,
         deprecated: bool = None,
         alias: str = None,
@@ -66,7 +66,7 @@ class Path(Param):
         min_length: int = None,
         max_length: int = None,
         regex: str = None,
-        **extra: Dict[str, Any],
+        **extra: Any,
     ):
         self.description = description
         self.deprecated = deprecated
@@ -92,7 +92,7 @@ class Query(Param):
 
     def __init__(
         self,
-        default,
+        default: Any,
         *,
         deprecated: bool = None,
         alias: str = None,
@@ -105,7 +105,7 @@ class Query(Param):
         min_length: int = None,
         max_length: int = None,
         regex: str = None,
-        **extra: Dict[str, Any],
+        **extra: Any,
     ):
         self.description = description
         self.deprecated = deprecated
@@ -130,10 +130,11 @@ class Header(Param):
 
     def __init__(
         self,
-        default,
+        default: Any,
         *,
         deprecated: bool = None,
         alias: str = None,
+        alias_underscore_to_hyphen: bool = True,
         title: str = None,
         description: str = None,
         gt: float = None,
@@ -143,10 +144,11 @@ class Header(Param):
         min_length: int = None,
         max_length: int = None,
         regex: str = None,
-        **extra: Dict[str, Any],
+        **extra: Any,
     ):
         self.description = description
         self.deprecated = deprecated
+        self.alias_underscore_to_hyphen = alias_underscore_to_hyphen
         super().__init__(
             default,
             alias=alias,
@@ -168,7 +170,7 @@ class Cookie(Param):
 
     def __init__(
         self,
-        default,
+        default: Any,
         *,
         deprecated: bool = None,
         alias: str = None,
@@ -181,7 +183,7 @@ class Cookie(Param):
         min_length: int = None,
         max_length: int = None,
         regex: str = None,
-        **extra: Dict[str, Any],
+        **extra: Any,
     ):
         self.description = description
         self.deprecated = deprecated
@@ -204,9 +206,9 @@ class Cookie(Param):
 class Body(Schema):
     def __init__(
         self,
-        default,
+        default: Any,
         *,
-        embed=False,
+        embed: bool = False,
         media_type: str = "application/json",
         alias: str = None,
         title: str = None,
@@ -218,7 +220,7 @@ class Body(Schema):
         min_length: int = None,
         max_length: int = None,
         regex: str = None,
-        **extra: Dict[str, Any],
+        **extra: Any,
     ):
         self.embed = embed
         self.media_type = media_type
@@ -241,9 +243,9 @@ class Body(Schema):
 class Form(Body):
     def __init__(
         self,
-        default,
+        default: Any,
         *,
-        sub_key=False,
+        sub_key: bool = False,
         media_type: str = "application/x-www-form-urlencoded",
         alias: str = None,
         title: str = None,
@@ -255,7 +257,7 @@ class Form(Body):
         min_length: int = None,
         max_length: int = None,
         regex: str = None,
-        **extra: Dict[str, Any],
+        **extra: Any,
     ):
         super().__init__(
             default,
@@ -278,9 +280,9 @@ class Form(Body):
 class File(Form):
     def __init__(
         self,
-        default,
+        default: Any,
         *,
-        sub_key=False,
+        sub_key: bool = False,
         media_type: str = "multipart/form-data",
         alias: str = None,
         title: str = None,
@@ -292,7 +294,7 @@ class File(Form):
         min_length: int = None,
         max_length: int = None,
         regex: str = None,
-        **extra: Dict[str, Any],
+        **extra: Any,
     ):
         super().__init__(
             default,
@@ -313,11 +315,11 @@ class File(Form):
 
 
 class Depends:
-    def __init__(self, dependency=None):
+    def __init__(self, dependency: Callable = None):
         self.dependency = dependency
 
 
 class Security(Depends):
-    def __init__(self, dependency=None, scopes: Sequence[str] = None):
+    def __init__(self, dependency: Callable = None, scopes: Sequence[str] = None):
         self.scopes = scopes or []
         super().__init__(dependency=dependency)
index 22a62a53a24eed6688b5c01a3754604bd91ed14b..8620db5db5fbbeabfccafdde5ced36dd6b68b5e8 100644 (file)
@@ -1,12 +1,11 @@
 import asyncio
 import inspect
-from typing import Callable, List, Type
+from typing import Any, Callable, List, Optional, Type
 
 from pydantic import BaseConfig, BaseModel, Schema
 from pydantic.error_wrappers import ErrorWrapper, ValidationError
 from pydantic.fields import Field
 from pydantic.utils import lenient_issubclass
-
 from starlette import routing
 from starlette.concurrency import run_in_threadpool
 from starlette.exceptions import HTTPException
@@ -22,7 +21,7 @@ from fastapi.dependencies.utils import get_body_field, get_dependant, solve_depe
 from fastapi.encoders import jsonable_encoder
 
 
-def serialize_response(*, field: Field = None, response):
+def serialize_response(*, field: Field = None, response: Response) -> Any:
     if field:
         errors = []
         value, errors_ = field.validate(response, {}, loc=("response",))
@@ -40,11 +39,12 @@ def serialize_response(*, field: Field = None, response):
 def get_app(
     dependant: Dependant,
     body_field: Field = None,
-    response_code: str = 200,
-    response_wrapper: Type[Response] = JSONResponse,
-    response_field: Type[Field] = None,
-):
-    is_coroutine = dependant.call and asyncio.iscoroutinefunction(dependant.call)
+    status_code: int = 200,
+    content_type: Type[Response] = JSONResponse,
+    response_field: Field = None,
+) -> Callable:
+    assert dependant.call is not None, "dependant.call must me a function"
+    is_coroutine = asyncio.iscoroutinefunction(dependant.call)
     is_body_form = body_field and isinstance(body_field.schema, params.Form)
 
     async def app(request: Request) -> Response:
@@ -69,6 +69,7 @@ def get_app(
                 status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=errors_out.errors()
             )
         else:
+            assert dependant.call is not None, "dependant.call must me a function"
             if is_coroutine:
                 raw_response = await dependant.call(**values)
             else:
@@ -76,32 +77,32 @@ def get_app(
             if isinstance(raw_response, Response):
                 return raw_response
             if isinstance(raw_response, BaseModel):
-                return response_wrapper(
-                    content=jsonable_encoder(raw_response), status_code=response_code
+                return content_type(
+                    content=jsonable_encoder(raw_response), status_code=status_code
                 )
             errors = []
             try:
-                return response_wrapper(
+                return content_type(
                     content=serialize_response(
                         field=response_field, response=raw_response
                     ),
-                    status_code=response_code,
+                    status_code=status_code,
                 )
             except Exception as e:
                 errors.append(e)
             try:
                 response = dict(raw_response)
-                return response_wrapper(
+                return content_type(
                     content=serialize_response(field=response_field, response=response),
-                    status_code=response_code,
+                    status_code=status_code,
                 )
             except Exception as e:
                 errors.append(e)
             try:
                 response = vars(raw_response)
-                return response_wrapper(
+                return content_type(
                     content=serialize_response(field=response_field, response=response),
-                    status_code=response_code,
+                    status_code=status_code,
                 )
             except Exception as e:
                 errors.append(e)
@@ -116,43 +117,32 @@ class APIRoute(routing.Route):
         path: str,
         endpoint: Callable,
         *,
-        methods: List[str] = None,
-        name: str = None,
-        include_in_schema: bool = True,
+        response_model: Type[BaseModel] = None,
+        status_code: int = 200,
         tags: List[str] = None,
         summary: str = None,
         description: str = None,
-        operation_id: str = None,
-        deprecated: bool = None,
-        response_type: Type = None,
         response_description: str = "Successful Response",
-        response_code=200,
-        response_wrapper=JSONResponse,
+        deprecated: bool = None,
+        name: str = None,
+        methods: List[str] = None,
+        operation_id: str = None,
+        include_in_schema: bool = True,
+        content_type: Type[Response] = JSONResponse,
     ) -> None:
         assert path.startswith("/"), "Routed paths must always start with '/'"
         self.path = path
         self.endpoint = endpoint
         self.name = get_name(endpoint) if name is None else name
-        self.include_in_schema = include_in_schema
-        self.tags = tags or []
-        self.summary = summary
-        self.description = description or self.endpoint.__doc__
-        self.operation_id = operation_id
-        self.deprecated = deprecated
-        self.body_field: Field = None
-        self.response_description = response_description
-        self.response_code = response_code
-        self.response_wrapper = response_wrapper
-        self.response_field = None
-        if response_type:
+        self.response_model = response_model
+        if self.response_model:
             assert lenient_issubclass(
-                response_wrapper, JSONResponse
+                content_type, JSONResponse
             ), "To declare a type the response must be a JSON response"
-            self.response_type = response_type
             response_name = "Response_" + self.name
-            self.response_field = Field(
+            self.response_field: Optional[Field] = Field(
                 name=response_name,
-                type_=self.response_type,
+                type_=self.response_model,
                 class_validators=[],
                 default=None,
                 required=False,
@@ -160,25 +150,34 @@ class APIRoute(routing.Route):
                 schema=Schema(None),
             )
         else:
-            self.response_type = None
+            self.response_field = None
+        self.status_code = status_code
+        self.tags = tags or []
+        self.summary = summary
+        self.description = description or self.endpoint.__doc__
+        self.response_description = response_description
+        self.deprecated = deprecated
         if methods is None:
             methods = ["GET"]
         self.methods = methods
+        self.operation_id = operation_id
+        self.include_in_schema = include_in_schema
+        self.content_type = content_type
+
         self.path_regex, self.path_format, self.param_convertors = self.compile_path(
             path
         )
         assert inspect.isfunction(endpoint) or inspect.ismethod(
             endpoint
         ), f"An endpoint must be a function or method"
-
         self.dependant = get_dependant(path=path, call=self.endpoint)
         self.body_field = get_body_field(dependant=self.dependant, name=self.name)
         self.app = request_response(
             get_app(
                 dependant=self.dependant,
                 body_field=self.body_field,
-                response_code=self.response_code,
-                response_wrapper=self.response_wrapper,
+                status_code=self.status_code,
+                content_type=self.content_type,
                 response_field=self.response_field,
             )
         )
@@ -189,75 +188,77 @@ class APIRouter(routing.Router):
         self,
         path: str,
         endpoint: Callable,
-        methods: List[str] = None,
-        name: str = None,
-        include_in_schema: bool = True,
+        *,
+        response_model: Type[BaseModel] = None,
+        status_code: int = 200,
         tags: List[str] = None,
         summary: str = None,
         description: str = None,
-        operation_id: str = None,
-        deprecated: bool = None,
-        response_type: Type = None,
         response_description: str = "Successful Response",
-        response_code=200,
-        response_wrapper=JSONResponse,
+        deprecated: bool = None,
+        name: str = None,
+        methods: List[str] = None,
+        operation_id: str = None,
+        include_in_schema: bool = True,
+        content_type: Type[Response] = JSONResponse,
     ) -> None:
         route = APIRoute(
             path,
             endpoint=endpoint,
-            methods=methods,
-            name=name,
-            include_in_schema=include_in_schema,
+            response_model=response_model,
+            status_code=status_code,
             tags=tags or [],
             summary=summary,
             description=description,
-            operation_id=operation_id,
-            deprecated=deprecated,
-            response_type=response_type,
             response_description=response_description,
-            response_code=response_code,
-            response_wrapper=response_wrapper,
+            deprecated=deprecated,
+            name=name,
+            methods=methods,
+            operation_id=operation_id,
+            include_in_schema=include_in_schema,
+            content_type=content_type,
         )
         self.routes.append(route)
 
     def api_route(
         self,
         path: str,
-        methods: List[str] = None,
-        name: str = None,
-        include_in_schema: bool = True,
+        *,
+        response_model: Type[BaseModel] = None,
+        status_code: int = 200,
         tags: List[str] = None,
         summary: str = None,
         description: str = None,
-        operation_id: str = None,
-        deprecated: bool = None,
-        response_type: Type = None,
         response_description: str = "Successful Response",
-        response_code=200,
-        response_wrapper=JSONResponse,
+        deprecated: bool = None,
+        name: str = None,
+        methods: List[str] = None,
+        operation_id: str = None,
+        include_in_schema: bool = True,
+        content_type: Type[Response] = JSONResponse,
     ) -> Callable:
         def decorator(func: Callable) -> Callable:
             self.add_api_route(
                 path,
                 func,
-                methods=methods,
-                name=name,
-                include_in_schema=include_in_schema,
+                response_model=response_model,
+                status_code=status_code,
                 tags=tags or [],
                 summary=summary,
                 description=description,
-                operation_id=operation_id,
-                deprecated=deprecated,
-                response_type=response_type,
                 response_description=response_description,
-                response_code=response_code,
-                response_wrapper=response_wrapper,
+                deprecated=deprecated,
+                name=name,
+                methods=methods,
+                operation_id=operation_id,
+                include_in_schema=include_in_schema,
+                content_type=content_type,
             )
             return func
 
         return decorator
 
-    def include_router(self, router: "APIRouter", *, prefix=""):
+    def include_router(self, router: "APIRouter", *, prefix: str = "") -> None:
         if prefix:
             assert prefix.startswith("/"), "A path prefix must start with '/'"
             assert not prefix.endswith(
@@ -268,18 +269,18 @@ class APIRouter(routing.Router):
                 self.add_api_route(
                     prefix + route.path,
                     route.endpoint,
-                    methods=route.methods,
-                    name=route.name,
-                    include_in_schema=route.include_in_schema,
-                    tags=route.tags,
+                    response_model=route.response_model,
+                    status_code=route.status_code,
+                    tags=route.tags or [],
                     summary=route.summary,
                     description=route.description,
-                    operation_id=route.operation_id,
-                    deprecated=route.deprecated,
-                    response_type=route.response_type,
                     response_description=route.response_description,
-                    response_code=route.response_code,
-                    response_wrapper=route.response_wrapper,
+                    deprecated=route.deprecated,
+                    name=route.name,
+                    methods=route.methods,
+                    operation_id=route.operation_id,
+                    include_in_schema=route.include_in_schema,
+                    content_type=route.content_type,
                 )
             elif isinstance(route, routing.Route):
                 self.add_route(
@@ -293,247 +294,255 @@ class APIRouter(routing.Router):
     def get(
         self,
         path: str,
-        name: str = None,
-        include_in_schema: bool = True,
+        *,
+        response_model: Type[BaseModel] = None,
+        status_code: int = 200,
         tags: List[str] = None,
         summary: str = None,
         description: str = None,
-        operation_id: str = None,
-        deprecated: bool = None,
-        response_type: Type = None,
         response_description: str = "Successful Response",
-        response_code=200,
-        response_wrapper=JSONResponse,
-    ):
+        deprecated: bool = None,
+        name: str = None,
+        operation_id: str = None,
+        include_in_schema: bool = True,
+        content_type: Type[Response] = JSONResponse,
+    ) -> Callable:
         return self.api_route(
             path=path,
-            methods=["GET"],
-            name=name,
-            include_in_schema=include_in_schema,
+            response_model=response_model,
+            status_code=status_code,
             tags=tags or [],
             summary=summary,
             description=description,
-            operation_id=operation_id,
-            deprecated=deprecated,
-            response_type=response_type,
             response_description=response_description,
-            response_code=response_code,
-            response_wrapper=response_wrapper,
+            deprecated=deprecated,
+            name=name,
+            methods=["GET"],
+            operation_id=operation_id,
+            include_in_schema=include_in_schema,
+            content_type=content_type,
         )
 
     def put(
         self,
         path: str,
-        name: str = None,
-        include_in_schema: bool = True,
+        *,
+        response_model: Type[BaseModel] = None,
+        status_code: int = 200,
         tags: List[str] = None,
         summary: str = None,
         description: str = None,
-        operation_id: str = None,
-        deprecated: bool = None,
-        response_type: Type = None,
         response_description: str = "Successful Response",
-        response_code=200,
-        response_wrapper=JSONResponse,
-    ):
+        deprecated: bool = None,
+        name: str = None,
+        operation_id: str = None,
+        include_in_schema: bool = True,
+        content_type: Type[Response] = JSONResponse,
+    ) -> Callable:
         return self.api_route(
             path=path,
-            methods=["PUT"],
-            name=name,
-            include_in_schema=include_in_schema,
+            response_model=response_model,
+            status_code=status_code,
             tags=tags or [],
             summary=summary,
             description=description,
-            operation_id=operation_id,
-            deprecated=deprecated,
-            response_type=response_type,
             response_description=response_description,
-            response_code=response_code,
-            response_wrapper=response_wrapper,
+            deprecated=deprecated,
+            name=name,
+            methods=["PUT"],
+            operation_id=operation_id,
+            include_in_schema=include_in_schema,
+            content_type=content_type,
         )
 
     def post(
         self,
         path: str,
-        name: str = None,
-        include_in_schema: bool = True,
+        *,
+        response_model: Type[BaseModel] = None,
+        status_code: int = 200,
         tags: List[str] = None,
         summary: str = None,
         description: str = None,
-        operation_id: str = None,
-        deprecated: bool = None,
-        response_type: Type = None,
         response_description: str = "Successful Response",
-        response_code=200,
-        response_wrapper=JSONResponse,
-    ):
+        deprecated: bool = None,
+        name: str = None,
+        operation_id: str = None,
+        include_in_schema: bool = True,
+        content_type: Type[Response] = JSONResponse,
+    ) -> Callable:
         return self.api_route(
             path=path,
-            methods=["POST"],
-            name=name,
-            include_in_schema=include_in_schema,
+            response_model=response_model,
+            status_code=status_code,
             tags=tags or [],
             summary=summary,
             description=description,
-            operation_id=operation_id,
-            deprecated=deprecated,
-            response_type=response_type,
             response_description=response_description,
-            response_code=response_code,
-            response_wrapper=response_wrapper,
+            deprecated=deprecated,
+            name=name,
+            methods=["POST"],
+            operation_id=operation_id,
+            include_in_schema=include_in_schema,
+            content_type=content_type,
         )
 
     def delete(
         self,
         path: str,
-        name: str = None,
-        include_in_schema: bool = True,
+        *,
+        response_model: Type[BaseModel] = None,
+        status_code: int = 200,
         tags: List[str] = None,
         summary: str = None,
         description: str = None,
-        operation_id: str = None,
-        deprecated: bool = None,
-        response_type: Type = None,
         response_description: str = "Successful Response",
-        response_code=200,
-        response_wrapper=JSONResponse,
-    ):
+        deprecated: bool = None,
+        name: str = None,
+        operation_id: str = None,
+        include_in_schema: bool = True,
+        content_type: Type[Response] = JSONResponse,
+    ) -> Callable:
         return self.api_route(
             path=path,
-            methods=["DELETE"],
-            name=name,
-            include_in_schema=include_in_schema,
+            response_model=response_model,
+            status_code=status_code,
             tags=tags or [],
             summary=summary,
             description=description,
-            operation_id=operation_id,
-            deprecated=deprecated,
-            response_type=response_type,
             response_description=response_description,
-            response_code=response_code,
-            response_wrapper=response_wrapper,
+            deprecated=deprecated,
+            name=name,
+            methods=["DELETE"],
+            operation_id=operation_id,
+            include_in_schema=include_in_schema,
+            content_type=content_type,
         )
 
     def options(
         self,
         path: str,
-        name: str = None,
-        include_in_schema: bool = True,
+        *,
+        response_model: Type[BaseModel] = None,
+        status_code: int = 200,
         tags: List[str] = None,
         summary: str = None,
         description: str = None,
-        operation_id: str = None,
-        deprecated: bool = None,
-        response_type: Type = None,
         response_description: str = "Successful Response",
-        response_code=200,
-        response_wrapper=JSONResponse,
-    ):
+        deprecated: bool = None,
+        name: str = None,
+        operation_id: str = None,
+        include_in_schema: bool = True,
+        content_type: Type[Response] = JSONResponse,
+    ) -> Callable:
         return self.api_route(
             path=path,
-            methods=["OPTIONS"],
-            name=name,
-            include_in_schema=include_in_schema,
+            response_model=response_model,
+            status_code=status_code,
             tags=tags or [],
             summary=summary,
             description=description,
-            operation_id=operation_id,
-            deprecated=deprecated,
-            response_type=response_type,
             response_description=response_description,
-            response_code=response_code,
-            response_wrapper=response_wrapper,
+            deprecated=deprecated,
+            name=name,
+            methods=["OPTIONS"],
+            operation_id=operation_id,
+            include_in_schema=include_in_schema,
+            content_type=content_type,
         )
 
     def head(
         self,
         path: str,
-        name: str = None,
-        include_in_schema: bool = True,
+        *,
+        response_model: Type[BaseModel] = None,
+        status_code: int = 200,
         tags: List[str] = None,
         summary: str = None,
         description: str = None,
-        operation_id: str = None,
-        deprecated: bool = None,
-        response_type: Type = None,
         response_description: str = "Successful Response",
-        response_code=200,
-        response_wrapper=JSONResponse,
-    ):
+        deprecated: bool = None,
+        name: str = None,
+        operation_id: str = None,
+        include_in_schema: bool = True,
+        content_type: Type[Response] = JSONResponse,
+    ) -> Callable:
         return self.api_route(
             path=path,
-            methods=["HEAD"],
-            name=name,
-            include_in_schema=include_in_schema,
+            response_model=response_model,
+            status_code=status_code,
             tags=tags or [],
             summary=summary,
             description=description,
-            operation_id=operation_id,
-            deprecated=deprecated,
-            response_type=response_type,
             response_description=response_description,
-            response_code=response_code,
-            response_wrapper=response_wrapper,
+            deprecated=deprecated,
+            name=name,
+            methods=["HEAD"],
+            operation_id=operation_id,
+            include_in_schema=include_in_schema,
+            content_type=content_type,
         )
 
     def patch(
         self,
         path: str,
-        name: str = None,
-        include_in_schema: bool = True,
+        *,
+        response_model: Type[BaseModel] = None,
+        status_code: int = 200,
         tags: List[str] = None,
         summary: str = None,
         description: str = None,
-        operation_id: str = None,
-        deprecated: bool = None,
-        response_type: Type = None,
         response_description: str = "Successful Response",
-        response_code=200,
-        response_wrapper=JSONResponse,
-    ):
+        deprecated: bool = None,
+        name: str = None,
+        operation_id: str = None,
+        include_in_schema: bool = True,
+        content_type: Type[Response] = JSONResponse,
+    ) -> Callable:
         return self.api_route(
             path=path,
-            methods=["PATCH"],
-            name=name,
-            include_in_schema=include_in_schema,
+            response_model=response_model,
+            status_code=status_code,
             tags=tags or [],
             summary=summary,
             description=description,
-            operation_id=operation_id,
-            deprecated=deprecated,
-            response_type=response_type,
             response_description=response_description,
-            response_code=response_code,
-            response_wrapper=response_wrapper,
+            deprecated=deprecated,
+            name=name,
+            methods=["PATCH"],
+            operation_id=operation_id,
+            include_in_schema=include_in_schema,
+            content_type=content_type,
         )
 
     def trace(
         self,
         path: str,
-        name: str = None,
-        include_in_schema: bool = True,
+        *,
+        response_model: Type[BaseModel] = None,
+        status_code: int = 200,
         tags: List[str] = None,
         summary: str = None,
         description: str = None,
-        operation_id: str = None,
-        deprecated: bool = None,
-        response_type: Type = None,
         response_description: str = "Successful Response",
-        response_code=200,
-        response_wrapper=JSONResponse,
-    ):
+        deprecated: bool = None,
+        name: str = None,
+        operation_id: str = None,
+        include_in_schema: bool = True,
+        content_type: Type[Response] = JSONResponse,
+    ) -> Callable:
         return self.api_route(
             path=path,
-            methods=["TRACE"],
-            name=name,
-            include_in_schema=include_in_schema,
+            response_model=response_model,
+            status_code=status_code,
             tags=tags or [],
             summary=summary,
             description=description,
-            operation_id=operation_id,
-            deprecated=deprecated,
-            response_type=response_type,
             response_description=response_description,
-            response_code=response_code,
-            response_wrapper=response_wrapper,
+            deprecated=deprecated,
+            name=name,
+            methods=["TRACE"],
+            operation_id=operation_id,
+            include_in_schema=include_in_schema,
+            content_type=content_type,
         )
index 047898dfe21646ecdf0ce3187803f7ca9f22450e..c4b045b715bcc4f5471b93dd72a41581cb8d73b8 100644 (file)
@@ -1,18 +1,19 @@
 from starlette.requests import Request
 
-from .base import SecurityBase
-from fastapi.openapi.models import APIKeyIn, APIKey
+from fastapi.openapi.models import APIKey, APIKeyIn
+from fastapi.security.base import SecurityBase
+
 
 class APIKeyBase(SecurityBase):
     pass
 
-class APIKeyQuery(APIKeyBase):
 
+class APIKeyQuery(APIKeyBase):
     def __init__(self, *, name: str, scheme_name: str = None):
         self.model = APIKey(in_=APIKeyIn.query, name=name)
         self.scheme_name = scheme_name or self.__class__.__name__
 
-    async def __call__(self, requests: Request):
+    async def __call__(self, requests: Request) -> str:
         return requests.query_params.get(self.model.name)
 
 
@@ -21,7 +22,7 @@ class APIKeyHeader(APIKeyBase):
         self.model = APIKey(in_=APIKeyIn.header, name=name)
         self.scheme_name = scheme_name or self.__class__.__name__
 
-    async def __call__(self, requests: Request):
+    async def __call__(self, requests: Request) -> str:
         return requests.headers.get(self.model.name)
 
 
@@ -30,5 +31,5 @@ class APIKeyCookie(APIKeyBase):
         self.model = APIKey(in_=APIKeyIn.cookie, name=name)
         self.scheme_name = scheme_name or self.__class__.__name__
 
-    async def __call__(self, requests: Request):
+    async def __call__(self, requests: Request) -> str:
         return requests.cookies.get(self.model.name)
index 8589da0be0a35d2a9444b9b3e44182c415183994..c43555deb8ea83b14241a5631c9ea451c96f6e7f 100644 (file)
@@ -1,6 +1,6 @@
-from pydantic import BaseModel
-
 from fastapi.openapi.models import SecurityBase as SecurityBaseModel
 
+
 class SecurityBase:
-    pass
+    model: SecurityBaseModel
+    scheme_name: str
index cee42b8687b8b2ff48cfe9d493202baac665a14f..480a1ae546f5e883e5abab9282f7222728e8fbf8 100644 (file)
@@ -1,7 +1,10 @@
 from starlette.requests import Request
 
-from .base import SecurityBase
-from fastapi.openapi.models import HTTPBase as HTTPBaseModel, HTTPBearer as HTTPBearerModel
+from fastapi.openapi.models import (
+    HTTPBase as HTTPBaseModel,
+    HTTPBearer as HTTPBearerModel,
+)
+from fastapi.security.base import SecurityBase
 
 
 class HTTPBase(SecurityBase):
@@ -9,7 +12,7 @@ class HTTPBase(SecurityBase):
         self.model = HTTPBaseModel(scheme=scheme)
         self.scheme_name = scheme_name or self.__class__.__name__
 
-    async def __call__(self, request: Request):
+    async def __call__(self, request: Request) -> str:
         return request.headers.get("Authorization")
 
 
@@ -17,8 +20,8 @@ class HTTPBasic(HTTPBase):
     def __init__(self, *, scheme_name: str = None):
         self.model = HTTPBaseModel(scheme="basic")
         self.scheme_name = scheme_name or self.__class__.__name__
-    
-    async def __call__(self, request: Request):
+
+    async def __call__(self, request: Request) -> str:
         return request.headers.get("Authorization")
 
 
@@ -26,8 +29,8 @@ class HTTPBearer(HTTPBase):
     def __init__(self, *, bearerFormat: str = None, scheme_name: str = None):
         self.model = HTTPBearerModel(bearerFormat=bearerFormat)
         self.scheme_name = scheme_name or self.__class__.__name__
-    
-    async def __call__(self, request: Request):
+
+    async def __call__(self, request: Request) -> str:
         return request.headers.get("Authorization")
 
 
@@ -35,6 +38,6 @@ class HTTPDigest(HTTPBase):
     def __init__(self, *, scheme_name: str = None):
         self.model = HTTPBaseModel(scheme="digest")
         self.scheme_name = scheme_name or self.__class__.__name__
-    
-    async def __call__(self, request: Request):
+
+    async def __call__(self, request: Request) -> str:
         return request.headers.get("Authorization")
index 65517e962de7d9f064d830a952a592973743be76..90838fdad08aaf5319718e444a9b21abb4cc458e 100644 (file)
@@ -1,13 +1,15 @@
 from starlette.requests import Request
 
-from .base import SecurityBase
 from fastapi.openapi.models import OAuth2 as OAuth2Model, OAuthFlows as OAuthFlowsModel
+from fastapi.security.base import SecurityBase
 
 
 class OAuth2(SecurityBase):
-    def __init__(self, *, flows: OAuthFlowsModel = OAuthFlowsModel(), scheme_name: str = None):
+    def __init__(
+        self, *, flows: OAuthFlowsModel = OAuthFlowsModel(), scheme_name: str = None
+    ):
         self.model = OAuth2Model(flows=flows)
         self.scheme_name = scheme_name or self.__class__.__name__
-    
-    async def __call__(self, request: Request):
+
+    async def __call__(self, request: Request) -> str:
         return request.headers.get("Authorization")
index 49c5aae2d86d800c5c3c188271d91936d612becf..b6c0a32dc412f42a9abb78366248f6e09795d373 100644 (file)
@@ -1,13 +1,13 @@
 from starlette.requests import Request
 
-from .base import SecurityBase
 from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
+from fastapi.security.base import SecurityBase
 
 
 class OpenIdConnect(SecurityBase):
     def __init__(self, *, openIdConnectUrl: str, scheme_name: str = None):
         self.model = OpenIdConnectModel(openIdConnectUrl=openIdConnectUrl)
         self.scheme_name = scheme_name or self.__class__.__name__
-    
-    async def __call__(self, request: Request):
+
+    async def __call__(self, request: Request) -> str:
         return request.headers.get("Authorization")
index 091f868fe975b7b1c4c3feb06f1b7fd524305e35..81ca910cfaa5f3c319cfbcaa64780b039057f391 100644 (file)
@@ -1,20 +1,24 @@
 import re
-from typing import Dict, Sequence, Set, Type
+from typing import Any, Dict, List, Sequence, Set, Type
 
+from pydantic import BaseModel
+from pydantic.fields import Field
+from pydantic.schema import get_flat_models_from_fields, model_process_schema
 from starlette.routing import BaseRoute
 
 from fastapi import routing
 from fastapi.openapi.constants import REF_PREFIX
-from pydantic import BaseModel
-from pydantic.fields import Field
-from pydantic.schema import get_flat_models_from_fields, model_process_schema
 
 
-def get_flat_models_from_routes(routes: Sequence[BaseRoute]):
-    body_fields_from_routes = []
-    responses_from_routes = []
+def get_flat_models_from_routes(
+    routes: Sequence[Type[BaseRoute]]
+) -> Set[Type[BaseModel]]:
+    body_fields_from_routes: List[Field] = []
+    responses_from_routes: List[Field] = []
     for route in routes:
-        if route.include_in_schema and isinstance(route, routing.APIRoute):
+        if getattr(route, "include_in_schema", None) and isinstance(
+            route, routing.APIRoute
+        ):
             if route.body_field:
                 assert isinstance(
                     route.body_field, Field
@@ -30,7 +34,7 @@ def get_flat_models_from_routes(routes: Sequence[BaseRoute]):
 
 def get_model_definitions(
     *, flat_models: Set[Type[BaseModel]], model_name_map: Dict[Type[BaseModel], str]
-):
+) -> Dict[str, Any]:
     definitions: Dict[str, Dict] = {}
     for model in flat_models:
         m_schema, m_definitions = model_process_schema(
@@ -42,5 +46,5 @@ def get_model_definitions(
     return definitions
 
 
-def get_path_param_names(path: str):
+def get_path_param_names(path: str) -> Set[str]:
     return {item.strip("{}") for item in re.findall("{[^}]*}", path)}