[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
- "if typing.TYPE_CHECKING:",
- "@typing.overload",
+ "if TYPE_CHECKING:",
+ "@overload",
"raise NotImplementedError",
]
from __future__ import annotations
-import typing
+from typing import Any
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
from starlette.websockets import WebSocket
-ExceptionHandlers = dict[typing.Any, ExceptionHandler]
+ExceptionHandlers = dict[Any, ExceptionHandler]
StatusHandlers = dict[int, ExceptionHandler]
import functools
import inspect
import sys
-import typing
-from contextlib import contextmanager
+from collections.abc import Awaitable, Generator
+from contextlib import AbstractAsyncContextManager, contextmanager
+from typing import Any, Callable, Generic, Protocol, TypeVar, overload
from starlette.types import Scope
except ImportError:
has_exceptiongroups = False
-T = typing.TypeVar("T")
-AwaitableCallable = typing.Callable[..., typing.Awaitable[T]]
+T = TypeVar("T")
+AwaitableCallable = Callable[..., Awaitable[T]]
-@typing.overload
+@overload
def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]: ...
-@typing.overload
-def is_async_callable(obj: typing.Any) -> TypeGuard[AwaitableCallable[typing.Any]]: ...
+@overload
+def is_async_callable(obj: Any) -> TypeGuard[AwaitableCallable[Any]]: ...
-def is_async_callable(obj: typing.Any) -> typing.Any:
+def is_async_callable(obj: Any) -> Any:
while isinstance(obj, functools.partial):
obj = obj.func
return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__))
-T_co = typing.TypeVar("T_co", covariant=True)
+T_co = TypeVar("T_co", covariant=True)
-class AwaitableOrContextManager(typing.Awaitable[T_co], typing.AsyncContextManager[T_co], typing.Protocol[T_co]): ...
+class AwaitableOrContextManager(Awaitable[T_co], AbstractAsyncContextManager[T_co], Protocol[T_co]): ...
-class SupportsAsyncClose(typing.Protocol):
+class SupportsAsyncClose(Protocol):
async def close(self) -> None: ... # pragma: no cover
-SupportsAsyncCloseType = typing.TypeVar("SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False)
+SupportsAsyncCloseType = TypeVar("SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False)
-class AwaitableOrContextManagerWrapper(typing.Generic[SupportsAsyncCloseType]):
+class AwaitableOrContextManagerWrapper(Generic[SupportsAsyncCloseType]):
__slots__ = ("aw", "entered")
- def __init__(self, aw: typing.Awaitable[SupportsAsyncCloseType]) -> None:
+ def __init__(self, aw: Awaitable[SupportsAsyncCloseType]) -> None:
self.aw = aw
- def __await__(self) -> typing.Generator[typing.Any, None, SupportsAsyncCloseType]:
+ def __await__(self) -> Generator[Any, None, SupportsAsyncCloseType]:
return self.aw.__await__()
async def __aenter__(self) -> SupportsAsyncCloseType:
self.entered = await self.aw
return self.entered
- async def __aexit__(self, *args: typing.Any) -> None | bool:
+ async def __aexit__(self, *args: Any) -> None | bool:
await self.entered.close()
return None
@contextmanager
-def collapse_excgroups() -> typing.Generator[None, None, None]:
+def collapse_excgroups() -> Generator[None, None, None]:
try:
yield
except BaseException as exc:
from __future__ import annotations
import sys
-import typing
import warnings
+from collections.abc import Awaitable, Mapping, Sequence
+from typing import Any, Callable, TypeVar
if sys.version_info >= (3, 10): # pragma: no cover
from typing import ParamSpec
from starlette.types import ASGIApp, ExceptionHandler, Lifespan, Receive, Scope, Send
from starlette.websockets import WebSocket
-AppType = typing.TypeVar("AppType", bound="Starlette")
+AppType = TypeVar("AppType", bound="Starlette")
P = ParamSpec("P")
def __init__(
self: AppType,
debug: bool = False,
- routes: typing.Sequence[BaseRoute] | None = None,
- middleware: typing.Sequence[Middleware] | None = None,
- exception_handlers: typing.Mapping[typing.Any, ExceptionHandler] | None = None,
- on_startup: typing.Sequence[typing.Callable[[], typing.Any]] | None = None,
- on_shutdown: typing.Sequence[typing.Callable[[], typing.Any]] | None = None,
+ routes: Sequence[BaseRoute] | None = None,
+ middleware: Sequence[Middleware] | None = None,
+ exception_handlers: Mapping[Any, ExceptionHandler] | None = None,
+ on_startup: Sequence[Callable[[], Any]] | None = None,
+ on_shutdown: Sequence[Callable[[], Any]] | None = None,
lifespan: Lifespan[AppType] | None = None,
) -> None:
"""Initializes the application.
def build_middleware_stack(self) -> ASGIApp:
debug = self.debug
error_handler = None
- exception_handlers: dict[typing.Any, typing.Callable[[Request, Exception], Response]] = {}
+ exception_handlers: dict[Any, Callable[[Request, Exception], Response]] = {}
for key, value in self.exception_handlers.items():
if key in (500, Exception):
def routes(self) -> list[BaseRoute]:
return self.router.routes
- def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
+ def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
return self.router.url_path_for(name, **path_params)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
self.middleware_stack = self.build_middleware_stack()
await self.middleware_stack(scope, receive, send)
- def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg]
+ def on_event(self, event_type: str) -> Callable: # type: ignore[type-arg]
return self.router.on_event(event_type) # pragma: no cover
def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None:
def add_event_handler(
self,
event_type: str,
- func: typing.Callable, # type: ignore[type-arg]
+ func: Callable, # type: ignore[type-arg]
) -> None: # pragma: no cover
self.router.add_event_handler(event_type, func)
def add_route(
self,
path: str,
- route: typing.Callable[[Request], typing.Awaitable[Response] | Response],
+ route: Callable[[Request], Awaitable[Response] | Response],
methods: list[str] | None = None,
name: str | None = None,
include_in_schema: bool = True,
def add_websocket_route(
self,
path: str,
- route: typing.Callable[[WebSocket], typing.Awaitable[None]],
+ route: Callable[[WebSocket], Awaitable[None]],
name: str | None = None,
) -> None: # pragma: no cover
self.router.add_websocket_route(path, route, name=name)
- def exception_handler(self, exc_class_or_status_code: int | type[Exception]) -> typing.Callable: # type: ignore[type-arg]
+ def exception_handler(self, exc_class_or_status_code: int | type[Exception]) -> Callable: # type: ignore[type-arg]
warnings.warn(
"The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. "
"Refer to https://www.starlette.io/exceptions/ for the recommended approach.",
DeprecationWarning,
)
- def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
+ def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
self.add_exception_handler(exc_class_or_status_code, func)
return func
methods: list[str] | None = None,
name: str | None = None,
include_in_schema: bool = True,
- ) -> typing.Callable: # type: ignore[type-arg]
+ ) -> Callable: # type: ignore[type-arg]
"""
We no longer document this decorator style API, and its usage is discouraged.
Instead you should use the following approach:
DeprecationWarning,
)
- def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
+ def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
self.router.add_route(
path,
func,
return decorator
- def websocket_route(self, path: str, name: str | None = None) -> typing.Callable: # type: ignore[type-arg]
+ def websocket_route(self, path: str, name: str | None = None) -> Callable: # type: ignore[type-arg]
"""
We no longer document this decorator style API, and its usage is discouraged.
Instead you should use the following approach:
DeprecationWarning,
)
- def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
+ def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
self.router.add_websocket_route(path, func, name=name)
return func
return decorator
- def middleware(self, middleware_type: str) -> typing.Callable: # type: ignore[type-arg]
+ def middleware(self, middleware_type: str) -> Callable: # type: ignore[type-arg]
"""
We no longer document this decorator style API, and its usage is discouraged.
Instead you should use the following approach:
)
assert middleware_type == "http", 'Currently only middleware("http") is supported.'
- def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
+ def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
self.add_middleware(BaseHTTPMiddleware, dispatch=func)
return func
import functools
import inspect
import sys
-import typing
+from collections.abc import Sequence
+from typing import Any, Callable
from urllib.parse import urlencode
if sys.version_info >= (3, 10): # pragma: no cover
_P = ParamSpec("_P")
-def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool:
+def has_required_scope(conn: HTTPConnection, scopes: Sequence[str]) -> bool:
for scope in scopes:
if scope not in conn.auth.scopes:
return False
def requires(
- scopes: str | typing.Sequence[str],
+ scopes: str | Sequence[str],
status_code: int = 403,
redirect: str | None = None,
-) -> typing.Callable[[typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]]:
+) -> Callable[[Callable[_P, Any]], Callable[_P, Any]]:
scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
def decorator(
- func: typing.Callable[_P, typing.Any],
- ) -> typing.Callable[_P, typing.Any]:
+ func: Callable[_P, Any],
+ ) -> Callable[_P, Any]:
sig = inspect.signature(func)
for idx, parameter in enumerate(sig.parameters.values()):
if parameter.name == "request" or parameter.name == "websocket":
elif is_async_callable(func):
# Handle async request/response functions.
@functools.wraps(func)
- async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any:
+ async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
request = kwargs.get("request", args[idx] if idx < len(args) else None)
assert isinstance(request, Request)
else:
# Handle sync request/response functions.
@functools.wraps(func)
- def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any:
+ def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
request = kwargs.get("request", args[idx] if idx < len(args) else None)
assert isinstance(request, Request)
class AuthCredentials:
- def __init__(self, scopes: typing.Sequence[str] | None = None):
+ def __init__(self, scopes: Sequence[str] | None = None):
self.scopes = [] if scopes is None else list(scopes)
from __future__ import annotations
import sys
-import typing
+from collections.abc import Sequence
+from typing import Any, Callable
if sys.version_info >= (3, 10): # pragma: no cover
from typing import ParamSpec
class BackgroundTask:
- def __init__(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None:
+ def __init__(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None:
self.func = func
self.args = args
self.kwargs = kwargs
class BackgroundTasks(BackgroundTask):
- def __init__(self, tasks: typing.Sequence[BackgroundTask] | None = None):
+ def __init__(self, tasks: Sequence[BackgroundTask] | None = None):
self.tasks = list(tasks) if tasks else []
- def add_task(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None:
+ def add_task(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None:
task = BackgroundTask(func, *args, **kwargs)
self.tasks.append(task)
import functools
import sys
-import typing
import warnings
+from collections.abc import AsyncIterator, Coroutine, Iterable, Iterator
+from typing import Callable, TypeVar
import anyio.to_thread
from typing_extensions import ParamSpec
P = ParamSpec("P")
-T = typing.TypeVar("T")
+T = TypeVar("T")
-async def run_until_first_complete(*args: tuple[typing.Callable, dict]) -> None: # type: ignore[type-arg]
+async def run_until_first_complete(*args: tuple[Callable, dict]) -> None: # type: ignore[type-arg]
warnings.warn(
"run_until_first_complete is deprecated and will be removed in a future version.",
DeprecationWarning,
async with anyio.create_task_group() as task_group:
- async def run(func: typing.Callable[[], typing.Coroutine]) -> None: # type: ignore[type-arg]
+ async def run(func: Callable[[], Coroutine]) -> None: # type: ignore[type-arg]
await func()
task_group.cancel_scope.cancel()
task_group.start_soon(run, functools.partial(func, **kwargs))
-async def run_in_threadpool(func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
+async def run_in_threadpool(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
func = functools.partial(func, *args, **kwargs)
return await anyio.to_thread.run_sync(func)
pass
-def _next(iterator: typing.Iterator[T]) -> T:
+def _next(iterator: Iterator[T]) -> T:
# We can't raise `StopIteration` from within the threadpool iterator
# and catch it outside that context, so we coerce them into a different
# exception type.
async def iterate_in_threadpool(
- iterator: typing.Iterable[T],
-) -> typing.AsyncIterator[T]:
+ iterator: Iterable[T],
+) -> AsyncIterator[T]:
as_iterator = iter(iterator)
while True:
try:
from __future__ import annotations
import os
-import typing
import warnings
+from collections.abc import Iterator, Mapping, MutableMapping
from pathlib import Path
+from typing import Any, Callable, TypeVar, overload
class undefined:
pass
-class Environ(typing.MutableMapping[str, str]):
- def __init__(self, environ: typing.MutableMapping[str, str] = os.environ):
+class Environ(MutableMapping[str, str]):
+ def __init__(self, environ: MutableMapping[str, str] = os.environ):
self._environ = environ
self._has_been_read: set[str] = set()
raise EnvironError(f"Attempting to delete environ['{key}'], but the value has already been read.")
self._environ.__delitem__(key)
- def __iter__(self) -> typing.Iterator[str]:
+ def __iter__(self) -> Iterator[str]:
return iter(self._environ)
def __len__(self) -> int:
environ = Environ()
-T = typing.TypeVar("T")
+T = TypeVar("T")
class Config:
def __init__(
self,
env_file: str | Path | None = None,
- environ: typing.Mapping[str, str] = environ,
+ environ: Mapping[str, str] = environ,
env_prefix: str = "",
) -> None:
self.environ = environ
else:
self.file_values = self._read_file(env_file)
- @typing.overload
+ @overload
def __call__(self, key: str, *, default: None) -> str | None: ...
- @typing.overload
+ @overload
def __call__(self, key: str, cast: type[T], default: T = ...) -> T: ...
- @typing.overload
+ @overload
def __call__(self, key: str, cast: type[str] = ..., default: str = ...) -> str: ...
- @typing.overload
+ @overload
def __call__(
self,
key: str,
- cast: typing.Callable[[typing.Any], T] = ...,
- default: typing.Any = ...,
+ cast: Callable[[Any], T] = ...,
+ default: Any = ...,
) -> T: ...
- @typing.overload
+ @overload
def __call__(self, key: str, cast: type[str] = ..., default: T = ...) -> T | str: ...
def __call__(
self,
key: str,
- cast: typing.Callable[[typing.Any], typing.Any] | None = None,
- default: typing.Any = undefined,
- ) -> typing.Any:
+ cast: Callable[[Any], Any] | None = None,
+ default: Any = undefined,
+ ) -> Any:
return self.get(key, cast, default)
def get(
self,
key: str,
- cast: typing.Callable[[typing.Any], typing.Any] | None = None,
- default: typing.Any = undefined,
- ) -> typing.Any:
+ cast: Callable[[Any], Any] | None = None,
+ default: Any = undefined,
+ ) -> Any:
key = self.env_prefix + key
if key in self.environ:
value = self.environ[key]
def _perform_cast(
self,
key: str,
- value: typing.Any,
- cast: typing.Callable[[typing.Any], typing.Any] | None = None,
- ) -> typing.Any:
+ value: Any,
+ cast: Callable[[Any], Any] | None = None,
+ ) -> Any:
if cast is None or value is None:
return value
elif cast is bool and isinstance(value, str):
from __future__ import annotations
import math
-import typing
import uuid
+from typing import Any, ClassVar, Generic, TypeVar
-T = typing.TypeVar("T")
+T = TypeVar("T")
-class Convertor(typing.Generic[T]):
- regex: typing.ClassVar[str] = ""
+class Convertor(Generic[T]):
+ regex: ClassVar[str] = ""
def convert(self, value: str) -> T:
raise NotImplementedError() # pragma: no cover
return str(value)
-CONVERTOR_TYPES: dict[str, Convertor[typing.Any]] = {
+CONVERTOR_TYPES: dict[str, Convertor[Any]] = {
"str": StringConvertor(),
"path": PathConvertor(),
"int": IntegerConvertor(),
}
-def register_url_convertor(key: str, convertor: Convertor[typing.Any]) -> None:
+def register_url_convertor(key: str, convertor: Convertor[Any]) -> None:
CONVERTOR_TYPES[key] = convertor
from __future__ import annotations
-import typing
+from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, MutableMapping, Sequence, ValuesView
from shlex import shlex
+from typing import (
+ Any,
+ BinaryIO,
+ NamedTuple,
+ TypeVar,
+ Union,
+ cast,
+)
from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit
from starlette.concurrency import run_in_threadpool
from starlette.types import Scope
-class Address(typing.NamedTuple):
+class Address(NamedTuple):
host: str
port: int
-_KeyType = typing.TypeVar("_KeyType")
+_KeyType = TypeVar("_KeyType")
# Mapping keys are invariant but their values are covariant since
# you can only read them
# that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()`
-_CovariantValueType = typing.TypeVar("_CovariantValueType", covariant=True)
+_CovariantValueType = TypeVar("_CovariantValueType", covariant=True)
class URL:
self,
url: str = "",
scope: Scope | None = None,
- **components: typing.Any,
+ **components: Any,
) -> None:
if scope is not None:
assert not url, 'Cannot set both "url" and "scope".'
def is_secure(self) -> bool:
return self.scheme in ("https", "wss")
- def replace(self, **kwargs: typing.Any) -> URL:
+ def replace(self, **kwargs: Any) -> URL:
if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs:
hostname = kwargs.pop("hostname", None)
port = kwargs.pop("port", self.port)
components = self.components._replace(**kwargs)
return self.__class__(components.geturl())
- def include_query_params(self, **kwargs: typing.Any) -> URL:
+ def include_query_params(self, **kwargs: Any) -> URL:
params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
params.update({str(key): str(value) for key, value in kwargs.items()})
query = urlencode(params.multi_items())
return self.replace(query=query)
- def replace_query_params(self, **kwargs: typing.Any) -> URL:
+ def replace_query_params(self, **kwargs: Any) -> URL:
query = urlencode([(str(key), str(value)) for key, value in kwargs.items()])
return self.replace(query=query)
- def remove_query_params(self, keys: str | typing.Sequence[str]) -> URL:
+ def remove_query_params(self, keys: str | Sequence[str]) -> URL:
if isinstance(keys, str):
keys = [keys]
params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
query = urlencode(params.multi_items())
return self.replace(query=query)
- def __eq__(self, other: typing.Any) -> bool:
+ def __eq__(self, other: Any) -> bool:
return str(self) == str(other)
def __str__(self) -> str:
return bool(self._value)
-class CommaSeparatedStrings(typing.Sequence[str]):
- def __init__(self, value: str | typing.Sequence[str]):
+class CommaSeparatedStrings(Sequence[str]):
+ def __init__(self, value: str | Sequence[str]):
if isinstance(value, str):
splitter = shlex(value, posix=True)
splitter.whitespace = ","
def __len__(self) -> int:
return len(self._items)
- def __getitem__(self, index: int | slice) -> typing.Any:
+ def __getitem__(self, index: int | slice) -> Any:
return self._items[index]
- def __iter__(self) -> typing.Iterator[str]:
+ def __iter__(self) -> Iterator[str]:
return iter(self._items)
def __repr__(self) -> str:
return ", ".join(repr(item) for item in self)
-class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]):
+class ImmutableMultiDict(Mapping[_KeyType, _CovariantValueType]):
_dict: dict[_KeyType, _CovariantValueType]
def __init__(
self,
*args: ImmutableMultiDict[_KeyType, _CovariantValueType]
- | typing.Mapping[_KeyType, _CovariantValueType]
- | typing.Iterable[tuple[_KeyType, _CovariantValueType]],
- **kwargs: typing.Any,
+ | Mapping[_KeyType, _CovariantValueType]
+ | Iterable[tuple[_KeyType, _CovariantValueType]],
+ **kwargs: Any,
) -> None:
assert len(args) < 2, "Too many arguments."
- value: typing.Any = args[0] if args else []
+ value: Any = args[0] if args else []
if kwargs:
value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items()
if not value:
- _items: list[tuple[typing.Any, typing.Any]] = []
+ _items: list[tuple[Any, Any]] = []
elif hasattr(value, "multi_items"):
- value = typing.cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value)
+ value = cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value)
_items = list(value.multi_items())
elif hasattr(value, "items"):
- value = typing.cast(typing.Mapping[_KeyType, _CovariantValueType], value)
+ value = cast(Mapping[_KeyType, _CovariantValueType], value)
_items = list(value.items())
else:
- value = typing.cast("list[tuple[typing.Any, typing.Any]]", value)
+ value = cast("list[tuple[Any, Any]]", value)
_items = list(value)
self._dict = {k: v for k, v in _items}
self._list = _items
- def getlist(self, key: typing.Any) -> list[_CovariantValueType]:
+ def getlist(self, key: Any) -> list[_CovariantValueType]:
return [item_value for item_key, item_value in self._list if item_key == key]
- def keys(self) -> typing.KeysView[_KeyType]:
+ def keys(self) -> KeysView[_KeyType]:
return self._dict.keys()
- def values(self) -> typing.ValuesView[_CovariantValueType]:
+ def values(self) -> ValuesView[_CovariantValueType]:
return self._dict.values()
- def items(self) -> typing.ItemsView[_KeyType, _CovariantValueType]:
+ def items(self) -> ItemsView[_KeyType, _CovariantValueType]:
return self._dict.items()
def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]:
def __getitem__(self, key: _KeyType) -> _CovariantValueType:
return self._dict[key]
- def __contains__(self, key: typing.Any) -> bool:
+ def __contains__(self, key: Any) -> bool:
return key in self._dict
- def __iter__(self) -> typing.Iterator[_KeyType]:
+ def __iter__(self) -> Iterator[_KeyType]:
return iter(self.keys())
def __len__(self) -> int:
return len(self._dict)
- def __eq__(self, other: typing.Any) -> bool:
+ def __eq__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
return False
return sorted(self._list) == sorted(other._list)
return f"{class_name}({items!r})"
-class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]):
- def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
+class MultiDict(ImmutableMultiDict[Any, Any]):
+ def __setitem__(self, key: Any, value: Any) -> None:
self.setlist(key, [value])
- def __delitem__(self, key: typing.Any) -> None:
+ def __delitem__(self, key: Any) -> None:
self._list = [(k, v) for k, v in self._list if k != key]
del self._dict[key]
- def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
+ def pop(self, key: Any, default: Any = None) -> Any:
self._list = [(k, v) for k, v in self._list if k != key]
return self._dict.pop(key, default)
- def popitem(self) -> tuple[typing.Any, typing.Any]:
+ def popitem(self) -> tuple[Any, Any]:
key, value = self._dict.popitem()
self._list = [(k, v) for k, v in self._list if k != key]
return key, value
- def poplist(self, key: typing.Any) -> list[typing.Any]:
+ def poplist(self, key: Any) -> list[Any]:
values = [v for k, v in self._list if k == key]
self.pop(key)
return values
self._dict.clear()
self._list.clear()
- def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
+ def setdefault(self, key: Any, default: Any = None) -> Any:
if key not in self:
self._dict[key] = default
self._list.append((key, default))
return self[key]
- def setlist(self, key: typing.Any, values: list[typing.Any]) -> None:
+ def setlist(self, key: Any, values: list[Any]) -> None:
if not values:
self.pop(key, None)
else:
self._list = existing_items + [(key, value) for value in values]
self._dict[key] = values[-1]
- def append(self, key: typing.Any, value: typing.Any) -> None:
+ def append(self, key: Any, value: Any) -> None:
self._list.append((key, value))
self._dict[key] = value
def update(
self,
- *args: MultiDict | typing.Mapping[typing.Any, typing.Any] | list[tuple[typing.Any, typing.Any]],
- **kwargs: typing.Any,
+ *args: MultiDict | Mapping[Any, Any] | list[tuple[Any, Any]],
+ **kwargs: Any,
) -> None:
value = MultiDict(*args, **kwargs)
existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()]
def __init__(
self,
- *args: ImmutableMultiDict[typing.Any, typing.Any]
- | typing.Mapping[typing.Any, typing.Any]
- | list[tuple[typing.Any, typing.Any]]
- | str
- | bytes,
- **kwargs: typing.Any,
+ *args: ImmutableMultiDict[Any, Any] | Mapping[Any, Any] | list[tuple[Any, Any]] | str | bytes,
+ **kwargs: Any,
) -> None:
assert len(args) < 2, "Too many arguments."
def __init__(
self,
- file: typing.BinaryIO,
+ file: BinaryIO,
*,
size: int | None = None,
filename: str | None = None,
return f"{self.__class__.__name__}(filename={self.filename!r}, size={self.size!r}, headers={self.headers!r})"
-class FormData(ImmutableMultiDict[str, typing.Union[UploadFile, str]]):
+class FormData(ImmutableMultiDict[str, Union[UploadFile, str]]):
"""
An immutable multidict, containing both file uploads and text input.
"""
def __init__(
self,
- *args: FormData | typing.Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]],
+ *args: FormData | Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]],
**kwargs: str | UploadFile,
) -> None:
super().__init__(*args, **kwargs)
await value.close()
-class Headers(typing.Mapping[str, str]):
+class Headers(Mapping[str, str]):
"""
An immutable, case-insensitive multidict.
"""
def __init__(
self,
- headers: typing.Mapping[str, str] | None = None,
+ headers: Mapping[str, str] | None = None,
raw: list[tuple[bytes, bytes]] | None = None,
- scope: typing.MutableMapping[str, typing.Any] | None = None,
+ scope: MutableMapping[str, Any] | None = None,
) -> None:
self._list: list[tuple[bytes, bytes]] = []
if headers is not None:
return header_value.decode("latin-1")
raise KeyError(key)
- def __contains__(self, key: typing.Any) -> bool:
+ def __contains__(self, key: Any) -> bool:
get_header_key = key.lower().encode("latin-1")
for header_key, header_value in self._list:
if header_key == get_header_key:
return True
return False
- def __iter__(self) -> typing.Iterator[typing.Any]:
+ def __iter__(self) -> Iterator[Any]:
return iter(self.keys())
def __len__(self) -> int:
return len(self._list)
- def __eq__(self, other: typing.Any) -> bool:
+ def __eq__(self, other: Any) -> bool:
if not isinstance(other, Headers):
return False
return sorted(self._list) == sorted(other._list)
for idx in reversed(pop_indexes):
del self._list[idx]
- def __ior__(self, other: typing.Mapping[str, str]) -> MutableHeaders:
- if not isinstance(other, typing.Mapping):
+ def __ior__(self, other: Mapping[str, str]) -> MutableHeaders:
+ if not isinstance(other, Mapping):
raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
self.update(other)
return self
- def __or__(self, other: typing.Mapping[str, str]) -> MutableHeaders:
- if not isinstance(other, typing.Mapping):
+ def __or__(self, other: Mapping[str, str]) -> MutableHeaders:
+ if not isinstance(other, Mapping):
raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
new = self.mutablecopy()
new.update(other)
self._list.append((set_key, set_value))
return value
- def update(self, other: typing.Mapping[str, str]) -> None:
+ def update(self, other: Mapping[str, str]) -> None:
for key, val in other.items():
self[key] = val
Used for `request.state` and `app.state`.
"""
- _state: dict[str, typing.Any]
+ _state: dict[str, Any]
- def __init__(self, state: dict[str, typing.Any] | None = None):
+ def __init__(self, state: dict[str, Any] | None = None):
if state is None:
state = {}
super().__setattr__("_state", state)
- def __setattr__(self, key: typing.Any, value: typing.Any) -> None:
+ def __setattr__(self, key: Any, value: Any) -> None:
self._state[key] = value
- def __getattr__(self, key: typing.Any) -> typing.Any:
+ def __getattr__(self, key: Any) -> Any:
try:
return self._state[key]
except KeyError:
message = "'{}' object has no attribute '{}'"
raise AttributeError(message.format(self.__class__.__name__, key))
- def __delattr__(self, key: typing.Any) -> None:
+ def __delattr__(self, key: Any) -> None:
del self._state[key]
from __future__ import annotations
import json
-import typing
+from collections.abc import Generator
+from typing import Any, Callable
from starlette import status
from starlette._utils import is_async_callable
if getattr(self, method.lower(), None) is not None
]
- def __await__(self) -> typing.Generator[typing.Any, None, None]:
+ def __await__(self) -> Generator[Any, None, None]:
return self.dispatch().__await__()
async def dispatch(self) -> None:
request = Request(self.scope, receive=self.receive)
handler_name = "get" if request.method == "HEAD" and not hasattr(self, "head") else request.method.lower()
- handler: typing.Callable[[Request], typing.Any] = getattr(self, handler_name, self.method_not_allowed)
+ handler: Callable[[Request], Any] = getattr(self, handler_name, self.method_not_allowed)
is_async = is_async_callable(handler)
if is_async:
response = await handler(request)
self.receive = receive
self.send = send
- def __await__(self) -> typing.Generator[typing.Any, None, None]:
+ def __await__(self) -> Generator[Any, None, None]:
return self.dispatch().__await__()
async def dispatch(self) -> None:
finally:
await self.on_disconnect(websocket, close_code)
- async def decode(self, websocket: WebSocket, message: Message) -> typing.Any:
+ async def decode(self, websocket: WebSocket, message: Message) -> Any:
if self.encoding == "text":
if "text" not in message:
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
"""Override to handle an incoming websocket connection"""
await websocket.accept()
- async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None:
+ async def on_receive(self, websocket: WebSocket, data: Any) -> None:
"""Override to handle an incoming websocket message"""
async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None:
from __future__ import annotations
-import typing
+from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
from enum import Enum
from tempfile import SpooledTemporaryFile
+from typing import TYPE_CHECKING
from urllib.parse import unquote_plus
from starlette.datastructures import FormData, Headers, UploadFile
-if typing.TYPE_CHECKING:
+if TYPE_CHECKING:
import python_multipart as multipart
from python_multipart.multipart import MultipartCallbacks, QuerystringCallbacks, parse_options_header
else:
class FormParser:
- def __init__(self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]) -> None:
+ def __init__(self, headers: Headers, stream: AsyncGenerator[bytes, None]) -> None:
assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
self.headers = headers
self.stream = stream
def __init__(
self,
headers: Headers,
- stream: typing.AsyncGenerator[bytes, None],
+ stream: AsyncGenerator[bytes, None],
*,
max_files: int | float = 1000,
max_fields: int | float = 1000,
from __future__ import annotations
-import typing
+from typing import Callable
from starlette.authentication import (
AuthCredentials,
self,
app: ASGIApp,
backend: AuthenticationBackend,
- on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] | None = None,
+ on_error: Callable[[HTTPConnection, AuthenticationError], Response] | None = None,
) -> None:
self.app = app
self.backend = backend
- self.on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] = (
+ self.on_error: Callable[[HTTPConnection, AuthenticationError], Response] = (
on_error if on_error is not None else self.default_on_error
)
from __future__ import annotations
-import typing
+from collections.abc import AsyncGenerator, Awaitable, Mapping
+from typing import Any, Callable, TypeVar
import anyio
from starlette.responses import AsyncContentStream, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
-RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
-DispatchFunction = typing.Callable[[Request, RequestResponseEndpoint], typing.Awaitable[Response]]
-T = typing.TypeVar("T")
+RequestResponseEndpoint = Callable[[Request], Awaitable[Response]]
+DispatchFunction = Callable[[Request, RequestResponseEndpoint], Awaitable[Response]]
+T = TypeVar("T")
class _CachedRequest(Request):
async with anyio.create_task_group() as task_group:
- async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
+ async def wrap(func: Callable[[], Awaitable[T]]) -> T:
result = await func()
task_group.cancel_scope.cancel()
return result
assert message["type"] == "http.response.start"
- async def body_stream() -> typing.AsyncGenerator[bytes, None]:
+ async def body_stream() -> AsyncGenerator[bytes, None]:
async for message in recv_stream:
assert message["type"] == "http.response.body"
body = message.get("body", b"")
self,
content: AsyncContentStream,
status_code: int = 200,
- headers: typing.Mapping[str, str] | None = None,
+ headers: Mapping[str, str] | None = None,
media_type: str | None = None,
- info: typing.Mapping[str, typing.Any] | None = None,
+ info: Mapping[str, Any] | None = None,
) -> None:
self.info = info
self.body_iterator = content
import functools
import re
-import typing
+from collections.abc import Sequence
from starlette.datastructures import Headers, MutableHeaders
from starlette.responses import PlainTextResponse, Response
def __init__(
self,
app: ASGIApp,
- allow_origins: typing.Sequence[str] = (),
- allow_methods: typing.Sequence[str] = ("GET",),
- allow_headers: typing.Sequence[str] = (),
+ allow_origins: Sequence[str] = (),
+ allow_methods: Sequence[str] = ("GET",),
+ allow_headers: Sequence[str] = (),
allow_credentials: bool = False,
allow_origin_regex: str | None = None,
- expose_headers: typing.Sequence[str] = (),
+ expose_headers: Sequence[str] = (),
max_age: int = 600,
) -> None:
if "*" in allow_methods:
import inspect
import sys
import traceback
-import typing
+from typing import Any, Callable
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
def __init__(
self,
app: ASGIApp,
- handler: typing.Callable[[Request, Exception], typing.Any] | None = None,
+ handler: Callable[[Request, Exception], Any] | None = None,
debug: bool = False,
) -> None:
self.app = app
from __future__ import annotations
-import typing
+from collections.abc import Mapping
+from typing import Any, Callable
from starlette._exception_handler import (
ExceptionHandlers,
def __init__(
self,
app: ASGIApp,
- handlers: typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] | None = None,
+ handlers: Mapping[Any, Callable[[Request, Exception], Response]] | None = None,
debug: bool = False,
) -> None:
self.app = app
def add_exception_handler(
self,
exc_class_or_status_code: int | type[Exception],
- handler: typing.Callable[[Request, Exception], Response],
+ handler: Callable[[Request, Exception], Response],
) -> None:
if isinstance(exc_class_or_status_code, int):
self._status_handlers[exc_class_or_status_code] = handler
import gzip
import io
-import typing
+from typing import NoReturn
from starlette.datastructures import Headers, MutableHeaders
from starlette.types import ASGIApp, Message, Receive, Scope, Send
return body
-async def unattached_send(message: Message) -> typing.NoReturn:
+async def unattached_send(message: Message) -> NoReturn:
raise RuntimeError("send awaitable not set") # pragma: no cover
from __future__ import annotations
import json
-import typing
from base64 import b64decode, b64encode
+from typing import Literal
import itsdangerous
from itsdangerous.exc import BadSignature
session_cookie: str = "session",
max_age: int | None = 14 * 24 * 60 * 60, # 14 days, in seconds
path: str = "/",
- same_site: typing.Literal["lax", "strict", "none"] = "lax",
+ same_site: Literal["lax", "strict", "none"] = "lax",
https_only: bool = False,
domain: str | None = None,
) -> None:
from __future__ import annotations
-import typing
+from collections.abc import Sequence
from starlette.datastructures import URL, Headers
from starlette.responses import PlainTextResponse, RedirectResponse, Response
def __init__(
self,
app: ASGIApp,
- allowed_hosts: typing.Sequence[str] | None = None,
+ allowed_hosts: Sequence[str] | None = None,
www_redirect: bool = True,
) -> None:
if allowed_hosts is None:
import io
import math
import sys
-import typing
import warnings
+from collections.abc import MutableMapping
+from typing import Any, Callable
import anyio
from anyio.abc import ObjectReceiveStream, ObjectSendStream
)
-def build_environ(scope: Scope, body: bytes) -> dict[str, typing.Any]:
+def build_environ(scope: Scope, body: bytes) -> dict[str, Any]:
"""
Builds a scope and request body into a WSGI environ object.
"""
class WSGIMiddleware:
- def __init__(self, app: typing.Callable[..., typing.Any]) -> None:
+ def __init__(self, app: Callable[..., Any]) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
class WSGIResponder:
- stream_send: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
- stream_receive: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
+ stream_send: ObjectSendStream[MutableMapping[str, Any]]
+ stream_receive: ObjectReceiveStream[MutableMapping[str, Any]]
- def __init__(self, app: typing.Callable[..., typing.Any], scope: Scope) -> None:
+ def __init__(self, app: Callable[..., Any], scope: Scope) -> None:
self.app = app
self.scope = scope
self.status = None
self.response_headers = None
self.stream_send, self.stream_receive = anyio.create_memory_object_stream(math.inf)
self.response_started = False
- self.exc_info: typing.Any = None
+ self.exc_info: Any = None
async def __call__(self, receive: Receive, send: Send) -> None:
body = b""
self,
status: str,
response_headers: list[tuple[str, str]],
- exc_info: typing.Any = None,
+ exc_info: Any = None,
) -> None:
self.exc_info = exc_info
if not self.response_started: # pragma: no branch
def wsgi(
self,
- environ: dict[str, typing.Any],
- start_response: typing.Callable[..., typing.Any],
+ environ: dict[str, Any],
+ start_response: Callable[..., Any],
) -> None:
for chunk in self.app(environ, start_response):
anyio.from_thread.run(
from __future__ import annotations
import json
-import typing
+from collections.abc import AsyncGenerator, Iterator, Mapping
from http import cookies as http_cookies
+from typing import TYPE_CHECKING, Any, NoReturn, cast
import anyio
from starlette.formparsers import FormParser, MultiPartException, MultiPartParser
from starlette.types import Message, Receive, Scope, Send
-if typing.TYPE_CHECKING:
+if TYPE_CHECKING:
from python_multipart.multipart import parse_options_header
from starlette.applications import Starlette
pass
-class HTTPConnection(typing.Mapping[str, typing.Any]):
+class HTTPConnection(Mapping[str, Any]):
"""
A base class for incoming HTTP connections, that is used to provide
any functionality that is common to both `Request` and `WebSocket`.
assert scope["type"] in ("http", "websocket")
self.scope = scope
- def __getitem__(self, key: str) -> typing.Any:
+ def __getitem__(self, key: str) -> Any:
return self.scope[key]
- def __iter__(self) -> typing.Iterator[str]:
+ def __iter__(self) -> Iterator[str]:
return iter(self.scope)
def __len__(self) -> int:
__hash__ = object.__hash__
@property
- def app(self) -> typing.Any:
+ def app(self) -> Any:
return self.scope["app"]
@property
return self._query_params
@property
- def path_params(self) -> dict[str, typing.Any]:
+ def path_params(self) -> dict[str, Any]:
return self.scope.get("path_params", {})
@property
return None
@property
- def session(self) -> dict[str, typing.Any]:
+ def session(self) -> dict[str, Any]:
assert "session" in self.scope, "SessionMiddleware must be installed to access request.session"
return self.scope["session"] # type: ignore[no-any-return]
@property
- def auth(self) -> typing.Any:
+ def auth(self) -> Any:
assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth"
return self.scope["auth"]
@property
- def user(self) -> typing.Any:
+ def user(self) -> Any:
assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user"
return self.scope["user"]
self._state = State(self.scope["state"])
return self._state
- def url_for(self, name: str, /, **path_params: typing.Any) -> URL:
+ def url_for(self, name: str, /, **path_params: Any) -> URL:
url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app")
if url_path_provider is None:
raise RuntimeError("The `url_for` method can only be used inside a Starlette application or with a router.")
return url_path.make_absolute_url(base_url=self.base_url)
-async def empty_receive() -> typing.NoReturn:
+async def empty_receive() -> NoReturn:
raise RuntimeError("Receive channel has not been made available")
-async def empty_send(message: Message) -> typing.NoReturn:
+async def empty_send(message: Message) -> NoReturn:
raise RuntimeError("Send channel has not been made available")
@property
def method(self) -> str:
- return typing.cast(str, self.scope["method"])
+ return cast(str, self.scope["method"])
@property
def receive(self) -> Receive:
return self._receive
- async def stream(self) -> typing.AsyncGenerator[bytes, None]:
+ async def stream(self) -> AsyncGenerator[bytes, None]:
if hasattr(self, "_body"):
yield self._body
yield b""
self._body = b"".join(chunks)
return self._body
- async def json(self) -> typing.Any:
+ async def json(self) -> Any:
if not hasattr(self, "_json"): # pragma: no branch
body = await self.body()
self._json = json.loads(body)
import re
import stat
import sys
-import typing
import warnings
+from collections.abc import AsyncIterable, Awaitable, Iterable, Mapping, Sequence
from datetime import datetime
from email.utils import format_datetime, formatdate
from functools import partial
from mimetypes import guess_type
from secrets import token_hex
+from typing import Any, Callable, Literal, Union
from urllib.parse import quote
import anyio
def __init__(
self,
- content: typing.Any = None,
+ content: Any = None,
status_code: int = 200,
- headers: typing.Mapping[str, str] | None = None,
+ headers: Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> None:
self.body = self.render(content)
self.init_headers(headers)
- def render(self, content: typing.Any) -> bytes | memoryview:
+ def render(self, content: Any) -> bytes | memoryview:
if content is None:
return b""
if isinstance(content, (bytes, memoryview)):
return content
return content.encode(self.charset) # type: ignore
- def init_headers(self, headers: typing.Mapping[str, str] | None = None) -> None:
+ def init_headers(self, headers: Mapping[str, str] | None = None) -> None:
if headers is None:
raw_headers: list[tuple[bytes, bytes]] = []
populate_content_length = True
domain: str | None = None,
secure: bool = False,
httponly: bool = False,
- samesite: typing.Literal["lax", "strict", "none"] | None = "lax",
+ samesite: Literal["lax", "strict", "none"] | None = "lax",
partitioned: bool = False,
) -> None:
cookie: http.cookies.BaseCookie[str] = http.cookies.SimpleCookie()
domain: str | None = None,
secure: bool = False,
httponly: bool = False,
- samesite: typing.Literal["lax", "strict", "none"] | None = "lax",
+ samesite: Literal["lax", "strict", "none"] | None = "lax",
) -> None:
self.set_cookie(
key,
def __init__(
self,
- content: typing.Any,
+ content: Any,
status_code: int = 200,
- headers: typing.Mapping[str, str] | None = None,
+ headers: Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> None:
super().__init__(content, status_code, headers, media_type, background)
- def render(self, content: typing.Any) -> bytes:
+ def render(self, content: Any) -> bytes:
return json.dumps(
content,
ensure_ascii=False,
self,
url: str | URL,
status_code: int = 307,
- headers: typing.Mapping[str, str] | None = None,
+ headers: Mapping[str, str] | None = None,
background: BackgroundTask | None = None,
) -> None:
super().__init__(content=b"", status_code=status_code, headers=headers, background=background)
self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;")
-Content = typing.Union[str, bytes, memoryview]
-SyncContentStream = typing.Iterable[Content]
-AsyncContentStream = typing.AsyncIterable[Content]
-ContentStream = typing.Union[AsyncContentStream, SyncContentStream]
+Content = Union[str, bytes, memoryview]
+SyncContentStream = Iterable[Content]
+AsyncContentStream = AsyncIterable[Content]
+ContentStream = Union[AsyncContentStream, SyncContentStream]
class StreamingResponse(Response):
self,
content: ContentStream,
status_code: int = 200,
- headers: typing.Mapping[str, str] | None = None,
+ headers: Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> None:
- if isinstance(content, typing.AsyncIterable):
+ if isinstance(content, AsyncIterable):
self.body_iterator = content
else:
self.body_iterator = iterate_in_threadpool(content)
with collapse_excgroups():
async with anyio.create_task_group() as task_group:
- async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
+ async def wrap(func: Callable[[], Awaitable[None]]) -> None:
await func()
task_group.cancel_scope.cancel()
self,
path: str | os.PathLike[str],
status_code: int = 200,
- headers: typing.Mapping[str, str] | None = None,
+ headers: Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
filename: str | None = None,
def generate_multipart(
self,
- ranges: typing.Sequence[tuple[int, int]],
+ ranges: Sequence[tuple[int, int]],
boundary: str,
max_size: int,
content_type: str,
- ) -> tuple[int, typing.Callable[[int, int], bytes]]:
+ ) -> tuple[int, Callable[[int, int], bytes]]:
r"""
Multipart response headers generator.
import re
import traceback
import types
-import typing
import warnings
-from contextlib import asynccontextmanager
+from collections.abc import Awaitable, Generator, Sequence
+from contextlib import AbstractAsyncContextManager, AbstractContextManager, asynccontextmanager
from enum import Enum
+from re import Pattern
+from typing import Any, Callable, TypeVar
from starlette._exception_handler import wrap_app_handling_exceptions
from starlette._utils import get_route_path, is_async_callable
if no matching route exists.
"""
- def __init__(self, name: str, path_params: dict[str, typing.Any]) -> None:
+ def __init__(self, name: str, path_params: dict[str, Any]) -> None:
params = ", ".join(list(path_params.keys()))
super().__init__(f'No route exists for name "{name}" and params "{params}".')
FULL = 2
-def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover
+def iscoroutinefunction_or_partial(obj: Any) -> bool: # pragma: no cover
"""
Correctly determines if an object is a coroutine function,
including those wrapped in functools.partial objects.
def request_response(
- func: typing.Callable[[Request], typing.Awaitable[Response] | Response],
+ func: Callable[[Request], Awaitable[Response] | Response],
) -> ASGIApp:
"""
Takes a function or coroutine `func(request) -> response`,
and returns an ASGI application.
"""
- f: typing.Callable[[Request], typing.Awaitable[Response]] = (
+ f: Callable[[Request], Awaitable[Response]] = (
func if is_async_callable(func) else functools.partial(run_in_threadpool, func) # type:ignore
)
def websocket_session(
- func: typing.Callable[[WebSocket], typing.Awaitable[None]],
+ func: Callable[[WebSocket], Awaitable[None]],
) -> ASGIApp:
"""
Takes a coroutine `func(session)`, and returns an ASGI application.
return app
-def get_name(endpoint: typing.Callable[..., typing.Any]) -> str:
+def get_name(endpoint: Callable[..., Any]) -> str:
return getattr(endpoint, "__name__", endpoint.__class__.__name__)
def replace_params(
path: str,
- param_convertors: dict[str, Convertor[typing.Any]],
+ param_convertors: dict[str, Convertor[Any]],
path_params: dict[str, str],
) -> tuple[str, dict[str, str]]:
for key, value in list(path_params.items()):
def compile_path(
path: str,
-) -> tuple[typing.Pattern[str], str, dict[str, Convertor[typing.Any]]]:
+) -> tuple[Pattern[str], str, dict[str, Convertor[Any]]]:
"""
Given a path string, like: "/{username:str}",
or a host string, like: "{subdomain}.mydomain.org", return a three-tuple
def matches(self, scope: Scope) -> tuple[Match, Scope]:
raise NotImplementedError() # pragma: no cover
- def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
+ def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
raise NotImplementedError() # pragma: no cover
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
def __init__(
self,
path: str,
- endpoint: typing.Callable[..., typing.Any],
+ endpoint: Callable[..., Any],
*,
methods: list[str] | None = None,
name: str | None = None,
include_in_schema: bool = True,
- middleware: typing.Sequence[Middleware] | None = None,
+ middleware: Sequence[Middleware] | None = None,
) -> None:
assert path.startswith("/"), "Routed paths must start with '/'"
self.path = path
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
def matches(self, scope: Scope) -> tuple[Match, Scope]:
- path_params: dict[str, typing.Any]
+ path_params: dict[str, Any]
if scope["type"] == "http":
route_path = get_route_path(scope)
match = self.path_regex.match(route_path)
return Match.FULL, child_scope
return Match.NONE, {}
- def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
+ def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
seen_params = set(path_params.keys())
expected_params = set(self.param_convertors.keys())
else:
await self.app(scope, receive, send)
- def __eq__(self, other: typing.Any) -> bool:
+ def __eq__(self, other: Any) -> bool:
return (
isinstance(other, Route)
and self.path == other.path
def __init__(
self,
path: str,
- endpoint: typing.Callable[..., typing.Any],
+ endpoint: Callable[..., Any],
*,
name: str | None = None,
- middleware: typing.Sequence[Middleware] | None = None,
+ middleware: Sequence[Middleware] | None = None,
) -> None:
assert path.startswith("/"), "Routed paths must start with '/'"
self.path = path
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
def matches(self, scope: Scope) -> tuple[Match, Scope]:
- path_params: dict[str, typing.Any]
+ path_params: dict[str, Any]
if scope["type"] == "websocket":
route_path = get_route_path(scope)
match = self.path_regex.match(route_path)
return Match.FULL, child_scope
return Match.NONE, {}
- def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
+ def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
seen_params = set(path_params.keys())
expected_params = set(self.param_convertors.keys())
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
- def __eq__(self, other: typing.Any) -> bool:
+ def __eq__(self, other: Any) -> bool:
return isinstance(other, WebSocketRoute) and self.path == other.path and self.endpoint == other.endpoint
def __repr__(self) -> str:
self,
path: str,
app: ASGIApp | None = None,
- routes: typing.Sequence[BaseRoute] | None = None,
+ routes: Sequence[BaseRoute] | None = None,
name: str | None = None,
*,
- middleware: typing.Sequence[Middleware] | None = None,
+ middleware: Sequence[Middleware] | None = None,
) -> None:
assert path == "" or path.startswith("/"), "Routed paths must start with '/'"
assert app is not None or routes is not None, "Either 'app=...', or 'routes=' must be specified"
return getattr(self._base_app, "routes", [])
def matches(self, scope: Scope) -> tuple[Match, Scope]:
- path_params: dict[str, typing.Any]
+ path_params: dict[str, Any]
if scope["type"] in ("http", "websocket"): # pragma: no branch
root_path = scope.get("root_path", "")
route_path = get_route_path(scope)
return Match.FULL, child_scope
return Match.NONE, {}
- def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
+ def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
if self.name is not None and name == self.name and "path" in path_params:
# 'name' matches "<mount_name>".
path_params["path"] = path_params["path"].lstrip("/")
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
- def __eq__(self, other: typing.Any) -> bool:
+ def __eq__(self, other: Any) -> bool:
return isinstance(other, Mount) and self.path == other.path and self.app == other.app
def __repr__(self) -> str:
return Match.FULL, child_scope
return Match.NONE, {}
- def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
+ def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
if self.name is not None and name == self.name and "path" in path_params:
# 'name' matches "<mount_name>".
path = path_params.pop("path")
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
- def __eq__(self, other: typing.Any) -> bool:
+ def __eq__(self, other: Any) -> bool:
return isinstance(other, Host) and self.host == other.host and self.app == other.app
def __repr__(self) -> str:
return f"{class_name}(host={self.host!r}, name={name!r}, app={self.app!r})"
-_T = typing.TypeVar("_T")
+_T = TypeVar("_T")
-class _AsyncLiftContextManager(typing.AsyncContextManager[_T]):
- def __init__(self, cm: typing.ContextManager[_T]):
+class _AsyncLiftContextManager(AbstractAsyncContextManager[_T]):
+ def __init__(self, cm: AbstractContextManager[_T]):
self._cm = cm
async def __aenter__(self) -> _T:
def _wrap_gen_lifespan_context(
- lifespan_context: typing.Callable[[typing.Any], typing.Generator[typing.Any, typing.Any, typing.Any]],
-) -> typing.Callable[[typing.Any], typing.AsyncContextManager[typing.Any]]:
+ lifespan_context: Callable[[Any], Generator[Any, Any, Any]],
+) -> Callable[[Any], AbstractAsyncContextManager[Any]]:
cmgr = contextlib.contextmanager(lifespan_context)
@functools.wraps(cmgr)
- def wrapper(app: typing.Any) -> _AsyncLiftContextManager[typing.Any]:
+ def wrapper(app: Any) -> _AsyncLiftContextManager[Any]:
return _AsyncLiftContextManager(cmgr(app))
return wrapper
class Router:
def __init__(
self,
- routes: typing.Sequence[BaseRoute] | None = None,
+ routes: Sequence[BaseRoute] | None = None,
redirect_slashes: bool = True,
default: ASGIApp | None = None,
- on_startup: typing.Sequence[typing.Callable[[], typing.Any]] | None = None,
- on_shutdown: typing.Sequence[typing.Callable[[], typing.Any]] | None = None,
+ on_startup: Sequence[Callable[[], Any]] | None = None,
+ on_shutdown: Sequence[Callable[[], Any]] | None = None,
# the generic to Lifespan[AppType] is the type of the top level application
- # which the router cannot know statically, so we use typing.Any
- lifespan: Lifespan[typing.Any] | None = None,
+ # which the router cannot know statically, so we use Any
+ lifespan: Lifespan[Any] | None = None,
*,
- middleware: typing.Sequence[Middleware] | None = None,
+ middleware: Sequence[Middleware] | None = None,
) -> None:
self.routes = [] if routes is None else list(routes)
self.redirect_slashes = redirect_slashes
)
if lifespan is None:
- self.lifespan_context: Lifespan[typing.Any] = _DefaultLifespan(self)
+ self.lifespan_context: Lifespan[Any] = _DefaultLifespan(self)
elif inspect.isasyncgenfunction(lifespan):
warnings.warn(
response = PlainTextResponse("Not Found", status_code=404)
await response(scope, receive, send)
- def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
+ def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
for route in self.routes:
try:
return route.url_path_for(name, **path_params)
startup and shutdown events.
"""
started = False
- app: typing.Any = scope.get("app")
+ app: Any = scope.get("app")
await receive()
try:
async with self.lifespan_context(app) as maybe_state:
await self.default(scope, receive, send)
- def __eq__(self, other: typing.Any) -> bool:
+ def __eq__(self, other: Any) -> bool:
return isinstance(other, Router) and self.routes == other.routes
def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover
def add_route(
self,
path: str,
- endpoint: typing.Callable[[Request], typing.Awaitable[Response] | Response],
+ endpoint: Callable[[Request], Awaitable[Response] | Response],
methods: list[str] | None = None,
name: str | None = None,
include_in_schema: bool = True,
def add_websocket_route(
self,
path: str,
- endpoint: typing.Callable[[WebSocket], typing.Awaitable[None]],
+ endpoint: Callable[[WebSocket], Awaitable[None]],
name: str | None = None,
) -> None: # pragma: no cover
route = WebSocketRoute(path, endpoint=endpoint, name=name)
methods: list[str] | None = None,
name: str | None = None,
include_in_schema: bool = True,
- ) -> typing.Callable: # type: ignore[type-arg]
+ ) -> Callable: # type: ignore[type-arg]
"""
We no longer document this decorator style API, and its usage is discouraged.
Instead you should use the following approach:
DeprecationWarning,
)
- def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
+ def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
self.add_route(
path,
func,
return decorator
- def websocket_route(self, path: str, name: str | None = None) -> typing.Callable: # type: ignore[type-arg]
+ def websocket_route(self, path: str, name: str | None = None) -> Callable: # type: ignore[type-arg]
"""
We no longer document this decorator style API, and its usage is discouraged.
Instead you should use the following approach:
DeprecationWarning,
)
- def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
+ def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
self.add_websocket_route(path, func, name=name)
return func
return decorator
- def add_event_handler(self, event_type: str, func: typing.Callable[[], typing.Any]) -> None: # pragma: no cover
+ def add_event_handler(self, event_type: str, func: Callable[[], Any]) -> None: # pragma: no cover
assert event_type in ("startup", "shutdown")
if event_type == "startup":
else:
self.on_shutdown.append(func)
- def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg]
+ def on_event(self, event_type: str) -> Callable: # type: ignore[type-arg]
warnings.warn(
"The `on_event` decorator is deprecated, and will be removed in version 1.0.0. "
"Refer to https://www.starlette.io/lifespan/ for recommended approach.",
DeprecationWarning,
)
- def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
+ def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
self.add_event_handler(event_type, func)
return func
import inspect
import re
-import typing
+from typing import Any, Callable, NamedTuple
from starlette.requests import Request
from starlette.responses import Response
class OpenAPIResponse(Response):
media_type = "application/vnd.oai.openapi"
- def render(self, content: typing.Any) -> bytes:
+ def render(self, content: Any) -> bytes:
assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse."
assert isinstance(content, dict), "The schema passed to OpenAPIResponse should be a dictionary."
return yaml.dump(content, default_flow_style=False).encode("utf-8")
-class EndpointInfo(typing.NamedTuple):
+class EndpointInfo(NamedTuple):
path: str
http_method: str
- func: typing.Callable[..., typing.Any]
+ func: Callable[..., Any]
_remove_converter_pattern = re.compile(r":\w+}")
class BaseSchemaGenerator:
- def get_schema(self, routes: list[BaseRoute]) -> dict[str, typing.Any]:
+ def get_schema(self, routes: list[BaseRoute]) -> dict[str, Any]:
raise NotImplementedError() # pragma: no cover
def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]:
"""
return _remove_converter_pattern.sub("}", path)
- def parse_docstring(self, func_or_method: typing.Callable[..., typing.Any]) -> dict[str, typing.Any]:
+ def parse_docstring(self, func_or_method: Callable[..., Any]) -> dict[str, Any]:
"""
Given a function, parse the docstring as YAML and return a dictionary of info.
"""
class SchemaGenerator(BaseSchemaGenerator):
- def __init__(self, base_schema: dict[str, typing.Any]) -> None:
+ def __init__(self, base_schema: dict[str, Any]) -> None:
self.base_schema = base_schema
- def get_schema(self, routes: list[BaseRoute]) -> dict[str, typing.Any]:
+ def get_schema(self, routes: list[BaseRoute]) -> dict[str, Any]:
schema = dict(self.base_schema)
schema.setdefault("paths", {})
endpoints_info = self.get_endpoints(routes)
import importlib.util
import os
import stat
-import typing
from email.utils import parsedate
+from typing import Union
import anyio
import anyio.to_thread
from starlette.responses import FileResponse, RedirectResponse, Response
from starlette.types import Receive, Scope, Send
-PathLike = typing.Union[str, "os.PathLike[str]"]
+PathLike = Union[str, "os.PathLike[str]"]
class NotModifiedResponse(Response):
from __future__ import annotations
-import typing
import warnings
+from collections.abc import Mapping, Sequence
from os import PathLike
+from typing import Any, Callable, cast, overload
from starlette.background import BackgroundTask
from starlette.datastructures import URL
class _TemplateResponse(HTMLResponse):
def __init__(
self,
- template: typing.Any,
- context: dict[str, typing.Any],
+ template: Any,
+ context: dict[str, Any],
status_code: int = 200,
- headers: typing.Mapping[str, str] | None = None,
+ headers: Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
):
return templates.TemplateResponse("index.html", {"request": request})
"""
- @typing.overload
+ @overload
def __init__(
self,
- directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]],
+ directory: str | PathLike[str] | Sequence[str | PathLike[str]],
*,
- context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None,
- **env_options: typing.Any,
+ context_processors: list[Callable[[Request], dict[str, Any]]] | None = None,
+ **env_options: Any,
) -> None: ...
- @typing.overload
+ @overload
def __init__(
self,
*,
env: jinja2.Environment,
- context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None,
+ context_processors: list[Callable[[Request], dict[str, Any]]] | None = None,
) -> None: ...
def __init__(
self,
- directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]] | None = None,
+ directory: str | PathLike[str] | Sequence[str | PathLike[str]] | None = None,
*,
- context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None,
+ context_processors: list[Callable[[Request], dict[str, Any]]] | None = None,
env: jinja2.Environment | None = None,
- **env_options: typing.Any,
+ **env_options: Any,
) -> None:
if env_options:
warnings.warn(
def _create_env(
self,
- directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]],
- **env_options: typing.Any,
+ directory: str | PathLike[str] | Sequence[str | PathLike[str]],
+ **env_options: Any,
) -> jinja2.Environment:
loader = jinja2.FileSystemLoader(directory)
env_options.setdefault("loader", loader)
def _setup_env_defaults(self, env: jinja2.Environment) -> None:
@pass_context
def url_for(
- context: dict[str, typing.Any],
+ context: dict[str, Any],
name: str,
/,
- **path_params: typing.Any,
+ **path_params: Any,
) -> URL:
request: Request = context["request"]
return request.url_for(name, **path_params)
def get_template(self, name: str) -> jinja2.Template:
return self.env.get_template(name)
- @typing.overload
+ @overload
def TemplateResponse(
self,
request: Request,
name: str,
- context: dict[str, typing.Any] | None = None,
+ context: dict[str, Any] | None = None,
status_code: int = 200,
- headers: typing.Mapping[str, str] | None = None,
+ headers: Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> _TemplateResponse: ...
- @typing.overload
+ @overload
def TemplateResponse(
self,
name: str,
- context: dict[str, typing.Any] | None = None,
+ context: dict[str, Any] | None = None,
status_code: int = 200,
- headers: typing.Mapping[str, str] | None = None,
+ headers: Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> _TemplateResponse:
# Deprecated usage
...
- def TemplateResponse(self, *args: typing.Any, **kwargs: typing.Any) -> _TemplateResponse:
+ def TemplateResponse(self, *args: Any, **kwargs: Any) -> _TemplateResponse:
if args:
if isinstance(args[0], str): # the first argument is template name (old style)
warnings.warn(
context = kwargs.get("context", {})
request = kwargs.get("request", context.get("request"))
- name = typing.cast(str, kwargs["name"])
+ name = cast(str, kwargs["name"])
status_code = kwargs.get("status_code", 200)
headers = kwargs.get("headers")
media_type = kwargs.get("media_type")
import json
import math
import sys
-import typing
import warnings
+from collections.abc import Awaitable, Generator, Iterable, Mapping, MutableMapping, Sequence
from concurrent.futures import Future
+from contextlib import AbstractContextManager
from types import GeneratorType
+from typing import (
+ Any,
+ Callable,
+ Literal,
+ TypedDict,
+ Union,
+ cast,
+)
from urllib.parse import unquote, urljoin
import anyio
"You can install this with:\n"
" $ pip install httpx\n"
)
-_PortalFactoryType = typing.Callable[[], typing.ContextManager[anyio.abc.BlockingPortal]]
+_PortalFactoryType = Callable[[], AbstractContextManager[anyio.abc.BlockingPortal]]
-ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]]
-ASGI2App = typing.Callable[[Scope], ASGIInstance]
-ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
+ASGIInstance = Callable[[Receive, Send], Awaitable[None]]
+ASGI2App = Callable[[Scope], ASGIInstance]
+ASGI3App = Callable[[Scope, Receive, Send], Awaitable[None]]
-_RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str], bytes]]
+_RequestData = Mapping[str, Union[str, Iterable[str], bytes]]
def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]:
await instance(receive, send)
-class _AsyncBackend(typing.TypedDict):
+class _AsyncBackend(TypedDict):
backend: str
- backend_options: dict[str, typing.Any]
+ backend_options: dict[str, Any]
class _Upgrade(Exception):
self.exit_stack = stack.pop_all()
return self
- def __exit__(self, *args: typing.Any) -> bool | None:
+ def __exit__(self, *args: Any) -> bool | None:
return self.exit_stack.__exit__(*args)
async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None:
def send_bytes(self, data: bytes) -> None:
self.send({"type": "websocket.receive", "bytes": data})
- def send_json(self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text") -> None:
+ def send_json(self, data: Any, mode: Literal["text", "binary"] = "text") -> None:
text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
if mode == "text":
self.send({"type": "websocket.receive", "text": text})
def receive_text(self) -> str:
message = self.receive()
self._raise_on_close(message)
- return typing.cast(str, message["text"])
+ return cast(str, message["text"])
def receive_bytes(self) -> bytes:
message = self.receive()
self._raise_on_close(message)
- return typing.cast(bytes, message["bytes"])
+ return cast(bytes, message["bytes"])
- def receive_json(self, mode: typing.Literal["text", "binary"] = "text") -> typing.Any:
+ def receive_json(self, mode: Literal["text", "binary"] = "text") -> Any:
message = self.receive()
self._raise_on_close(message)
if mode == "text":
root_path: str = "",
*,
client: tuple[str, int],
- app_state: dict[str, typing.Any],
+ app_state: dict[str, Any],
) -> None:
self.app = app
self.raise_server_exceptions = raise_server_exceptions
# Include other request headers.
headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()]
- scope: dict[str, typing.Any]
+ scope: dict[str, Any]
if scheme in {"ws", "wss"}:
subprotocol = request.headers.get("sec-websocket-protocol", None)
if subprotocol is None:
- subprotocols: typing.Sequence[str] = []
+ subprotocols: Sequence[str] = []
else:
subprotocols = [value.strip() for value in subprotocol.split(",")]
scope = {
request_complete = False
response_started = False
response_complete: anyio.Event
- raw_kwargs: dict[str, typing.Any] = {"stream": io.BytesIO()}
+ raw_kwargs: dict[str, Any] = {"stream": io.BytesIO()}
template = None
context = None
base_url: str = "http://testserver",
raise_server_exceptions: bool = True,
root_path: str = "",
- backend: typing.Literal["asyncio", "trio"] = "asyncio",
- backend_options: dict[str, typing.Any] | None = None,
+ backend: Literal["asyncio", "trio"] = "asyncio",
+ backend_options: dict[str, Any] | None = None,
cookies: httpx._types.CookieTypes | None = None,
headers: dict[str, str] | None = None,
follow_redirects: bool = True,
if _is_asgi3(app):
asgi_app = app
else:
- app = typing.cast(ASGI2App, app) # type: ignore[assignment]
+ app = cast(ASGI2App, app) # type: ignore[assignment]
asgi_app = _WrapASGI2(app) # type: ignore[arg-type]
self.app = asgi_app
- self.app_state: dict[str, typing.Any] = {}
+ self.app_state: dict[str, Any] = {}
transport = _TestClientTransport(
self.app,
portal_factory=self._portal_factory,
)
@contextlib.contextmanager
- def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, None]:
+ def _portal_factory(self) -> Generator[anyio.abc.BlockingPortal, None, None]:
if self.portal is not None:
yield self.portal
else:
content: httpx._types.RequestContent | None = None,
data: _RequestData | None = None,
files: httpx._types.RequestFiles | None = None,
- json: typing.Any = None,
+ json: Any = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- extensions: dict[str, typing.Any] | None = None,
+ extensions: dict[str, Any] | None = None,
) -> httpx.Response:
if timeout is not httpx.USE_CLIENT_DEFAULT:
warnings.warn(
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- extensions: dict[str, typing.Any] | None = None,
+ extensions: dict[str, Any] | None = None,
) -> httpx.Response:
return super().get(
url,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- extensions: dict[str, typing.Any] | None = None,
+ extensions: dict[str, Any] | None = None,
) -> httpx.Response:
return super().options(
url,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- extensions: dict[str, typing.Any] | None = None,
+ extensions: dict[str, Any] | None = None,
) -> httpx.Response:
return super().head(
url,
content: httpx._types.RequestContent | None = None,
data: _RequestData | None = None,
files: httpx._types.RequestFiles | None = None,
- json: typing.Any = None,
+ json: Any = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- extensions: dict[str, typing.Any] | None = None,
+ extensions: dict[str, Any] | None = None,
) -> httpx.Response:
return super().post(
url,
content: httpx._types.RequestContent | None = None,
data: _RequestData | None = None,
files: httpx._types.RequestFiles | None = None,
- json: typing.Any = None,
+ json: Any = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- extensions: dict[str, typing.Any] | None = None,
+ extensions: dict[str, Any] | None = None,
) -> httpx.Response:
return super().put(
url,
content: httpx._types.RequestContent | None = None,
data: _RequestData | None = None,
files: httpx._types.RequestFiles | None = None,
- json: typing.Any = None,
+ json: Any = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- extensions: dict[str, typing.Any] | None = None,
+ extensions: dict[str, Any] | None = None,
) -> httpx.Response:
return super().patch(
url,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
- extensions: dict[str, typing.Any] | None = None,
+ extensions: dict[str, Any] | None = None,
) -> httpx.Response:
return super().delete(
url,
def websocket_connect(
self,
url: str,
- subprotocols: typing.Sequence[str] | None = None,
- **kwargs: typing.Any,
+ subprotocols: Sequence[str] | None = None,
+ **kwargs: Any,
) -> WebSocketTestSession:
url = urljoin("ws://testserver", url)
headers = kwargs.get("headers", {})
def reset_portal() -> None:
self.portal = None
- send: anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any] | None] = (
+ send: anyio.create_memory_object_stream[MutableMapping[str, Any] | None] = (
anyio.create_memory_object_stream(math.inf)
)
- receive: anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any]] = (
- anyio.create_memory_object_stream(math.inf)
+ receive: anyio.create_memory_object_stream[MutableMapping[str, Any]] = anyio.create_memory_object_stream(
+ math.inf
)
for channel in (*send, *receive):
stack.callback(channel.close)
return self
- def __exit__(self, *args: typing.Any) -> None:
+ def __exit__(self, *args: Any) -> None:
self.exit_stack.close()
async def lifespan(self) -> None:
async def wait_startup(self) -> None:
await self.stream_receive.send({"type": "lifespan.startup"})
- async def receive() -> typing.Any:
+ async def receive() -> Any:
message = await self.stream_send.receive()
if message is None:
self.task.result()
await receive()
async def wait_shutdown(self) -> None:
- async def receive() -> typing.Any:
+ async def receive() -> Any:
message = await self.stream_send.receive()
if message is None:
self.task.result()
-import typing
+from collections.abc import Awaitable, Mapping, MutableMapping
+from contextlib import AbstractAsyncContextManager
+from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union
-if typing.TYPE_CHECKING:
+if TYPE_CHECKING:
from starlette.requests import Request
from starlette.responses import Response
from starlette.websockets import WebSocket
-AppType = typing.TypeVar("AppType")
+AppType = TypeVar("AppType")
-Scope = typing.MutableMapping[str, typing.Any]
-Message = typing.MutableMapping[str, typing.Any]
+Scope = MutableMapping[str, Any]
+Message = MutableMapping[str, Any]
-Receive = typing.Callable[[], typing.Awaitable[Message]]
-Send = typing.Callable[[Message], typing.Awaitable[None]]
+Receive = Callable[[], Awaitable[Message]]
+Send = Callable[[Message], Awaitable[None]]
-ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
+ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]]
-StatelessLifespan = typing.Callable[[AppType], typing.AsyncContextManager[None]]
-StatefulLifespan = typing.Callable[[AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]]]
-Lifespan = typing.Union[StatelessLifespan[AppType], StatefulLifespan[AppType]]
+StatelessLifespan = Callable[[AppType], AbstractAsyncContextManager[None]]
+StatefulLifespan = Callable[[AppType], AbstractAsyncContextManager[Mapping[str, Any]]]
+Lifespan = Union[StatelessLifespan[AppType], StatefulLifespan[AppType]]
-HTTPExceptionHandler = typing.Callable[["Request", Exception], "Response | typing.Awaitable[Response]"]
-WebSocketExceptionHandler = typing.Callable[["WebSocket", Exception], typing.Awaitable[None]]
-ExceptionHandler = typing.Union[HTTPExceptionHandler, WebSocketExceptionHandler]
+HTTPExceptionHandler = Callable[["Request", Exception], "Response | Awaitable[Response]"]
+WebSocketExceptionHandler = Callable[["WebSocket", Exception], Awaitable[None]]
+ExceptionHandler = Union[HTTPExceptionHandler, WebSocketExceptionHandler]
import enum
import json
-import typing
+from collections.abc import AsyncIterator, Iterable
+from typing import Any, cast
from starlette.requests import HTTPConnection
from starlette.responses import Response
async def accept(
self,
subprotocol: str | None = None,
- headers: typing.Iterable[tuple[bytes, bytes]] | None = None,
+ headers: Iterable[tuple[bytes, bytes]] | None = None,
) -> None:
headers = headers or []
raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
message = await self.receive()
self._raise_on_disconnect(message)
- return typing.cast(str, message["text"])
+ return cast(str, message["text"])
async def receive_bytes(self) -> bytes:
if self.application_state != WebSocketState.CONNECTED:
raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
message = await self.receive()
self._raise_on_disconnect(message)
- return typing.cast(bytes, message["bytes"])
+ return cast(bytes, message["bytes"])
- async def receive_json(self, mode: str = "text") -> typing.Any:
+ async def receive_json(self, mode: str = "text") -> Any:
if mode not in {"text", "binary"}:
raise RuntimeError('The "mode" argument should be "text" or "binary".')
if self.application_state != WebSocketState.CONNECTED:
text = message["bytes"].decode("utf-8")
return json.loads(text)
- async def iter_text(self) -> typing.AsyncIterator[str]:
+ async def iter_text(self) -> AsyncIterator[str]:
try:
while True:
yield await self.receive_text()
except WebSocketDisconnect:
pass
- async def iter_bytes(self) -> typing.AsyncIterator[bytes]:
+ async def iter_bytes(self) -> AsyncIterator[bytes]:
try:
while True:
yield await self.receive_bytes()
except WebSocketDisconnect:
pass
- async def iter_json(self) -> typing.AsyncIterator[typing.Any]:
+ async def iter_json(self) -> AsyncIterator[Any]:
try:
while True:
yield await self.receive_json()
async def send_bytes(self, data: bytes) -> None:
await self.send({"type": "websocket.send", "bytes": data})
- async def send_json(self, data: typing.Any, mode: str = "text") -> None:
+ async def send_json(self, data: Any, mode: str = "text") -> None:
if mode not in {"text", "binary"}:
raise RuntimeError('The "mode" argument should be "text" or "binary".')
text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
import os
-import typing
from pathlib import Path
from typing import Any, Optional
config = Config(path, environ={"DEBUG": "true"})
- def cast_to_int(v: typing.Any) -> int:
+ def cast_to_int(v: Any) -> int:
return int(v)
DEBUG = config("DEBUG", cast=bool)
-import typing
from collections.abc import Generator
+from typing import Any
import pytest
from pytest import MonkeyPatch
from starlette import _exception_handler
# Replace run_in_threadpool with a function that raises an error
- def mock_run_in_threadpool(*args: typing.Any, **kwargs: typing.Any) -> None:
+ def mock_run_in_threadpool(*args: Any, **kwargs: Any) -> None:
pytest.fail("run_in_threadpool should not be called for HTTP exceptions") # pragma: no cover
# Apply the monkeypatch only during this test
from __future__ import annotations
import os
-import typing
-from contextlib import nullcontext as does_not_raise
+from contextlib import AbstractContextManager, nullcontext as does_not_raise
from pathlib import Path
+from typing import Any
import pytest
from tests.types import TestClientFactory
-class ForceMultipartDict(dict[typing.Any, typing.Any]):
+class ForceMultipartDict(dict[Any, Any]):
def __bool__(self) -> bool:
return True
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
data = await request.form()
- output: dict[str, typing.Any] = {}
+ output: dict[str, Any] = {}
for key, value in data.items():
if isinstance(value, UploadFile):
content = await value.read()
async def multi_items_app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
data = await request.form()
- output: dict[str, list[typing.Any]] = {}
+ output: dict[str, list[Any]] = {}
for key, value in data.multi_items():
if key not in output:
output[key] = []
async def app_with_headers(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
data = await request.form()
- output: dict[str, typing.Any] = {}
+ output: dict[str, Any] = {}
for key, value in data.items():
if isinstance(value, UploadFile):
content = await value.read()
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
data = await request.form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size)
- output: dict[str, typing.Any] = {}
+ output: dict[str, Any] = {}
for key, value in data.items():
if isinstance(value, UploadFile):
content = await value.read()
)
def test_missing_boundary_parameter(
app: ASGIApp,
- expectation: typing.ContextManager[Exception],
+ expectation: AbstractContextManager[Exception],
test_client_factory: TestClientFactory,
) -> None:
client = test_client_factory(app)
)
def test_missing_name_parameter_on_content_disposition(
app: ASGIApp,
- expectation: typing.ContextManager[Exception],
+ expectation: AbstractContextManager[Exception],
test_client_factory: TestClientFactory,
) -> None:
client = test_client_factory(app)
)
def test_too_many_fields_raise(
app: ASGIApp,
- expectation: typing.ContextManager[Exception],
+ expectation: AbstractContextManager[Exception],
test_client_factory: TestClientFactory,
) -> None:
client = test_client_factory(app)
)
def test_too_many_files_raise(
app: ASGIApp,
- expectation: typing.ContextManager[Exception],
+ expectation: AbstractContextManager[Exception],
test_client_factory: TestClientFactory,
) -> None:
client = test_client_factory(app)
)
def test_too_many_files_single_field_raise(
app: ASGIApp,
- expectation: typing.ContextManager[Exception],
+ expectation: AbstractContextManager[Exception],
test_client_factory: TestClientFactory,
) -> None:
client = test_client_factory(app)
)
def test_too_many_files_and_fields_raise(
app: ASGIApp,
- expectation: typing.ContextManager[Exception],
+ expectation: AbstractContextManager[Exception],
test_client_factory: TestClientFactory,
) -> None:
client = test_client_factory(app)
)
def test_max_fields_is_customizable_low_raises(
app: ASGIApp,
- expectation: typing.ContextManager[Exception],
+ expectation: AbstractContextManager[Exception],
test_client_factory: TestClientFactory,
) -> None:
client = test_client_factory(app)
)
def test_max_files_is_customizable_low_raises(
app: ASGIApp,
- expectation: typing.ContextManager[Exception],
+ expectation: AbstractContextManager[Exception],
test_client_factory: TestClientFactory,
) -> None:
client = test_client_factory(app)
)
def test_max_part_size_exceeds_limit(
app: ASGIApp,
- expectation: typing.ContextManager[Exception],
+ expectation: AbstractContextManager[Exception],
test_client_factory: TestClientFactory,
) -> None:
client = test_client_factory(app)
)
def test_max_part_size_exceeds_custom_limit(
app: ASGIApp,
- expectation: typing.ContextManager[Exception],
+ expectation: AbstractContextManager[Exception],
test_client_factory: TestClientFactory,
) -> None:
client = test_client_factory(app)
import contextlib
import functools
import json
-import typing
import uuid
+from collections.abc import AsyncGenerator, AsyncIterator, Generator
+from typing import Callable, TypedDict
import pytest
@pytest.fixture
def client(
test_client_factory: TestClientFactory,
-) -> typing.Generator[TestClient, None, None]:
+) -> Generator[TestClient, None, None]:
with test_client_factory(app) as client:
yield client
def test_standalone_route_does_not_match(
- test_client_factory: typing.Callable[..., TestClient],
+ test_client_factory: Callable[..., TestClient],
) -> None:
app = Route("/", PlainTextResponse("Hello, World!"))
client = test_client_factory(app)
shutdown_called = False
@contextlib.asynccontextmanager
- async def lifespan(app: Starlette) -> typing.AsyncGenerator[None, None]:
+ async def lifespan(app: Starlette) -> AsyncGenerator[None, None]:
nonlocal lifespan_called
lifespan_called = True
yield
@contextlib.asynccontextmanager
async def lifespan(
app: ASGIApp,
- ) -> typing.AsyncGenerator[dict[str, str], None]:
+ ) -> AsyncGenerator[dict[str, str], None]:
yield {"foo": "bar"}
app = Router(
startup_complete = False
shutdown_complete = False
- class State(typing.TypedDict):
+ class State(TypedDict):
count: int
items: list[int]
return PlainTextResponse("hello, world")
@contextlib.asynccontextmanager
- async def lifespan(app: Starlette) -> typing.AsyncIterator[State]:
+ async def lifespan(app: Starlette) -> AsyncIterator[State]:
nonlocal startup_complete, shutdown_complete
startup_complete = True
state = State(count=0, items=[])
pytest.param(lambda request: ..., "<lambda>", id="lambda"), # pragma: no branch
],
)
-def test_route_name(endpoint: typing.Callable[..., Response], expected_name: str) -> None:
+def test_route_name(endpoint: Callable[..., Response], expected_name: str) -> None:
assert Route(path="/", endpoint=endpoint).name == expected_name
def test_add_route_to_app_after_mount(
- test_client_factory: typing.Callable[..., TestClient],
+ test_client_factory: Callable[..., TestClient],
) -> None:
"""Checks that Mount will pick up routes
added to the underlying app after it is mounted
def test_mounted_middleware_does_not_catch_exception(
- test_client_factory: typing.Callable[..., TestClient],
+ test_client_factory: Callable[..., TestClient],
) -> None:
# https://github.com/encode/starlette/pull/1649#discussion_r960236107
def exc(request: Request) -> Response:
import stat
import tempfile
import time
-import typing
from pathlib import Path
+from typing import Any
import anyio
import pytest
test_client_factory: TestClientFactory,
monkeypatch: pytest.MonkeyPatch,
) -> None:
- def mock_timeout(*args: typing.Any, **kwargs: typing.Any) -> None:
+ def mock_timeout(*args: Any, **kwargs: Any) -> None:
raise TimeoutError
path = os.path.join(tmpdir, "example.txt")