]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Use mypy `strict` (#2180)
authorViicos <65306057+Viicos@users.noreply.github.com>
Sun, 23 Jul 2023 21:41:50 +0000 (23:41 +0200)
committerGitHub <noreply@github.com>
Sun, 23 Jul 2023 21:41:50 +0000 (15:41 -0600)
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
31 files changed:
pyproject.toml
scripts/check
starlette/_exception_handler.py
starlette/_utils.py
starlette/applications.py
starlette/authentication.py
starlette/concurrency.py
starlette/config.py
starlette/convertors.py
starlette/datastructures.py
starlette/endpoints.py
starlette/exceptions.py
starlette/formparsers.py
starlette/middleware/__init__.py
starlette/middleware/errors.py
starlette/middleware/exceptions.py
starlette/middleware/wsgi.py
starlette/requests.py
starlette/responses.py
starlette/routing.py
starlette/schemas.py
starlette/staticfiles.py
starlette/templating.py
starlette/testclient.py
starlette/types.py
starlette/websockets.py
tests/test_convertors.py
tests/test_formparsers.py
tests/test_requests.py
tests/test_responses.py
tests/test_routing.py

index f17ffb09d2ad5ad9bb4d1cd316248d638782e93e..cb876b52e25b6d7355d84577a6a4812c64590d70 100644 (file)
@@ -58,18 +58,21 @@ select = ["E", "F", "I"]
 combine-as-imports = true
 
 [tool.mypy]
-disallow_untyped_defs = true
+strict = true
 ignore_missing_imports = true
-show_error_codes = true
+python_version = "3.8"
 
 [[tool.mypy.overrides]]
 module = "starlette.testclient.*"
-no_implicit_optional = false
+implicit_optional = true
 
-[[tool.mypy.overrides]]
-module = "tests.*"
-disallow_untyped_defs = false
-check_untyped_defs = true
+# TODO: Uncomment the following configuration when
+# https://github.com/python/mypy/issues/10045 is solved. In the meantime,
+# we are calling `mypy tests` directly. Check `scripts/check` for more info.
+# [[tool.mypy.overrides]]
+# module = "tests.*"
+# disallow_untyped_defs = false
+# check_untyped_defs = true
 
 [tool.pytest.ini_options]
 addopts = "-rxXs --strict-config --strict-markers"
index 076ede9eb8d20e7fc6b4a0a28bfca5415df68a6f..cc515ddaf6837e29502210fca553ed65d0e34fb6 100755 (executable)
@@ -10,5 +10,8 @@ set -x
 
 ./scripts/sync-version
 ${PREFIX}black --check --diff $SOURCE_FILES
-${PREFIX}mypy $SOURCE_FILES
+# TODO: Use `[[tool.mypy.overrides]]` on the `pyproject.toml` when the mypy issue is solved:
+# github.com/python/mypy/issues/10045. Check github.com/encode/starlette/pull/2180 for more info.
+${PREFIX}mypy starlette
+${PREFIX}mypy tests --disable-error-code no-untyped-def --disable-error-code no-untyped-call
 ${PREFIX}ruff check $SOURCE_FILES
index 8a9beb3b29c30eabd5c494e34f16a6ea2e7f2734..ea9ffbe9dae18fdbbb7a74aaf8518c25bd95b496 100644 (file)
@@ -4,18 +4,16 @@ from starlette._utils import is_async_callable
 from starlette.concurrency import run_in_threadpool
 from starlette.exceptions import HTTPException
 from starlette.requests import Request
-from starlette.responses import Response
-from starlette.types import ASGIApp, Message, Receive, Scope, Send
+from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
 from starlette.websockets import WebSocket
 
-Handler = typing.Callable[..., typing.Any]
-ExceptionHandlers = typing.Dict[typing.Any, Handler]
-StatusHandlers = typing.Dict[int, Handler]
+ExceptionHandlers = typing.Dict[typing.Any, ExceptionHandler]
+StatusHandlers = typing.Dict[int, ExceptionHandler]
 
 
 def _lookup_exception_handler(
     exc_handlers: ExceptionHandlers, exc: Exception
-) -> typing.Optional[Handler]:
+) -> typing.Optional[ExceptionHandler]:
     for cls in type(exc).__mro__:
         if cls in exc_handlers:
             return exc_handlers[cls]
@@ -61,7 +59,6 @@ def wrap_app_handling_exceptions(
                 raise RuntimeError(msg) from exc
 
             if scope["type"] == "http":
-                response: Response
                 if is_async_callable(handler):
                     response = await handler(conn, exc)
                 else:
index 5a6e6965b9a628e471c3c8aa1f0d3bf30cf2f4e6..f06dd557cedc61825c31056dcb8f3a3e7756b88c 100644 (file)
@@ -1,9 +1,28 @@
 import asyncio
 import functools
+import sys
 import typing
 
+if sys.version_info >= (3, 10):  # pragma: no cover
+    from typing import TypeGuard
+else:  # pragma: no cover
+    from typing_extensions import TypeGuard
 
-def is_async_callable(obj: typing.Any) -> bool:
+T = typing.TypeVar("T")
+AwaitableCallable = typing.Callable[..., typing.Awaitable[T]]
+
+
+@typing.overload
+def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]:
+    ...
+
+
+@typing.overload
+def is_async_callable(obj: typing.Any) -> TypeGuard[AwaitableCallable[typing.Any]]:
+    ...
+
+
+def is_async_callable(obj: typing.Any) -> typing.Any:
     while isinstance(obj, functools.partial):
         obj = obj.func
 
index 344a4a37fe66d2684c80b1c7fdceaa19e904318a..cef4ace712508cc828c049abdb6014be2df0385a 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import typing
 import warnings
 
@@ -9,7 +11,8 @@ from starlette.middleware.exceptions import ExceptionMiddleware
 from starlette.requests import Request
 from starlette.responses import Response
 from starlette.routing import BaseRoute, Router
-from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
+from starlette.types import ASGIApp, ExceptionHandler, Lifespan, Receive, Scope, Send
+from starlette.websockets import WebSocket
 
 AppType = typing.TypeVar("AppType", bound="Starlette")
 
@@ -47,19 +50,11 @@ class Starlette:
     def __init__(
         self: "AppType",
         debug: bool = False,
-        routes: typing.Optional[typing.Sequence[BaseRoute]] = None,
-        middleware: typing.Optional[typing.Sequence[Middleware]] = None,
-        exception_handlers: typing.Optional[
-            typing.Mapping[
-                typing.Any,
-                typing.Callable[
-                    [Request, Exception],
-                    typing.Union[Response, typing.Awaitable[Response]],
-                ],
-            ]
-        ] = None,
-        on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None,
-        on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None,
+        routes: typing.Sequence[BaseRoute] | None = None,
+        middleware: typing.Sequence[Middleware] | None = None,
+        exception_handlers: typing.Mapping[typing.Any, ExceptionHandler] | None = None,
+        on_startup: typing.Sequence[typing.Callable[[], typing.Any]] | None = None,
+        on_shutdown: typing.Sequence[typing.Callable[[], typing.Any]] | None = None,
         lifespan: typing.Optional[Lifespan["AppType"]] = None,
     ) -> None:
         # The lifespan context function is a newer style that replaces
@@ -120,18 +115,14 @@ class Starlette:
             self.middleware_stack = self.build_middleware_stack()
         await self.middleware_stack(scope, receive, send)
 
-    def on_event(self, event_type: str) -> typing.Callable:  # pragma: nocover
-        return self.router.on_event(event_type)
+    def on_event(self, event_type: str) -> typing.Callable:  # type: ignore[type-arg]
+        return self.router.on_event(event_type)  # pragma: nocover
 
-    def mount(
-        self, path: str, app: ASGIApp, name: typing.Optional[str] = None
-    ) -> None:  # pragma: nocover
-        self.router.mount(path, app=app, name=name)
+    def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None:
+        self.router.mount(path, app=app, name=name)  # pragma: no cover
 
-    def host(
-        self, host: str, app: ASGIApp, name: typing.Optional[str] = None
-    ) -> None:  # pragma: no cover
-        self.router.host(host, app=app, name=name)
+    def host(self, host: str, app: ASGIApp, name: str | None = None) -> None:
+        self.router.host(host, app=app, name=name)  # pragma: no cover
 
     def add_middleware(self, middleware_class: type, **options: typing.Any) -> None:
         if self.middleware_stack is not None:  # pragma: no cover
@@ -140,20 +131,20 @@ class Starlette:
 
     def add_exception_handler(
         self,
-        exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
-        handler: typing.Callable,
+        exc_class_or_status_code: int | typing.Type[Exception],
+        handler: ExceptionHandler,
     ) -> None:  # pragma: no cover
         self.exception_handlers[exc_class_or_status_code] = handler
 
     def add_event_handler(
-        self, event_type: str, func: typing.Callable
+        self, event_type: str, func: typing.Callable  # type: ignore[type-arg]
     ) -> None:  # pragma: no cover
         self.router.add_event_handler(event_type, func)
 
     def add_route(
         self,
         path: str,
-        route: typing.Callable,
+        route: typing.Callable[[Request], typing.Awaitable[Response] | Response],
         methods: typing.Optional[typing.List[str]] = None,
         name: typing.Optional[str] = None,
         include_in_schema: bool = True,
@@ -163,20 +154,23 @@ class Starlette:
         )
 
     def add_websocket_route(
-        self, path: str, route: typing.Callable, name: typing.Optional[str] = None
+        self,
+        path: str,
+        route: typing.Callable[[WebSocket], typing.Awaitable[None]],
+        name: str | None = None,
     ) -> None:  # pragma: no cover
         self.router.add_websocket_route(path, route, name=name)
 
     def exception_handler(
-        self, exc_class_or_status_code: typing.Union[int, typing.Type[Exception]]
-    ) -> typing.Callable:
+        self, exc_class_or_status_code: int | typing.Type[Exception]
+    ) -> typing.Callable:  # type: ignore[type-arg]
         warnings.warn(
             "The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. "  # noqa: E501
             "Refer to https://www.starlette.io/exceptions/ for the recommended approach.",  # noqa: E501
             DeprecationWarning,
         )
 
-        def decorator(func: typing.Callable) -> typing.Callable:
+        def decorator(func: typing.Callable) -> typing.Callable:  # type: ignore[type-arg]  # noqa: E501
             self.add_exception_handler(exc_class_or_status_code, func)
             return func
 
@@ -185,10 +179,10 @@ class Starlette:
     def route(
         self,
         path: str,
-        methods: typing.Optional[typing.List[str]] = None,
-        name: typing.Optional[str] = None,
+        methods: typing.List[str] | None = None,
+        name: str | None = None,
         include_in_schema: bool = True,
-    ) -> typing.Callable:
+    ) -> typing.Callable:  # type: ignore[type-arg]
         """
         We no longer document this decorator style API, and its usage is discouraged.
         Instead you should use the following approach:
@@ -202,7 +196,7 @@ class Starlette:
             DeprecationWarning,
         )
 
-        def decorator(func: typing.Callable) -> typing.Callable:
+        def decorator(func: typing.Callable) -> typing.Callable:  # type: ignore[type-arg]  # noqa: E501
             self.router.add_route(
                 path,
                 func,
@@ -215,8 +209,8 @@ class Starlette:
         return decorator
 
     def websocket_route(
-        self, path: str, name: typing.Optional[str] = None
-    ) -> typing.Callable:
+        self, path: str, name: str | None = None
+    ) -> typing.Callable:  # type: ignore[type-arg]
         """
         We no longer document this decorator style API, and its usage is discouraged.
         Instead you should use the following approach:
@@ -230,13 +224,13 @@ class Starlette:
             DeprecationWarning,
         )
 
-        def decorator(func: typing.Callable) -> typing.Callable:
+        def decorator(func: typing.Callable) -> typing.Callable:  # type: ignore[type-arg]  # noqa: E501
             self.router.add_websocket_route(path, func, name=name)
             return func
 
         return decorator
 
-    def middleware(self, middleware_type: str) -> typing.Callable:
+    def middleware(self, middleware_type: str) -> typing.Callable:  # type: ignore[type-arg]  # noqa: E501
         """
         We no longer document this decorator style API, and its usage is discouraged.
         Instead you should use the following approach:
@@ -253,7 +247,7 @@ class Starlette:
             middleware_type == "http"
         ), 'Currently only middleware("http") is supported.'
 
-        def decorator(func: typing.Callable) -> typing.Callable:
+        def decorator(func: typing.Callable) -> typing.Callable:  # type: ignore[type-arg]  # noqa: E501
             self.add_middleware(BaseHTTPMiddleware, dispatch=func)
             return func
 
index 32713eb17477022b057040e78a218559b785661e..494c50a57cdfd668e3594ebbbef1f26c90b000f7 100644 (file)
@@ -1,15 +1,21 @@
 import functools
 import inspect
+import sys
 import typing
 from urllib.parse import urlencode
 
+if sys.version_info >= (3, 10):  # pragma: no cover
+    from typing import ParamSpec
+else:  # pragma: no cover
+    from typing_extensions import ParamSpec
+
 from starlette._utils import is_async_callable
 from starlette.exceptions import HTTPException
 from starlette.requests import HTTPConnection, Request
-from starlette.responses import RedirectResponse, Response
+from starlette.responses import RedirectResponse
 from starlette.websockets import WebSocket
 
-_CallableType = typing.TypeVar("_CallableType", bound=typing.Callable)
+_P = ParamSpec("_P")
 
 
 def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool:
@@ -23,10 +29,14 @@ def requires(
     scopes: typing.Union[str, typing.Sequence[str]],
     status_code: int = 403,
     redirect: typing.Optional[str] = None,
-) -> typing.Callable[[_CallableType], _CallableType]:
+) -> typing.Callable[
+    [typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]
+]:
     scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
 
-    def decorator(func: typing.Callable) -> typing.Callable:
+    def decorator(
+        func: typing.Callable[_P, typing.Any]
+    ) -> typing.Callable[_P, typing.Any]:
         sig = inspect.signature(func)
         for idx, parameter in enumerate(sig.parameters.values()):
             if parameter.name == "request" or parameter.name == "websocket":
@@ -40,9 +50,7 @@ def requires(
         if type_ == "websocket":
             # Handle websocket functions. (Always async)
             @functools.wraps(func)
-            async def websocket_wrapper(
-                *args: typing.Any, **kwargs: typing.Any
-            ) -> None:
+            async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
                 websocket = kwargs.get(
                     "websocket", args[idx] if idx < len(args) else None
                 )
@@ -58,9 +66,7 @@ def requires(
         elif is_async_callable(func):
             # Handle async request/response functions.
             @functools.wraps(func)
-            async def async_wrapper(
-                *args: typing.Any, **kwargs: typing.Any
-            ) -> Response:
+            async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any:
                 request = kwargs.get("request", args[idx] if idx < len(args) else None)
                 assert isinstance(request, Request)
 
@@ -80,7 +86,7 @@ def requires(
         else:
             # Handle sync request/response functions.
             @functools.wraps(func)
-            def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
+            def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any:
                 request = kwargs.get("request", args[idx] if idx < len(args) else None)
                 assert isinstance(request, Request)
 
@@ -97,7 +103,7 @@ def requires(
 
             return sync_wrapper
 
-    return decorator  # type: ignore[return-value]
+    return decorator
 
 
 class AuthenticationError(Exception):
index 5c76cb3df7c58d0ef5829bfd166bb7b29018e1c9..ca6033c0f47a5120e1675fee1ebdb0baafd7fa54 100644 (file)
@@ -1,21 +1,13 @@
 import functools
-import sys
 import typing
 import warnings
 
-import anyio
-
-if sys.version_info >= (3, 10):  # pragma: no cover
-    from typing import ParamSpec
-else:  # pragma: no cover
-    from typing_extensions import ParamSpec
-
+import anyio.to_thread
 
 T = typing.TypeVar("T")
-P = ParamSpec("P")
 
 
-async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None:
+async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None:  # type: ignore[type-arg]  # noqa: E501
     warnings.warn(
         "run_until_first_complete is deprecated "
         "and will be removed in a future version.",
@@ -24,7 +16,7 @@ async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -
 
     async with anyio.create_task_group() as task_group:
 
-        async def run(func: typing.Callable[[], typing.Coroutine]) -> None:
+        async def run(func: typing.Callable[[], typing.Coroutine]) -> None:  # type: ignore[type-arg]  # noqa: E501
             await func()
             task_group.cancel_scope.cancel()
 
@@ -32,8 +24,10 @@ async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -
             task_group.start_soon(run, functools.partial(func, **kwargs))
 
 
+# TODO: We should use `ParamSpec` here, but mypy doesn't support it yet.
+# Check https://github.com/python/mypy/issues/12278 for more details.
 async def run_in_threadpool(
-    func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs
+    func: typing.Callable[..., T], *args: typing.Any, **kwargs: typing.Any
 ) -> T:
     if kwargs:  # pragma: no cover
         # run_sync doesn't accept 'kwargs', so bind them in here
index 795232cf642a3c7ca107fbeb47bb74f854ac82a4..173955006e572c909be6ef82e3ccc029c2e6701f 100644 (file)
@@ -1,6 +1,5 @@
 import os
 import typing
-from collections.abc import MutableMapping
 from pathlib import Path
 
 
@@ -12,16 +11,16 @@ class EnvironError(Exception):
     pass
 
 
-class Environ(MutableMapping):
-    def __init__(self, environ: typing.MutableMapping = os.environ):
+class Environ(typing.MutableMapping[str, str]):
+    def __init__(self, environ: typing.MutableMapping[str, str] = os.environ):
         self._environ = environ
-        self._has_been_read: typing.Set[typing.Any] = set()
+        self._has_been_read: typing.Set[str] = set()
 
-    def __getitem__(self, key: typing.Any) -> typing.Any:
+    def __getitem__(self, key: str) -> str:
         self._has_been_read.add(key)
         return self._environ.__getitem__(key)
 
-    def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
+    def __setitem__(self, key: str, value: str) -> None:
         if key in self._has_been_read:
             raise EnvironError(
                 f"Attempting to set environ['{key}'], but the value has already been "
@@ -29,7 +28,7 @@ class Environ(MutableMapping):
             )
         self._environ.__setitem__(key, value)
 
-    def __delitem__(self, key: typing.Any) -> None:
+    def __delitem__(self, key: str) -> None:
         if key in self._has_been_read:
             raise EnvironError(
                 f"Attempting to delete environ['{key}'], but the value has already "
@@ -37,7 +36,7 @@ class Environ(MutableMapping):
             )
         self._environ.__delitem__(key)
 
-    def __iter__(self) -> typing.Iterator:
+    def __iter__(self) -> typing.Iterator[str]:
         return iter(self._environ)
 
     def __len__(self) -> int:
@@ -94,7 +93,7 @@ class Config:
     def __call__(
         self,
         key: str,
-        cast: typing.Optional[typing.Callable] = None,
+        cast: typing.Optional[typing.Callable[[typing.Any], typing.Any]] = None,
         default: typing.Any = undefined,
     ) -> typing.Any:
         return self.get(key, cast, default)
@@ -102,7 +101,7 @@ class Config:
     def get(
         self,
         key: str,
-        cast: typing.Optional[typing.Callable] = None,
+        cast: typing.Optional[typing.Callable[[typing.Any], typing.Any]] = None,
         default: typing.Any = undefined,
     ) -> typing.Any:
         key = self.env_prefix + key
@@ -129,7 +128,10 @@ class Config:
         return file_values
 
     def _perform_cast(
-        self, key: str, value: typing.Any, cast: typing.Optional[typing.Callable] = None
+        self,
+        key: str,
+        value: typing.Any,
+        cast: typing.Optional[typing.Callable[[typing.Any], typing.Any]] = None,
     ) -> typing.Any:
         if cast is None or value is None:
             return value
index 3ade9f7af2c79163be639871f2592c7eb5e523f4..3b12ac7a0ced2f5a4f6bffb8e88000dd3eac8e0b 100644 (file)
@@ -15,7 +15,7 @@ class Convertor(typing.Generic[T]):
         raise NotImplementedError()  # pragma: no cover
 
 
-class StringConvertor(Convertor):
+class StringConvertor(Convertor[str]):
     regex = "[^/]+"
 
     def convert(self, value: str) -> str:
@@ -28,7 +28,7 @@ class StringConvertor(Convertor):
         return value
 
 
-class PathConvertor(Convertor):
+class PathConvertor(Convertor[str]):
     regex = ".*"
 
     def convert(self, value: str) -> str:
@@ -38,7 +38,7 @@ class PathConvertor(Convertor):
         return str(value)
 
 
-class IntegerConvertor(Convertor):
+class IntegerConvertor(Convertor[int]):
     regex = "[0-9]+"
 
     def convert(self, value: str) -> int:
@@ -50,7 +50,7 @@ class IntegerConvertor(Convertor):
         return str(value)
 
 
-class FloatConvertor(Convertor):
+class FloatConvertor(Convertor[float]):
     regex = r"[0-9]+(\.[0-9]+)?"
 
     def convert(self, value: str) -> float:
@@ -64,7 +64,7 @@ class FloatConvertor(Convertor):
         return ("%0.20f" % value).rstrip("0").rstrip(".")
 
 
-class UUIDConvertor(Convertor):
+class UUIDConvertor(Convertor[uuid.UUID]):
     regex = "[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
 
     def convert(self, value: str) -> uuid.UUID:
@@ -74,7 +74,7 @@ class UUIDConvertor(Convertor):
         return str(value)
 
 
-CONVERTOR_TYPES = {
+CONVERTOR_TYPES: typing.Dict[str, Convertor[typing.Any]] = {
     "str": StringConvertor(),
     "path": PathConvertor(),
     "int": IntegerConvertor(),
@@ -83,5 +83,5 @@ CONVERTOR_TYPES = {
 }
 
 
-def register_url_convertor(key: str, convertor: Convertor) -> None:
+def register_url_convertor(key: str, convertor: Convertor[typing.Any]) -> None:
     CONVERTOR_TYPES[key] = convertor
index 236f9fa433436424eb32c26c6ff08fc80e58923a..dc57c2e9f8f0abcf119c1be2cd81e38c446dfa2b 100644 (file)
@@ -1,5 +1,4 @@
 import typing
-from collections.abc import Sequence
 from shlex import shlex
 from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit
 
@@ -223,7 +222,7 @@ class Secret:
         return bool(self._value)
 
 
-class CommaSeparatedStrings(Sequence):
+class CommaSeparatedStrings(typing.Sequence[str]):
     def __init__(self, value: typing.Union[str, typing.Sequence[str]]):
         if isinstance(value, str):
             splitter = shlex(value, posix=True)
@@ -269,7 +268,7 @@ class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]):
         if kwargs:
             value = (
                 ImmutableMultiDict(value).multi_items()
-                + ImmutableMultiDict(kwargs).multi_items()  # type: ignore[operator]
+                + ImmutableMultiDict(kwargs).multi_items()
             )
 
         if not value:
@@ -341,12 +340,12 @@ class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]):
         self._list = [(k, v) for k, v in self._list if k != key]
         return self._dict.pop(key, default)
 
-    def popitem(self) -> typing.Tuple:
+    def popitem(self) -> typing.Tuple[typing.Any, typing.Any]:
         key, value = self._dict.popitem()
         self._list = [(k, v) for k, v in self._list if k != key]
         return key, value
 
-    def poplist(self, key: typing.Any) -> typing.List:
+    def poplist(self, key: typing.Any) -> typing.List[typing.Any]:
         values = [v for k, v in self._list if k == key]
         self.pop(key)
         return values
@@ -362,7 +361,7 @@ class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]):
 
         return self[key]
 
-    def setlist(self, key: typing.Any, values: typing.List) -> None:
+    def setlist(self, key: typing.Any, values: typing.List[typing.Any]) -> None:
         if not values:
             self.pop(key, None)
         else:
@@ -378,7 +377,7 @@ class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]):
         self,
         *args: typing.Union[
             "MultiDict",
-            typing.Mapping,
+            typing.Mapping[typing.Any, typing.Any],
             typing.List[typing.Tuple[typing.Any, typing.Any]],
         ],
         **kwargs: typing.Any,
@@ -397,8 +396,8 @@ class QueryParams(ImmutableMultiDict[str, str]):
     def __init__(
         self,
         *args: typing.Union[
-            "ImmutableMultiDict",
-            typing.Mapping,
+            "ImmutableMultiDict[typing.Any, typing.Any]",
+            typing.Mapping[typing.Any, typing.Any],
             typing.List[typing.Tuple[typing.Any, typing.Any]],
             str,
             bytes,
index 95cd7640dc06ac5d3ac029d3d528a89db541e750..c25dd9db2adc378dcae1aca0d51947fb6c03c3c9 100644 (file)
@@ -23,7 +23,7 @@ class HTTPEndpoint:
             if getattr(self, method.lower(), None) is not None
         ]
 
-    def __await__(self) -> typing.Generator:
+    def __await__(self) -> typing.Generator[typing.Any, None, None]:
         return self.dispatch().__await__()
 
     async def dispatch(self) -> None:
@@ -63,7 +63,7 @@ class WebSocketEndpoint:
         self.receive = receive
         self.send = send
 
-    def __await__(self) -> typing.Generator:
+    def __await__(self) -> typing.Generator[typing.Any, None, None]:
         return self.dispatch().__await__()
 
     async def dispatch(self) -> None:
index cc08ed909d5dd696c12f323b1800488050d21445..a583d93a06038d3e6beea2ac0c0eaf17a762e945 100644 (file)
@@ -10,7 +10,7 @@ class HTTPException(Exception):
         self,
         status_code: int,
         detail: typing.Optional[str] = None,
-        headers: typing.Optional[dict] = None,
+        headers: typing.Optional[typing.Dict[str, str]] = None,
     ) -> None:
         if detail is None:
             detail = http.HTTPStatus(status_code).phrase
index eb3cba5bed93d154441bc5e8b2f815d2ba39e2bd..5ac2bcc1b3aaddd461ac4ac6d43e9f947ead2721 100644 (file)
@@ -142,7 +142,7 @@ class MultiPartParser:
         self._charset = ""
         self._file_parts_to_write: typing.List[typing.Tuple[MultipartPart, bytes]] = []
         self._file_parts_to_finish: typing.List[MultipartPart] = []
-        self._files_to_close_on_error: typing.List[SpooledTemporaryFile] = []
+        self._files_to_close_on_error: typing.List[SpooledTemporaryFile[bytes]] = []
 
     def on_part_begin(self) -> None:
         self._current_part = MultipartPart()
index 5ac5b96c817db836da675bdd7677605a9167bc33..05bd57f04068a01838097d551b789efe8726374f 100644 (file)
@@ -6,7 +6,7 @@ class Middleware:
         self.cls = cls
         self.options = options
 
-    def __iter__(self) -> typing.Iterator:
+    def __iter__(self) -> typing.Iterator[typing.Any]:
         as_tuple = (self.cls, self.options)
         return iter(as_tuple)
 
index b9d9c691096fc9abcc13ab0e31c4aa7e28d574cb..f4c3d67461d74838d2288b165c65de5a1f45c9cf 100644 (file)
@@ -137,7 +137,9 @@ class ServerErrorMiddleware:
     def __init__(
         self,
         app: ASGIApp,
-        handler: typing.Optional[typing.Callable] = None,
+        handler: typing.Optional[
+            typing.Callable[[Request, Exception], typing.Any]
+        ] = None,
         debug: bool = False,
     ) -> None:
         self.app = app
index 59010c7e684f7a811c5f0582b8eb461dad3c5b72..0124f5c8f3b3452b4ef671c83f7dfe258e5aa802 100644 (file)
@@ -26,7 +26,7 @@ class ExceptionMiddleware:
         self._status_handlers: StatusHandlers = {}
         self._exception_handlers: ExceptionHandlers = {
             HTTPException: self.http_exception,
-            WebSocketException: self.websocket_exception,  # type: ignore[dict-item]
+            WebSocketException: self.websocket_exception,
         }
         if handlers is not None:
             for key, value in handlers.items():
index d4a117cacfa97f68a1d409a808dd5aaaa4ee2601..95578c9d26938217db704202ad9322164223a93e 100644 (file)
@@ -16,7 +16,7 @@ warnings.warn(
 )
 
 
-def build_environ(scope: Scope, body: bytes) -> dict:
+def build_environ(scope: Scope, body: bytes) -> typing.Dict[str, typing.Any]:
     """
     Builds a scope and request body into a WSGI environ object.
     """
@@ -63,7 +63,7 @@ def build_environ(scope: Scope, body: bytes) -> dict:
 
 
 class WSGIMiddleware:
-    def __init__(self, app: typing.Callable) -> None:
+    def __init__(self, app: typing.Callable[..., typing.Any]) -> None:
         self.app = app
 
     async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
@@ -76,7 +76,7 @@ class WSGIResponder:
     stream_send: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
     stream_receive: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
 
-    def __init__(self, app: typing.Callable, scope: Scope) -> None:
+    def __init__(self, app: typing.Callable[..., typing.Any], scope: Scope) -> None:
         self.app = app
         self.scope = scope
         self.status = None
@@ -132,7 +132,11 @@ class WSGIResponder:
                 },
             )
 
-    def wsgi(self, environ: dict, start_response: typing.Callable) -> None:
+    def wsgi(
+        self,
+        environ: typing.Dict[str, typing.Any],
+        start_response: typing.Callable[..., typing.Any],
+    ) -> None:
         for chunk in self.app(environ, start_response):
             anyio.from_thread.run(
                 self.stream_send.send,
index fff451e2321280e29044b8546d0d065caa46bd3b..5c7a4296c0bbfc7feb1107cb95a01bc97c78c7f8 100644 (file)
@@ -147,7 +147,7 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
         assert (
             "session" in self.scope
         ), "SessionMiddleware must be installed to access request.session"
-        return self.scope["session"]
+        return self.scope["session"]  # type: ignore[no-any-return]
 
     @property
     def auth(self) -> typing.Any:
@@ -203,7 +203,7 @@ class Request(HTTPConnection):
 
     @property
     def method(self) -> str:
-        return self.scope["method"]
+        return typing.cast(str, self.scope["method"])
 
     @property
     def receive(self) -> Receive:
index 575caf655b115f1e8780d09bbfedec111557af11..16380db0601c82e103744051e8c0fa194e1cc3ef 100644 (file)
@@ -37,7 +37,7 @@ class Response:
         self.body = self.render(content)
         self.init_headers(headers)
 
-    def render(self, content: typing.Any) -> bytes:
+    def render(self, content: typing.Union[str, bytes, None]) -> bytes:
         if content is None:
             return b""
         if isinstance(content, bytes):
index b50d32a1fc6d6247e7791bef9019a7580b5b06f5..9da6730ca018dc0ee3f9268f4099362fe1a4c4b8 100644 (file)
@@ -17,7 +17,7 @@ from starlette.datastructures import URL, Headers, URLPath
 from starlette.exceptions import HTTPException
 from starlette.middleware import Middleware
 from starlette.requests import Request
-from starlette.responses import PlainTextResponse, RedirectResponse
+from starlette.responses import PlainTextResponse, RedirectResponse, Response
 from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
 from starlette.websockets import WebSocket, WebSocketClose
 
@@ -54,18 +54,19 @@ def iscoroutinefunction_or_partial(obj: typing.Any) -> bool:  # pragma: no cover
     return inspect.iscoroutinefunction(obj)
 
 
-def request_response(func: typing.Callable) -> ASGIApp:
+def request_response(
+    func: typing.Callable[[Request], typing.Union[typing.Awaitable[Response], Response]]
+) -> ASGIApp:
     """
     Takes a function or coroutine `func(request) -> response`,
     and returns an ASGI application.
     """
-    is_coroutine = is_async_callable(func)
 
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         request = Request(scope, receive, send)
 
         async def app(scope: Scope, receive: Receive, send: Send) -> None:
-            if is_coroutine:
+            if is_async_callable(func):
                 response = await func(request)
             else:
                 response = await run_in_threadpool(func, request)
@@ -76,7 +77,9 @@ def request_response(func: typing.Callable) -> ASGIApp:
     return app
 
 
-def websocket_session(func: typing.Callable) -> ASGIApp:
+def websocket_session(
+    func: typing.Callable[[WebSocket], typing.Awaitable[None]]
+) -> ASGIApp:
     """
     Takes a coroutine `func(session)`, and returns an ASGI application.
     """
@@ -93,7 +96,7 @@ def websocket_session(func: typing.Callable) -> ASGIApp:
     return app
 
 
-def get_name(endpoint: typing.Callable) -> str:
+def get_name(endpoint: typing.Callable[..., typing.Any]) -> str:
     if inspect.isroutine(endpoint) or inspect.isclass(endpoint):
         return endpoint.__name__
     return endpoint.__class__.__name__
@@ -101,9 +104,9 @@ def get_name(endpoint: typing.Callable) -> str:
 
 def replace_params(
     path: str,
-    param_convertors: typing.Dict[str, Convertor],
+    param_convertors: typing.Dict[str, Convertor[typing.Any]],
     path_params: typing.Dict[str, str],
-) -> typing.Tuple[str, dict]:
+) -> typing.Tuple[str, typing.Dict[str, str]]:
     for key, value in list(path_params.items()):
         if "{" + key + "}" in path:
             convertor = param_convertors[key]
@@ -119,7 +122,7 @@ PARAM_REGEX = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}")
 
 def compile_path(
     path: str,
-) -> typing.Tuple[typing.Pattern, str, typing.Dict[str, Convertor]]:
+) -> typing.Tuple[typing.Pattern[str], str, typing.Dict[str, Convertor[typing.Any]]]:
     """
     Given a path string, like: "/{username:str}",
     or a host string, like: "{subdomain}.mydomain.org", return a three-tuple
@@ -209,7 +212,7 @@ class Route(BaseRoute):
     def __init__(
         self,
         path: str,
-        endpoint: typing.Callable,
+        endpoint: typing.Callable[..., typing.Any],
         *,
         methods: typing.Optional[typing.List[str]] = None,
         name: typing.Optional[str] = None,
@@ -301,7 +304,11 @@ class Route(BaseRoute):
 
 class WebSocketRoute(BaseRoute):
     def __init__(
-        self, path: str, endpoint: typing.Callable, *, name: typing.Optional[str] = None
+        self,
+        path: str,
+        endpoint: typing.Callable[..., typing.Any],
+        *,
+        name: typing.Optional[str] = None,
     ) -> None:
         assert path.startswith("/"), "Routed paths must start with '/'"
         self.path = path
@@ -556,12 +563,14 @@ class _AsyncLiftContextManager(typing.AsyncContextManager[_T]):
 
 
 def _wrap_gen_lifespan_context(
-    lifespan_context: typing.Callable[[typing.Any], typing.Generator]
-) -> typing.Callable[[typing.Any], typing.AsyncContextManager]:
+    lifespan_context: typing.Callable[
+        [typing.Any], typing.Generator[typing.Any, typing.Any, typing.Any]
+    ]
+) -> typing.Callable[[typing.Any], typing.AsyncContextManager[typing.Any]]:
     cmgr = contextlib.contextmanager(lifespan_context)
 
     @functools.wraps(cmgr)
-    def wrapper(app: typing.Any) -> _AsyncLiftContextManager:
+    def wrapper(app: typing.Any) -> _AsyncLiftContextManager[typing.Any]:
         return _AsyncLiftContextManager(cmgr(app))
 
     return wrapper
@@ -587,8 +596,12 @@ class Router:
         routes: typing.Optional[typing.Sequence[BaseRoute]] = None,
         redirect_slashes: bool = True,
         default: typing.Optional[ASGIApp] = None,
-        on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None,
-        on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None,
+        on_startup: typing.Optional[
+            typing.Sequence[typing.Callable[[], typing.Any]]
+        ] = None,
+        on_shutdown: typing.Optional[
+            typing.Sequence[typing.Callable[[], typing.Any]]
+        ] = None,
         # the generic to Lifespan[AppType] is the type of the top level application
         # which the router cannot know statically, so we use typing.Any
         lifespan: typing.Optional[Lifespan[typing.Any]] = None,
@@ -614,7 +627,7 @@ class Router:
                 )
 
         if lifespan is None:
-            self.lifespan_context: Lifespan = _DefaultLifespan(self)
+            self.lifespan_context: Lifespan[typing.Any] = _DefaultLifespan(self)
 
         elif inspect.isasyncgenfunction(lifespan):
             warnings.warn(
@@ -623,7 +636,7 @@ class Router:
                 DeprecationWarning,
             )
             self.lifespan_context = asynccontextmanager(
-                lifespan,  # type: ignore[arg-type]
+                lifespan,
             )
         elif inspect.isgeneratorfunction(lifespan):
             warnings.warn(
@@ -632,7 +645,7 @@ class Router:
                 DeprecationWarning,
             )
             self.lifespan_context = _wrap_gen_lifespan_context(
-                lifespan,  # type: ignore[arg-type]
+                lifespan,
             )
         else:
             self.lifespan_context = lifespan
@@ -779,7 +792,9 @@ class Router:
     def add_route(
         self,
         path: str,
-        endpoint: typing.Callable,
+        endpoint: typing.Callable[
+            [Request], typing.Union[typing.Awaitable[Response], Response]
+        ],
         methods: typing.Optional[typing.List[str]] = None,
         name: typing.Optional[str] = None,
         include_in_schema: bool = True,
@@ -794,7 +809,10 @@ class Router:
         self.routes.append(route)
 
     def add_websocket_route(
-        self, path: str, endpoint: typing.Callable, name: typing.Optional[str] = None
+        self,
+        path: str,
+        endpoint: typing.Callable[[WebSocket], typing.Awaitable[None]],
+        name: typing.Optional[str] = None,
     ) -> None:  # pragma: no cover
         route = WebSocketRoute(path, endpoint=endpoint, name=name)
         self.routes.append(route)
@@ -805,7 +823,7 @@ class Router:
         methods: typing.Optional[typing.List[str]] = None,
         name: typing.Optional[str] = None,
         include_in_schema: bool = True,
-    ) -> typing.Callable:
+    ) -> typing.Callable:  # type: ignore[type-arg]
         """
         We no longer document this decorator style API, and its usage is discouraged.
         Instead you should use the following approach:
@@ -819,7 +837,7 @@ class Router:
             DeprecationWarning,
         )
 
-        def decorator(func: typing.Callable) -> typing.Callable:
+        def decorator(func: typing.Callable) -> typing.Callable:  # type: ignore[type-arg]  # noqa: E501
             self.add_route(
                 path,
                 func,
@@ -833,7 +851,7 @@ class Router:
 
     def websocket_route(
         self, path: str, name: typing.Optional[str] = None
-    ) -> typing.Callable:
+    ) -> typing.Callable:  # type: ignore[type-arg]
         """
         We no longer document this decorator style API, and its usage is discouraged.
         Instead you should use the following approach:
@@ -847,14 +865,14 @@ class Router:
             DeprecationWarning,
         )
 
-        def decorator(func: typing.Callable) -> typing.Callable:
+        def decorator(func: typing.Callable) -> typing.Callable:  # type: ignore[type-arg]  # noqa: E501
             self.add_websocket_route(path, func, name=name)
             return func
 
         return decorator
 
     def add_event_handler(
-        self, event_type: str, func: typing.Callable
+        self, event_type: str, func: typing.Callable[[], typing.Any]
     ) -> None:  # pragma: no cover
         assert event_type in ("startup", "shutdown")
 
@@ -863,14 +881,14 @@ class Router:
         else:
             self.on_shutdown.append(func)
 
-    def on_event(self, event_type: str) -> typing.Callable:
+    def on_event(self, event_type: str) -> typing.Callable:  # type: ignore[type-arg]
         warnings.warn(
             "The `on_event` decorator is deprecated, and will be removed in version 1.0.0. "  # noqa: E501
             "Refer to https://www.starlette.io/lifespan/ for recommended approach.",
             DeprecationWarning,
         )
 
-        def decorator(func: typing.Callable) -> typing.Callable:
+        def decorator(func: typing.Callable) -> typing.Callable:  # type: ignore[type-arg]  # noqa: E501
             self.add_event_handler(event_type, func)
             return func
 
index 72d93e7d704ecacabd3fa0bf2d459b4cb165f667..737f6b029de6c968e40f14189cd03dbd3a1b6113 100644 (file)
@@ -26,11 +26,13 @@ class OpenAPIResponse(Response):
 class EndpointInfo(typing.NamedTuple):
     path: str
     http_method: str
-    func: typing.Callable
+    func: typing.Callable[..., typing.Any]
 
 
 class BaseSchemaGenerator:
-    def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
+    def get_schema(
+        self, routes: typing.List[BaseRoute]
+    ) -> typing.Dict[str, typing.Any]:
         raise NotImplementedError()  # pragma: no cover
 
     def get_endpoints(
@@ -46,7 +48,7 @@ class BaseSchemaGenerator:
         - func
             method ready to extract the docstring
         """
-        endpoints_info: list = []
+        endpoints_info: typing.List[EndpointInfo] = []
 
         for route in routes:
             if isinstance(route, (Mount, Host)):
@@ -95,7 +97,9 @@ class BaseSchemaGenerator:
         """
         return re.sub(r":\w+}", "}", path)
 
-    def parse_docstring(self, func_or_method: typing.Callable) -> dict:
+    def parse_docstring(
+        self, func_or_method: typing.Callable[..., typing.Any]
+    ) -> typing.Dict[str, typing.Any]:
         """
         Given a function, parse the docstring as YAML and return a dictionary of info.
         """
@@ -126,10 +130,12 @@ class BaseSchemaGenerator:
 
 
 class SchemaGenerator(BaseSchemaGenerator):
-    def __init__(self, base_schema: dict) -> None:
+    def __init__(self, base_schema: typing.Dict[str, typing.Any]) -> None:
         self.base_schema = base_schema
 
-    def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
+    def get_schema(
+        self, routes: typing.List[BaseRoute]
+    ) -> typing.Dict[str, typing.Any]:
         schema = dict(self.base_schema)
         schema.setdefault("paths", {})
         endpoints_info = self.get_endpoints(routes)
index 4c856063c2945d93f1b79505f9d8de7919b947aa..2f1f1ddab8ca8af8ee99ef5fb3bc0a3fa3bb26d9 100644 (file)
@@ -108,7 +108,7 @@ class StaticFiles:
         Given the ASGI scope, return the `path` string to serve up,
         with OS specific path separators, and any '..', '.' components removed.
         """
-        return os.path.normpath(os.path.join(*scope["path"].split("/")))
+        return os.path.normpath(os.path.join(*scope["path"].split("/")))  # type: ignore[no-any-return]  # noqa: E501
 
     async def get_response(self, path: str, scope: Scope) -> Response:
         """
index ffa4133b81464a7a02896fae23d7a209d12fb3f9..071e8a4bb05c8b8bef988746bc411a4d66a56424 100644 (file)
@@ -29,7 +29,7 @@ class _TemplateResponse(Response):
     def __init__(
         self,
         template: typing.Any,
-        context: dict,
+        context: typing.Dict[str, typing.Any],
         status_code: int = 200,
         headers: typing.Optional[typing.Mapping[str, str]] = None,
         media_type: typing.Optional[str] = None,
@@ -66,11 +66,7 @@ class Jinja2Templates:
     @typing.overload
     def __init__(
         self,
-        directory: typing.Union[
-            str,
-            PathLike,
-            typing.Sequence[typing.Union[str, PathLike]],
-        ],
+        directory: "typing.Union[str, PathLike[typing.AnyStr], typing.Sequence[typing.Union[str, PathLike[typing.AnyStr]]]]",  # noqa: E501
         *,
         context_processors: typing.Optional[
             typing.List[typing.Callable[[Request], typing.Dict[str, typing.Any]]]
@@ -92,9 +88,7 @@ class Jinja2Templates:
 
     def __init__(
         self,
-        directory: typing.Union[
-            str, PathLike, typing.Sequence[typing.Union[str, PathLike]], None
-        ] = None,
+        directory: "typing.Union[str, PathLike[typing.AnyStr], typing.Sequence[typing.Union[str, PathLike[typing.AnyStr]]], None]" = None,  # noqa: E501
         *,
         context_processors: typing.Optional[
             typing.List[typing.Callable[[Request], typing.Dict[str, typing.Any]]]
@@ -117,14 +111,17 @@ class Jinja2Templates:
 
     def _create_env(
         self,
-        directory: typing.Union[
-            str, PathLike, typing.Sequence[typing.Union[str, PathLike]]
-        ],
+        directory: "typing.Union[str, PathLike[typing.AnyStr], typing.Sequence[typing.Union[str, PathLike[typing.AnyStr]]]]",  # noqa: E501
         **env_options: typing.Any,
     ) -> "jinja2.Environment":
         @pass_context
-        def url_for(context: dict, name: str, /, **path_params: typing.Any) -> URL:
-            request = context["request"]
+        def url_for(
+            context: typing.Dict[str, typing.Any],
+            name: str,
+            /,
+            **path_params: typing.Any,
+        ) -> URL:
+            request: Request = context["request"]
             return request.url_for(name, **path_params)
 
         loader = jinja2.FileSystemLoader(directory)
@@ -143,7 +140,7 @@ class Jinja2Templates:
         self,
         request: Request,
         name: str,
-        context: typing.Optional[dict] = None,
+        context: typing.Optional[typing.Dict[str, typing.Any]] = None,
         status_code: int = 200,
         headers: typing.Optional[typing.Mapping[str, str]] = None,
         media_type: typing.Optional[str] = None,
@@ -155,7 +152,7 @@ class Jinja2Templates:
     def TemplateResponse(
         self,
         name: str,
-        context: typing.Optional[dict] = None,
+        context: typing.Optional[typing.Dict[str, typing.Any]] = None,
         status_code: int = 200,
         headers: typing.Optional[typing.Mapping[str, str]] = None,
         media_type: typing.Optional[str] = None,
index c9ae97a0824aaf1c5fb821aa4bd75c3a62103851..bfac4bb9fca828a02e1792bde491a453a1a18a95 100644 (file)
@@ -79,8 +79,8 @@ class WebSocketTestSession:
         self.scope = scope
         self.accepted_subprotocol = None
         self.portal_factory = portal_factory
-        self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue()
-        self._send_queue: "queue.Queue[typing.Any]" = queue.Queue()
+        self._receive_queue: "queue.Queue[Message]" = queue.Queue()
+        self._send_queue: "queue.Queue[Message | BaseException]" = queue.Queue()
         self.extra_headers = None
 
     def __enter__(self) -> "WebSocketTestSession":
@@ -165,12 +165,12 @@ class WebSocketTestSession:
     def receive_text(self) -> str:
         message = self.receive()
         self._raise_on_close(message)
-        return message["text"]
+        return typing.cast(str, message["text"])
 
     def receive_bytes(self) -> bytes:
         message = self.receive()
         self._raise_on_close(message)
-        return message["bytes"]
+        return typing.cast(bytes, message["bytes"])
 
     def receive_json(self, mode: str = "text") -> typing.Any:
         assert mode in ["text", "binary"]
@@ -374,7 +374,7 @@ class TestClient(httpx.Client):
         root_path: str = "",
         backend: str = "asyncio",
         backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None,
-        cookies: httpx._client.CookieTypes = None,
+        cookies: httpx._types.CookieTypes = None,
         headers: typing.Dict[str, str] = None,
         follow_redirects: bool = True,
     ) -> None:
@@ -459,7 +459,7 @@ class TestClient(httpx.Client):
         follow_redirects: typing.Optional[bool] = None,
         allow_redirects: typing.Optional[bool] = None,
         timeout: typing.Union[
-            httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+            httpx._types.TimeoutTypes, httpx._client.UseClientDefault
         ] = httpx._client.USE_CLIENT_DEFAULT,
         extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
     ) -> httpx.Response:
@@ -469,7 +469,7 @@ class TestClient(httpx.Client):
             method,
             url,
             content=content,
-            data=data,  # type: ignore[arg-type]
+            data=data,
             files=files,
             json=json,
             params=params,
@@ -494,7 +494,7 @@ class TestClient(httpx.Client):
         follow_redirects: typing.Optional[bool] = None,
         allow_redirects: typing.Optional[bool] = None,
         timeout: typing.Union[
-            httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+            httpx._types.TimeoutTypes, httpx._client.UseClientDefault
         ] = httpx._client.USE_CLIENT_DEFAULT,
         extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
     ) -> httpx.Response:
@@ -523,7 +523,7 @@ class TestClient(httpx.Client):
         follow_redirects: typing.Optional[bool] = None,
         allow_redirects: typing.Optional[bool] = None,
         timeout: typing.Union[
-            httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+            httpx._types.TimeoutTypes, httpx._client.UseClientDefault
         ] = httpx._client.USE_CLIENT_DEFAULT,
         extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
     ) -> httpx.Response:
@@ -552,7 +552,7 @@ class TestClient(httpx.Client):
         follow_redirects: typing.Optional[bool] = None,
         allow_redirects: typing.Optional[bool] = None,
         timeout: typing.Union[
-            httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+            httpx._types.TimeoutTypes, httpx._client.UseClientDefault
         ] = httpx._client.USE_CLIENT_DEFAULT,
         extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
     ) -> httpx.Response:
@@ -585,7 +585,7 @@ class TestClient(httpx.Client):
         follow_redirects: typing.Optional[bool] = None,
         allow_redirects: typing.Optional[bool] = None,
         timeout: typing.Union[
-            httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+            httpx._types.TimeoutTypes, httpx._client.UseClientDefault
         ] = httpx._client.USE_CLIENT_DEFAULT,
         extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
     ) -> httpx.Response:
@@ -593,7 +593,7 @@ class TestClient(httpx.Client):
         return super().post(
             url,
             content=content,
-            data=data,  # type: ignore[arg-type]
+            data=data,
             files=files,
             json=json,
             params=params,
@@ -622,7 +622,7 @@ class TestClient(httpx.Client):
         follow_redirects: typing.Optional[bool] = None,
         allow_redirects: typing.Optional[bool] = None,
         timeout: typing.Union[
-            httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+            httpx._types.TimeoutTypes, httpx._client.UseClientDefault
         ] = httpx._client.USE_CLIENT_DEFAULT,
         extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
     ) -> httpx.Response:
@@ -630,7 +630,7 @@ class TestClient(httpx.Client):
         return super().put(
             url,
             content=content,
-            data=data,  # type: ignore[arg-type]
+            data=data,
             files=files,
             json=json,
             params=params,
@@ -659,7 +659,7 @@ class TestClient(httpx.Client):
         follow_redirects: typing.Optional[bool] = None,
         allow_redirects: typing.Optional[bool] = None,
         timeout: typing.Union[
-            httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+            httpx._types.TimeoutTypes, httpx._client.UseClientDefault
         ] = httpx._client.USE_CLIENT_DEFAULT,
         extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
     ) -> httpx.Response:
@@ -667,7 +667,7 @@ class TestClient(httpx.Client):
         return super().patch(
             url,
             content=content,
-            data=data,  # type: ignore[arg-type]
+            data=data,
             files=files,
             json=json,
             params=params,
@@ -692,7 +692,7 @@ class TestClient(httpx.Client):
         follow_redirects: typing.Optional[bool] = None,
         allow_redirects: typing.Optional[bool] = None,
         timeout: typing.Union[
-            httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+            httpx._types.TimeoutTypes, httpx._client.UseClientDefault
         ] = httpx._client.USE_CLIENT_DEFAULT,
         extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
     ) -> httpx.Response:
index 713d18a80ff2adaf0a83fd8c7501921be8b9c937..19484301e4d857b9dacc01ca88f073cf2a6b8a0a 100644 (file)
@@ -1,5 +1,10 @@
 import typing
 
+if typing.TYPE_CHECKING:
+    from starlette.requests import Request
+    from starlette.responses import Response
+    from starlette.websockets import WebSocket
+
 AppType = typing.TypeVar("AppType")
 
 Scope = typing.MutableMapping[str, typing.Any]
@@ -15,3 +20,11 @@ StatefulLifespan = typing.Callable[
     [AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]]
 ]
 Lifespan = typing.Union[StatelessLifespan[AppType], StatefulLifespan[AppType]]
+
+HTTPExceptionHandler = typing.Callable[
+    ["Request", Exception], typing.Union["Response", typing.Awaitable["Response"]]
+]
+WebSocketExceptionHandler = typing.Callable[
+    ["WebSocket", Exception], typing.Awaitable[None]
+]
+ExceptionHandler = typing.Union[HTTPExceptionHandler, WebSocketExceptionHandler]
index 5aa411824269809fc427390b6a0cb5f44769c080..4704dff728d99a6c13cfd6212bd3ffbc41e1c693 100644 (file)
@@ -111,7 +111,7 @@ class WebSocket(HTTPConnection):
             )
         message = await self.receive()
         self._raise_on_disconnect(message)
-        return message["text"]
+        return typing.cast(str, message["text"])
 
     async def receive_bytes(self) -> bytes:
         if self.application_state != WebSocketState.CONNECTED:
@@ -120,7 +120,7 @@ class WebSocket(HTTPConnection):
             )
         message = await self.receive()
         self._raise_on_disconnect(message)
-        return message["bytes"]
+        return typing.cast(bytes, message["bytes"])
 
     async def receive_json(self, mode: str = "text") -> typing.Any:
         if mode not in {"text", "binary"}:
index 72ca9ba1209adf892cfe7eadd314648b7cc03c57..2a866309ff17ded610aa7b16137ff2beef4acd8e 100644 (file)
@@ -15,7 +15,7 @@ def refresh_convertor_types():
     convertors.CONVERTOR_TYPES = convert_types
 
 
-class DateTimeConvertor(Convertor):
+class DateTimeConvertor(Convertor[datetime]):
     regex = "[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}(.[0-9]+)?"
 
     def convert(self, value: str) -> datetime:
index 502f7809f97b3d95ce137ebabc668e139bb9af6c..77ed776eaaa28d6fae7737801b217ff375629a39 100644 (file)
@@ -5,13 +5,14 @@ from contextlib import nullcontext as does_not_raise
 import pytest
 
 from starlette.applications import Starlette
-from starlette.formparsers import MultiPartException, UploadFile, _user_safe_decode
+from starlette.datastructures import UploadFile
+from starlette.formparsers import MultiPartException, _user_safe_decode
 from starlette.requests import Request
 from starlette.responses import JSONResponse
 from starlette.routing import Mount
 
 
-class ForceMultipartDict(dict):
+class ForceMultipartDict(typing.Dict[typing.Any, typing.Any]):
     def __bool__(self):
         return True
 
@@ -43,7 +44,7 @@ async def app(scope, receive, send):
 async def multi_items_app(scope, receive, send):
     request = Request(scope, receive)
     data = await request.form()
-    output: typing.Dict[str, list] = {}
+    output: typing.Dict[str, typing.List[typing.Any]] = {}
     for key, value in data.multi_items():
         if key not in output:
             output[key] = []
index a8f62b39e070a8b6dc4d35c5eaacb7ea6653dca9..caf110efe37a30c9476ea5ad8b885feaf7638003 100644 (file)
@@ -4,8 +4,8 @@ from typing import List, Optional
 import anyio
 import pytest
 
-from starlette.datastructures import Address
-from starlette.requests import ClientDisconnect, Request, State
+from starlette.datastructures import Address, State
+from starlette.requests import ClientDisconnect, Request
 from starlette.responses import JSONResponse, PlainTextResponse, Response
 from starlette.types import Message, Scope
 
index 284bda1efd11e4546e7b2169f87a2ec8567f46e7..7535fa64128a0eea8aad18e94aa56260379436b5 100644 (file)
@@ -1,6 +1,7 @@
 import datetime as dt
 import os
 import time
+import typing
 from http.cookies import SimpleCookie
 
 import anyio
@@ -343,7 +344,9 @@ def test_expires_on_set_cookie(test_client_factory, monkeypatch, expires):
 
     client = test_client_factory(app)
     response = client.get("/")
-    cookie: SimpleCookie = SimpleCookie(response.headers.get("set-cookie"))
+    cookie: "SimpleCookie[typing.Any]" = SimpleCookie(
+        response.headers.get("set-cookie")
+    )
     assert cookie["mycookie"]["expires"] == "Thu, 22 Jan 2037 12:00:10 GMT"
 
 
index 24f2bf7d734d4c8a806b52e7f8eeb1b2f11ab7da..7159a4bfcc74165faa8b30eac9d7d0da6476a657 100644 (file)
@@ -895,7 +895,7 @@ class Endpoint:
         pytest.param(lambda request: ..., "<lambda>", id="lambda"),
     ],
 )
-def test_route_name(endpoint: typing.Callable, expected_name: str):
+def test_route_name(endpoint: typing.Callable[..., typing.Any], expected_name: str):
     assert Route(path="/", endpoint=endpoint).name == expected_name