From 2ff76532710cdc49d6f9b3e5067533aa481cecd0 Mon Sep 17 00:00:00 2001 From: Waket Zheng Date: Tue, 27 May 2025 14:55:52 +0800 Subject: [PATCH] chore: improve type hints (#2867) --- pyproject.toml | 4 +- starlette/_exception_handler.py | 4 +- starlette/_utils.py | 35 +++---- starlette/applications.py | 43 ++++---- starlette/authentication.py | 19 ++-- starlette/background.py | 9 +- starlette/concurrency.py | 17 ++-- starlette/config.py | 45 ++++---- starlette/convertors.py | 12 +-- starlette/datastructures.py | 136 +++++++++++++------------ starlette/endpoints.py | 13 +-- starlette/formparsers.py | 9 +- starlette/middleware/authentication.py | 6 +- starlette/middleware/base.py | 17 ++-- starlette/middleware/cors.py | 10 +- starlette/middleware/errors.py | 4 +- starlette/middleware/exceptions.py | 7 +- starlette/middleware/gzip.py | 4 +- starlette/middleware/sessions.py | 4 +- starlette/middleware/trustedhost.py | 4 +- starlette/middleware/wsgi.py | 21 ++-- starlette/requests.py | 33 +++--- starlette/responses.py | 43 ++++---- starlette/routing.py | 108 ++++++++++---------- starlette/schemas.py | 16 +-- starlette/staticfiles.py | 4 +- starlette/templating.py | 51 +++++----- starlette/testclient.py | 93 +++++++++-------- starlette/types.py | 30 +++--- starlette/websockets.py | 19 ++-- tests/test_config.py | 3 +- tests/test_exceptions.py | 4 +- tests/test_formparsers.py | 34 +++---- tests/test_routing.py | 21 ++-- tests/test_staticfiles.py | 4 +- 35 files changed, 459 insertions(+), 427 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 02a3820f..e86f49d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,7 +91,7 @@ source_pkgs = ["starlette", "tests"] [tool.coverage.report] exclude_lines = [ "pragma: no cover", - "if typing.TYPE_CHECKING:", - "@typing.overload", + "if TYPE_CHECKING:", + "@overload", "raise NotImplementedError", ] diff --git a/starlette/_exception_handler.py b/starlette/_exception_handler.py index 72bc89d9..bcb96c9f 100644 --- a/starlette/_exception_handler.py +++ b/starlette/_exception_handler.py @@ -1,6 +1,6 @@ from __future__ import annotations -import typing +from typing import Any from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool @@ -9,7 +9,7 @@ from starlette.requests import Request from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send from starlette.websockets import WebSocket -ExceptionHandlers = dict[typing.Any, ExceptionHandler] +ExceptionHandlers = dict[Any, ExceptionHandler] StatusHandlers = dict[int, ExceptionHandler] diff --git a/starlette/_utils.py b/starlette/_utils.py index 8001c472..a35ca82d 100644 --- a/starlette/_utils.py +++ b/starlette/_utils.py @@ -3,8 +3,9 @@ from __future__ import annotations import functools import inspect import sys -import typing -from contextlib import contextmanager +from collections.abc import Awaitable, Generator +from contextlib import AbstractAsyncContextManager, contextmanager +from typing import Any, Callable, Generic, Protocol, TypeVar, overload from starlette.types import Scope @@ -20,58 +21,58 @@ if sys.version_info < (3, 11): # pragma: no cover except ImportError: has_exceptiongroups = False -T = typing.TypeVar("T") -AwaitableCallable = typing.Callable[..., typing.Awaitable[T]] +T = TypeVar("T") +AwaitableCallable = Callable[..., Awaitable[T]] -@typing.overload +@overload def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]: ... -@typing.overload -def is_async_callable(obj: typing.Any) -> TypeGuard[AwaitableCallable[typing.Any]]: ... +@overload +def is_async_callable(obj: Any) -> TypeGuard[AwaitableCallable[Any]]: ... -def is_async_callable(obj: typing.Any) -> typing.Any: +def is_async_callable(obj: Any) -> Any: while isinstance(obj, functools.partial): obj = obj.func return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__)) -T_co = typing.TypeVar("T_co", covariant=True) +T_co = TypeVar("T_co", covariant=True) -class AwaitableOrContextManager(typing.Awaitable[T_co], typing.AsyncContextManager[T_co], typing.Protocol[T_co]): ... +class AwaitableOrContextManager(Awaitable[T_co], AbstractAsyncContextManager[T_co], Protocol[T_co]): ... -class SupportsAsyncClose(typing.Protocol): +class SupportsAsyncClose(Protocol): async def close(self) -> None: ... # pragma: no cover -SupportsAsyncCloseType = typing.TypeVar("SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False) +SupportsAsyncCloseType = TypeVar("SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False) -class AwaitableOrContextManagerWrapper(typing.Generic[SupportsAsyncCloseType]): +class AwaitableOrContextManagerWrapper(Generic[SupportsAsyncCloseType]): __slots__ = ("aw", "entered") - def __init__(self, aw: typing.Awaitable[SupportsAsyncCloseType]) -> None: + def __init__(self, aw: Awaitable[SupportsAsyncCloseType]) -> None: self.aw = aw - def __await__(self) -> typing.Generator[typing.Any, None, SupportsAsyncCloseType]: + def __await__(self) -> Generator[Any, None, SupportsAsyncCloseType]: return self.aw.__await__() async def __aenter__(self) -> SupportsAsyncCloseType: self.entered = await self.aw return self.entered - async def __aexit__(self, *args: typing.Any) -> None | bool: + async def __aexit__(self, *args: Any) -> None | bool: await self.entered.close() return None @contextmanager -def collapse_excgroups() -> typing.Generator[None, None, None]: +def collapse_excgroups() -> Generator[None, None, None]: try: yield except BaseException as exc: diff --git a/starlette/applications.py b/starlette/applications.py index 6df5a707..32f6e560 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -1,8 +1,9 @@ from __future__ import annotations import sys -import typing import warnings +from collections.abc import Awaitable, Mapping, Sequence +from typing import Any, Callable, TypeVar if sys.version_info >= (3, 10): # pragma: no cover from typing import ParamSpec @@ -20,7 +21,7 @@ from starlette.routing import BaseRoute, Router from starlette.types import ASGIApp, ExceptionHandler, Lifespan, Receive, Scope, Send from starlette.websockets import WebSocket -AppType = typing.TypeVar("AppType", bound="Starlette") +AppType = TypeVar("AppType", bound="Starlette") P = ParamSpec("P") @@ -30,11 +31,11 @@ class Starlette: def __init__( self: AppType, debug: bool = False, - routes: typing.Sequence[BaseRoute] | None = None, - middleware: typing.Sequence[Middleware] | None = None, - exception_handlers: typing.Mapping[typing.Any, ExceptionHandler] | None = None, - on_startup: typing.Sequence[typing.Callable[[], typing.Any]] | None = None, - on_shutdown: typing.Sequence[typing.Callable[[], typing.Any]] | None = None, + routes: Sequence[BaseRoute] | None = None, + middleware: Sequence[Middleware] | None = None, + exception_handlers: Mapping[Any, ExceptionHandler] | None = None, + on_startup: Sequence[Callable[[], Any]] | None = None, + on_shutdown: Sequence[Callable[[], Any]] | None = None, lifespan: Lifespan[AppType] | None = None, ) -> None: """Initializes the application. @@ -79,7 +80,7 @@ class Starlette: def build_middleware_stack(self) -> ASGIApp: debug = self.debug error_handler = None - exception_handlers: dict[typing.Any, typing.Callable[[Request, Exception], Response]] = {} + exception_handlers: dict[Any, Callable[[Request, Exception], Response]] = {} for key, value in self.exception_handlers.items(): if key in (500, Exception): @@ -102,7 +103,7 @@ class Starlette: def routes(self) -> list[BaseRoute]: return self.router.routes - def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: + def url_path_for(self, name: str, /, **path_params: Any) -> URLPath: return self.router.url_path_for(name, **path_params) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: @@ -111,7 +112,7 @@ class Starlette: self.middleware_stack = self.build_middleware_stack() await self.middleware_stack(scope, receive, send) - def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg] + def on_event(self, event_type: str) -> Callable: # type: ignore[type-arg] return self.router.on_event(event_type) # pragma: no cover def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None: @@ -140,14 +141,14 @@ class Starlette: def add_event_handler( self, event_type: str, - func: typing.Callable, # type: ignore[type-arg] + func: Callable, # type: ignore[type-arg] ) -> None: # pragma: no cover self.router.add_event_handler(event_type, func) def add_route( self, path: str, - route: typing.Callable[[Request], typing.Awaitable[Response] | Response], + route: Callable[[Request], Awaitable[Response] | Response], methods: list[str] | None = None, name: str | None = None, include_in_schema: bool = True, @@ -157,19 +158,19 @@ class Starlette: def add_websocket_route( self, path: str, - route: typing.Callable[[WebSocket], typing.Awaitable[None]], + route: Callable[[WebSocket], Awaitable[None]], name: str | None = None, ) -> None: # pragma: no cover self.router.add_websocket_route(path, route, name=name) - def exception_handler(self, exc_class_or_status_code: int | type[Exception]) -> typing.Callable: # type: ignore[type-arg] + def exception_handler(self, exc_class_or_status_code: int | type[Exception]) -> Callable: # type: ignore[type-arg] warnings.warn( "The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. " "Refer to https://www.starlette.io/exceptions/ for the recommended approach.", DeprecationWarning, ) - def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] + def decorator(func: Callable) -> Callable: # type: ignore[type-arg] self.add_exception_handler(exc_class_or_status_code, func) return func @@ -181,7 +182,7 @@ class Starlette: methods: list[str] | None = None, name: str | None = None, include_in_schema: bool = True, - ) -> typing.Callable: # type: ignore[type-arg] + ) -> Callable: # type: ignore[type-arg] """ We no longer document this decorator style API, and its usage is discouraged. Instead you should use the following approach: @@ -195,7 +196,7 @@ class Starlette: DeprecationWarning, ) - def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] + def decorator(func: Callable) -> Callable: # type: ignore[type-arg] self.router.add_route( path, func, @@ -207,7 +208,7 @@ class Starlette: return decorator - def websocket_route(self, path: str, name: str | None = None) -> typing.Callable: # type: ignore[type-arg] + def websocket_route(self, path: str, name: str | None = None) -> Callable: # type: ignore[type-arg] """ We no longer document this decorator style API, and its usage is discouraged. Instead you should use the following approach: @@ -221,13 +222,13 @@ class Starlette: DeprecationWarning, ) - def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] + def decorator(func: Callable) -> Callable: # type: ignore[type-arg] self.router.add_websocket_route(path, func, name=name) return func return decorator - def middleware(self, middleware_type: str) -> typing.Callable: # type: ignore[type-arg] + def middleware(self, middleware_type: str) -> Callable: # type: ignore[type-arg] """ We no longer document this decorator style API, and its usage is discouraged. Instead you should use the following approach: @@ -242,7 +243,7 @@ class Starlette: ) assert middleware_type == "http", 'Currently only middleware("http") is supported.' - def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] + def decorator(func: Callable) -> Callable: # type: ignore[type-arg] self.add_middleware(BaseHTTPMiddleware, dispatch=func) return func diff --git a/starlette/authentication.py b/starlette/authentication.py index 4fd86641..a7138949 100644 --- a/starlette/authentication.py +++ b/starlette/authentication.py @@ -3,7 +3,8 @@ from __future__ import annotations import functools import inspect import sys -import typing +from collections.abc import Sequence +from typing import Any, Callable from urllib.parse import urlencode if sys.version_info >= (3, 10): # pragma: no cover @@ -20,7 +21,7 @@ from starlette.websockets import WebSocket _P = ParamSpec("_P") -def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool: +def has_required_scope(conn: HTTPConnection, scopes: Sequence[str]) -> bool: for scope in scopes: if scope not in conn.auth.scopes: return False @@ -28,15 +29,15 @@ def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bo def requires( - scopes: str | typing.Sequence[str], + scopes: str | Sequence[str], status_code: int = 403, redirect: str | None = None, -) -> typing.Callable[[typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]]: +) -> Callable[[Callable[_P, Any]], Callable[_P, Any]]: scopes_list = [scopes] if isinstance(scopes, str) else list(scopes) def decorator( - func: typing.Callable[_P, typing.Any], - ) -> typing.Callable[_P, typing.Any]: + func: Callable[_P, Any], + ) -> Callable[_P, Any]: sig = inspect.signature(func) for idx, parameter in enumerate(sig.parameters.values()): if parameter.name == "request" or parameter.name == "websocket": @@ -62,7 +63,7 @@ def requires( elif is_async_callable(func): # Handle async request/response functions. @functools.wraps(func) - async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any: + async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any: request = kwargs.get("request", args[idx] if idx < len(args) else None) assert isinstance(request, Request) @@ -79,7 +80,7 @@ def requires( else: # Handle sync request/response functions. @functools.wraps(func) - def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any: + def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any: request = kwargs.get("request", args[idx] if idx < len(args) else None) assert isinstance(request, Request) @@ -106,7 +107,7 @@ class AuthenticationBackend: class AuthCredentials: - def __init__(self, scopes: typing.Sequence[str] | None = None): + def __init__(self, scopes: Sequence[str] | None = None): self.scopes = [] if scopes is None else list(scopes) diff --git a/starlette/background.py b/starlette/background.py index 0430fc08..8a4562cd 100644 --- a/starlette/background.py +++ b/starlette/background.py @@ -1,7 +1,8 @@ from __future__ import annotations import sys -import typing +from collections.abc import Sequence +from typing import Any, Callable if sys.version_info >= (3, 10): # pragma: no cover from typing import ParamSpec @@ -15,7 +16,7 @@ P = ParamSpec("P") class BackgroundTask: - def __init__(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None: + def __init__(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None: self.func = func self.args = args self.kwargs = kwargs @@ -29,10 +30,10 @@ class BackgroundTask: class BackgroundTasks(BackgroundTask): - def __init__(self, tasks: typing.Sequence[BackgroundTask] | None = None): + def __init__(self, tasks: Sequence[BackgroundTask] | None = None): self.tasks = list(tasks) if tasks else [] - def add_task(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None: + def add_task(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None: task = BackgroundTask(func, *args, **kwargs) self.tasks.append(task) diff --git a/starlette/concurrency.py b/starlette/concurrency.py index 494f3420..79864753 100644 --- a/starlette/concurrency.py +++ b/starlette/concurrency.py @@ -2,8 +2,9 @@ from __future__ import annotations import functools import sys -import typing import warnings +from collections.abc import AsyncIterator, Coroutine, Iterable, Iterator +from typing import Callable, TypeVar import anyio.to_thread @@ -13,10 +14,10 @@ else: # pragma: no cover from typing_extensions import ParamSpec P = ParamSpec("P") -T = typing.TypeVar("T") +T = TypeVar("T") -async def run_until_first_complete(*args: tuple[typing.Callable, dict]) -> None: # type: ignore[type-arg] +async def run_until_first_complete(*args: tuple[Callable, dict]) -> None: # type: ignore[type-arg] warnings.warn( "run_until_first_complete is deprecated and will be removed in a future version.", DeprecationWarning, @@ -24,7 +25,7 @@ async def run_until_first_complete(*args: tuple[typing.Callable, dict]) -> None: async with anyio.create_task_group() as task_group: - async def run(func: typing.Callable[[], typing.Coroutine]) -> None: # type: ignore[type-arg] + async def run(func: Callable[[], Coroutine]) -> None: # type: ignore[type-arg] await func() task_group.cancel_scope.cancel() @@ -32,7 +33,7 @@ async def run_until_first_complete(*args: tuple[typing.Callable, dict]) -> None: task_group.start_soon(run, functools.partial(func, **kwargs)) -async def run_in_threadpool(func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: +async def run_in_threadpool(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: func = functools.partial(func, *args, **kwargs) return await anyio.to_thread.run_sync(func) @@ -41,7 +42,7 @@ class _StopIteration(Exception): pass -def _next(iterator: typing.Iterator[T]) -> T: +def _next(iterator: Iterator[T]) -> T: # We can't raise `StopIteration` from within the threadpool iterator # and catch it outside that context, so we coerce them into a different # exception type. @@ -52,8 +53,8 @@ def _next(iterator: typing.Iterator[T]) -> T: async def iterate_in_threadpool( - iterator: typing.Iterable[T], -) -> typing.AsyncIterator[T]: + iterator: Iterable[T], +) -> AsyncIterator[T]: as_iterator = iter(iterator) while True: try: diff --git a/starlette/config.py b/starlette/config.py index ca15c564..091f857f 100644 --- a/starlette/config.py +++ b/starlette/config.py @@ -1,9 +1,10 @@ from __future__ import annotations import os -import typing import warnings +from collections.abc import Iterator, Mapping, MutableMapping from pathlib import Path +from typing import Any, Callable, TypeVar, overload class undefined: @@ -14,8 +15,8 @@ class EnvironError(Exception): pass -class Environ(typing.MutableMapping[str, str]): - def __init__(self, environ: typing.MutableMapping[str, str] = os.environ): +class Environ(MutableMapping[str, str]): + def __init__(self, environ: MutableMapping[str, str] = os.environ): self._environ = environ self._has_been_read: set[str] = set() @@ -33,7 +34,7 @@ class Environ(typing.MutableMapping[str, str]): raise EnvironError(f"Attempting to delete environ['{key}'], but the value has already been read.") self._environ.__delitem__(key) - def __iter__(self) -> typing.Iterator[str]: + def __iter__(self) -> Iterator[str]: return iter(self._environ) def __len__(self) -> int: @@ -42,14 +43,14 @@ class Environ(typing.MutableMapping[str, str]): environ = Environ() -T = typing.TypeVar("T") +T = TypeVar("T") class Config: def __init__( self, env_file: str | Path | None = None, - environ: typing.Mapping[str, str] = environ, + environ: Mapping[str, str] = environ, env_prefix: str = "", ) -> None: self.environ = environ @@ -61,40 +62,40 @@ class Config: else: self.file_values = self._read_file(env_file) - @typing.overload + @overload def __call__(self, key: str, *, default: None) -> str | None: ... - @typing.overload + @overload def __call__(self, key: str, cast: type[T], default: T = ...) -> T: ... - @typing.overload + @overload def __call__(self, key: str, cast: type[str] = ..., default: str = ...) -> str: ... - @typing.overload + @overload def __call__( self, key: str, - cast: typing.Callable[[typing.Any], T] = ..., - default: typing.Any = ..., + cast: Callable[[Any], T] = ..., + default: Any = ..., ) -> T: ... - @typing.overload + @overload def __call__(self, key: str, cast: type[str] = ..., default: T = ...) -> T | str: ... def __call__( self, key: str, - cast: typing.Callable[[typing.Any], typing.Any] | None = None, - default: typing.Any = undefined, - ) -> typing.Any: + cast: Callable[[Any], Any] | None = None, + default: Any = undefined, + ) -> Any: return self.get(key, cast, default) def get( self, key: str, - cast: typing.Callable[[typing.Any], typing.Any] | None = None, - default: typing.Any = undefined, - ) -> typing.Any: + cast: Callable[[Any], Any] | None = None, + default: Any = undefined, + ) -> Any: key = self.env_prefix + key if key in self.environ: value = self.environ[key] @@ -121,9 +122,9 @@ class Config: def _perform_cast( self, key: str, - value: typing.Any, - cast: typing.Callable[[typing.Any], typing.Any] | None = None, - ) -> typing.Any: + value: Any, + cast: Callable[[Any], Any] | None = None, + ) -> Any: if cast is None or value is None: return value elif cast is bool and isinstance(value, str): diff --git a/starlette/convertors.py b/starlette/convertors.py index 84df87a5..72b1cf9f 100644 --- a/starlette/convertors.py +++ b/starlette/convertors.py @@ -1,14 +1,14 @@ from __future__ import annotations import math -import typing import uuid +from typing import Any, ClassVar, Generic, TypeVar -T = typing.TypeVar("T") +T = TypeVar("T") -class Convertor(typing.Generic[T]): - regex: typing.ClassVar[str] = "" +class Convertor(Generic[T]): + regex: ClassVar[str] = "" def convert(self, value: str) -> T: raise NotImplementedError() # pragma: no cover @@ -76,7 +76,7 @@ class UUIDConvertor(Convertor[uuid.UUID]): return str(value) -CONVERTOR_TYPES: dict[str, Convertor[typing.Any]] = { +CONVERTOR_TYPES: dict[str, Convertor[Any]] = { "str": StringConvertor(), "path": PathConvertor(), "int": IntegerConvertor(), @@ -85,5 +85,5 @@ CONVERTOR_TYPES: dict[str, Convertor[typing.Any]] = { } -def register_url_convertor(key: str, convertor: Convertor[typing.Any]) -> None: +def register_url_convertor(key: str, convertor: Convertor[Any]) -> None: CONVERTOR_TYPES[key] = convertor diff --git a/starlette/datastructures.py b/starlette/datastructures.py index f5d74d25..70dacd02 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -1,23 +1,31 @@ from __future__ import annotations -import typing +from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, MutableMapping, Sequence, ValuesView from shlex import shlex +from typing import ( + Any, + BinaryIO, + NamedTuple, + TypeVar, + Union, + cast, +) from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit from starlette.concurrency import run_in_threadpool from starlette.types import Scope -class Address(typing.NamedTuple): +class Address(NamedTuple): host: str port: int -_KeyType = typing.TypeVar("_KeyType") +_KeyType = TypeVar("_KeyType") # Mapping keys are invariant but their values are covariant since # you can only read them # that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()` -_CovariantValueType = typing.TypeVar("_CovariantValueType", covariant=True) +_CovariantValueType = TypeVar("_CovariantValueType", covariant=True) class URL: @@ -25,7 +33,7 @@ class URL: self, url: str = "", scope: Scope | None = None, - **components: typing.Any, + **components: Any, ) -> None: if scope is not None: assert not url, 'Cannot set both "url" and "scope".' @@ -107,7 +115,7 @@ class URL: def is_secure(self) -> bool: return self.scheme in ("https", "wss") - def replace(self, **kwargs: typing.Any) -> URL: + def replace(self, **kwargs: Any) -> URL: if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs: hostname = kwargs.pop("hostname", None) port = kwargs.pop("port", self.port) @@ -135,17 +143,17 @@ class URL: components = self.components._replace(**kwargs) return self.__class__(components.geturl()) - def include_query_params(self, **kwargs: typing.Any) -> URL: + def include_query_params(self, **kwargs: Any) -> URL: params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) params.update({str(key): str(value) for key, value in kwargs.items()}) query = urlencode(params.multi_items()) return self.replace(query=query) - def replace_query_params(self, **kwargs: typing.Any) -> URL: + def replace_query_params(self, **kwargs: Any) -> URL: query = urlencode([(str(key), str(value)) for key, value in kwargs.items()]) return self.replace(query=query) - def remove_query_params(self, keys: str | typing.Sequence[str]) -> URL: + def remove_query_params(self, keys: str | Sequence[str]) -> URL: if isinstance(keys, str): keys = [keys] params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) @@ -154,7 +162,7 @@ class URL: query = urlencode(params.multi_items()) return self.replace(query=query) - def __eq__(self, other: typing.Any) -> bool: + def __eq__(self, other: Any) -> bool: return str(self) == str(other) def __str__(self) -> str: @@ -217,8 +225,8 @@ class Secret: return bool(self._value) -class CommaSeparatedStrings(typing.Sequence[str]): - def __init__(self, value: str | typing.Sequence[str]): +class CommaSeparatedStrings(Sequence[str]): + def __init__(self, value: str | Sequence[str]): if isinstance(value, str): splitter = shlex(value, posix=True) splitter.whitespace = "," @@ -230,10 +238,10 @@ class CommaSeparatedStrings(typing.Sequence[str]): def __len__(self) -> int: return len(self._items) - def __getitem__(self, index: int | slice) -> typing.Any: + def __getitem__(self, index: int | slice) -> Any: return self._items[index] - def __iter__(self) -> typing.Iterator[str]: + def __iter__(self) -> Iterator[str]: return iter(self._items) def __repr__(self) -> str: @@ -245,47 +253,47 @@ class CommaSeparatedStrings(typing.Sequence[str]): return ", ".join(repr(item) for item in self) -class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]): +class ImmutableMultiDict(Mapping[_KeyType, _CovariantValueType]): _dict: dict[_KeyType, _CovariantValueType] def __init__( self, *args: ImmutableMultiDict[_KeyType, _CovariantValueType] - | typing.Mapping[_KeyType, _CovariantValueType] - | typing.Iterable[tuple[_KeyType, _CovariantValueType]], - **kwargs: typing.Any, + | Mapping[_KeyType, _CovariantValueType] + | Iterable[tuple[_KeyType, _CovariantValueType]], + **kwargs: Any, ) -> None: assert len(args) < 2, "Too many arguments." - value: typing.Any = args[0] if args else [] + value: Any = args[0] if args else [] if kwargs: value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items() if not value: - _items: list[tuple[typing.Any, typing.Any]] = [] + _items: list[tuple[Any, Any]] = [] elif hasattr(value, "multi_items"): - value = typing.cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value) + value = cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value) _items = list(value.multi_items()) elif hasattr(value, "items"): - value = typing.cast(typing.Mapping[_KeyType, _CovariantValueType], value) + value = cast(Mapping[_KeyType, _CovariantValueType], value) _items = list(value.items()) else: - value = typing.cast("list[tuple[typing.Any, typing.Any]]", value) + value = cast("list[tuple[Any, Any]]", value) _items = list(value) self._dict = {k: v for k, v in _items} self._list = _items - def getlist(self, key: typing.Any) -> list[_CovariantValueType]: + def getlist(self, key: Any) -> list[_CovariantValueType]: return [item_value for item_key, item_value in self._list if item_key == key] - def keys(self) -> typing.KeysView[_KeyType]: + def keys(self) -> KeysView[_KeyType]: return self._dict.keys() - def values(self) -> typing.ValuesView[_CovariantValueType]: + def values(self) -> ValuesView[_CovariantValueType]: return self._dict.values() - def items(self) -> typing.ItemsView[_KeyType, _CovariantValueType]: + def items(self) -> ItemsView[_KeyType, _CovariantValueType]: return self._dict.items() def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]: @@ -294,16 +302,16 @@ class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]): def __getitem__(self, key: _KeyType) -> _CovariantValueType: return self._dict[key] - def __contains__(self, key: typing.Any) -> bool: + def __contains__(self, key: Any) -> bool: return key in self._dict - def __iter__(self) -> typing.Iterator[_KeyType]: + def __iter__(self) -> Iterator[_KeyType]: return iter(self.keys()) def __len__(self) -> int: return len(self._dict) - def __eq__(self, other: typing.Any) -> bool: + def __eq__(self, other: Any) -> bool: if not isinstance(other, self.__class__): return False return sorted(self._list) == sorted(other._list) @@ -314,24 +322,24 @@ class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]): return f"{class_name}({items!r})" -class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]): - def __setitem__(self, key: typing.Any, value: typing.Any) -> None: +class MultiDict(ImmutableMultiDict[Any, Any]): + def __setitem__(self, key: Any, value: Any) -> None: self.setlist(key, [value]) - def __delitem__(self, key: typing.Any) -> None: + def __delitem__(self, key: Any) -> None: self._list = [(k, v) for k, v in self._list if k != key] del self._dict[key] - def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any: + def pop(self, key: Any, default: Any = None) -> Any: self._list = [(k, v) for k, v in self._list if k != key] return self._dict.pop(key, default) - def popitem(self) -> tuple[typing.Any, typing.Any]: + def popitem(self) -> tuple[Any, Any]: key, value = self._dict.popitem() self._list = [(k, v) for k, v in self._list if k != key] return key, value - def poplist(self, key: typing.Any) -> list[typing.Any]: + def poplist(self, key: Any) -> list[Any]: values = [v for k, v in self._list if k == key] self.pop(key) return values @@ -340,14 +348,14 @@ class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]): self._dict.clear() self._list.clear() - def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any: + def setdefault(self, key: Any, default: Any = None) -> Any: if key not in self: self._dict[key] = default self._list.append((key, default)) return self[key] - def setlist(self, key: typing.Any, values: list[typing.Any]) -> None: + def setlist(self, key: Any, values: list[Any]) -> None: if not values: self.pop(key, None) else: @@ -355,14 +363,14 @@ class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]): self._list = existing_items + [(key, value) for value in values] self._dict[key] = values[-1] - def append(self, key: typing.Any, value: typing.Any) -> None: + def append(self, key: Any, value: Any) -> None: self._list.append((key, value)) self._dict[key] = value def update( self, - *args: MultiDict | typing.Mapping[typing.Any, typing.Any] | list[tuple[typing.Any, typing.Any]], - **kwargs: typing.Any, + *args: MultiDict | Mapping[Any, Any] | list[tuple[Any, Any]], + **kwargs: Any, ) -> None: value = MultiDict(*args, **kwargs) existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()] @@ -377,12 +385,8 @@ class QueryParams(ImmutableMultiDict[str, str]): def __init__( self, - *args: ImmutableMultiDict[typing.Any, typing.Any] - | typing.Mapping[typing.Any, typing.Any] - | list[tuple[typing.Any, typing.Any]] - | str - | bytes, - **kwargs: typing.Any, + *args: ImmutableMultiDict[Any, Any] | Mapping[Any, Any] | list[tuple[Any, Any]] | str | bytes, + **kwargs: Any, ) -> None: assert len(args) < 2, "Too many arguments." @@ -413,7 +417,7 @@ class UploadFile: def __init__( self, - file: typing.BinaryIO, + file: BinaryIO, *, size: int | None = None, filename: str | None = None, @@ -464,14 +468,14 @@ class UploadFile: return f"{self.__class__.__name__}(filename={self.filename!r}, size={self.size!r}, headers={self.headers!r})" -class FormData(ImmutableMultiDict[str, typing.Union[UploadFile, str]]): +class FormData(ImmutableMultiDict[str, Union[UploadFile, str]]): """ An immutable multidict, containing both file uploads and text input. """ def __init__( self, - *args: FormData | typing.Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]], + *args: FormData | Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]], **kwargs: str | UploadFile, ) -> None: super().__init__(*args, **kwargs) @@ -482,16 +486,16 @@ class FormData(ImmutableMultiDict[str, typing.Union[UploadFile, str]]): await value.close() -class Headers(typing.Mapping[str, str]): +class Headers(Mapping[str, str]): """ An immutable, case-insensitive multidict. """ def __init__( self, - headers: typing.Mapping[str, str] | None = None, + headers: Mapping[str, str] | None = None, raw: list[tuple[bytes, bytes]] | None = None, - scope: typing.MutableMapping[str, typing.Any] | None = None, + scope: MutableMapping[str, Any] | None = None, ) -> None: self._list: list[tuple[bytes, bytes]] = [] if headers is not None: @@ -533,20 +537,20 @@ class Headers(typing.Mapping[str, str]): return header_value.decode("latin-1") raise KeyError(key) - def __contains__(self, key: typing.Any) -> bool: + def __contains__(self, key: Any) -> bool: get_header_key = key.lower().encode("latin-1") for header_key, header_value in self._list: if header_key == get_header_key: return True return False - def __iter__(self) -> typing.Iterator[typing.Any]: + def __iter__(self) -> Iterator[Any]: return iter(self.keys()) def __len__(self) -> int: return len(self._list) - def __eq__(self, other: typing.Any) -> bool: + def __eq__(self, other: Any) -> bool: if not isinstance(other, Headers): return False return sorted(self._list) == sorted(other._list) @@ -596,14 +600,14 @@ class MutableHeaders(Headers): for idx in reversed(pop_indexes): del self._list[idx] - def __ior__(self, other: typing.Mapping[str, str]) -> MutableHeaders: - if not isinstance(other, typing.Mapping): + def __ior__(self, other: Mapping[str, str]) -> MutableHeaders: + if not isinstance(other, Mapping): raise TypeError(f"Expected a mapping but got {other.__class__.__name__}") self.update(other) return self - def __or__(self, other: typing.Mapping[str, str]) -> MutableHeaders: - if not isinstance(other, typing.Mapping): + def __or__(self, other: Mapping[str, str]) -> MutableHeaders: + if not isinstance(other, Mapping): raise TypeError(f"Expected a mapping but got {other.__class__.__name__}") new = self.mutablecopy() new.update(other) @@ -627,7 +631,7 @@ class MutableHeaders(Headers): self._list.append((set_key, set_value)) return value - def update(self, other: typing.Mapping[str, str]) -> None: + def update(self, other: Mapping[str, str]) -> None: for key, val in other.items(): self[key] = val @@ -653,22 +657,22 @@ class State: Used for `request.state` and `app.state`. """ - _state: dict[str, typing.Any] + _state: dict[str, Any] - def __init__(self, state: dict[str, typing.Any] | None = None): + def __init__(self, state: dict[str, Any] | None = None): if state is None: state = {} super().__setattr__("_state", state) - def __setattr__(self, key: typing.Any, value: typing.Any) -> None: + def __setattr__(self, key: Any, value: Any) -> None: self._state[key] = value - def __getattr__(self, key: typing.Any) -> typing.Any: + def __getattr__(self, key: Any) -> Any: try: return self._state[key] except KeyError: message = "'{}' object has no attribute '{}'" raise AttributeError(message.format(self.__class__.__name__, key)) - def __delattr__(self, key: typing.Any) -> None: + def __delattr__(self, key: Any) -> None: del self._state[key] diff --git a/starlette/endpoints.py b/starlette/endpoints.py index 10769026..2cdbeb11 100644 --- a/starlette/endpoints.py +++ b/starlette/endpoints.py @@ -1,7 +1,8 @@ from __future__ import annotations import json -import typing +from collections.abc import Generator +from typing import Any, Callable from starlette import status from starlette._utils import is_async_callable @@ -25,14 +26,14 @@ class HTTPEndpoint: if getattr(self, method.lower(), None) is not None ] - def __await__(self) -> typing.Generator[typing.Any, None, None]: + def __await__(self) -> Generator[Any, None, None]: return self.dispatch().__await__() async def dispatch(self) -> None: request = Request(self.scope, receive=self.receive) handler_name = "get" if request.method == "HEAD" and not hasattr(self, "head") else request.method.lower() - handler: typing.Callable[[Request], typing.Any] = getattr(self, handler_name, self.method_not_allowed) + handler: Callable[[Request], Any] = getattr(self, handler_name, self.method_not_allowed) is_async = is_async_callable(handler) if is_async: response = await handler(request) @@ -59,7 +60,7 @@ class WebSocketEndpoint: self.receive = receive self.send = send - def __await__(self) -> typing.Generator[typing.Any, None, None]: + def __await__(self) -> Generator[Any, None, None]: return self.dispatch().__await__() async def dispatch(self) -> None: @@ -83,7 +84,7 @@ class WebSocketEndpoint: finally: await self.on_disconnect(websocket, close_code) - async def decode(self, websocket: WebSocket, message: Message) -> typing.Any: + async def decode(self, websocket: WebSocket, message: Message) -> Any: if self.encoding == "text": if "text" not in message: await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) @@ -115,7 +116,7 @@ class WebSocketEndpoint: """Override to handle an incoming websocket connection""" await websocket.accept() - async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None: + async def on_receive(self, websocket: WebSocket, data: Any) -> None: """Override to handle an incoming websocket message""" async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None: diff --git a/starlette/formparsers.py b/starlette/formparsers.py index 4551d688..8e389dec 100644 --- a/starlette/formparsers.py +++ b/starlette/formparsers.py @@ -1,14 +1,15 @@ from __future__ import annotations -import typing +from collections.abc import AsyncGenerator from dataclasses import dataclass, field from enum import Enum from tempfile import SpooledTemporaryFile +from typing import TYPE_CHECKING from urllib.parse import unquote_plus from starlette.datastructures import FormData, Headers, UploadFile -if typing.TYPE_CHECKING: +if TYPE_CHECKING: import python_multipart as multipart from python_multipart.multipart import MultipartCallbacks, QuerystringCallbacks, parse_options_header else: @@ -54,7 +55,7 @@ class MultiPartException(Exception): class FormParser: - def __init__(self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]) -> None: + def __init__(self, headers: Headers, stream: AsyncGenerator[bytes, None]) -> None: assert multipart is not None, "The `python-multipart` library must be installed to use form parsing." self.headers = headers self.stream = stream @@ -130,7 +131,7 @@ class MultiPartParser: def __init__( self, headers: Headers, - stream: typing.AsyncGenerator[bytes, None], + stream: AsyncGenerator[bytes, None], *, max_files: int | float = 1000, max_fields: int | float = 1000, diff --git a/starlette/middleware/authentication.py b/starlette/middleware/authentication.py index 8555ee07..77fc742d 100644 --- a/starlette/middleware/authentication.py +++ b/starlette/middleware/authentication.py @@ -1,6 +1,6 @@ from __future__ import annotations -import typing +from typing import Callable from starlette.authentication import ( AuthCredentials, @@ -18,11 +18,11 @@ class AuthenticationMiddleware: self, app: ASGIApp, backend: AuthenticationBackend, - on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] | None = None, + on_error: Callable[[HTTPConnection, AuthenticationError], Response] | None = None, ) -> None: self.app = app self.backend = backend - self.on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] = ( + self.on_error: Callable[[HTTPConnection, AuthenticationError], Response] = ( on_error if on_error is not None else self.default_on_error ) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index b49ab611..4d139c25 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -1,6 +1,7 @@ from __future__ import annotations -import typing +from collections.abc import AsyncGenerator, Awaitable, Mapping +from typing import Any, Callable, TypeVar import anyio @@ -9,9 +10,9 @@ from starlette.requests import ClientDisconnect, Request from starlette.responses import AsyncContentStream, Response from starlette.types import ASGIApp, Message, Receive, Scope, Send -RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] -DispatchFunction = typing.Callable[[Request, RequestResponseEndpoint], typing.Awaitable[Response]] -T = typing.TypeVar("T") +RequestResponseEndpoint = Callable[[Request], Awaitable[Response]] +DispatchFunction = Callable[[Request, RequestResponseEndpoint], Awaitable[Response]] +T = TypeVar("T") class _CachedRequest(Request): @@ -113,7 +114,7 @@ class BaseHTTPMiddleware: async with anyio.create_task_group() as task_group: - async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T: + async def wrap(func: Callable[[], Awaitable[T]]) -> T: result = await func() task_group.cancel_scope.cancel() return result @@ -158,7 +159,7 @@ class BaseHTTPMiddleware: assert message["type"] == "http.response.start" - async def body_stream() -> typing.AsyncGenerator[bytes, None]: + async def body_stream() -> AsyncGenerator[bytes, None]: async for message in recv_stream: assert message["type"] == "http.response.body" body = message.get("body", b"") @@ -191,9 +192,9 @@ class _StreamingResponse(Response): self, content: AsyncContentStream, status_code: int = 200, - headers: typing.Mapping[str, str] | None = None, + headers: Mapping[str, str] | None = None, media_type: str | None = None, - info: typing.Mapping[str, typing.Any] | None = None, + info: Mapping[str, Any] | None = None, ) -> None: self.info = info self.body_iterator = content diff --git a/starlette/middleware/cors.py b/starlette/middleware/cors.py index 61502691..ffd8aefc 100644 --- a/starlette/middleware/cors.py +++ b/starlette/middleware/cors.py @@ -2,7 +2,7 @@ from __future__ import annotations import functools import re -import typing +from collections.abc import Sequence from starlette.datastructures import Headers, MutableHeaders from starlette.responses import PlainTextResponse, Response @@ -16,12 +16,12 @@ class CORSMiddleware: def __init__( self, app: ASGIApp, - allow_origins: typing.Sequence[str] = (), - allow_methods: typing.Sequence[str] = ("GET",), - allow_headers: typing.Sequence[str] = (), + allow_origins: Sequence[str] = (), + allow_methods: Sequence[str] = ("GET",), + allow_headers: Sequence[str] = (), allow_credentials: bool = False, allow_origin_regex: str | None = None, - expose_headers: typing.Sequence[str] = (), + expose_headers: Sequence[str] = (), max_age: int = 600, ) -> None: if "*" in allow_methods: diff --git a/starlette/middleware/errors.py b/starlette/middleware/errors.py index 76ad776b..60b96a5c 100644 --- a/starlette/middleware/errors.py +++ b/starlette/middleware/errors.py @@ -4,7 +4,7 @@ import html import inspect import sys import traceback -import typing +from typing import Any, Callable from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool @@ -140,7 +140,7 @@ class ServerErrorMiddleware: def __init__( self, app: ASGIApp, - handler: typing.Callable[[Request, Exception], typing.Any] | None = None, + handler: Callable[[Request, Exception], Any] | None = None, debug: bool = False, ) -> None: self.app = app diff --git a/starlette/middleware/exceptions.py b/starlette/middleware/exceptions.py index a99b44de..864c2238 100644 --- a/starlette/middleware/exceptions.py +++ b/starlette/middleware/exceptions.py @@ -1,6 +1,7 @@ from __future__ import annotations -import typing +from collections.abc import Mapping +from typing import Any, Callable from starlette._exception_handler import ( ExceptionHandlers, @@ -18,7 +19,7 @@ class ExceptionMiddleware: def __init__( self, app: ASGIApp, - handlers: typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] | None = None, + handlers: Mapping[Any, Callable[[Request, Exception], Response]] | None = None, debug: bool = False, ) -> None: self.app = app @@ -35,7 +36,7 @@ class ExceptionMiddleware: def add_exception_handler( self, exc_class_or_status_code: int | type[Exception], - handler: typing.Callable[[Request, Exception], Response], + handler: Callable[[Request, Exception], Response], ) -> None: if isinstance(exc_class_or_status_code, int): self._status_handlers[exc_class_or_status_code] = handler diff --git a/starlette/middleware/gzip.py b/starlette/middleware/gzip.py index c7fd5b77..502f0552 100644 --- a/starlette/middleware/gzip.py +++ b/starlette/middleware/gzip.py @@ -1,6 +1,6 @@ import gzip import io -import typing +from typing import NoReturn from starlette.datastructures import Headers, MutableHeaders from starlette.types import ASGIApp, Message, Receive, Scope, Send @@ -137,5 +137,5 @@ class GZipResponder(IdentityResponder): return body -async def unattached_send(message: Message) -> typing.NoReturn: +async def unattached_send(message: Message) -> NoReturn: raise RuntimeError("send awaitable not set") # pragma: no cover diff --git a/starlette/middleware/sessions.py b/starlette/middleware/sessions.py index 5f9fcd88..1b95db4b 100644 --- a/starlette/middleware/sessions.py +++ b/starlette/middleware/sessions.py @@ -1,8 +1,8 @@ from __future__ import annotations import json -import typing from base64 import b64decode, b64encode +from typing import Literal import itsdangerous from itsdangerous.exc import BadSignature @@ -20,7 +20,7 @@ class SessionMiddleware: session_cookie: str = "session", max_age: int | None = 14 * 24 * 60 * 60, # 14 days, in seconds path: str = "/", - same_site: typing.Literal["lax", "strict", "none"] = "lax", + same_site: Literal["lax", "strict", "none"] = "lax", https_only: bool = False, domain: str | None = None, ) -> None: diff --git a/starlette/middleware/trustedhost.py b/starlette/middleware/trustedhost.py index 2d1c999e..98451e29 100644 --- a/starlette/middleware/trustedhost.py +++ b/starlette/middleware/trustedhost.py @@ -1,6 +1,6 @@ from __future__ import annotations -import typing +from collections.abc import Sequence from starlette.datastructures import URL, Headers from starlette.responses import PlainTextResponse, RedirectResponse, Response @@ -13,7 +13,7 @@ class TrustedHostMiddleware: def __init__( self, app: ASGIApp, - allowed_hosts: typing.Sequence[str] | None = None, + allowed_hosts: Sequence[str] | None = None, www_redirect: bool = True, ) -> None: if allowed_hosts is None: diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index 6e0a3fae..7d7fd0d1 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -3,8 +3,9 @@ from __future__ import annotations import io import math import sys -import typing import warnings +from collections.abc import MutableMapping +from typing import Any, Callable import anyio from anyio.abc import ObjectReceiveStream, ObjectSendStream @@ -18,7 +19,7 @@ warnings.warn( ) -def build_environ(scope: Scope, body: bytes) -> dict[str, typing.Any]: +def build_environ(scope: Scope, body: bytes) -> dict[str, Any]: """ Builds a scope and request body into a WSGI environ object. """ @@ -71,7 +72,7 @@ def build_environ(scope: Scope, body: bytes) -> dict[str, typing.Any]: class WSGIMiddleware: - def __init__(self, app: typing.Callable[..., typing.Any]) -> None: + def __init__(self, app: Callable[..., Any]) -> None: self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: @@ -81,17 +82,17 @@ class WSGIMiddleware: class WSGIResponder: - stream_send: ObjectSendStream[typing.MutableMapping[str, typing.Any]] - stream_receive: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]] + stream_send: ObjectSendStream[MutableMapping[str, Any]] + stream_receive: ObjectReceiveStream[MutableMapping[str, Any]] - def __init__(self, app: typing.Callable[..., typing.Any], scope: Scope) -> None: + def __init__(self, app: Callable[..., Any], scope: Scope) -> None: self.app = app self.scope = scope self.status = None self.response_headers = None self.stream_send, self.stream_receive = anyio.create_memory_object_stream(math.inf) self.response_started = False - self.exc_info: typing.Any = None + self.exc_info: Any = None async def __call__(self, receive: Receive, send: Send) -> None: body = b"" @@ -118,7 +119,7 @@ class WSGIResponder: self, status: str, response_headers: list[tuple[str, str]], - exc_info: typing.Any = None, + exc_info: Any = None, ) -> None: self.exc_info = exc_info if not self.response_started: # pragma: no branch @@ -140,8 +141,8 @@ class WSGIResponder: def wsgi( self, - environ: dict[str, typing.Any], - start_response: typing.Callable[..., typing.Any], + environ: dict[str, Any], + start_response: Callable[..., Any], ) -> None: for chunk in self.app(environ, start_response): anyio.from_thread.run( diff --git a/starlette/requests.py b/starlette/requests.py index 7dc04a74..628358d1 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -1,8 +1,9 @@ from __future__ import annotations import json -import typing +from collections.abc import AsyncGenerator, Iterator, Mapping from http import cookies as http_cookies +from typing import TYPE_CHECKING, Any, NoReturn, cast import anyio @@ -12,7 +13,7 @@ from starlette.exceptions import HTTPException from starlette.formparsers import FormParser, MultiPartException, MultiPartParser from starlette.types import Message, Receive, Scope, Send -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from python_multipart.multipart import parse_options_header from starlette.applications import Starlette @@ -67,7 +68,7 @@ class ClientDisconnect(Exception): pass -class HTTPConnection(typing.Mapping[str, typing.Any]): +class HTTPConnection(Mapping[str, Any]): """ A base class for incoming HTTP connections, that is used to provide any functionality that is common to both `Request` and `WebSocket`. @@ -77,10 +78,10 @@ class HTTPConnection(typing.Mapping[str, typing.Any]): assert scope["type"] in ("http", "websocket") self.scope = scope - def __getitem__(self, key: str) -> typing.Any: + def __getitem__(self, key: str) -> Any: return self.scope[key] - def __iter__(self) -> typing.Iterator[str]: + def __iter__(self) -> Iterator[str]: return iter(self.scope) def __len__(self) -> int: @@ -93,7 +94,7 @@ class HTTPConnection(typing.Mapping[str, typing.Any]): __hash__ = object.__hash__ @property - def app(self) -> typing.Any: + def app(self) -> Any: return self.scope["app"] @property @@ -132,7 +133,7 @@ class HTTPConnection(typing.Mapping[str, typing.Any]): return self._query_params @property - def path_params(self) -> dict[str, typing.Any]: + def path_params(self) -> dict[str, Any]: return self.scope.get("path_params", {}) @property @@ -155,17 +156,17 @@ class HTTPConnection(typing.Mapping[str, typing.Any]): return None @property - def session(self) -> dict[str, typing.Any]: + def session(self) -> dict[str, Any]: assert "session" in self.scope, "SessionMiddleware must be installed to access request.session" return self.scope["session"] # type: ignore[no-any-return] @property - def auth(self) -> typing.Any: + def auth(self) -> Any: assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth" return self.scope["auth"] @property - def user(self) -> typing.Any: + def user(self) -> Any: assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user" return self.scope["user"] @@ -179,7 +180,7 @@ class HTTPConnection(typing.Mapping[str, typing.Any]): self._state = State(self.scope["state"]) return self._state - def url_for(self, name: str, /, **path_params: typing.Any) -> URL: + def url_for(self, name: str, /, **path_params: Any) -> URL: url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app") if url_path_provider is None: raise RuntimeError("The `url_for` method can only be used inside a Starlette application or with a router.") @@ -187,11 +188,11 @@ class HTTPConnection(typing.Mapping[str, typing.Any]): return url_path.make_absolute_url(base_url=self.base_url) -async def empty_receive() -> typing.NoReturn: +async def empty_receive() -> NoReturn: raise RuntimeError("Receive channel has not been made available") -async def empty_send(message: Message) -> typing.NoReturn: +async def empty_send(message: Message) -> NoReturn: raise RuntimeError("Send channel has not been made available") @@ -209,13 +210,13 @@ class Request(HTTPConnection): @property def method(self) -> str: - return typing.cast(str, self.scope["method"]) + return cast(str, self.scope["method"]) @property def receive(self) -> Receive: return self._receive - async def stream(self) -> typing.AsyncGenerator[bytes, None]: + async def stream(self) -> AsyncGenerator[bytes, None]: if hasattr(self, "_body"): yield self._body yield b"" @@ -243,7 +244,7 @@ class Request(HTTPConnection): self._body = b"".join(chunks) return self._body - async def json(self) -> typing.Any: + async def json(self) -> Any: if not hasattr(self, "_json"): # pragma: no branch body = await self.body() self._json = json.loads(body) diff --git a/starlette/responses.py b/starlette/responses.py index c956ff6a..b9d334ea 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -7,13 +7,14 @@ import os import re import stat import sys -import typing import warnings +from collections.abc import AsyncIterable, Awaitable, Iterable, Mapping, Sequence from datetime import datetime from email.utils import format_datetime, formatdate from functools import partial from mimetypes import guess_type from secrets import token_hex +from typing import Any, Callable, Literal, Union from urllib.parse import quote import anyio @@ -33,9 +34,9 @@ class Response: def __init__( self, - content: typing.Any = None, + content: Any = None, status_code: int = 200, - headers: typing.Mapping[str, str] | None = None, + headers: Mapping[str, str] | None = None, media_type: str | None = None, background: BackgroundTask | None = None, ) -> None: @@ -46,14 +47,14 @@ class Response: self.body = self.render(content) self.init_headers(headers) - def render(self, content: typing.Any) -> bytes | memoryview: + def render(self, content: Any) -> bytes | memoryview: if content is None: return b"" if isinstance(content, (bytes, memoryview)): return content return content.encode(self.charset) # type: ignore - def init_headers(self, headers: typing.Mapping[str, str] | None = None) -> None: + def init_headers(self, headers: Mapping[str, str] | None = None) -> None: if headers is None: raw_headers: list[tuple[bytes, bytes]] = [] populate_content_length = True @@ -97,7 +98,7 @@ class Response: domain: str | None = None, secure: bool = False, httponly: bool = False, - samesite: typing.Literal["lax", "strict", "none"] | None = "lax", + samesite: Literal["lax", "strict", "none"] | None = "lax", partitioned: bool = False, ) -> None: cookie: http.cookies.BaseCookie[str] = http.cookies.SimpleCookie() @@ -139,7 +140,7 @@ class Response: domain: str | None = None, secure: bool = False, httponly: bool = False, - samesite: typing.Literal["lax", "strict", "none"] | None = "lax", + samesite: Literal["lax", "strict", "none"] | None = "lax", ) -> None: self.set_cookie( key, @@ -180,15 +181,15 @@ class JSONResponse(Response): def __init__( self, - content: typing.Any, + content: Any, status_code: int = 200, - headers: typing.Mapping[str, str] | None = None, + headers: Mapping[str, str] | None = None, media_type: str | None = None, background: BackgroundTask | None = None, ) -> None: super().__init__(content, status_code, headers, media_type, background) - def render(self, content: typing.Any) -> bytes: + def render(self, content: Any) -> bytes: return json.dumps( content, ensure_ascii=False, @@ -203,17 +204,17 @@ class RedirectResponse(Response): self, url: str | URL, status_code: int = 307, - headers: typing.Mapping[str, str] | None = None, + headers: Mapping[str, str] | None = None, background: BackgroundTask | None = None, ) -> None: super().__init__(content=b"", status_code=status_code, headers=headers, background=background) self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;") -Content = typing.Union[str, bytes, memoryview] -SyncContentStream = typing.Iterable[Content] -AsyncContentStream = typing.AsyncIterable[Content] -ContentStream = typing.Union[AsyncContentStream, SyncContentStream] +Content = Union[str, bytes, memoryview] +SyncContentStream = Iterable[Content] +AsyncContentStream = AsyncIterable[Content] +ContentStream = Union[AsyncContentStream, SyncContentStream] class StreamingResponse(Response): @@ -223,11 +224,11 @@ class StreamingResponse(Response): self, content: ContentStream, status_code: int = 200, - headers: typing.Mapping[str, str] | None = None, + headers: Mapping[str, str] | None = None, media_type: str | None = None, background: BackgroundTask | None = None, ) -> None: - if isinstance(content, typing.AsyncIterable): + if isinstance(content, AsyncIterable): self.body_iterator = content else: self.body_iterator = iterate_in_threadpool(content) @@ -269,7 +270,7 @@ class StreamingResponse(Response): with collapse_excgroups(): async with anyio.create_task_group() as task_group: - async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None: + async def wrap(func: Callable[[], Awaitable[None]]) -> None: await func() task_group.cancel_scope.cancel() @@ -300,7 +301,7 @@ class FileResponse(Response): self, path: str | os.PathLike[str], status_code: int = 200, - headers: typing.Mapping[str, str] | None = None, + headers: Mapping[str, str] | None = None, media_type: str | None = None, background: BackgroundTask | None = None, filename: str | None = None, @@ -504,11 +505,11 @@ class FileResponse(Response): def generate_multipart( self, - ranges: typing.Sequence[tuple[int, int]], + ranges: Sequence[tuple[int, int]], boundary: str, max_size: int, content_type: str, - ) -> tuple[int, typing.Callable[[int, int], bytes]]: + ) -> tuple[int, Callable[[int, int], bytes]]: r""" Multipart response headers generator. diff --git a/starlette/routing.py b/starlette/routing.py index add7df0c..6eb57f48 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -6,10 +6,12 @@ import inspect import re import traceback import types -import typing import warnings -from contextlib import asynccontextmanager +from collections.abc import Awaitable, Generator, Sequence +from contextlib import AbstractAsyncContextManager, AbstractContextManager, asynccontextmanager from enum import Enum +from re import Pattern +from typing import Any, Callable, TypeVar from starlette._exception_handler import wrap_app_handling_exceptions from starlette._utils import get_route_path, is_async_callable @@ -30,7 +32,7 @@ class NoMatchFound(Exception): if no matching route exists. """ - def __init__(self, name: str, path_params: dict[str, typing.Any]) -> None: + def __init__(self, name: str, path_params: dict[str, Any]) -> None: params = ", ".join(list(path_params.keys())) super().__init__(f'No route exists for name "{name}" and params "{params}".') @@ -41,7 +43,7 @@ class Match(Enum): FULL = 2 -def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover +def iscoroutinefunction_or_partial(obj: Any) -> bool: # pragma: no cover """ Correctly determines if an object is a coroutine function, including those wrapped in functools.partial objects. @@ -56,13 +58,13 @@ def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover def request_response( - func: typing.Callable[[Request], typing.Awaitable[Response] | Response], + func: Callable[[Request], Awaitable[Response] | Response], ) -> ASGIApp: """ Takes a function or coroutine `func(request) -> response`, and returns an ASGI application. """ - f: typing.Callable[[Request], typing.Awaitable[Response]] = ( + f: Callable[[Request], Awaitable[Response]] = ( func if is_async_callable(func) else functools.partial(run_in_threadpool, func) # type:ignore ) @@ -79,7 +81,7 @@ def request_response( def websocket_session( - func: typing.Callable[[WebSocket], typing.Awaitable[None]], + func: Callable[[WebSocket], Awaitable[None]], ) -> ASGIApp: """ Takes a coroutine `func(session)`, and returns an ASGI application. @@ -97,13 +99,13 @@ def websocket_session( return app -def get_name(endpoint: typing.Callable[..., typing.Any]) -> str: +def get_name(endpoint: Callable[..., Any]) -> str: return getattr(endpoint, "__name__", endpoint.__class__.__name__) def replace_params( path: str, - param_convertors: dict[str, Convertor[typing.Any]], + param_convertors: dict[str, Convertor[Any]], path_params: dict[str, str], ) -> tuple[str, dict[str, str]]: for key, value in list(path_params.items()): @@ -121,7 +123,7 @@ PARAM_REGEX = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}") def compile_path( path: str, -) -> tuple[typing.Pattern[str], str, dict[str, Convertor[typing.Any]]]: +) -> tuple[Pattern[str], str, dict[str, Convertor[Any]]]: """ Given a path string, like: "/{username:str}", or a host string, like: "{subdomain}.mydomain.org", return a three-tuple @@ -179,7 +181,7 @@ class BaseRoute: def matches(self, scope: Scope) -> tuple[Match, Scope]: raise NotImplementedError() # pragma: no cover - def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: + def url_path_for(self, name: str, /, **path_params: Any) -> URLPath: raise NotImplementedError() # pragma: no cover async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: @@ -209,12 +211,12 @@ class Route(BaseRoute): def __init__( self, path: str, - endpoint: typing.Callable[..., typing.Any], + endpoint: Callable[..., Any], *, methods: list[str] | None = None, name: str | None = None, include_in_schema: bool = True, - middleware: typing.Sequence[Middleware] | None = None, + middleware: Sequence[Middleware] | None = None, ) -> None: assert path.startswith("/"), "Routed paths must start with '/'" self.path = path @@ -248,7 +250,7 @@ class Route(BaseRoute): self.path_regex, self.path_format, self.param_convertors = compile_path(path) def matches(self, scope: Scope) -> tuple[Match, Scope]: - path_params: dict[str, typing.Any] + path_params: dict[str, Any] if scope["type"] == "http": route_path = get_route_path(scope) match = self.path_regex.match(route_path) @@ -265,7 +267,7 @@ class Route(BaseRoute): return Match.FULL, child_scope return Match.NONE, {} - def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: + def url_path_for(self, name: str, /, **path_params: Any) -> URLPath: seen_params = set(path_params.keys()) expected_params = set(self.param_convertors.keys()) @@ -287,7 +289,7 @@ class Route(BaseRoute): else: await self.app(scope, receive, send) - def __eq__(self, other: typing.Any) -> bool: + def __eq__(self, other: Any) -> bool: return ( isinstance(other, Route) and self.path == other.path @@ -306,10 +308,10 @@ class WebSocketRoute(BaseRoute): def __init__( self, path: str, - endpoint: typing.Callable[..., typing.Any], + endpoint: Callable[..., Any], *, name: str | None = None, - middleware: typing.Sequence[Middleware] | None = None, + middleware: Sequence[Middleware] | None = None, ) -> None: assert path.startswith("/"), "Routed paths must start with '/'" self.path = path @@ -333,7 +335,7 @@ class WebSocketRoute(BaseRoute): self.path_regex, self.path_format, self.param_convertors = compile_path(path) def matches(self, scope: Scope) -> tuple[Match, Scope]: - path_params: dict[str, typing.Any] + path_params: dict[str, Any] if scope["type"] == "websocket": route_path = get_route_path(scope) match = self.path_regex.match(route_path) @@ -347,7 +349,7 @@ class WebSocketRoute(BaseRoute): return Match.FULL, child_scope return Match.NONE, {} - def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: + def url_path_for(self, name: str, /, **path_params: Any) -> URLPath: seen_params = set(path_params.keys()) expected_params = set(self.param_convertors.keys()) @@ -361,7 +363,7 @@ class WebSocketRoute(BaseRoute): async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) - def __eq__(self, other: typing.Any) -> bool: + def __eq__(self, other: Any) -> bool: return isinstance(other, WebSocketRoute) and self.path == other.path and self.endpoint == other.endpoint def __repr__(self) -> str: @@ -373,10 +375,10 @@ class Mount(BaseRoute): self, path: str, app: ASGIApp | None = None, - routes: typing.Sequence[BaseRoute] | None = None, + routes: Sequence[BaseRoute] | None = None, name: str | None = None, *, - middleware: typing.Sequence[Middleware] | None = None, + middleware: Sequence[Middleware] | None = None, ) -> None: assert path == "" or path.startswith("/"), "Routed paths must start with '/'" assert app is not None or routes is not None, "Either 'app=...', or 'routes=' must be specified" @@ -397,7 +399,7 @@ class Mount(BaseRoute): return getattr(self._base_app, "routes", []) def matches(self, scope: Scope) -> tuple[Match, Scope]: - path_params: dict[str, typing.Any] + path_params: dict[str, Any] if scope["type"] in ("http", "websocket"): # pragma: no branch root_path = scope.get("root_path", "") route_path = get_route_path(scope) @@ -429,7 +431,7 @@ class Mount(BaseRoute): return Match.FULL, child_scope return Match.NONE, {} - def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: + def url_path_for(self, name: str, /, **path_params: Any) -> URLPath: if self.name is not None and name == self.name and "path" in path_params: # 'name' matches "". path_params["path"] = path_params["path"].lstrip("/") @@ -459,7 +461,7 @@ class Mount(BaseRoute): async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) - def __eq__(self, other: typing.Any) -> bool: + def __eq__(self, other: Any) -> bool: return isinstance(other, Mount) and self.path == other.path and self.app == other.app def __repr__(self) -> str: @@ -495,7 +497,7 @@ class Host(BaseRoute): return Match.FULL, child_scope return Match.NONE, {} - def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: + def url_path_for(self, name: str, /, **path_params: Any) -> URLPath: if self.name is not None and name == self.name and "path" in path_params: # 'name' matches "". path = path_params.pop("path") @@ -521,7 +523,7 @@ class Host(BaseRoute): async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) - def __eq__(self, other: typing.Any) -> bool: + def __eq__(self, other: Any) -> bool: return isinstance(other, Host) and self.host == other.host and self.app == other.app def __repr__(self) -> str: @@ -530,11 +532,11 @@ class Host(BaseRoute): return f"{class_name}(host={self.host!r}, name={name!r}, app={self.app!r})" -_T = typing.TypeVar("_T") +_T = TypeVar("_T") -class _AsyncLiftContextManager(typing.AsyncContextManager[_T]): - def __init__(self, cm: typing.ContextManager[_T]): +class _AsyncLiftContextManager(AbstractAsyncContextManager[_T]): + def __init__(self, cm: AbstractContextManager[_T]): self._cm = cm async def __aenter__(self) -> _T: @@ -550,12 +552,12 @@ class _AsyncLiftContextManager(typing.AsyncContextManager[_T]): def _wrap_gen_lifespan_context( - lifespan_context: typing.Callable[[typing.Any], typing.Generator[typing.Any, typing.Any, typing.Any]], -) -> typing.Callable[[typing.Any], typing.AsyncContextManager[typing.Any]]: + lifespan_context: Callable[[Any], Generator[Any, Any, Any]], +) -> Callable[[Any], AbstractAsyncContextManager[Any]]: cmgr = contextlib.contextmanager(lifespan_context) @functools.wraps(cmgr) - def wrapper(app: typing.Any) -> _AsyncLiftContextManager[typing.Any]: + def wrapper(app: Any) -> _AsyncLiftContextManager[Any]: return _AsyncLiftContextManager(cmgr(app)) return wrapper @@ -578,16 +580,16 @@ class _DefaultLifespan: class Router: def __init__( self, - routes: typing.Sequence[BaseRoute] | None = None, + routes: Sequence[BaseRoute] | None = None, redirect_slashes: bool = True, default: ASGIApp | None = None, - on_startup: typing.Sequence[typing.Callable[[], typing.Any]] | None = None, - on_shutdown: typing.Sequence[typing.Callable[[], typing.Any]] | None = None, + on_startup: Sequence[Callable[[], Any]] | None = None, + on_shutdown: Sequence[Callable[[], Any]] | None = None, # the generic to Lifespan[AppType] is the type of the top level application - # which the router cannot know statically, so we use typing.Any - lifespan: Lifespan[typing.Any] | None = None, + # which the router cannot know statically, so we use Any + lifespan: Lifespan[Any] | None = None, *, - middleware: typing.Sequence[Middleware] | None = None, + middleware: Sequence[Middleware] | None = None, ) -> None: self.routes = [] if routes is None else list(routes) self.redirect_slashes = redirect_slashes @@ -610,7 +612,7 @@ class Router: ) if lifespan is None: - self.lifespan_context: Lifespan[typing.Any] = _DefaultLifespan(self) + self.lifespan_context: Lifespan[Any] = _DefaultLifespan(self) elif inspect.isasyncgenfunction(lifespan): warnings.warn( @@ -652,7 +654,7 @@ class Router: response = PlainTextResponse("Not Found", status_code=404) await response(scope, receive, send) - def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: + def url_path_for(self, name: str, /, **path_params: Any) -> URLPath: for route in self.routes: try: return route.url_path_for(name, **path_params) @@ -686,7 +688,7 @@ class Router: startup and shutdown events. """ started = False - app: typing.Any = scope.get("app") + app: Any = scope.get("app") await receive() try: async with self.lifespan_context(app) as maybe_state: @@ -763,7 +765,7 @@ class Router: await self.default(scope, receive, send) - def __eq__(self, other: typing.Any) -> bool: + def __eq__(self, other: Any) -> bool: return isinstance(other, Router) and self.routes == other.routes def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover @@ -777,7 +779,7 @@ class Router: def add_route( self, path: str, - endpoint: typing.Callable[[Request], typing.Awaitable[Response] | Response], + endpoint: Callable[[Request], Awaitable[Response] | Response], methods: list[str] | None = None, name: str | None = None, include_in_schema: bool = True, @@ -794,7 +796,7 @@ class Router: def add_websocket_route( self, path: str, - endpoint: typing.Callable[[WebSocket], typing.Awaitable[None]], + endpoint: Callable[[WebSocket], Awaitable[None]], name: str | None = None, ) -> None: # pragma: no cover route = WebSocketRoute(path, endpoint=endpoint, name=name) @@ -806,7 +808,7 @@ class Router: methods: list[str] | None = None, name: str | None = None, include_in_schema: bool = True, - ) -> typing.Callable: # type: ignore[type-arg] + ) -> Callable: # type: ignore[type-arg] """ We no longer document this decorator style API, and its usage is discouraged. Instead you should use the following approach: @@ -820,7 +822,7 @@ class Router: DeprecationWarning, ) - def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] + def decorator(func: Callable) -> Callable: # type: ignore[type-arg] self.add_route( path, func, @@ -832,7 +834,7 @@ class Router: return decorator - def websocket_route(self, path: str, name: str | None = None) -> typing.Callable: # type: ignore[type-arg] + def websocket_route(self, path: str, name: str | None = None) -> Callable: # type: ignore[type-arg] """ We no longer document this decorator style API, and its usage is discouraged. Instead you should use the following approach: @@ -846,13 +848,13 @@ class Router: DeprecationWarning, ) - def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] + def decorator(func: Callable) -> Callable: # type: ignore[type-arg] self.add_websocket_route(path, func, name=name) return func return decorator - def add_event_handler(self, event_type: str, func: typing.Callable[[], typing.Any]) -> None: # pragma: no cover + def add_event_handler(self, event_type: str, func: Callable[[], Any]) -> None: # pragma: no cover assert event_type in ("startup", "shutdown") if event_type == "startup": @@ -860,14 +862,14 @@ class Router: else: self.on_shutdown.append(func) - def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg] + def on_event(self, event_type: str) -> Callable: # type: ignore[type-arg] warnings.warn( "The `on_event` decorator is deprecated, and will be removed in version 1.0.0. " "Refer to https://www.starlette.io/lifespan/ for recommended approach.", DeprecationWarning, ) - def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] + def decorator(func: Callable) -> Callable: # type: ignore[type-arg] self.add_event_handler(event_type, func) return func diff --git a/starlette/schemas.py b/starlette/schemas.py index bfc40e2a..e97f83ec 100644 --- a/starlette/schemas.py +++ b/starlette/schemas.py @@ -2,7 +2,7 @@ from __future__ import annotations import inspect import re -import typing +from typing import Any, Callable, NamedTuple from starlette.requests import Request from starlette.responses import Response @@ -17,23 +17,23 @@ except ModuleNotFoundError: # pragma: no cover class OpenAPIResponse(Response): media_type = "application/vnd.oai.openapi" - def render(self, content: typing.Any) -> bytes: + def render(self, content: Any) -> bytes: assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse." assert isinstance(content, dict), "The schema passed to OpenAPIResponse should be a dictionary." return yaml.dump(content, default_flow_style=False).encode("utf-8") -class EndpointInfo(typing.NamedTuple): +class EndpointInfo(NamedTuple): path: str http_method: str - func: typing.Callable[..., typing.Any] + func: Callable[..., Any] _remove_converter_pattern = re.compile(r":\w+}") class BaseSchemaGenerator: - def get_schema(self, routes: list[BaseRoute]) -> dict[str, typing.Any]: + def get_schema(self, routes: list[BaseRoute]) -> dict[str, Any]: raise NotImplementedError() # pragma: no cover def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]: @@ -94,7 +94,7 @@ class BaseSchemaGenerator: """ return _remove_converter_pattern.sub("}", path) - def parse_docstring(self, func_or_method: typing.Callable[..., typing.Any]) -> dict[str, typing.Any]: + def parse_docstring(self, func_or_method: Callable[..., Any]) -> dict[str, Any]: """ Given a function, parse the docstring as YAML and return a dictionary of info. """ @@ -125,10 +125,10 @@ class BaseSchemaGenerator: class SchemaGenerator(BaseSchemaGenerator): - def __init__(self, base_schema: dict[str, typing.Any]) -> None: + def __init__(self, base_schema: dict[str, Any]) -> None: self.base_schema = base_schema - def get_schema(self, routes: list[BaseRoute]) -> dict[str, typing.Any]: + def get_schema(self, routes: list[BaseRoute]) -> dict[str, Any]: schema = dict(self.base_schema) schema.setdefault("paths", {}) endpoints_info = self.get_endpoints(routes) diff --git a/starlette/staticfiles.py b/starlette/staticfiles.py index 637da648..7fba9aa9 100644 --- a/starlette/staticfiles.py +++ b/starlette/staticfiles.py @@ -4,8 +4,8 @@ import errno import importlib.util import os import stat -import typing from email.utils import parsedate +from typing import Union import anyio import anyio.to_thread @@ -16,7 +16,7 @@ from starlette.exceptions import HTTPException from starlette.responses import FileResponse, RedirectResponse, Response from starlette.types import Receive, Scope, Send -PathLike = typing.Union[str, "os.PathLike[str]"] +PathLike = Union[str, "os.PathLike[str]"] class NotModifiedResponse(Response): diff --git a/starlette/templating.py b/starlette/templating.py index f764858b..10fa0271 100644 --- a/starlette/templating.py +++ b/starlette/templating.py @@ -1,8 +1,9 @@ from __future__ import annotations -import typing import warnings +from collections.abc import Mapping, Sequence from os import PathLike +from typing import Any, Callable, cast, overload from starlette.background import BackgroundTask from starlette.datastructures import URL @@ -28,10 +29,10 @@ except ModuleNotFoundError: # pragma: no cover class _TemplateResponse(HTMLResponse): def __init__( self, - template: typing.Any, - context: dict[str, typing.Any], + template: Any, + context: dict[str, Any], status_code: int = 200, - headers: typing.Mapping[str, str] | None = None, + headers: Mapping[str, str] | None = None, media_type: str | None = None, background: BackgroundTask | None = None, ): @@ -63,30 +64,30 @@ class Jinja2Templates: return templates.TemplateResponse("index.html", {"request": request}) """ - @typing.overload + @overload def __init__( self, - directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]], + directory: str | PathLike[str] | Sequence[str | PathLike[str]], *, - context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None, - **env_options: typing.Any, + context_processors: list[Callable[[Request], dict[str, Any]]] | None = None, + **env_options: Any, ) -> None: ... - @typing.overload + @overload def __init__( self, *, env: jinja2.Environment, - context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None, + context_processors: list[Callable[[Request], dict[str, Any]]] | None = None, ) -> None: ... def __init__( self, - directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]] | None = None, + directory: str | PathLike[str] | Sequence[str | PathLike[str]] | None = None, *, - context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None, + context_processors: list[Callable[[Request], dict[str, Any]]] | None = None, env: jinja2.Environment | None = None, - **env_options: typing.Any, + **env_options: Any, ) -> None: if env_options: warnings.warn( @@ -105,8 +106,8 @@ class Jinja2Templates: def _create_env( self, - directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]], - **env_options: typing.Any, + directory: str | PathLike[str] | Sequence[str | PathLike[str]], + **env_options: Any, ) -> jinja2.Environment: loader = jinja2.FileSystemLoader(directory) env_options.setdefault("loader", loader) @@ -117,10 +118,10 @@ class Jinja2Templates: def _setup_env_defaults(self, env: jinja2.Environment) -> None: @pass_context def url_for( - context: dict[str, typing.Any], + context: dict[str, Any], name: str, /, - **path_params: typing.Any, + **path_params: Any, ) -> URL: request: Request = context["request"] return request.url_for(name, **path_params) @@ -130,32 +131,32 @@ class Jinja2Templates: def get_template(self, name: str) -> jinja2.Template: return self.env.get_template(name) - @typing.overload + @overload def TemplateResponse( self, request: Request, name: str, - context: dict[str, typing.Any] | None = None, + context: dict[str, Any] | None = None, status_code: int = 200, - headers: typing.Mapping[str, str] | None = None, + headers: Mapping[str, str] | None = None, media_type: str | None = None, background: BackgroundTask | None = None, ) -> _TemplateResponse: ... - @typing.overload + @overload def TemplateResponse( self, name: str, - context: dict[str, typing.Any] | None = None, + context: dict[str, Any] | None = None, status_code: int = 200, - headers: typing.Mapping[str, str] | None = None, + headers: Mapping[str, str] | None = None, media_type: str | None = None, background: BackgroundTask | None = None, ) -> _TemplateResponse: # Deprecated usage ... - def TemplateResponse(self, *args: typing.Any, **kwargs: typing.Any) -> _TemplateResponse: + def TemplateResponse(self, *args: Any, **kwargs: Any) -> _TemplateResponse: if args: if isinstance(args[0], str): # the first argument is template name (old style) warnings.warn( @@ -195,7 +196,7 @@ class Jinja2Templates: context = kwargs.get("context", {}) request = kwargs.get("request", context.get("request")) - name = typing.cast(str, kwargs["name"]) + name = cast(str, kwargs["name"]) status_code = kwargs.get("status_code", 200) headers = kwargs.get("headers") media_type = kwargs.get("media_type") diff --git a/starlette/testclient.py b/starlette/testclient.py index d54025e5..df8e1138 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -6,10 +6,19 @@ import io import json import math import sys -import typing import warnings +from collections.abc import Awaitable, Generator, Iterable, Mapping, MutableMapping, Sequence from concurrent.futures import Future +from contextlib import AbstractContextManager from types import GeneratorType +from typing import ( + Any, + Callable, + Literal, + TypedDict, + Union, + cast, +) from urllib.parse import unquote, urljoin import anyio @@ -34,14 +43,14 @@ except ModuleNotFoundError: # pragma: no cover "You can install this with:\n" " $ pip install httpx\n" ) -_PortalFactoryType = typing.Callable[[], typing.ContextManager[anyio.abc.BlockingPortal]] +_PortalFactoryType = Callable[[], AbstractContextManager[anyio.abc.BlockingPortal]] -ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]] -ASGI2App = typing.Callable[[Scope], ASGIInstance] -ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] +ASGIInstance = Callable[[Receive, Send], Awaitable[None]] +ASGI2App = Callable[[Scope], ASGIInstance] +ASGI3App = Callable[[Scope, Receive, Send], Awaitable[None]] -_RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str], bytes]] +_RequestData = Mapping[str, Union[str, Iterable[str], bytes]] def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]: @@ -63,9 +72,9 @@ class _WrapASGI2: await instance(receive, send) -class _AsyncBackend(typing.TypedDict): +class _AsyncBackend(TypedDict): backend: str - backend_options: dict[str, typing.Any] + backend_options: dict[str, Any] class _Upgrade(Exception): @@ -111,7 +120,7 @@ class WebSocketTestSession: self.exit_stack = stack.pop_all() return self - def __exit__(self, *args: typing.Any) -> bool | None: + def __exit__(self, *args: Any) -> bool | None: return self.exit_stack.__exit__(*args) async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None: @@ -155,7 +164,7 @@ class WebSocketTestSession: def send_bytes(self, data: bytes) -> None: self.send({"type": "websocket.receive", "bytes": data}) - def send_json(self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text") -> None: + def send_json(self, data: Any, mode: Literal["text", "binary"] = "text") -> None: text = json.dumps(data, separators=(",", ":"), ensure_ascii=False) if mode == "text": self.send({"type": "websocket.receive", "text": text}) @@ -171,14 +180,14 @@ class WebSocketTestSession: def receive_text(self) -> str: message = self.receive() self._raise_on_close(message) - return typing.cast(str, message["text"]) + return cast(str, message["text"]) def receive_bytes(self) -> bytes: message = self.receive() self._raise_on_close(message) - return typing.cast(bytes, message["bytes"]) + return cast(bytes, message["bytes"]) - def receive_json(self, mode: typing.Literal["text", "binary"] = "text") -> typing.Any: + def receive_json(self, mode: Literal["text", "binary"] = "text") -> Any: message = self.receive() self._raise_on_close(message) if mode == "text": @@ -197,7 +206,7 @@ class _TestClientTransport(httpx.BaseTransport): root_path: str = "", *, client: tuple[str, int], - app_state: dict[str, typing.Any], + app_state: dict[str, Any], ) -> None: self.app = app self.raise_server_exceptions = raise_server_exceptions @@ -233,12 +242,12 @@ class _TestClientTransport(httpx.BaseTransport): # Include other request headers. headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()] - scope: dict[str, typing.Any] + scope: dict[str, Any] if scheme in {"ws", "wss"}: subprotocol = request.headers.get("sec-websocket-protocol", None) if subprotocol is None: - subprotocols: typing.Sequence[str] = [] + subprotocols: Sequence[str] = [] else: subprotocols = [value.strip() for value in subprotocol.split(",")] scope = { @@ -277,7 +286,7 @@ class _TestClientTransport(httpx.BaseTransport): request_complete = False response_started = False response_complete: anyio.Event - raw_kwargs: dict[str, typing.Any] = {"stream": io.BytesIO()} + raw_kwargs: dict[str, Any] = {"stream": io.BytesIO()} template = None context = None @@ -368,8 +377,8 @@ class TestClient(httpx.Client): base_url: str = "http://testserver", raise_server_exceptions: bool = True, root_path: str = "", - backend: typing.Literal["asyncio", "trio"] = "asyncio", - backend_options: dict[str, typing.Any] | None = None, + backend: Literal["asyncio", "trio"] = "asyncio", + backend_options: dict[str, Any] | None = None, cookies: httpx._types.CookieTypes | None = None, headers: dict[str, str] | None = None, follow_redirects: bool = True, @@ -379,10 +388,10 @@ class TestClient(httpx.Client): if _is_asgi3(app): asgi_app = app else: - app = typing.cast(ASGI2App, app) # type: ignore[assignment] + app = cast(ASGI2App, app) # type: ignore[assignment] asgi_app = _WrapASGI2(app) # type: ignore[arg-type] self.app = asgi_app - self.app_state: dict[str, typing.Any] = {} + self.app_state: dict[str, Any] = {} transport = _TestClientTransport( self.app, portal_factory=self._portal_factory, @@ -403,7 +412,7 @@ class TestClient(httpx.Client): ) @contextlib.contextmanager - def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, None]: + def _portal_factory(self) -> Generator[anyio.abc.BlockingPortal, None, None]: if self.portal is not None: yield self.portal else: @@ -418,14 +427,14 @@ class TestClient(httpx.Client): content: httpx._types.RequestContent | None = None, data: _RequestData | None = None, files: httpx._types.RequestFiles | None = None, - json: typing.Any = None, + json: Any = None, params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, - extensions: dict[str, typing.Any] | None = None, + extensions: dict[str, Any] | None = None, ) -> httpx.Response: if timeout is not httpx.USE_CLIENT_DEFAULT: warnings.warn( @@ -460,7 +469,7 @@ class TestClient(httpx.Client): auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, - extensions: dict[str, typing.Any] | None = None, + extensions: dict[str, Any] | None = None, ) -> httpx.Response: return super().get( url, @@ -483,7 +492,7 @@ class TestClient(httpx.Client): auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, - extensions: dict[str, typing.Any] | None = None, + extensions: dict[str, Any] | None = None, ) -> httpx.Response: return super().options( url, @@ -506,7 +515,7 @@ class TestClient(httpx.Client): auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, - extensions: dict[str, typing.Any] | None = None, + extensions: dict[str, Any] | None = None, ) -> httpx.Response: return super().head( url, @@ -526,14 +535,14 @@ class TestClient(httpx.Client): content: httpx._types.RequestContent | None = None, data: _RequestData | None = None, files: httpx._types.RequestFiles | None = None, - json: typing.Any = None, + json: Any = None, params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, - extensions: dict[str, typing.Any] | None = None, + extensions: dict[str, Any] | None = None, ) -> httpx.Response: return super().post( url, @@ -557,14 +566,14 @@ class TestClient(httpx.Client): content: httpx._types.RequestContent | None = None, data: _RequestData | None = None, files: httpx._types.RequestFiles | None = None, - json: typing.Any = None, + json: Any = None, params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, - extensions: dict[str, typing.Any] | None = None, + extensions: dict[str, Any] | None = None, ) -> httpx.Response: return super().put( url, @@ -588,14 +597,14 @@ class TestClient(httpx.Client): content: httpx._types.RequestContent | None = None, data: _RequestData | None = None, files: httpx._types.RequestFiles | None = None, - json: typing.Any = None, + json: Any = None, params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, - extensions: dict[str, typing.Any] | None = None, + extensions: dict[str, Any] | None = None, ) -> httpx.Response: return super().patch( url, @@ -622,7 +631,7 @@ class TestClient(httpx.Client): auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, - extensions: dict[str, typing.Any] | None = None, + extensions: dict[str, Any] | None = None, ) -> httpx.Response: return super().delete( url, @@ -638,8 +647,8 @@ class TestClient(httpx.Client): def websocket_connect( self, url: str, - subprotocols: typing.Sequence[str] | None = None, - **kwargs: typing.Any, + subprotocols: Sequence[str] | None = None, + **kwargs: Any, ) -> WebSocketTestSession: url = urljoin("ws://testserver", url) headers = kwargs.get("headers", {}) @@ -666,11 +675,11 @@ class TestClient(httpx.Client): def reset_portal() -> None: self.portal = None - send: anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any] | None] = ( + send: anyio.create_memory_object_stream[MutableMapping[str, Any] | None] = ( anyio.create_memory_object_stream(math.inf) ) - receive: anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any]] = ( - anyio.create_memory_object_stream(math.inf) + receive: anyio.create_memory_object_stream[MutableMapping[str, Any]] = anyio.create_memory_object_stream( + math.inf ) for channel in (*send, *receive): stack.callback(channel.close) @@ -687,7 +696,7 @@ class TestClient(httpx.Client): return self - def __exit__(self, *args: typing.Any) -> None: + def __exit__(self, *args: Any) -> None: self.exit_stack.close() async def lifespan(self) -> None: @@ -700,7 +709,7 @@ class TestClient(httpx.Client): async def wait_startup(self) -> None: await self.stream_receive.send({"type": "lifespan.startup"}) - async def receive() -> typing.Any: + async def receive() -> Any: message = await self.stream_send.receive() if message is None: self.task.result() @@ -715,7 +724,7 @@ class TestClient(httpx.Client): await receive() async def wait_shutdown(self) -> None: - async def receive() -> typing.Any: + async def receive() -> Any: message = await self.stream_send.receive() if message is None: self.task.result() diff --git a/starlette/types.py b/starlette/types.py index 893f8729..e1f478d7 100644 --- a/starlette/types.py +++ b/starlette/types.py @@ -1,24 +1,26 @@ -import typing +from collections.abc import Awaitable, Mapping, MutableMapping +from contextlib import AbstractAsyncContextManager +from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from starlette.requests import Request from starlette.responses import Response from starlette.websockets import WebSocket -AppType = typing.TypeVar("AppType") +AppType = TypeVar("AppType") -Scope = typing.MutableMapping[str, typing.Any] -Message = typing.MutableMapping[str, typing.Any] +Scope = MutableMapping[str, Any] +Message = MutableMapping[str, Any] -Receive = typing.Callable[[], typing.Awaitable[Message]] -Send = typing.Callable[[Message], typing.Awaitable[None]] +Receive = Callable[[], Awaitable[Message]] +Send = Callable[[Message], Awaitable[None]] -ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] +ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]] -StatelessLifespan = typing.Callable[[AppType], typing.AsyncContextManager[None]] -StatefulLifespan = typing.Callable[[AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]]] -Lifespan = typing.Union[StatelessLifespan[AppType], StatefulLifespan[AppType]] +StatelessLifespan = Callable[[AppType], AbstractAsyncContextManager[None]] +StatefulLifespan = Callable[[AppType], AbstractAsyncContextManager[Mapping[str, Any]]] +Lifespan = Union[StatelessLifespan[AppType], StatefulLifespan[AppType]] -HTTPExceptionHandler = typing.Callable[["Request", Exception], "Response | typing.Awaitable[Response]"] -WebSocketExceptionHandler = typing.Callable[["WebSocket", Exception], typing.Awaitable[None]] -ExceptionHandler = typing.Union[HTTPExceptionHandler, WebSocketExceptionHandler] +HTTPExceptionHandler = Callable[["Request", Exception], "Response | Awaitable[Response]"] +WebSocketExceptionHandler = Callable[["WebSocket", Exception], Awaitable[None]] +ExceptionHandler = Union[HTTPExceptionHandler, WebSocketExceptionHandler] diff --git a/starlette/websockets.py b/starlette/websockets.py index 6b46f4ea..fb76361c 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -2,7 +2,8 @@ from __future__ import annotations import enum import json -import typing +from collections.abc import AsyncIterator, Iterable +from typing import Any, cast from starlette.requests import HTTPConnection from starlette.responses import Response @@ -99,7 +100,7 @@ class WebSocket(HTTPConnection): async def accept( self, subprotocol: str | None = None, - headers: typing.Iterable[tuple[bytes, bytes]] | None = None, + headers: Iterable[tuple[bytes, bytes]] | None = None, ) -> None: headers = headers or [] @@ -117,16 +118,16 @@ class WebSocket(HTTPConnection): raise RuntimeError('WebSocket is not connected. Need to call "accept" first.') message = await self.receive() self._raise_on_disconnect(message) - return typing.cast(str, message["text"]) + return cast(str, message["text"]) async def receive_bytes(self) -> bytes: if self.application_state != WebSocketState.CONNECTED: raise RuntimeError('WebSocket is not connected. Need to call "accept" first.') message = await self.receive() self._raise_on_disconnect(message) - return typing.cast(bytes, message["bytes"]) + return cast(bytes, message["bytes"]) - async def receive_json(self, mode: str = "text") -> typing.Any: + async def receive_json(self, mode: str = "text") -> Any: if mode not in {"text", "binary"}: raise RuntimeError('The "mode" argument should be "text" or "binary".') if self.application_state != WebSocketState.CONNECTED: @@ -140,21 +141,21 @@ class WebSocket(HTTPConnection): text = message["bytes"].decode("utf-8") return json.loads(text) - async def iter_text(self) -> typing.AsyncIterator[str]: + async def iter_text(self) -> AsyncIterator[str]: try: while True: yield await self.receive_text() except WebSocketDisconnect: pass - async def iter_bytes(self) -> typing.AsyncIterator[bytes]: + async def iter_bytes(self) -> AsyncIterator[bytes]: try: while True: yield await self.receive_bytes() except WebSocketDisconnect: pass - async def iter_json(self) -> typing.AsyncIterator[typing.Any]: + async def iter_json(self) -> AsyncIterator[Any]: try: while True: yield await self.receive_json() @@ -167,7 +168,7 @@ class WebSocket(HTTPConnection): async def send_bytes(self, data: bytes) -> None: await self.send({"type": "websocket.send", "bytes": data}) - async def send_json(self, data: typing.Any, mode: str = "text") -> None: + async def send_json(self, data: Any, mode: str = "text") -> None: if mode not in {"text", "binary"}: raise RuntimeError('The "mode" argument should be "text" or "binary".') text = json.dumps(data, separators=(",", ":"), ensure_ascii=False) diff --git a/tests/test_config.py b/tests/test_config.py index 7d2cd1f9..c256ffc6 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,5 +1,4 @@ import os -import typing from pathlib import Path from typing import Any, Optional @@ -51,7 +50,7 @@ def test_config(tmpdir: Path, monkeypatch: pytest.MonkeyPatch) -> None: config = Config(path, environ={"DEBUG": "true"}) - def cast_to_int(v: typing.Any) -> int: + def cast_to_int(v: Any) -> int: return int(v) DEBUG = config("DEBUG", cast=bool) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 86824565..3b48f814 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -1,5 +1,5 @@ -import typing from collections.abc import Generator +from typing import Any import pytest from pytest import MonkeyPatch @@ -196,7 +196,7 @@ def test_http_exception_does_not_use_threadpool(client: TestClient, monkeypatch: from starlette import _exception_handler # Replace run_in_threadpool with a function that raises an error - def mock_run_in_threadpool(*args: typing.Any, **kwargs: typing.Any) -> None: + def mock_run_in_threadpool(*args: Any, **kwargs: Any) -> None: pytest.fail("run_in_threadpool should not be called for HTTP exceptions") # pragma: no cover # Apply the monkeypatch only during this test diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index b18fd6c4..35681e78 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -1,9 +1,9 @@ from __future__ import annotations import os -import typing -from contextlib import nullcontext as does_not_raise +from contextlib import AbstractContextManager, nullcontext as does_not_raise from pathlib import Path +from typing import Any import pytest @@ -17,7 +17,7 @@ from starlette.types import ASGIApp, Receive, Scope, Send from tests.types import TestClientFactory -class ForceMultipartDict(dict[typing.Any, typing.Any]): +class ForceMultipartDict(dict[Any, Any]): def __bool__(self) -> bool: return True @@ -29,7 +29,7 @@ FORCE_MULTIPART = ForceMultipartDict() async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) data = await request.form() - output: dict[str, typing.Any] = {} + output: dict[str, Any] = {} for key, value in data.items(): if isinstance(value, UploadFile): content = await value.read() @@ -49,7 +49,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: async def multi_items_app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) data = await request.form() - output: dict[str, list[typing.Any]] = {} + output: dict[str, list[Any]] = {} for key, value in data.multi_items(): if key not in output: output[key] = [] @@ -73,7 +73,7 @@ async def multi_items_app(scope: Scope, receive: Receive, send: Send) -> None: async def app_with_headers(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) data = await request.form() - output: dict[str, typing.Any] = {} + output: dict[str, Any] = {} for key, value in data.items(): if isinstance(value, UploadFile): content = await value.read() @@ -108,7 +108,7 @@ def make_app_max_parts(max_files: int = 1000, max_fields: int = 1000, max_part_s async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) data = await request.form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size) - output: dict[str, typing.Any] = {} + output: dict[str, Any] = {} for key, value in data.items(): if isinstance(value, UploadFile): content = await value.read() @@ -422,7 +422,7 @@ def test_user_safe_decode_ignores_wrong_charset() -> None: ) def test_missing_boundary_parameter( app: ASGIApp, - expectation: typing.ContextManager[Exception], + expectation: AbstractContextManager[Exception], test_client_factory: TestClientFactory, ) -> None: client = test_client_factory(app) @@ -450,7 +450,7 @@ def test_missing_boundary_parameter( ) def test_missing_name_parameter_on_content_disposition( app: ASGIApp, - expectation: typing.ContextManager[Exception], + expectation: AbstractContextManager[Exception], test_client_factory: TestClientFactory, ) -> None: client = test_client_factory(app) @@ -478,7 +478,7 @@ def test_missing_name_parameter_on_content_disposition( ) def test_too_many_fields_raise( app: ASGIApp, - expectation: typing.ContextManager[Exception], + expectation: AbstractContextManager[Exception], test_client_factory: TestClientFactory, ) -> None: client = test_client_factory(app) @@ -505,7 +505,7 @@ def test_too_many_fields_raise( ) def test_too_many_files_raise( app: ASGIApp, - expectation: typing.ContextManager[Exception], + expectation: AbstractContextManager[Exception], test_client_factory: TestClientFactory, ) -> None: client = test_client_factory(app) @@ -532,7 +532,7 @@ def test_too_many_files_raise( ) def test_too_many_files_single_field_raise( app: ASGIApp, - expectation: typing.ContextManager[Exception], + expectation: AbstractContextManager[Exception], test_client_factory: TestClientFactory, ) -> None: client = test_client_factory(app) @@ -561,7 +561,7 @@ def test_too_many_files_single_field_raise( ) def test_too_many_files_and_fields_raise( app: ASGIApp, - expectation: typing.ContextManager[Exception], + expectation: AbstractContextManager[Exception], test_client_factory: TestClientFactory, ) -> None: client = test_client_factory(app) @@ -592,7 +592,7 @@ def test_too_many_files_and_fields_raise( ) def test_max_fields_is_customizable_low_raises( app: ASGIApp, - expectation: typing.ContextManager[Exception], + expectation: AbstractContextManager[Exception], test_client_factory: TestClientFactory, ) -> None: client = test_client_factory(app) @@ -622,7 +622,7 @@ def test_max_fields_is_customizable_low_raises( ) def test_max_files_is_customizable_low_raises( app: ASGIApp, - expectation: typing.ContextManager[Exception], + expectation: AbstractContextManager[Exception], test_client_factory: TestClientFactory, ) -> None: client = test_client_factory(app) @@ -673,7 +673,7 @@ def test_max_fields_is_customizable_high(test_client_factory: TestClientFactory) ) def test_max_part_size_exceeds_limit( app: ASGIApp, - expectation: typing.ContextManager[Exception], + expectation: AbstractContextManager[Exception], test_client_factory: TestClientFactory, ) -> None: client = test_client_factory(app) @@ -713,7 +713,7 @@ def test_max_part_size_exceeds_limit( ) def test_max_part_size_exceeds_custom_limit( app: ASGIApp, - expectation: typing.ContextManager[Exception], + expectation: AbstractContextManager[Exception], test_client_factory: TestClientFactory, ) -> None: client = test_client_factory(app) diff --git a/tests/test_routing.py b/tests/test_routing.py index 933fe7c3..041aab10 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -3,8 +3,9 @@ from __future__ import annotations import contextlib import functools import json -import typing import uuid +from collections.abc import AsyncGenerator, AsyncIterator, Generator +from typing import Callable, TypedDict import pytest @@ -165,7 +166,7 @@ app = Router( @pytest.fixture def client( test_client_factory: TestClientFactory, -) -> typing.Generator[TestClient, None, None]: +) -> Generator[TestClient, None, None]: with test_client_factory(app) as client: yield client @@ -586,7 +587,7 @@ def test_standalone_route_matches( def test_standalone_route_does_not_match( - test_client_factory: typing.Callable[..., TestClient], + test_client_factory: Callable[..., TestClient], ) -> None: app = Route("/", PlainTextResponse("Hello, World!")) client = test_client_factory(app) @@ -659,7 +660,7 @@ def test_lifespan_with_on_events(test_client_factory: TestClientFactory) -> None shutdown_called = False @contextlib.asynccontextmanager - async def lifespan(app: Starlette) -> typing.AsyncGenerator[None, None]: + async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: nonlocal lifespan_called lifespan_called = True yield @@ -731,7 +732,7 @@ def test_lifespan_state_unsupported( @contextlib.asynccontextmanager async def lifespan( app: ASGIApp, - ) -> typing.AsyncGenerator[dict[str, str], None]: + ) -> AsyncGenerator[dict[str, str], None]: yield {"foo": "bar"} app = Router( @@ -752,7 +753,7 @@ def test_lifespan_state_async_cm(test_client_factory: TestClientFactory) -> None startup_complete = False shutdown_complete = False - class State(typing.TypedDict): + class State(TypedDict): count: int items: list[int] @@ -767,7 +768,7 @@ def test_lifespan_state_async_cm(test_client_factory: TestClientFactory) -> None return PlainTextResponse("hello, world") @contextlib.asynccontextmanager - async def lifespan(app: Starlette) -> typing.AsyncIterator[State]: + async def lifespan(app: Starlette) -> AsyncIterator[State]: nonlocal startup_complete, shutdown_complete startup_complete = True state = State(count=0, items=[]) @@ -896,7 +897,7 @@ class Endpoint: pytest.param(lambda request: ..., "", id="lambda"), # pragma: no branch ], ) -def test_route_name(endpoint: typing.Callable[..., Response], expected_name: str) -> None: +def test_route_name(endpoint: Callable[..., Response], expected_name: str) -> None: assert Route(path="/", endpoint=endpoint).name == expected_name @@ -1005,7 +1006,7 @@ def test_mount_asgi_app_with_middleware_url_path_for() -> None: def test_add_route_to_app_after_mount( - test_client_factory: typing.Callable[..., TestClient], + test_client_factory: Callable[..., TestClient], ) -> None: """Checks that Mount will pick up routes added to the underlying app after it is mounted @@ -1038,7 +1039,7 @@ def test_exception_on_mounted_apps( def test_mounted_middleware_does_not_catch_exception( - test_client_factory: typing.Callable[..., TestClient], + test_client_factory: Callable[..., TestClient], ) -> None: # https://github.com/encode/starlette/pull/1649#discussion_r960236107 def exc(request: Request) -> Response: diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index e11e3cb3..4f4e07ea 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -2,8 +2,8 @@ import os import stat import tempfile import time -import typing from pathlib import Path +from typing import Any import anyio import pytest @@ -458,7 +458,7 @@ def test_staticfiles_unhandled_os_error_returns_500( test_client_factory: TestClientFactory, monkeypatch: pytest.MonkeyPatch, ) -> None: - def mock_timeout(*args: typing.Any, **kwargs: typing.Any) -> None: + def mock_timeout(*args: Any, **kwargs: Any) -> None: raise TimeoutError path = os.path.join(tmpdir, "example.txt") -- 2.47.2