]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
✨ Improve type annotations, add support for mypy --strict, internally and for externa...
authorSebastián Ramírez <tiangolo@gmail.com>
Sun, 20 Dec 2020 18:50:00 +0000 (19:50 +0100)
committerGitHub <noreply@github.com>
Sun, 20 Dec 2020 18:50:00 +0000 (19:50 +0100)
43 files changed:
docs_src/openapi_callbacks/tutorial001.py
fastapi/__init__.py
fastapi/applications.py
fastapi/background.py
fastapi/concurrency.py
fastapi/datastructures.py
fastapi/dependencies/models.py
fastapi/dependencies/utils.py
fastapi/encoders.py
fastapi/middleware/__init__.py
fastapi/middleware/cors.py
fastapi/middleware/gzip.py
fastapi/middleware/httpsredirect.py
fastapi/middleware/trustedhost.py
fastapi/middleware/wsgi.py
fastapi/openapi/docs.py
fastapi/openapi/models.py
fastapi/openapi/utils.py
fastapi/param_functions.py
fastapi/params.py
fastapi/responses.py
fastapi/routing.py
fastapi/security/__init__.py
fastapi/security/oauth2.py
fastapi/staticfiles.py
fastapi/templating.py
fastapi/testclient.py
fastapi/types.py [new file with mode: 0644]
fastapi/utils.py
fastapi/websockets.py
mypy.ini
pyproject.toml
tests/test_custom_route_class.py
tests/test_get_request_body.py
tests/test_include_router_defaults_overrides.py
tests/test_inherited_custom_class.py
tests/test_jsonable_encoder.py
tests/test_local_docs.py
tests/test_multi_body_errors.py
tests/test_param_class.py
tests/test_params_repr.py
tests/test_starlette_urlconvertors.py
tests/test_sub_callbacks.py

index f04fec4d7bf87c84523f842ff4c72d0aaa50f445..2fb8367515db146d18fdab02773049d14184b032 100644 (file)
@@ -26,7 +26,7 @@ invoices_callback_router = APIRouter()
 
 
 @invoices_callback_router.post(
-    "{$callback_url}/invoices/{$request.body.id}", response_model=InvoiceEventReceived,
+    "{$callback_url}/invoices/{$request.body.id}", response_model=InvoiceEventReceived
 )
 def invoice_notification(body: InvoiceEvent):
     pass
index 3d1a699d9fbf25e74241c5065a0593755ce732f0..858da48c5c7e281a1a292ee14ba9abdc4ed0230f 100644 (file)
@@ -2,24 +2,23 @@
 
 __version__ = "0.62.0"
 
-from starlette import status
+from starlette import status as status
 
-from .applications import FastAPI
-from .background import BackgroundTasks
-from .datastructures import UploadFile
-from .exceptions import HTTPException
-from .param_functions import (
-    Body,
-    Cookie,
-    Depends,
-    File,
-    Form,
-    Header,
-    Path,
-    Query,
-    Security,
-)
-from .requests import Request
-from .responses import Response
-from .routing import APIRouter
-from .websockets import WebSocket, WebSocketDisconnect
+from .applications import FastAPI as FastAPI
+from .background import BackgroundTasks as BackgroundTasks
+from .datastructures import UploadFile as UploadFile
+from .exceptions import HTTPException as HTTPException
+from .param_functions import Body as Body
+from .param_functions import Cookie as Cookie
+from .param_functions import Depends as Depends
+from .param_functions import File as File
+from .param_functions import Form as Form
+from .param_functions import Header as Header
+from .param_functions import Path as Path
+from .param_functions import Query as Query
+from .param_functions import Security as Security
+from .requests import Request as Request
+from .responses import Response as Response
+from .routing import APIRouter as APIRouter
+from .websockets import WebSocket as WebSocket
+from .websockets import WebSocketDisconnect as WebSocketDisconnect
index 519dc74aeb8e5219f6ac25fe02bf26824a574022..92d041c5cf3f654925e07d8b8cdb1e74a523a737 100644 (file)
@@ -1,4 +1,4 @@
-from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
+from typing import Any, Callable, Coroutine, Dict, List, Optional, Sequence, Type, Union
 
 from fastapi import routing
 from fastapi.concurrency import AsyncExitStack
@@ -17,6 +17,7 @@ from fastapi.openapi.docs import (
 )
 from fastapi.openapi.utils import get_openapi
 from fastapi.params import Depends
+from fastapi.types import DecoratedCallable
 from starlette.applications import Starlette
 from starlette.datastructures import State
 from starlette.exceptions import HTTPException
@@ -24,7 +25,7 @@ from starlette.middleware import Middleware
 from starlette.requests import Request
 from starlette.responses import HTMLResponse, JSONResponse, Response
 from starlette.routing import BaseRoute
-from starlette.types import Receive, Scope, Send
+from starlette.types import ASGIApp, Receive, Scope, Send
 
 
 class FastAPI(Starlette):
@@ -44,24 +45,27 @@ class FastAPI(Starlette):
         docs_url: Optional[str] = "/docs",
         redoc_url: Optional[str] = "/redoc",
         swagger_ui_oauth2_redirect_url: Optional[str] = "/docs/oauth2-redirect",
-        swagger_ui_init_oauth: Optional[dict] = None,
+        swagger_ui_init_oauth: Optional[Dict[str, Any]] = None,
         middleware: Optional[Sequence[Middleware]] = None,
         exception_handlers: Optional[
-            Dict[Union[int, Type[Exception]], Callable]
+            Dict[
+                Union[int, Type[Exception]],
+                Callable[[Request, Any], Coroutine[Any, Any, Response]],
+            ]
         ] = None,
-        on_startup: Optional[Sequence[Callable]] = None,
-        on_shutdown: Optional[Sequence[Callable]] = None,
+        on_startup: Optional[Sequence[Callable[[], Any]]] = None,
+        on_shutdown: Optional[Sequence[Callable[[], Any]]] = None,
         openapi_prefix: str = "",
         root_path: str = "",
         root_path_in_servers: bool = True,
         responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
-        callbacks: Optional[List[routing.APIRoute]] = None,
-        deprecated: bool = None,
+        callbacks: Optional[List[BaseRoute]] = None,
+        deprecated: Optional[bool] = None,
         include_in_schema: bool = True,
         **extra: Any,
     ) -> None:
-        self._debug = debug
-        self.state = State()
+        self._debug: bool = debug
+        self.state: State = State()
         self.router: routing.APIRouter = routing.APIRouter(
             routes=routes,
             dependency_overrides_provider=self,
@@ -74,7 +78,10 @@ class FastAPI(Starlette):
             include_in_schema=include_in_schema,
             responses=responses,
         )
-        self.exception_handlers = (
+        self.exception_handlers: Dict[
+            Union[int, Type[Exception]],
+            Callable[[Request, Any], Coroutine[Any, Any, Response]],
+        ] = (
             {} if exception_handlers is None else dict(exception_handlers)
         )
         self.exception_handlers.setdefault(HTTPException, http_exception_handler)
@@ -82,8 +89,10 @@ class FastAPI(Starlette):
             RequestValidationError, request_validation_exception_handler
         )
 
-        self.user_middleware = [] if middleware is None else list(middleware)
-        self.middleware_stack = self.build_middleware_stack()
+        self.user_middleware: List[Middleware] = (
+            [] if middleware is None else list(middleware)
+        )
+        self.middleware_stack: ASGIApp = self.build_middleware_stack()
 
         self.title = title
         self.description = description
@@ -106,7 +115,7 @@ class FastAPI(Starlette):
         self.swagger_ui_oauth2_redirect_url = swagger_ui_oauth2_redirect_url
         self.swagger_ui_init_oauth = swagger_ui_init_oauth
         self.extra = extra
-        self.dependency_overrides: Dict[Callable, Callable] = {}
+        self.dependency_overrides: Dict[Callable[..., Any], Callable[..., Any]] = {}
 
         self.openapi_version = "3.0.2"
 
@@ -116,7 +125,7 @@ class FastAPI(Starlette):
         self.openapi_schema: Optional[Dict[str, Any]] = None
         self.setup()
 
-    def openapi(self) -> Dict:
+    def openapi(self) -> Dict[str, Any]:
         if not self.openapi_schema:
             self.openapi_schema = get_openapi(
                 title=self.title,
@@ -194,7 +203,7 @@ class FastAPI(Starlette):
     def add_api_route(
         self,
         path: str,
-        endpoint: Callable,
+        endpoint: Callable[..., Coroutine[Any, Any, Response]],
         *,
         response_model: Optional[Type[Any]] = None,
         status_code: int = 200,
@@ -268,8 +277,8 @@ class FastAPI(Starlette):
         include_in_schema: bool = True,
         response_class: Type[Response] = Default(JSONResponse),
         name: Optional[str] = None,
-    ) -> Callable:
-        def decorator(func: Callable) -> Callable:
+    ) -> Callable[[DecoratedCallable], DecoratedCallable]:
+        def decorator(func: DecoratedCallable) -> DecoratedCallable:
             self.router.add_api_route(
                 path,
                 func,
@@ -299,12 +308,14 @@ class FastAPI(Starlette):
         return decorator
 
     def add_api_websocket_route(
-        self, path: str, endpoint: Callable, name: Optional[str] = None
+        self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None
     ) -> None:
         self.router.add_api_websocket_route(path, endpoint, name=name)
 
-    def websocket(self, path: str, name: Optional[str] = None) -> Callable:
-        def decorator(func: Callable) -> Callable:
+    def websocket(
+        self, path: str, name: Optional[str] = None
+    ) -> Callable[[DecoratedCallable], DecoratedCallable]:
+        def decorator(func: DecoratedCallable) -> DecoratedCallable:
             self.add_api_websocket_route(path, func, name=name)
             return func
 
@@ -318,10 +329,10 @@ class FastAPI(Starlette):
         tags: Optional[List[str]] = None,
         dependencies: Optional[Sequence[Depends]] = None,
         responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
-        deprecated: bool = None,
+        deprecated: Optional[bool] = None,
         include_in_schema: bool = True,
         default_response_class: Type[Response] = Default(JSONResponse),
-        callbacks: Optional[List[routing.APIRoute]] = None,
+        callbacks: Optional[List[BaseRoute]] = None,
     ) -> None:
         self.router.include_router(
             router,
@@ -358,8 +369,8 @@ class FastAPI(Starlette):
         include_in_schema: bool = True,
         response_class: Type[Response] = Default(JSONResponse),
         name: Optional[str] = None,
-        callbacks: Optional[List[routing.APIRoute]] = None,
-    ) -> Callable:
+        callbacks: Optional[List[BaseRoute]] = None,
+    ) -> Callable[[DecoratedCallable], DecoratedCallable]:
         return self.router.get(
             path,
             response_model=response_model,
@@ -407,8 +418,8 @@ class FastAPI(Starlette):
         include_in_schema: bool = True,
         response_class: Type[Response] = Default(JSONResponse),
         name: Optional[str] = None,
-        callbacks: Optional[List[routing.APIRoute]] = None,
-    ) -> Callable:
+        callbacks: Optional[List[BaseRoute]] = None,
+    ) -> Callable[[DecoratedCallable], DecoratedCallable]:
         return self.router.put(
             path,
             response_model=response_model,
@@ -456,8 +467,8 @@ class FastAPI(Starlette):
         include_in_schema: bool = True,
         response_class: Type[Response] = Default(JSONResponse),
         name: Optional[str] = None,
-        callbacks: Optional[List[routing.APIRoute]] = None,
-    ) -> Callable:
+        callbacks: Optional[List[BaseRoute]] = None,
+    ) -> Callable[[DecoratedCallable], DecoratedCallable]:
         return self.router.post(
             path,
             response_model=response_model,
@@ -505,8 +516,8 @@ class FastAPI(Starlette):
         include_in_schema: bool = True,
         response_class: Type[Response] = Default(JSONResponse),
         name: Optional[str] = None,
-        callbacks: Optional[List[routing.APIRoute]] = None,
-    ) -> Callable:
+        callbacks: Optional[List[BaseRoute]] = None,
+    ) -> Callable[[DecoratedCallable], DecoratedCallable]:
         return self.router.delete(
             path,
             response_model=response_model,
@@ -554,8 +565,8 @@ class FastAPI(Starlette):
         include_in_schema: bool = True,
         response_class: Type[Response] = Default(JSONResponse),
         name: Optional[str] = None,
-        callbacks: Optional[List[routing.APIRoute]] = None,
-    ) -> Callable:
+        callbacks: Optional[List[BaseRoute]] = None,
+    ) -> Callable[[DecoratedCallable], DecoratedCallable]:
         return self.router.options(
             path,
             response_model=response_model,
@@ -603,8 +614,8 @@ class FastAPI(Starlette):
         include_in_schema: bool = True,
         response_class: Type[Response] = Default(JSONResponse),
         name: Optional[str] = None,
-        callbacks: Optional[List[routing.APIRoute]] = None,
-    ) -> Callable:
+        callbacks: Optional[List[BaseRoute]] = None,
+    ) -> Callable[[DecoratedCallable], DecoratedCallable]:
         return self.router.head(
             path,
             response_model=response_model,
@@ -652,8 +663,8 @@ class FastAPI(Starlette):
         include_in_schema: bool = True,
         response_class: Type[Response] = Default(JSONResponse),
         name: Optional[str] = None,
-        callbacks: Optional[List[routing.APIRoute]] = None,
-    ) -> Callable:
+        callbacks: Optional[List[BaseRoute]] = None,
+    ) -> Callable[[DecoratedCallable], DecoratedCallable]:
         return self.router.patch(
             path,
             response_model=response_model,
@@ -701,8 +712,8 @@ class FastAPI(Starlette):
         include_in_schema: bool = True,
         response_class: Type[Response] = Default(JSONResponse),
         name: Optional[str] = None,
-        callbacks: Optional[List[routing.APIRoute]] = None,
-    ) -> Callable:
+        callbacks: Optional[List[BaseRoute]] = None,
+    ) -> Callable[[DecoratedCallable], DecoratedCallable]:
         return self.router.trace(
             path,
             response_model=response_model,
index 2d0d3d35e0c0ff9336fb139a519e673ee72171fe..dd3bbe249130348881331aea569ce3ec3f295128 100644 (file)
@@ -1 +1 @@
-from starlette.background import BackgroundTasks  # noqa
+from starlette.background import BackgroundTasks as BackgroundTasks  # noqa
index 451923c550c318f50dbd6586b4229f624f5412d8..d1fdfe5f606475879c7e8a264cba3a528a07f242 100644 (file)
@@ -1,8 +1,10 @@
 from typing import Any, Callable
 
-from starlette.concurrency import iterate_in_threadpool  # noqa
-from starlette.concurrency import run_in_threadpool  # noqa
-from starlette.concurrency import run_until_first_complete  # noqa
+from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool  # noqa
+from starlette.concurrency import run_in_threadpool as run_in_threadpool  # noqa
+from starlette.concurrency import (  # noqa
+    run_until_first_complete as run_until_first_complete,
+)
 
 asynccontextmanager_error_message = """
 FastAPI's contextmanager_in_threadpool require Python 3.7 or above,
@@ -11,7 +13,7 @@ or the backport for Python 3.6, installed with:
 """
 
 
-def _fake_asynccontextmanager(func: Callable) -> Callable:
+def _fake_asynccontextmanager(func: Callable[..., Any]) -> Callable[..., Any]:
     def raiser(*args: Any, **kwargs: Any) -> Any:
         raise RuntimeError(asynccontextmanager_error_message)
 
@@ -19,23 +21,25 @@ def _fake_asynccontextmanager(func: Callable) -> Callable:
 
 
 try:
-    from contextlib import asynccontextmanager  # type: ignore
+    from contextlib import asynccontextmanager as asynccontextmanager  # type: ignore
 except ImportError:
     try:
-        from async_generator import asynccontextmanager  # type: ignore
+        from async_generator import (  # type: ignore  # isort: skip
+            asynccontextmanager as asynccontextmanager,
+        )
     except ImportError:  # pragma: no cover
         asynccontextmanager = _fake_asynccontextmanager
 
 try:
-    from contextlib import AsyncExitStack  # type: ignore
+    from contextlib import AsyncExitStack as AsyncExitStack  # type: ignore
 except ImportError:
     try:
-        from async_exit_stack import AsyncExitStack  # type: ignore
+        from async_exit_stack import AsyncExitStack as AsyncExitStack  # type: ignore
     except ImportError:  # pragma: no cover
         AsyncExitStack = None  # type: ignore
 
 
-@asynccontextmanager
+@asynccontextmanager  # type: ignore
 async def contextmanager_in_threadpool(cm: Any) -> Any:
     try:
         yield await run_in_threadpool(cm.__enter__)
index 1fe8ebdadc271913146e98305de6a37ef3cfe453..f22409c5175b77757095d95b66c532dd86e340af 100644 (file)
@@ -1,11 +1,12 @@
 from typing import Any, Callable, Iterable, Type, TypeVar
 
+from starlette.datastructures import State as State  # noqa: F401
 from starlette.datastructures import UploadFile as StarletteUploadFile
 
 
 class UploadFile(StarletteUploadFile):
     @classmethod
-    def __get_validators__(cls: Type["UploadFile"]) -> Iterable[Callable]:
+    def __get_validators__(cls: Type["UploadFile"]) -> Iterable[Callable[..., Any]]:
         yield cls.validate
 
     @classmethod
index 4e2294bd796563b9c7c2e708bb3876d98aac9b01..443590b9c82b70e66b699c30c7c9d755d6818de7 100644 (file)
@@ -1,4 +1,4 @@
-from typing import Callable, List, Optional, Sequence
+from typing import Any, Callable, List, Optional, Sequence
 
 from fastapi.security.base import SecurityBase
 from pydantic.fields import ModelField
@@ -24,7 +24,7 @@ class Dependant:
         dependencies: Optional[List["Dependant"]] = None,
         security_schemes: Optional[List[SecurityRequirement]] = None,
         name: Optional[str] = None,
-        call: Optional[Callable] = None,
+        call: Optional[Callable[..., Any]] = None,
         request_param_name: Optional[str] = None,
         websocket_param_name: Optional[str] = None,
         http_connection_param_name: Optional[str] = None,
index 35329a46a5b0c9718938aaa5080c460e10182cad..fcfaa2cb19246080159d5ff2c85ca93c5c1f7561 100644 (file)
@@ -90,12 +90,12 @@ def check_file_field(field: ModelField) -> None:
     if isinstance(field_info, params.Form):
         try:
             # __version__ is available in both multiparts, and can be mocked
-            from multipart import __version__
+            from multipart import __version__  # type: ignore
 
             assert __version__
             try:
                 # parse_options_header is only available in the right multipart
-                from multipart.multipart import parse_options_header
+                from multipart.multipart import parse_options_header  # type: ignore
 
                 assert parse_options_header
             except ImportError:
@@ -133,7 +133,7 @@ def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> De
 def get_sub_dependant(
     *,
     depends: params.Depends,
-    dependency: Callable,
+    dependency: Callable[..., Any],
     path: str,
     name: Optional[str] = None,
     security_scopes: Optional[List[str]] = None,
@@ -163,7 +163,7 @@ def get_sub_dependant(
     return sub_dependant
 
 
-CacheKey = Tuple[Optional[Callable], Tuple[str, ...]]
+CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]]
 
 
 def get_flat_dependant(
@@ -240,7 +240,7 @@ def is_scalar_sequence_field(field: ModelField) -> bool:
     return False
 
 
-def get_typed_signature(call: Callable) -> inspect.Signature:
+def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
     signature = inspect.signature(call)
     globalns = getattr(call, "__globals__", {})
     typed_params = [
@@ -259,9 +259,7 @@ def get_typed_signature(call: Callable) -> inspect.Signature:
 def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) -> Any:
     annotation = param.annotation
     if isinstance(annotation, str):
-        # Temporary ignore type
-        # Ref: https://github.com/samuelcolvin/pydantic/issues/1738
-        annotation = ForwardRef(annotation)  # type: ignore
+        annotation = ForwardRef(annotation)
         annotation = evaluate_forwardref(annotation, globalns, globalns)
     return annotation
 
@@ -281,7 +279,7 @@ def check_dependency_contextmanagers() -> None:
 def get_dependant(
     *,
     path: str,
-    call: Callable,
+    call: Callable[..., Any],
     name: Optional[str] = None,
     security_scopes: Optional[List[str]] = None,
     use_cache: bool = True,
@@ -423,7 +421,7 @@ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
         dependant.cookie_params.append(field)
 
 
-def is_coroutine_callable(call: Callable) -> bool:
+def is_coroutine_callable(call: Callable[..., Any]) -> bool:
     if inspect.isroutine(call):
         return inspect.iscoroutinefunction(call)
     if inspect.isclass(call):
@@ -432,14 +430,14 @@ def is_coroutine_callable(call: Callable) -> bool:
     return inspect.iscoroutinefunction(call)
 
 
-def is_async_gen_callable(call: Callable) -> bool:
+def is_async_gen_callable(call: Callable[..., Any]) -> bool:
     if inspect.isasyncgenfunction(call):
         return True
     call = getattr(call, "__call__", None)
     return inspect.isasyncgenfunction(call)
 
 
-def is_gen_callable(call: Callable) -> bool:
+def is_gen_callable(call: Callable[..., Any]) -> bool:
     if inspect.isgeneratorfunction(call):
         return True
     call = getattr(call, "__call__", None)
@@ -447,7 +445,7 @@ def is_gen_callable(call: Callable) -> bool:
 
 
 async def solve_generator(
-    *, call: Callable, stack: AsyncExitStack, sub_values: Dict[str, Any]
+    *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
 ) -> Any:
     if is_gen_callable(call):
         cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
@@ -472,29 +470,29 @@ async def solve_dependencies(
     background_tasks: Optional[BackgroundTasks] = None,
     response: Optional[Response] = None,
     dependency_overrides_provider: Optional[Any] = None,
-    dependency_cache: Optional[Dict[Tuple[Callable, Tuple[str]], Any]] = None,
+    dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
 ) -> Tuple[
     Dict[str, Any],
     List[ErrorWrapper],
     Optional[BackgroundTasks],
     Response,
-    Dict[Tuple[Callable, Tuple[str]], Any],
+    Dict[Tuple[Callable[..., Any], Tuple[str]], Any],
 ]:
     values: Dict[str, Any] = {}
     errors: List[ErrorWrapper] = []
     response = response or Response(
         content=None,
         status_code=None,  # type: ignore
-        headers=None,
-        media_type=None,
-        background=None,
+        headers=None,  # type: ignore # in Starlette
+        media_type=None,  # type: ignore # in Starlette
+        background=None,  # type: ignore # in Starlette
     )
     dependency_cache = dependency_cache or {}
     sub_dependant: Dependant
     for sub_dependant in dependant.dependencies:
-        sub_dependant.call = cast(Callable, sub_dependant.call)
+        sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
         sub_dependant.cache_key = cast(
-            Tuple[Callable, Tuple[str]], sub_dependant.cache_key
+            Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
         )
         call = sub_dependant.call
         use_sub_dependant = sub_dependant
index 1255b74977e3be1a184a381c2e63449428aad50d..6a2a75dda629013e2f13ccd0a8b38e1dced77ea7 100644 (file)
@@ -12,9 +12,11 @@ DictIntStrAny = Dict[Union[int, str], Any]
 
 
 def generate_encoders_by_class_tuples(
-    type_encoder_map: Dict[Any, Callable]
-) -> Dict[Callable, Tuple]:
-    encoders_by_class_tuples: Dict[Callable, Tuple] = defaultdict(tuple)
+    type_encoder_map: Dict[Any, Callable[[Any], Any]]
+) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]:
+    encoders_by_class_tuples: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict(
+        tuple
+    )
     for type_, encoder in type_encoder_map.items():
         encoders_by_class_tuples[encoder] += (type_,)
     return encoders_by_class_tuples
@@ -31,7 +33,7 @@ def jsonable_encoder(
     exclude_unset: bool = False,
     exclude_defaults: bool = False,
     exclude_none: bool = False,
-    custom_encoder: dict = {},
+    custom_encoder: Dict[Any, Callable[[Any], Any]] = {},
     sqlalchemy_safe: bool = True,
 ) -> Any:
     if include is not None and not isinstance(include, set):
@@ -43,8 +45,8 @@ def jsonable_encoder(
         if custom_encoder:
             encoder.update(custom_encoder)
         obj_dict = obj.dict(
-            include=include,
-            exclude=exclude,
+            include=include,  # type: ignore # in Pydantic
+            exclude=exclude,  # type: ignore # in Pydantic
             by_alias=by_alias,
             exclude_unset=exclude_unset,
             exclude_none=exclude_none,
index 6601b1783803d3444466cc3ec1d1e3039d1a186e..620296d5ad6ca2cc49eb5d0dc140bcbc3204e9b4 100644 (file)
@@ -1 +1 @@
-from starlette.middleware import Middleware
+from starlette.middleware import Middleware as Middleware
index 4c08a161ae8775f718ce71f15735da1aa1451dd3..8dfaad0dbb3ff5300cccb2023748cd30f54bc920 100644 (file)
@@ -1 +1 @@
-from starlette.middleware.cors import CORSMiddleware  # noqa
+from starlette.middleware.cors import CORSMiddleware as CORSMiddleware  # noqa
index 08460d07ee46c694b1ad4768835ae35ed5dca42a..bbeb2cc7861a735d6cd5c0e29aeb6dbf8457023a 100644 (file)
@@ -1 +1 @@
-from starlette.middleware.gzip import GZipMiddleware  # noqa
+from starlette.middleware.gzip import GZipMiddleware as GZipMiddleware  # noqa
index 674263af3ff4abb2756093d64ad8a7732ab3eae4..b7a3d8e078574e87dc6e345d621f5a596c3bdc1e 100644 (file)
@@ -1 +1,3 @@
-from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware  # noqa
+from starlette.middleware.httpsredirect import (  # noqa
+    HTTPSRedirectMiddleware as HTTPSRedirectMiddleware,
+)
index b16aee8728c678c86f0d16c27980c37ba93f157c..08d7e035315677856fd2cd0be2044689b57619bf 100644 (file)
@@ -1 +1,3 @@
-from starlette.middleware.trustedhost import TrustedHostMiddleware  # noqa
+from starlette.middleware.trustedhost import (  # noqa
+    TrustedHostMiddleware as TrustedHostMiddleware,
+)
index bf8d3e66ec99d7a2fdb334ff9ac42c9cfeb9181b..c4c6a797d2675e1c13b028be977c64a822fb649b 100644 (file)
@@ -1 +1 @@
-from starlette.middleware.wsgi import WSGIMiddleware  # noqa
+from starlette.middleware.wsgi import WSGIMiddleware as WSGIMiddleware  # noqa
index 44c4e69a34f750ebf82ac231f3229700693958ff..fd22e4e8c167ddc6aa65d46942ff41dc6a2a3af8 100644 (file)
@@ -1,5 +1,5 @@
 import json
-from typing import Optional
+from typing import Any, Dict, Optional
 
 from fastapi.encoders import jsonable_encoder
 from starlette.responses import HTMLResponse
@@ -13,7 +13,7 @@ def get_swagger_ui_html(
     swagger_css_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css",
     swagger_favicon_url: str = "https://fastapi.tiangolo.com/img/favicon.png",
     oauth2_redirect_url: Optional[str] = None,
-    init_oauth: Optional[dict] = None,
+    init_oauth: Optional[Dict[str, Any]] = None,
 ) -> HTMLResponse:
 
     html = f"""
index 3b716766d055c862456dccaba4e12480958c55c7..fd480946dcf1932e41f42e113532e43a82215bf5 100644 (file)
@@ -5,7 +5,7 @@ from fastapi.logger import logger
 from pydantic import AnyUrl, BaseModel, Field
 
 try:
-    import email_validator
+    import email_validator  # type: ignore
 
     assert email_validator  # make autoflake ignore the unused import
     from pydantic import EmailStr
@@ -13,7 +13,7 @@ except ImportError:  # pragma: no cover
 
     class EmailStr(str):  # type: ignore
         @classmethod
-        def __get_validators__(cls) -> Iterable[Callable]:
+        def __get_validators__(cls) -> Iterable[Callable[..., Any]]:
             yield cls.validate
 
         @classmethod
index 5547cce4f78be3b3ce0d7b32f47db1f9dabccebf..410ba9389c5b673f596127f4a2730733d7864e1a 100644 (file)
@@ -14,6 +14,7 @@ from fastapi.openapi.constants import (
 )
 from fastapi.openapi.models import OpenAPI
 from fastapi.params import Body, Param
+from fastapi.responses import Response
 from fastapi.utils import (
     deep_dict_update,
     generate_operation_id_for_path,
@@ -64,7 +65,9 @@ status_code_ranges: Dict[str, str] = {
 }
 
 
-def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, List]:
+def get_openapi_security_definitions(
+    flat_dependant: Dependant,
+) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
     security_definitions = {}
     operation_security = []
     for security_requirement in flat_dependant.security_requirements:
@@ -88,13 +91,12 @@ def get_openapi_operation_parameters(
     for param in all_route_params:
         field_info = param.field_info
         field_info = cast(Param, field_info)
-        # ignore mypy error until enum schemas are released
         parameter = {
             "name": param.alias,
             "in": field_info.in_.value,
             "required": param.required,
             "schema": field_schema(
-                param, model_name_map=model_name_map, ref_prefix=REF_PREFIX  # type: ignore
+                param, model_name_map=model_name_map, ref_prefix=REF_PREFIX
             )[0],
         }
         if field_info.description:
@@ -109,13 +111,12 @@ def get_openapi_operation_request_body(
     *,
     body_field: Optional[ModelField],
     model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
-) -> Optional[Dict]:
+) -> Optional[Dict[str, Any]]:
     if not body_field:
         return None
     assert isinstance(body_field, ModelField)
-    # ignore mypy error until enum schemas are released
     body_schema, _, _ = field_schema(
-        body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX  # type: ignore
+        body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
     )
     field_info = cast(Body, body_field.field_info)
     request_media_type = field_info.media_type
@@ -140,7 +141,9 @@ def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str:
     return route.name.replace("_", " ").title()
 
 
-def get_openapi_operation_metadata(*, route: routing.APIRoute, method: str) -> Dict:
+def get_openapi_operation_metadata(
+    *, route: routing.APIRoute, method: str
+) -> Dict[str, Any]:
     operation: Dict[str, Any] = {}
     if route.tags:
         operation["tags"] = route.tags
@@ -154,14 +157,14 @@ def get_openapi_operation_metadata(*, route: routing.APIRoute, method: str) -> D
 
 
 def get_openapi_path(
-    *, route: routing.APIRoute, model_name_map: Dict[Type, str]
-) -> Tuple[Dict, Dict, Dict]:
+    *, route: routing.APIRoute, model_name_map: Dict[type, str]
+) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
     path = {}
     security_schemes: Dict[str, Any] = {}
     definitions: Dict[str, Any] = {}
     assert route.methods is not None, "Methods must be a list"
     if isinstance(route.response_class, DefaultPlaceholder):
-        current_response_class: Type[routing.Response] = route.response_class.value
+        current_response_class: Type[Response] = route.response_class.value
     else:
         current_response_class = route.response_class
     assert current_response_class, "A response class is needed to generate OpenAPI"
@@ -169,7 +172,7 @@ def get_openapi_path(
     if route.include_in_schema:
         for method in route.methods:
             operation = get_openapi_operation_metadata(route=route, method=method)
-            parameters: List[Dict] = []
+            parameters: List[Dict[str, Any]] = []
             flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
             security_definitions, operation_security = get_openapi_security_definitions(
                 flat_dependant=flat_dependant
@@ -196,10 +199,15 @@ def get_openapi_path(
             if route.callbacks:
                 callbacks = {}
                 for callback in route.callbacks:
-                    cb_path, cb_security_schemes, cb_definitions, = get_openapi_path(
-                        route=callback, model_name_map=model_name_map
-                    )
-                    callbacks[callback.name] = {callback.path: cb_path}
+                    if isinstance(callback, routing.APIRoute):
+                        (
+                            cb_path,
+                            cb_security_schemes,
+                            cb_definitions,
+                        ) = get_openapi_path(
+                            route=callback, model_name_map=model_name_map
+                        )
+                        callbacks[callback.name] = {callback.path: cb_path}
                 operation["callbacks"] = callbacks
             status_code = str(route.status_code)
             operation.setdefault("responses", {}).setdefault(status_code, {})[
@@ -332,21 +340,19 @@ def get_openapi(
     routes: Sequence[BaseRoute],
     tags: Optional[List[Dict[str, Any]]] = None,
     servers: Optional[List[Dict[str, Union[str, Any]]]] = None,
-) -> Dict:
+) -> Dict[str, Any]:
     info = {"title": title, "version": version}
     if description:
         info["description"] = description
     output: Dict[str, Any] = {"openapi": openapi_version, "info": info}
     if servers:
         output["servers"] = servers
-    components: Dict[str, Dict] = {}
-    paths: Dict[str, Dict] = {}
+    components: Dict[str, Dict[str, Any]] = {}
+    paths: Dict[str, Dict[str, Any]] = {}
     flat_models = get_flat_models_from_routes(routes)
-    # ignore mypy error until enum schemas are released
-    model_name_map = get_model_name_map(flat_models)  # type: ignore
-    # ignore mypy error until enum schemas are released
+    model_name_map = get_model_name_map(flat_models)
     definitions = get_model_definitions(
-        flat_models=flat_models, model_name_map=model_name_map  # type: ignore
+        flat_models=flat_models, model_name_map=model_name_map
     )
     for route in routes:
         if isinstance(route, routing.APIRoute):
@@ -368,4 +374,4 @@ def get_openapi(
     output["paths"] = paths
     if tags:
         output["tags"] = tags
-    return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True)
+    return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True)  # type: ignore
index 91620c7c0063fe73f69aa7741aef6e202d4e8d63..9ebb59100081ed2332c096a2e143bdad937369ca 100644 (file)
@@ -239,13 +239,13 @@ def File(  # noqa: N802
 
 
 def Depends(  # noqa: N802
-    dependency: Optional[Callable] = None, *, use_cache: bool = True
+    dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True
 ) -> Any:
     return params.Depends(dependency=dependency, use_cache=use_cache)
 
 
 def Security(  # noqa: N802
-    dependency: Optional[Callable] = None,
+    dependency: Optional[Callable[..., Any]] = None,
     *,
     scopes: Optional[Sequence[str]] = None,
     use_cache: bool = True,
index f53e2dba982c5bf3b76bb8752a471db1ec22a103..aa3269a8054c1e85656204a36cbc6fcf79389e80 100644 (file)
@@ -315,7 +315,7 @@ class File(Form):
 
 class Depends:
     def __init__(
-        self, dependency: Optional[Callable] = None, *, use_cache: bool = True
+        self, dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True
     ):
         self.dependency = dependency
         self.use_cache = use_cache
@@ -329,7 +329,7 @@ class Depends:
 class Security(Depends):
     def __init__(
         self,
-        dependency: Optional[Callable] = None,
+        dependency: Optional[Callable[..., Any]] = None,
         *,
         scopes: Optional[Sequence[str]] = None,
         use_cache: bool = True,
index 0aeff61d0f8c8abe21e84509ba92ff7e68a8811d..8d9d62dfb00683b511540d1fa1b2b28a4b685c10 100644 (file)
@@ -1,13 +1,13 @@
 from typing import Any
 
-from starlette.responses import FileResponse  # noqa
-from starlette.responses import HTMLResponse  # noqa
-from starlette.responses import JSONResponse  # noqa
-from starlette.responses import PlainTextResponse  # noqa
-from starlette.responses import RedirectResponse  # noqa
-from starlette.responses import Response  # noqa
-from starlette.responses import StreamingResponse  # noqa
-from starlette.responses import UJSONResponse  # noqa
+from starlette.responses import FileResponse as FileResponse  # noqa
+from starlette.responses import HTMLResponse as HTMLResponse  # noqa
+from starlette.responses import JSONResponse as JSONResponse  # noqa
+from starlette.responses import PlainTextResponse as PlainTextResponse  # noqa
+from starlette.responses import RedirectResponse as RedirectResponse  # noqa
+from starlette.responses import Response as Response  # noqa
+from starlette.responses import StreamingResponse as StreamingResponse  # noqa
+from starlette.responses import UJSONResponse as UJSONResponse  # noqa
 
 try:
     import orjson
index 53f35a4a5c055bd59295bc0e89eb4661809fe531..ac5e19d99835a7b7e07db0a8d87ac02e9794fb51 100644 (file)
@@ -2,7 +2,18 @@ import asyncio
 import enum
 import inspect
 import json
-from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Type, Union
+from typing import (
+    Any,
+    Callable,
+    Coroutine,
+    Dict,
+    List,
+    Optional,
+    Sequence,
+    Set,
+    Type,
+    Union,
+)
 
 from fastapi import params
 from fastapi.datastructures import Default, DefaultPlaceholder
@@ -16,6 +27,7 @@ from fastapi.dependencies.utils import (
 from fastapi.encoders import DictIntStrAny, SetIntStr, jsonable_encoder
 from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
 from fastapi.openapi.constants import STATUS_CODES_WITH_NO_BODY
+from fastapi.types import DecoratedCallable
 from fastapi.utils import (
     create_cloned_field,
     create_response_field,
@@ -30,7 +42,8 @@ from starlette.concurrency import run_in_threadpool
 from starlette.exceptions import HTTPException
 from starlette.requests import Request
 from starlette.responses import JSONResponse, Response
-from starlette.routing import Mount  # noqa
+from starlette.routing import BaseRoute
+from starlette.routing import Mount as Mount  # noqa
 from starlette.routing import (
     compile_path,
     get_name,
@@ -150,7 +163,7 @@ def get_request_handler(
     response_model_exclude_defaults: bool = False,
     response_model_exclude_none: bool = False,
     dependency_overrides_provider: Optional[Any] = None,
-) -> Callable:
+) -> Callable[[Request], Coroutine[Any, Any, Response]]:
     assert dependant.call is not None, "dependant.call must be a function"
     is_coroutine = asyncio.iscoroutinefunction(dependant.call)
     is_body_form = body_field and isinstance(body_field.field_info, params.Form)
@@ -207,7 +220,7 @@ def get_request_handler(
             response = actual_response_class(
                 content=response_data,
                 status_code=status_code,
-                background=background_tasks,
+                background=background_tasks,  # type: ignore # in Starlette
             )
             response.headers.raw.extend(sub_response.headers.raw)
             if sub_response.status_code:
@@ -219,7 +232,7 @@ def get_request_handler(
 
 def get_websocket_app(
     dependant: Dependant, dependency_overrides_provider: Optional[Any] = None
-) -> Callable:
+) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
     async def app(websocket: WebSocket) -> None:
         solved_result = await solve_dependencies(
             request=websocket,
@@ -240,7 +253,7 @@ class APIWebSocketRoute(routing.WebSocketRoute):
     def __init__(
         self,
         path: str,
-        endpoint: Callable,
+        endpoint: Callable[..., Any],
         *,
         name: Optional[str] = None,
         dependency_overrides_provider: Optional[Any] = None,
@@ -262,7 +275,7 @@ class APIRoute(routing.Route):
     def __init__(
         self,
         path: str,
-        endpoint: Callable,
+        endpoint: Callable[..., Any],
         *,
         response_model: Optional[Type[Any]] = None,
         status_code: int = 200,
@@ -287,7 +300,7 @@ class APIRoute(routing.Route):
             JSONResponse
         ),
         dependency_overrides_provider: Optional[Any] = None,
-        callbacks: Optional[List["APIRoute"]] = None,
+        callbacks: Optional[List[BaseRoute]] = None,
     ) -> None:
         # normalise enums e.g. http.HTTPStatus
         if isinstance(status_code, enum.IntEnum):
@@ -298,7 +311,7 @@ class APIRoute(routing.Route):
         self.path_regex, self.path_format, self.param_convertors = compile_path(path)
         if methods is None:
             methods = ["GET"]
-        self.methods = set([method.upper() for method in methods])
+        self.methods: Set[str] = set([method.upper() for method in methods])
         self.unique_id = generate_operation_id_for_path(
             name=self.name, path=self.path_format, method=list(methods)[0]
         )
@@ -375,7 +388,7 @@ class APIRoute(routing.Route):
         self.callbacks = callbacks
         self.app = request_response(self.get_route_handler())
 
-    def get_route_handler(self) -> Callable:
+    def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
         return get_request_handler(
             dependant=self.dependant,
             body_field=self.body_field,
@@ -401,23 +414,23 @@ class APIRouter(routing.Router):
         dependencies: Optional[Sequence[params.Depends]] = None,
         default_response_class: Type[Response] = Default(JSONResponse),
         responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
-        callbacks: Optional[List[APIRoute]] = None,
+        callbacks: Optional[List[BaseRoute]] = None,
         routes: Optional[List[routing.BaseRoute]] = None,
         redirect_slashes: bool = True,
         default: Optional[ASGIApp] = None,
         dependency_overrides_provider: Optional[Any] = None,
         route_class: Type[APIRoute] = APIRoute,
-        on_startup: Optional[Sequence[Callable]] = None,
-        on_shutdown: Optional[Sequence[Callable]] = None,
-        deprecated: bool = None,
+        on_startup: Optional[Sequence[Callable[[], Any]]] = None,
+        on_shutdown: Optional[Sequence[Callable[[], Any]]] = None,
+        deprecated: Optional[bool] = None,
         include_in_schema: bool = True,
     ) -> None:
         super().__init__(
-            routes=routes,
+            routes=routes,  # type: ignore # in Starlette
             redirect_slashes=redirect_slashes,
-            default=default,
-            on_startup=on_startup,
-            on_shutdown=on_shutdown,
+            default=default,  # type: ignore # in Starlette
+            on_startup=on_startup,  # type: ignore # in Starlette
+            on_shutdown=on_shutdown,  # type: ignore # in Starlette
         )
         if prefix:
             assert prefix.startswith("/"), "A path prefix must start with '/'"
@@ -438,7 +451,7 @@ class APIRouter(routing.Router):
     def add_api_route(
         self,
         path: str,
-        endpoint: Callable,
+        endpoint: Callable[..., Any],
         *,
         response_model: Optional[Type[Any]] = None,
         status_code: int = 200,
@@ -463,7 +476,7 @@ class APIRouter(routing.Router):
         ),
         name: Optional[str] = None,
         route_class_override: Optional[Type[APIRoute]] = None,
-        callbacks: Optional[List[APIRoute]] = None,
+        callbacks: Optional[List[BaseRoute]] = None,
     ) -> None:
         route_class = route_class_override or self.route_class
         responses = responses or {}
@@ -532,9 +545,9 @@ class APIRouter(routing.Router):
         include_in_schema: bool = True,
         response_class: Type[Response] = Default(JSONResponse),
         name: Optional[str] = None,
-        callbacks: Optional[List[APIRoute]] = None,
-    ) -> Callable:
-        def decorator(func: Callable) -> Callable:
+        callbacks: Optional[List[BaseRoute]] = None,
+    ) -> Callable[[DecoratedCallable], DecoratedCallable]:
+        def decorator(func: DecoratedCallable) -> DecoratedCallable:
             self.add_api_route(
                 path,
                 func,
@@ -565,7 +578,7 @@ class APIRouter(routing.Router):
         return decorator
 
     def add_api_websocket_route(
-        self, path: str, endpoint: Callable, name: Optional[str] = None
+        self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None
     ) -> None:
         route = APIWebSocketRoute(
             path,
@@ -575,8 +588,10 @@ class APIRouter(routing.Router):
         )
         self.routes.append(route)
 
-    def websocket(self, path: str, name: Optional[str] = None) -> Callable:
-        def decorator(func: Callable) -> Callable:
+    def websocket(
+        self, path: str, name: Optional[str] = None
+    ) -> Callable[[DecoratedCallable], DecoratedCallable]:
+        def decorator(func: DecoratedCallable) -> DecoratedCallable:
             self.add_api_websocket_route(path, func, name=name)
             return func
 
@@ -591,8 +606,8 @@ class APIRouter(routing.Router):
         dependencies: Optional[Sequence[params.Depends]] = None,
         default_response_class: Type[Response] = Default(JSONResponse),
         responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
-        callbacks: Optional[List[APIRoute]] = None,
-        deprecated: bool = None,
+        callbacks: Optional[List[BaseRoute]] = None,
+        deprecated: Optional[bool] = None,
         include_in_schema: bool = True,
     ) -> None:
         if prefix:
@@ -663,10 +678,11 @@ class APIRouter(routing.Router):
                     callbacks=current_callbacks,
                 )
             elif isinstance(route, routing.Route):
+                methods = list(route.methods or [])  # type: ignore # in Starlette
                 self.add_route(
                     prefix + route.path,
                     route.endpoint,
-                    methods=list(route.methods or []),
+                    methods=methods,
                     include_in_schema=route.include_in_schema,
                     name=route.name,
                 )
@@ -706,8 +722,8 @@ class APIRouter(routing.Router):
         include_in_schema: bool = True,
         response_class: Type[Response] = Default(JSONResponse),
         name: Optional[str] = None,
-        callbacks: Optional[List[APIRoute]] = None,
-    ) -> Callable:
+        callbacks: Optional[List[BaseRoute]] = None,
+    ) -> Callable[[DecoratedCallable], DecoratedCallable]:
         return self.api_route(
             path=path,
             response_model=response_model,
@@ -756,8 +772,8 @@ class APIRouter(routing.Router):
         include_in_schema: bool = True,
         response_class: Type[Response] = Default(JSONResponse),
         name: Optional[str] = None,
-        callbacks: Optional[List[APIRoute]] = None,
-    ) -> Callable:
+        callbacks: Optional[List[BaseRoute]] = None,
+    ) -> Callable[[DecoratedCallable], DecoratedCallable]:
         return self.api_route(
             path=path,
             response_model=response_model,
@@ -806,8 +822,8 @@ class APIRouter(routing.Router):
         include_in_schema: bool = True,
         response_class: Type[Response] = Default(JSONResponse),
         name: Optional[str] = None,
-        callbacks: Optional[List[APIRoute]] = None,
-    ) -> Callable:
+        callbacks: Optional[List[BaseRoute]] = None,
+    ) -> Callable[[DecoratedCallable], DecoratedCallable]:
         return self.api_route(
             path=path,
             response_model=response_model,
@@ -856,8 +872,8 @@ class APIRouter(routing.Router):
         include_in_schema: bool = True,
         response_class: Type[Response] = Default(JSONResponse),
         name: Optional[str] = None,
-        callbacks: Optional[List[APIRoute]] = None,
-    ) -> Callable:
+        callbacks: Optional[List[BaseRoute]] = None,
+    ) -> Callable[[DecoratedCallable], DecoratedCallable]:
         return self.api_route(
             path=path,
             response_model=response_model,
@@ -906,8 +922,8 @@ class APIRouter(routing.Router):
         include_in_schema: bool = True,
         response_class: Type[Response] = Default(JSONResponse),
         name: Optional[str] = None,
-        callbacks: Optional[List[APIRoute]] = None,
-    ) -> Callable:
+        callbacks: Optional[List[BaseRoute]] = None,
+    ) -> Callable[[DecoratedCallable], DecoratedCallable]:
         return self.api_route(
             path=path,
             response_model=response_model,
@@ -956,8 +972,8 @@ class APIRouter(routing.Router):
         include_in_schema: bool = True,
         response_class: Type[Response] = Default(JSONResponse),
         name: Optional[str] = None,
-        callbacks: Optional[List[APIRoute]] = None,
-    ) -> Callable:
+        callbacks: Optional[List[BaseRoute]] = None,
+    ) -> Callable[[DecoratedCallable], DecoratedCallable]:
         return self.api_route(
             path=path,
             response_model=response_model,
@@ -1006,8 +1022,8 @@ class APIRouter(routing.Router):
         include_in_schema: bool = True,
         response_class: Type[Response] = Default(JSONResponse),
         name: Optional[str] = None,
-        callbacks: Optional[List[APIRoute]] = None,
-    ) -> Callable:
+        callbacks: Optional[List[BaseRoute]] = None,
+    ) -> Callable[[DecoratedCallable], DecoratedCallable]:
         return self.api_route(
             path=path,
             response_model=response_model,
@@ -1056,8 +1072,8 @@ class APIRouter(routing.Router):
         include_in_schema: bool = True,
         response_class: Type[Response] = Default(JSONResponse),
         name: Optional[str] = None,
-        callbacks: Optional[List[APIRoute]] = None,
-    ) -> Callable:
+        callbacks: Optional[List[BaseRoute]] = None,
+    ) -> Callable[[DecoratedCallable], DecoratedCallable]:
 
         return self.api_route(
             path=path,
index ad727742c3f4f4e843c4ced8b4da735f79aed9fe..3aa6bf21e44f3069adb94242fbba5c8160532a1c 100644 (file)
@@ -1,17 +1,15 @@
-from .api_key import APIKeyCookie, APIKeyHeader, APIKeyQuery
-from .http import (
-    HTTPAuthorizationCredentials,
-    HTTPBasic,
-    HTTPBasicCredentials,
-    HTTPBearer,
-    HTTPDigest,
-)
-from .oauth2 import (
-    OAuth2,
-    OAuth2AuthorizationCodeBearer,
-    OAuth2PasswordBearer,
-    OAuth2PasswordRequestForm,
-    OAuth2PasswordRequestFormStrict,
-    SecurityScopes,
-)
-from .open_id_connect_url import OpenIdConnect
+from .api_key import APIKeyCookie as APIKeyCookie
+from .api_key import APIKeyHeader as APIKeyHeader
+from .api_key import APIKeyQuery as APIKeyQuery
+from .http import HTTPAuthorizationCredentials as HTTPAuthorizationCredentials
+from .http import HTTPBasic as HTTPBasic
+from .http import HTTPBasicCredentials as HTTPBasicCredentials
+from .http import HTTPBearer as HTTPBearer
+from .http import HTTPDigest as HTTPDigest
+from .oauth2 import OAuth2 as OAuth2
+from .oauth2 import OAuth2AuthorizationCodeBearer as OAuth2AuthorizationCodeBearer
+from .oauth2 import OAuth2PasswordBearer as OAuth2PasswordBearer
+from .oauth2 import OAuth2PasswordRequestForm as OAuth2PasswordRequestForm
+from .oauth2 import OAuth2PasswordRequestFormStrict as OAuth2PasswordRequestFormStrict
+from .oauth2 import SecurityScopes as SecurityScopes
+from .open_id_connect_url import OpenIdConnect as OpenIdConnect
index 0d1a5f12fcd4c6c57751f45e9a72f1a461882985..46571ad53762caca300a49f808907c56a2a741ea 100644 (file)
@@ -1,4 +1,4 @@
-from typing import List, Optional
+from typing import Any, Dict, List, Optional, Union
 
 from fastapi.exceptions import HTTPException
 from fastapi.openapi.models import OAuth2 as OAuth2Model
@@ -116,7 +116,7 @@ class OAuth2(SecurityBase):
     def __init__(
         self,
         *,
-        flows: OAuthFlowsModel = OAuthFlowsModel(),
+        flows: Union[OAuthFlowsModel, Dict[str, Dict[str, Any]]] = OAuthFlowsModel(),
         scheme_name: Optional[str] = None,
         auto_error: Optional[bool] = True
     ):
@@ -141,7 +141,7 @@ class OAuth2PasswordBearer(OAuth2):
         self,
         tokenUrl: str,
         scheme_name: Optional[str] = None,
-        scopes: Optional[dict] = None,
+        scopes: Optional[Dict[str, str]] = None,
         auto_error: bool = True,
     ):
         if not scopes:
@@ -171,7 +171,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
         tokenUrl: str,
         refreshUrl: Optional[str] = None,
         scheme_name: Optional[str] = None,
-        scopes: Optional[dict] = None,
+        scopes: Optional[Dict[str, str]] = None,
         auto_error: bool = True,
     ):
         if not scopes:
index 78359dd1ea0830c2decd63f0bab7914057faf0b1..299015d4fef268cde91273790251f35192e1c8a6 100644 (file)
@@ -1 +1 @@
-from starlette.staticfiles import StaticFiles  # noqa
+from starlette.staticfiles import StaticFiles as StaticFiles  # noqa
index d4c035cf89c634deea9228c5fd5d201e0c6ab0e4..0cb868486edd9dda38f90c65f314597813128cf8 100644 (file)
@@ -1 +1 @@
-from starlette.templating import Jinja2Templates  # noqa
+from starlette.templating import Jinja2Templates as Jinja2Templates  # noqa
index 0288f694cd87cb106bc8694bd322d5af989a4ba8..4012406aa76f743c5c5d1ab8ff56d6d67cfb6653 100644 (file)
@@ -1 +1 @@
-from starlette.testclient import TestClient  # noqa
+from starlette.testclient import TestClient as TestClient  # noqa
diff --git a/fastapi/types.py b/fastapi/types.py
new file mode 100644 (file)
index 0000000..e0bca46
--- /dev/null
@@ -0,0 +1,3 @@
+from typing import Any, Callable, TypeVar
+
+DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any])
index 058956e329d4a979d745e747b992274bfb3272c3..8913d85b2dc40cec52ea02e502c7b085b0172cb6 100644 (file)
@@ -19,11 +19,10 @@ def get_model_definitions(
     flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
     model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
 ) -> Dict[str, Any]:
-    definitions: Dict[str, Dict] = {}
+    definitions: Dict[str, Dict[str, Any]] = {}
     for model in flat_models:
-        # ignore mypy error until enum schemas are released
         m_schema, m_definitions, m_nested_models = model_process_schema(
-            model, model_name_map=model_name_map, ref_prefix=REF_PREFIX  # type: ignore
+            model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
         )
         definitions.update(m_definitions)
         model_name = model_name_map[model]
@@ -80,7 +79,7 @@ def create_cloned_field(
         cloned_types = dict()
     original_type = field.type_
     if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
-        original_type = original_type.__pydantic_model__  # type: ignore
+        original_type = original_type.__pydantic_model__
     use_type = original_type
     if lenient_issubclass(original_type, BaseModel):
         original_type = cast(Type[BaseModel], original_type)
@@ -127,7 +126,7 @@ def generate_operation_id_for_path(*, name: str, path: str, method: str) -> str:
     return operation_id
 
 
-def deep_dict_update(main_dict: dict, update_dict: dict) -> None:
+def deep_dict_update(main_dict: Dict[Any, Any], update_dict: Dict[Any, Any]) -> None:
     for key in update_dict:
         if (
             key in main_dict
index 2edf97328943d6df756c70f0485dcabdb1acca96..bed672acff5f25b6694a04edc639f3d08adf3553 100644 (file)
@@ -1,2 +1,2 @@
-from starlette.websockets import WebSocket  # noqa
-from starlette.websockets import WebSocketDisconnect  # noqa
+from starlette.websockets import WebSocket as WebSocket  # noqa
+from starlette.websockets import WebSocketDisconnect as WebSocketDisconnect  # noqa
index 4ff4483ab48c7c6ee607b8350593ac14d2ebbd98..e6a33cffbf9984edee9c68521882fe0bf263dbd5 100644 (file)
--- a/mypy.ini
+++ b/mypy.ini
@@ -1,3 +1,25 @@
 [mypy]
+
+# --strict
+disallow_any_generics = True
+disallow_subclassing_any = True 
+disallow_untyped_calls = True 
 disallow_untyped_defs = True
+disallow_incomplete_defs = True 
+check_untyped_defs = True 
+disallow_untyped_decorators = True 
+no_implicit_optional = True
+warn_redundant_casts = True 
+warn_unused_ignores = True
+warn_return_any = True 
+implicit_reexport = False
+strict_equality = True
+# --strict end
+
+[mypy-fastapi.concurrency]
+warn_unused_ignores = False
 ignore_missing_imports = True
+
+[mypy-fastapi.tests.*]
+ignore_missing_imports = True
+check_untyped_defs = True 
index c17f63e84add8fcfb12ce65fd189d4e5a26fb4a7..3dc6b6f83007ddfc267e8625bf0bef4e934e82b3 100644 (file)
@@ -46,9 +46,9 @@ test = [
     "pytest ==5.4.3",
     "pytest-cov ==2.10.0",
     "pytest-asyncio >=0.14.0,<0.15.0",
-    "mypy ==0.782",
+    "mypy ==0.790",
     "flake8 >=3.8.3,<4.0.0",
-    "black ==19.10b0",
+    "black ==20.8b1",
     "isort >=5.0.6,<6.0.0",
     "requests >=2.24.0,<3.0.0",
     "httpx >=0.14.0,<0.15.0",
index afca4732fff93efe078b66ec9729f45c7c79a659..1a9ea7199ad5431c44937900001e2691d4234aef 100644 (file)
@@ -2,6 +2,7 @@ import pytest
 from fastapi import APIRouter, FastAPI
 from fastapi.routing import APIRoute
 from fastapi.testclient import TestClient
+from starlette.routing import Route
 
 app = FastAPI()
 
@@ -106,9 +107,9 @@ def test_get_path(path, expected_status, expected_response):
 
 def test_route_classes():
     routes = {}
-    r: APIRoute
     for r in app.router.routes:
+        assert isinstance(r, Route)
         routes[r.path] = r
-    assert routes["/a/"].x_type == "A"
-    assert routes["/a/b/"].x_type == "B"
-    assert routes["/a/b/c/"].x_type == "C"
+    assert getattr(routes["/a/"], "x_type") == "A"
+    assert getattr(routes["/a/b/"], "x_type") == "B"
+    assert getattr(routes["/a/b/c/"], "x_type") == "C"
index 348aee5f9356e8a3bcf42fcc594609fcf2a325e9..b12f499ebcf6ac0d8837ac4486972ee5c11364ef 100644 (file)
@@ -7,7 +7,7 @@ app = FastAPI()
 
 class Product(BaseModel):
     name: str
-    description: str = None
+    description: str = None  # type: ignore
     price: float
 
 
index ecfa0b2fad3f3517995a04c762537a1dd943a476..c46cb6701453d9bd7a189f097949d9e7a7ff3085 100644 (file)
@@ -175,7 +175,7 @@ async def path3_override_router2_override(level3: str):
     return level3
 
 
-@router2_override.get("/default3",)
+@router2_override.get("/default3")
 async def path3_default_router2_override(level3: str):
     return level3
 
@@ -217,7 +217,9 @@ async def path5_override_router4_override(level5: str):
     return level5
 
 
-@router4_override.get("/default5",)
+@router4_override.get(
+    "/default5",
+)
 async def path5_default_router4_override(level5: str):
     return level5
 
@@ -238,7 +240,9 @@ async def path5_override_router4_default(level5: str):
     return level5
 
 
-@router4_default.get("/default5",)
+@router4_default.get(
+    "/default5",
+)
 async def path5_default_router4_default(level5: str):
     return level5
 
index 1ed5bf1b9d9f0f56f85761f28ea8b1301d5e05b5..bac7eec1b08869ed48dc5bf8147884e8aa0ee453 100644 (file)
@@ -15,7 +15,7 @@ class MyUuid:
     def __str__(self):
         return self.uuid
 
-    @property
+    @property  # type: ignore
     def __class__(self):
         return uuid.UUID
 
index 87b2466e89579d68e24437c5ff932b4c37eebe3b..e2aa8adf8448a6d5f331c027e4afd278d912b7c6 100644 (file)
@@ -71,7 +71,7 @@ class ModelWithAlias(BaseModel):
 
 
 class ModelWithDefault(BaseModel):
-    foo: str = ...
+    foo: str = ...  # type: ignore
     bar: str = "bar"
     bla: str = "bla"
 
@@ -88,7 +88,7 @@ def fixture_model_with_path(request):
         arbitrary_types_allowed = True
 
     ModelWithPath = create_model(
-        "ModelWithPath", path=(request.param, ...), __config__=Config
+        "ModelWithPath", path=(request.param, ...), __config__=Config  # type: ignore
     )
     return ModelWithPath(path=request.param("/foo", "bar"))
 
index 0ef7770309a19c05facef5d219e729d491ab38a4..5f102edf1ae8dbe874e143f40ecae858d01f5cd9 100644 (file)
@@ -5,9 +5,9 @@ from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
 
 def test_strings_in_generated_swagger():
     sig = inspect.signature(get_swagger_ui_html)
-    swagger_js_url = sig.parameters.get("swagger_js_url").default
-    swagger_css_url = sig.parameters.get("swagger_css_url").default
-    swagger_favicon_url = sig.parameters.get("swagger_favicon_url").default
+    swagger_js_url = sig.parameters.get("swagger_js_url").default  # type: ignore
+    swagger_css_url = sig.parameters.get("swagger_css_url").default  # type: ignore
+    swagger_favicon_url = sig.parameters.get("swagger_favicon_url").default  # type: ignore
     html = get_swagger_ui_html(openapi_url="/docs", title="title")
     body_content = html.body.decode()
     assert swagger_js_url in body_content
@@ -34,8 +34,8 @@ def test_strings_in_custom_swagger():
 
 def test_strings_in_generated_redoc():
     sig = inspect.signature(get_redoc_html)
-    redoc_js_url = sig.parameters.get("redoc_js_url").default
-    redoc_favicon_url = sig.parameters.get("redoc_favicon_url").default
+    redoc_js_url = sig.parameters.get("redoc_js_url").default  # type: ignore
+    redoc_favicon_url = sig.parameters.get("redoc_favicon_url").default  # type: ignore
     html = get_redoc_html(openapi_url="/docs", title="title")
     body_content = html.body.decode()
     assert redoc_js_url in body_content
index 4719f0b27f3a3161c35d14dd03e0face72daa1fd..c1be82806ebd1329bdf6b5778232ad39b653b88a 100644 (file)
@@ -10,7 +10,7 @@ app = FastAPI()
 
 class Item(BaseModel):
     name: str
-    age: condecimal(gt=Decimal(0.0))
+    age: condecimal(gt=Decimal(0.0))  # type: ignore
 
 
 @app.post("/items/")
index c2a9096d47bff0ce27baa5a940a16ac222b5b997..f5767ec96cab892f9edcb95c38cf8d00f99b87cc 100644 (file)
@@ -8,7 +8,7 @@ app = FastAPI()
 
 
 @app.get("/items/")
-def read_items(q: Optional[str] = Param(None)):
+def read_items(q: Optional[str] = Param(None)):  # type: ignore
     return {"q": q}
 
 
index e21772acabfaa9df7dbabc536ceb4e663942df49..d721257d76251ec1c64be4d3e5778c2e8a9b818e 100644 (file)
@@ -1,7 +1,9 @@
+from typing import Any, List
+
 import pytest
 from fastapi.params import Body, Cookie, Depends, Header, Param, Path, Query
 
-test_data = ["teststr", None, ..., 1, []]
+test_data: List[Any] = ["teststr", None, ..., 1, []]
 
 
 def get_user():
index 1ea22116c8702b79004b698e298497dfcd4aec1f..2320c7005f2cb38de8e1a3da3da3c7da5ad15610 100644 (file)
@@ -27,7 +27,7 @@ def test_route_converters_int():
     response = client.get("/int/5")
     assert response.status_code == 200, response.text
     assert response.json() == {"int": 5}
-    assert app.url_path_for("int_convertor", param=5) == "/int/5"
+    assert app.url_path_for("int_convertor", param=5) == "/int/5"  # type: ignore
 
 
 def test_route_converters_float():
@@ -35,7 +35,7 @@ def test_route_converters_float():
     response = client.get("/float/25.5")
     assert response.status_code == 200, response.text
     assert response.json() == {"float": 25.5}
-    assert app.url_path_for("float_convertor", param=25.5) == "/float/25.5"
+    assert app.url_path_for("float_convertor", param=25.5) == "/float/25.5"  # type: ignore
 
 
 def test_route_converters_path():
index 40ca1475d15758469a9ff6e01247e3cd0300f0c2..16644b5569d1324e486658d9a892aff778cf2be1 100644 (file)
@@ -27,7 +27,7 @@ invoices_callback_router = APIRouter()
 
 
 @invoices_callback_router.post(
-    "{$callback_url}/invoices/{$request.body.id}", response_model=InvoiceEventReceived,
+    "{$callback_url}/invoices/{$request.body.id}", response_model=InvoiceEventReceived
 )
 def invoice_notification(body: InvoiceEvent):
     pass  # pragma: nocover