combine-as-imports = true
[tool.mypy]
-disallow_untyped_defs = true
+strict = true
ignore_missing_imports = true
-show_error_codes = true
+python_version = "3.8"
[[tool.mypy.overrides]]
module = "starlette.testclient.*"
-no_implicit_optional = false
+implicit_optional = true
-[[tool.mypy.overrides]]
-module = "tests.*"
-disallow_untyped_defs = false
-check_untyped_defs = true
+# TODO: Uncomment the following configuration when
+# https://github.com/python/mypy/issues/10045 is solved. In the meantime,
+# we are calling `mypy tests` directly. Check `scripts/check` for more info.
+# [[tool.mypy.overrides]]
+# module = "tests.*"
+# disallow_untyped_defs = false
+# check_untyped_defs = true
[tool.pytest.ini_options]
addopts = "-rxXs --strict-config --strict-markers"
./scripts/sync-version
${PREFIX}black --check --diff $SOURCE_FILES
-${PREFIX}mypy $SOURCE_FILES
+# TODO: Use `[[tool.mypy.overrides]]` on the `pyproject.toml` when the mypy issue is solved:
+# github.com/python/mypy/issues/10045. Check github.com/encode/starlette/pull/2180 for more info.
+${PREFIX}mypy starlette
+${PREFIX}mypy tests --disable-error-code no-untyped-def --disable-error-code no-untyped-call
${PREFIX}ruff check $SOURCE_FILES
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
-from starlette.responses import Response
-from starlette.types import ASGIApp, Message, Receive, Scope, Send
+from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
from starlette.websockets import WebSocket
-Handler = typing.Callable[..., typing.Any]
-ExceptionHandlers = typing.Dict[typing.Any, Handler]
-StatusHandlers = typing.Dict[int, Handler]
+ExceptionHandlers = typing.Dict[typing.Any, ExceptionHandler]
+StatusHandlers = typing.Dict[int, ExceptionHandler]
def _lookup_exception_handler(
exc_handlers: ExceptionHandlers, exc: Exception
-) -> typing.Optional[Handler]:
+) -> typing.Optional[ExceptionHandler]:
for cls in type(exc).__mro__:
if cls in exc_handlers:
return exc_handlers[cls]
raise RuntimeError(msg) from exc
if scope["type"] == "http":
- response: Response
if is_async_callable(handler):
response = await handler(conn, exc)
else:
import asyncio
import functools
+import sys
import typing
+if sys.version_info >= (3, 10): # pragma: no cover
+ from typing import TypeGuard
+else: # pragma: no cover
+ from typing_extensions import TypeGuard
-def is_async_callable(obj: typing.Any) -> bool:
+T = typing.TypeVar("T")
+AwaitableCallable = typing.Callable[..., typing.Awaitable[T]]
+
+
+@typing.overload
+def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]:
+ ...
+
+
+@typing.overload
+def is_async_callable(obj: typing.Any) -> TypeGuard[AwaitableCallable[typing.Any]]:
+ ...
+
+
+def is_async_callable(obj: typing.Any) -> typing.Any:
while isinstance(obj, functools.partial):
obj = obj.func
+from __future__ import annotations
+
import typing
import warnings
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import BaseRoute, Router
-from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
+from starlette.types import ASGIApp, ExceptionHandler, Lifespan, Receive, Scope, Send
+from starlette.websockets import WebSocket
AppType = typing.TypeVar("AppType", bound="Starlette")
def __init__(
self: "AppType",
debug: bool = False,
- routes: typing.Optional[typing.Sequence[BaseRoute]] = None,
- middleware: typing.Optional[typing.Sequence[Middleware]] = None,
- exception_handlers: typing.Optional[
- typing.Mapping[
- typing.Any,
- typing.Callable[
- [Request, Exception],
- typing.Union[Response, typing.Awaitable[Response]],
- ],
- ]
- ] = None,
- on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None,
- on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None,
+ routes: typing.Sequence[BaseRoute] | None = None,
+ middleware: typing.Sequence[Middleware] | None = None,
+ exception_handlers: typing.Mapping[typing.Any, ExceptionHandler] | None = None,
+ on_startup: typing.Sequence[typing.Callable[[], typing.Any]] | None = None,
+ on_shutdown: typing.Sequence[typing.Callable[[], typing.Any]] | None = None,
lifespan: typing.Optional[Lifespan["AppType"]] = None,
) -> None:
# The lifespan context function is a newer style that replaces
self.middleware_stack = self.build_middleware_stack()
await self.middleware_stack(scope, receive, send)
- def on_event(self, event_type: str) -> typing.Callable: # pragma: nocover
- return self.router.on_event(event_type)
+ def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg]
+ return self.router.on_event(event_type) # pragma: nocover
- def mount(
- self, path: str, app: ASGIApp, name: typing.Optional[str] = None
- ) -> None: # pragma: nocover
- self.router.mount(path, app=app, name=name)
+ def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None:
+ self.router.mount(path, app=app, name=name) # pragma: no cover
- def host(
- self, host: str, app: ASGIApp, name: typing.Optional[str] = None
- ) -> None: # pragma: no cover
- self.router.host(host, app=app, name=name)
+ def host(self, host: str, app: ASGIApp, name: str | None = None) -> None:
+ self.router.host(host, app=app, name=name) # pragma: no cover
def add_middleware(self, middleware_class: type, **options: typing.Any) -> None:
if self.middleware_stack is not None: # pragma: no cover
def add_exception_handler(
self,
- exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
- handler: typing.Callable,
+ exc_class_or_status_code: int | typing.Type[Exception],
+ handler: ExceptionHandler,
) -> None: # pragma: no cover
self.exception_handlers[exc_class_or_status_code] = handler
def add_event_handler(
- self, event_type: str, func: typing.Callable
+ self, event_type: str, func: typing.Callable # type: ignore[type-arg]
) -> None: # pragma: no cover
self.router.add_event_handler(event_type, func)
def add_route(
self,
path: str,
- route: typing.Callable,
+ route: typing.Callable[[Request], typing.Awaitable[Response] | Response],
methods: typing.Optional[typing.List[str]] = None,
name: typing.Optional[str] = None,
include_in_schema: bool = True,
)
def add_websocket_route(
- self, path: str, route: typing.Callable, name: typing.Optional[str] = None
+ self,
+ path: str,
+ route: typing.Callable[[WebSocket], typing.Awaitable[None]],
+ name: str | None = None,
) -> None: # pragma: no cover
self.router.add_websocket_route(path, route, name=name)
def exception_handler(
- self, exc_class_or_status_code: typing.Union[int, typing.Type[Exception]]
- ) -> typing.Callable:
+ self, exc_class_or_status_code: int | typing.Type[Exception]
+ ) -> typing.Callable: # type: ignore[type-arg]
warnings.warn(
"The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
"Refer to https://www.starlette.io/exceptions/ for the recommended approach.", # noqa: E501
DeprecationWarning,
)
- def decorator(func: typing.Callable) -> typing.Callable:
+ def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
self.add_exception_handler(exc_class_or_status_code, func)
return func
def route(
self,
path: str,
- methods: typing.Optional[typing.List[str]] = None,
- name: typing.Optional[str] = None,
+ methods: typing.List[str] | None = None,
+ name: str | None = None,
include_in_schema: bool = True,
- ) -> typing.Callable:
+ ) -> typing.Callable: # type: ignore[type-arg]
"""
We no longer document this decorator style API, and its usage is discouraged.
Instead you should use the following approach:
DeprecationWarning,
)
- def decorator(func: typing.Callable) -> typing.Callable:
+ def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
self.router.add_route(
path,
func,
return decorator
def websocket_route(
- self, path: str, name: typing.Optional[str] = None
- ) -> typing.Callable:
+ self, path: str, name: str | None = None
+ ) -> typing.Callable: # type: ignore[type-arg]
"""
We no longer document this decorator style API, and its usage is discouraged.
Instead you should use the following approach:
DeprecationWarning,
)
- def decorator(func: typing.Callable) -> typing.Callable:
+ def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
self.router.add_websocket_route(path, func, name=name)
return func
return decorator
- def middleware(self, middleware_type: str) -> typing.Callable:
+ def middleware(self, middleware_type: str) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
"""
We no longer document this decorator style API, and its usage is discouraged.
Instead you should use the following approach:
middleware_type == "http"
), 'Currently only middleware("http") is supported.'
- def decorator(func: typing.Callable) -> typing.Callable:
+ def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
self.add_middleware(BaseHTTPMiddleware, dispatch=func)
return func
import functools
import inspect
+import sys
import typing
from urllib.parse import urlencode
+if sys.version_info >= (3, 10): # pragma: no cover
+ from typing import ParamSpec
+else: # pragma: no cover
+ from typing_extensions import ParamSpec
+
from starlette._utils import is_async_callable
from starlette.exceptions import HTTPException
from starlette.requests import HTTPConnection, Request
-from starlette.responses import RedirectResponse, Response
+from starlette.responses import RedirectResponse
from starlette.websockets import WebSocket
-_CallableType = typing.TypeVar("_CallableType", bound=typing.Callable)
+_P = ParamSpec("_P")
def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool:
scopes: typing.Union[str, typing.Sequence[str]],
status_code: int = 403,
redirect: typing.Optional[str] = None,
-) -> typing.Callable[[_CallableType], _CallableType]:
+) -> typing.Callable[
+ [typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]
+]:
scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
- def decorator(func: typing.Callable) -> typing.Callable:
+ def decorator(
+ func: typing.Callable[_P, typing.Any]
+ ) -> typing.Callable[_P, typing.Any]:
sig = inspect.signature(func)
for idx, parameter in enumerate(sig.parameters.values()):
if parameter.name == "request" or parameter.name == "websocket":
if type_ == "websocket":
# Handle websocket functions. (Always async)
@functools.wraps(func)
- async def websocket_wrapper(
- *args: typing.Any, **kwargs: typing.Any
- ) -> None:
+ async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
websocket = kwargs.get(
"websocket", args[idx] if idx < len(args) else None
)
elif is_async_callable(func):
# Handle async request/response functions.
@functools.wraps(func)
- async def async_wrapper(
- *args: typing.Any, **kwargs: typing.Any
- ) -> Response:
+ async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any:
request = kwargs.get("request", args[idx] if idx < len(args) else None)
assert isinstance(request, Request)
else:
# Handle sync request/response functions.
@functools.wraps(func)
- def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
+ def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any:
request = kwargs.get("request", args[idx] if idx < len(args) else None)
assert isinstance(request, Request)
return sync_wrapper
- return decorator # type: ignore[return-value]
+ return decorator
class AuthenticationError(Exception):
import functools
-import sys
import typing
import warnings
-import anyio
-
-if sys.version_info >= (3, 10): # pragma: no cover
- from typing import ParamSpec
-else: # pragma: no cover
- from typing_extensions import ParamSpec
-
+import anyio.to_thread
T = typing.TypeVar("T")
-P = ParamSpec("P")
-async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None:
+async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None: # type: ignore[type-arg] # noqa: E501
warnings.warn(
"run_until_first_complete is deprecated "
"and will be removed in a future version.",
async with anyio.create_task_group() as task_group:
- async def run(func: typing.Callable[[], typing.Coroutine]) -> None:
+ async def run(func: typing.Callable[[], typing.Coroutine]) -> None: # type: ignore[type-arg] # noqa: E501
await func()
task_group.cancel_scope.cancel()
task_group.start_soon(run, functools.partial(func, **kwargs))
+# TODO: We should use `ParamSpec` here, but mypy doesn't support it yet.
+# Check https://github.com/python/mypy/issues/12278 for more details.
async def run_in_threadpool(
- func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs
+ func: typing.Callable[..., T], *args: typing.Any, **kwargs: typing.Any
) -> T:
if kwargs: # pragma: no cover
# run_sync doesn't accept 'kwargs', so bind them in here
import os
import typing
-from collections.abc import MutableMapping
from pathlib import Path
pass
-class Environ(MutableMapping):
- def __init__(self, environ: typing.MutableMapping = os.environ):
+class Environ(typing.MutableMapping[str, str]):
+ def __init__(self, environ: typing.MutableMapping[str, str] = os.environ):
self._environ = environ
- self._has_been_read: typing.Set[typing.Any] = set()
+ self._has_been_read: typing.Set[str] = set()
- def __getitem__(self, key: typing.Any) -> typing.Any:
+ def __getitem__(self, key: str) -> str:
self._has_been_read.add(key)
return self._environ.__getitem__(key)
- def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
+ def __setitem__(self, key: str, value: str) -> None:
if key in self._has_been_read:
raise EnvironError(
f"Attempting to set environ['{key}'], but the value has already been "
)
self._environ.__setitem__(key, value)
- def __delitem__(self, key: typing.Any) -> None:
+ def __delitem__(self, key: str) -> None:
if key in self._has_been_read:
raise EnvironError(
f"Attempting to delete environ['{key}'], but the value has already "
)
self._environ.__delitem__(key)
- def __iter__(self) -> typing.Iterator:
+ def __iter__(self) -> typing.Iterator[str]:
return iter(self._environ)
def __len__(self) -> int:
def __call__(
self,
key: str,
- cast: typing.Optional[typing.Callable] = None,
+ cast: typing.Optional[typing.Callable[[typing.Any], typing.Any]] = None,
default: typing.Any = undefined,
) -> typing.Any:
return self.get(key, cast, default)
def get(
self,
key: str,
- cast: typing.Optional[typing.Callable] = None,
+ cast: typing.Optional[typing.Callable[[typing.Any], typing.Any]] = None,
default: typing.Any = undefined,
) -> typing.Any:
key = self.env_prefix + key
return file_values
def _perform_cast(
- self, key: str, value: typing.Any, cast: typing.Optional[typing.Callable] = None
+ self,
+ key: str,
+ value: typing.Any,
+ cast: typing.Optional[typing.Callable[[typing.Any], typing.Any]] = None,
) -> typing.Any:
if cast is None or value is None:
return value
raise NotImplementedError() # pragma: no cover
-class StringConvertor(Convertor):
+class StringConvertor(Convertor[str]):
regex = "[^/]+"
def convert(self, value: str) -> str:
return value
-class PathConvertor(Convertor):
+class PathConvertor(Convertor[str]):
regex = ".*"
def convert(self, value: str) -> str:
return str(value)
-class IntegerConvertor(Convertor):
+class IntegerConvertor(Convertor[int]):
regex = "[0-9]+"
def convert(self, value: str) -> int:
return str(value)
-class FloatConvertor(Convertor):
+class FloatConvertor(Convertor[float]):
regex = r"[0-9]+(\.[0-9]+)?"
def convert(self, value: str) -> float:
return ("%0.20f" % value).rstrip("0").rstrip(".")
-class UUIDConvertor(Convertor):
+class UUIDConvertor(Convertor[uuid.UUID]):
regex = "[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
def convert(self, value: str) -> uuid.UUID:
return str(value)
-CONVERTOR_TYPES = {
+CONVERTOR_TYPES: typing.Dict[str, Convertor[typing.Any]] = {
"str": StringConvertor(),
"path": PathConvertor(),
"int": IntegerConvertor(),
}
-def register_url_convertor(key: str, convertor: Convertor) -> None:
+def register_url_convertor(key: str, convertor: Convertor[typing.Any]) -> None:
CONVERTOR_TYPES[key] = convertor
import typing
-from collections.abc import Sequence
from shlex import shlex
from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit
return bool(self._value)
-class CommaSeparatedStrings(Sequence):
+class CommaSeparatedStrings(typing.Sequence[str]):
def __init__(self, value: typing.Union[str, typing.Sequence[str]]):
if isinstance(value, str):
splitter = shlex(value, posix=True)
if kwargs:
value = (
ImmutableMultiDict(value).multi_items()
- + ImmutableMultiDict(kwargs).multi_items() # type: ignore[operator]
+ + ImmutableMultiDict(kwargs).multi_items()
)
if not value:
self._list = [(k, v) for k, v in self._list if k != key]
return self._dict.pop(key, default)
- def popitem(self) -> typing.Tuple:
+ def popitem(self) -> typing.Tuple[typing.Any, typing.Any]:
key, value = self._dict.popitem()
self._list = [(k, v) for k, v in self._list if k != key]
return key, value
- def poplist(self, key: typing.Any) -> typing.List:
+ def poplist(self, key: typing.Any) -> typing.List[typing.Any]:
values = [v for k, v in self._list if k == key]
self.pop(key)
return values
return self[key]
- def setlist(self, key: typing.Any, values: typing.List) -> None:
+ def setlist(self, key: typing.Any, values: typing.List[typing.Any]) -> None:
if not values:
self.pop(key, None)
else:
self,
*args: typing.Union[
"MultiDict",
- typing.Mapping,
+ typing.Mapping[typing.Any, typing.Any],
typing.List[typing.Tuple[typing.Any, typing.Any]],
],
**kwargs: typing.Any,
def __init__(
self,
*args: typing.Union[
- "ImmutableMultiDict",
- typing.Mapping,
+ "ImmutableMultiDict[typing.Any, typing.Any]",
+ typing.Mapping[typing.Any, typing.Any],
typing.List[typing.Tuple[typing.Any, typing.Any]],
str,
bytes,
if getattr(self, method.lower(), None) is not None
]
- def __await__(self) -> typing.Generator:
+ def __await__(self) -> typing.Generator[typing.Any, None, None]:
return self.dispatch().__await__()
async def dispatch(self) -> None:
self.receive = receive
self.send = send
- def __await__(self) -> typing.Generator:
+ def __await__(self) -> typing.Generator[typing.Any, None, None]:
return self.dispatch().__await__()
async def dispatch(self) -> None:
self,
status_code: int,
detail: typing.Optional[str] = None,
- headers: typing.Optional[dict] = None,
+ headers: typing.Optional[typing.Dict[str, str]] = None,
) -> None:
if detail is None:
detail = http.HTTPStatus(status_code).phrase
self._charset = ""
self._file_parts_to_write: typing.List[typing.Tuple[MultipartPart, bytes]] = []
self._file_parts_to_finish: typing.List[MultipartPart] = []
- self._files_to_close_on_error: typing.List[SpooledTemporaryFile] = []
+ self._files_to_close_on_error: typing.List[SpooledTemporaryFile[bytes]] = []
def on_part_begin(self) -> None:
self._current_part = MultipartPart()
self.cls = cls
self.options = options
- def __iter__(self) -> typing.Iterator:
+ def __iter__(self) -> typing.Iterator[typing.Any]:
as_tuple = (self.cls, self.options)
return iter(as_tuple)
def __init__(
self,
app: ASGIApp,
- handler: typing.Optional[typing.Callable] = None,
+ handler: typing.Optional[
+ typing.Callable[[Request, Exception], typing.Any]
+ ] = None,
debug: bool = False,
) -> None:
self.app = app
self._status_handlers: StatusHandlers = {}
self._exception_handlers: ExceptionHandlers = {
HTTPException: self.http_exception,
- WebSocketException: self.websocket_exception, # type: ignore[dict-item]
+ WebSocketException: self.websocket_exception,
}
if handlers is not None:
for key, value in handlers.items():
)
-def build_environ(scope: Scope, body: bytes) -> dict:
+def build_environ(scope: Scope, body: bytes) -> typing.Dict[str, typing.Any]:
"""
Builds a scope and request body into a WSGI environ object.
"""
class WSGIMiddleware:
- def __init__(self, app: typing.Callable) -> None:
+ def __init__(self, app: typing.Callable[..., typing.Any]) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
stream_send: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
stream_receive: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
- def __init__(self, app: typing.Callable, scope: Scope) -> None:
+ def __init__(self, app: typing.Callable[..., typing.Any], scope: Scope) -> None:
self.app = app
self.scope = scope
self.status = None
},
)
- def wsgi(self, environ: dict, start_response: typing.Callable) -> None:
+ def wsgi(
+ self,
+ environ: typing.Dict[str, typing.Any],
+ start_response: typing.Callable[..., typing.Any],
+ ) -> None:
for chunk in self.app(environ, start_response):
anyio.from_thread.run(
self.stream_send.send,
assert (
"session" in self.scope
), "SessionMiddleware must be installed to access request.session"
- return self.scope["session"]
+ return self.scope["session"] # type: ignore[no-any-return]
@property
def auth(self) -> typing.Any:
@property
def method(self) -> str:
- return self.scope["method"]
+ return typing.cast(str, self.scope["method"])
@property
def receive(self) -> Receive:
self.body = self.render(content)
self.init_headers(headers)
- def render(self, content: typing.Any) -> bytes:
+ def render(self, content: typing.Union[str, bytes, None]) -> bytes:
if content is None:
return b""
if isinstance(content, bytes):
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.requests import Request
-from starlette.responses import PlainTextResponse, RedirectResponse
+from starlette.responses import PlainTextResponse, RedirectResponse, Response
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketClose
return inspect.iscoroutinefunction(obj)
-def request_response(func: typing.Callable) -> ASGIApp:
+def request_response(
+ func: typing.Callable[[Request], typing.Union[typing.Awaitable[Response], Response]]
+) -> ASGIApp:
"""
Takes a function or coroutine `func(request) -> response`,
and returns an ASGI application.
"""
- is_coroutine = is_async_callable(func)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive, send)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
- if is_coroutine:
+ if is_async_callable(func):
response = await func(request)
else:
response = await run_in_threadpool(func, request)
return app
-def websocket_session(func: typing.Callable) -> ASGIApp:
+def websocket_session(
+ func: typing.Callable[[WebSocket], typing.Awaitable[None]]
+) -> ASGIApp:
"""
Takes a coroutine `func(session)`, and returns an ASGI application.
"""
return app
-def get_name(endpoint: typing.Callable) -> str:
+def get_name(endpoint: typing.Callable[..., typing.Any]) -> str:
if inspect.isroutine(endpoint) or inspect.isclass(endpoint):
return endpoint.__name__
return endpoint.__class__.__name__
def replace_params(
path: str,
- param_convertors: typing.Dict[str, Convertor],
+ param_convertors: typing.Dict[str, Convertor[typing.Any]],
path_params: typing.Dict[str, str],
-) -> typing.Tuple[str, dict]:
+) -> typing.Tuple[str, typing.Dict[str, str]]:
for key, value in list(path_params.items()):
if "{" + key + "}" in path:
convertor = param_convertors[key]
def compile_path(
path: str,
-) -> typing.Tuple[typing.Pattern, str, typing.Dict[str, Convertor]]:
+) -> typing.Tuple[typing.Pattern[str], str, typing.Dict[str, Convertor[typing.Any]]]:
"""
Given a path string, like: "/{username:str}",
or a host string, like: "{subdomain}.mydomain.org", return a three-tuple
def __init__(
self,
path: str,
- endpoint: typing.Callable,
+ endpoint: typing.Callable[..., typing.Any],
*,
methods: typing.Optional[typing.List[str]] = None,
name: typing.Optional[str] = None,
class WebSocketRoute(BaseRoute):
def __init__(
- self, path: str, endpoint: typing.Callable, *, name: typing.Optional[str] = None
+ self,
+ path: str,
+ endpoint: typing.Callable[..., typing.Any],
+ *,
+ name: typing.Optional[str] = None,
) -> None:
assert path.startswith("/"), "Routed paths must start with '/'"
self.path = path
def _wrap_gen_lifespan_context(
- lifespan_context: typing.Callable[[typing.Any], typing.Generator]
-) -> typing.Callable[[typing.Any], typing.AsyncContextManager]:
+ lifespan_context: typing.Callable[
+ [typing.Any], typing.Generator[typing.Any, typing.Any, typing.Any]
+ ]
+) -> typing.Callable[[typing.Any], typing.AsyncContextManager[typing.Any]]:
cmgr = contextlib.contextmanager(lifespan_context)
@functools.wraps(cmgr)
- def wrapper(app: typing.Any) -> _AsyncLiftContextManager:
+ def wrapper(app: typing.Any) -> _AsyncLiftContextManager[typing.Any]:
return _AsyncLiftContextManager(cmgr(app))
return wrapper
routes: typing.Optional[typing.Sequence[BaseRoute]] = None,
redirect_slashes: bool = True,
default: typing.Optional[ASGIApp] = None,
- on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None,
- on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None,
+ on_startup: typing.Optional[
+ typing.Sequence[typing.Callable[[], typing.Any]]
+ ] = None,
+ on_shutdown: typing.Optional[
+ typing.Sequence[typing.Callable[[], typing.Any]]
+ ] = None,
# the generic to Lifespan[AppType] is the type of the top level application
# which the router cannot know statically, so we use typing.Any
lifespan: typing.Optional[Lifespan[typing.Any]] = None,
)
if lifespan is None:
- self.lifespan_context: Lifespan = _DefaultLifespan(self)
+ self.lifespan_context: Lifespan[typing.Any] = _DefaultLifespan(self)
elif inspect.isasyncgenfunction(lifespan):
warnings.warn(
DeprecationWarning,
)
self.lifespan_context = asynccontextmanager(
- lifespan, # type: ignore[arg-type]
+ lifespan,
)
elif inspect.isgeneratorfunction(lifespan):
warnings.warn(
DeprecationWarning,
)
self.lifespan_context = _wrap_gen_lifespan_context(
- lifespan, # type: ignore[arg-type]
+ lifespan,
)
else:
self.lifespan_context = lifespan
def add_route(
self,
path: str,
- endpoint: typing.Callable,
+ endpoint: typing.Callable[
+ [Request], typing.Union[typing.Awaitable[Response], Response]
+ ],
methods: typing.Optional[typing.List[str]] = None,
name: typing.Optional[str] = None,
include_in_schema: bool = True,
self.routes.append(route)
def add_websocket_route(
- self, path: str, endpoint: typing.Callable, name: typing.Optional[str] = None
+ self,
+ path: str,
+ endpoint: typing.Callable[[WebSocket], typing.Awaitable[None]],
+ name: typing.Optional[str] = None,
) -> None: # pragma: no cover
route = WebSocketRoute(path, endpoint=endpoint, name=name)
self.routes.append(route)
methods: typing.Optional[typing.List[str]] = None,
name: typing.Optional[str] = None,
include_in_schema: bool = True,
- ) -> typing.Callable:
+ ) -> typing.Callable: # type: ignore[type-arg]
"""
We no longer document this decorator style API, and its usage is discouraged.
Instead you should use the following approach:
DeprecationWarning,
)
- def decorator(func: typing.Callable) -> typing.Callable:
+ def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
self.add_route(
path,
func,
def websocket_route(
self, path: str, name: typing.Optional[str] = None
- ) -> typing.Callable:
+ ) -> typing.Callable: # type: ignore[type-arg]
"""
We no longer document this decorator style API, and its usage is discouraged.
Instead you should use the following approach:
DeprecationWarning,
)
- def decorator(func: typing.Callable) -> typing.Callable:
+ def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
self.add_websocket_route(path, func, name=name)
return func
return decorator
def add_event_handler(
- self, event_type: str, func: typing.Callable
+ self, event_type: str, func: typing.Callable[[], typing.Any]
) -> None: # pragma: no cover
assert event_type in ("startup", "shutdown")
else:
self.on_shutdown.append(func)
- def on_event(self, event_type: str) -> typing.Callable:
+ def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg]
warnings.warn(
"The `on_event` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
"Refer to https://www.starlette.io/lifespan/ for recommended approach.",
DeprecationWarning,
)
- def decorator(func: typing.Callable) -> typing.Callable:
+ def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
self.add_event_handler(event_type, func)
return func
class EndpointInfo(typing.NamedTuple):
path: str
http_method: str
- func: typing.Callable
+ func: typing.Callable[..., typing.Any]
class BaseSchemaGenerator:
- def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
+ def get_schema(
+ self, routes: typing.List[BaseRoute]
+ ) -> typing.Dict[str, typing.Any]:
raise NotImplementedError() # pragma: no cover
def get_endpoints(
- func
method ready to extract the docstring
"""
- endpoints_info: list = []
+ endpoints_info: typing.List[EndpointInfo] = []
for route in routes:
if isinstance(route, (Mount, Host)):
"""
return re.sub(r":\w+}", "}", path)
- def parse_docstring(self, func_or_method: typing.Callable) -> dict:
+ def parse_docstring(
+ self, func_or_method: typing.Callable[..., typing.Any]
+ ) -> typing.Dict[str, typing.Any]:
"""
Given a function, parse the docstring as YAML and return a dictionary of info.
"""
class SchemaGenerator(BaseSchemaGenerator):
- def __init__(self, base_schema: dict) -> None:
+ def __init__(self, base_schema: typing.Dict[str, typing.Any]) -> None:
self.base_schema = base_schema
- def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
+ def get_schema(
+ self, routes: typing.List[BaseRoute]
+ ) -> typing.Dict[str, typing.Any]:
schema = dict(self.base_schema)
schema.setdefault("paths", {})
endpoints_info = self.get_endpoints(routes)
Given the ASGI scope, return the `path` string to serve up,
with OS specific path separators, and any '..', '.' components removed.
"""
- return os.path.normpath(os.path.join(*scope["path"].split("/")))
+ return os.path.normpath(os.path.join(*scope["path"].split("/"))) # type: ignore[no-any-return] # noqa: E501
async def get_response(self, path: str, scope: Scope) -> Response:
"""
def __init__(
self,
template: typing.Any,
- context: dict,
+ context: typing.Dict[str, typing.Any],
status_code: int = 200,
headers: typing.Optional[typing.Mapping[str, str]] = None,
media_type: typing.Optional[str] = None,
@typing.overload
def __init__(
self,
- directory: typing.Union[
- str,
- PathLike,
- typing.Sequence[typing.Union[str, PathLike]],
- ],
+ directory: "typing.Union[str, PathLike[typing.AnyStr], typing.Sequence[typing.Union[str, PathLike[typing.AnyStr]]]]", # noqa: E501
*,
context_processors: typing.Optional[
typing.List[typing.Callable[[Request], typing.Dict[str, typing.Any]]]
def __init__(
self,
- directory: typing.Union[
- str, PathLike, typing.Sequence[typing.Union[str, PathLike]], None
- ] = None,
+ directory: "typing.Union[str, PathLike[typing.AnyStr], typing.Sequence[typing.Union[str, PathLike[typing.AnyStr]]], None]" = None, # noqa: E501
*,
context_processors: typing.Optional[
typing.List[typing.Callable[[Request], typing.Dict[str, typing.Any]]]
def _create_env(
self,
- directory: typing.Union[
- str, PathLike, typing.Sequence[typing.Union[str, PathLike]]
- ],
+ directory: "typing.Union[str, PathLike[typing.AnyStr], typing.Sequence[typing.Union[str, PathLike[typing.AnyStr]]]]", # noqa: E501
**env_options: typing.Any,
) -> "jinja2.Environment":
@pass_context
- def url_for(context: dict, name: str, /, **path_params: typing.Any) -> URL:
- request = context["request"]
+ def url_for(
+ context: typing.Dict[str, typing.Any],
+ name: str,
+ /,
+ **path_params: typing.Any,
+ ) -> URL:
+ request: Request = context["request"]
return request.url_for(name, **path_params)
loader = jinja2.FileSystemLoader(directory)
self,
request: Request,
name: str,
- context: typing.Optional[dict] = None,
+ context: typing.Optional[typing.Dict[str, typing.Any]] = None,
status_code: int = 200,
headers: typing.Optional[typing.Mapping[str, str]] = None,
media_type: typing.Optional[str] = None,
def TemplateResponse(
self,
name: str,
- context: typing.Optional[dict] = None,
+ context: typing.Optional[typing.Dict[str, typing.Any]] = None,
status_code: int = 200,
headers: typing.Optional[typing.Mapping[str, str]] = None,
media_type: typing.Optional[str] = None,
self.scope = scope
self.accepted_subprotocol = None
self.portal_factory = portal_factory
- self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue()
- self._send_queue: "queue.Queue[typing.Any]" = queue.Queue()
+ self._receive_queue: "queue.Queue[Message]" = queue.Queue()
+ self._send_queue: "queue.Queue[Message | BaseException]" = queue.Queue()
self.extra_headers = None
def __enter__(self) -> "WebSocketTestSession":
def receive_text(self) -> str:
message = self.receive()
self._raise_on_close(message)
- return message["text"]
+ return typing.cast(str, message["text"])
def receive_bytes(self) -> bytes:
message = self.receive()
self._raise_on_close(message)
- return message["bytes"]
+ return typing.cast(bytes, message["bytes"])
def receive_json(self, mode: str = "text") -> typing.Any:
assert mode in ["text", "binary"]
root_path: str = "",
backend: str = "asyncio",
backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None,
- cookies: httpx._client.CookieTypes = None,
+ cookies: httpx._types.CookieTypes = None,
headers: typing.Dict[str, str] = None,
follow_redirects: bool = True,
) -> None:
follow_redirects: typing.Optional[bool] = None,
allow_redirects: typing.Optional[bool] = None,
timeout: typing.Union[
- httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+ httpx._types.TimeoutTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
) -> httpx.Response:
method,
url,
content=content,
- data=data, # type: ignore[arg-type]
+ data=data,
files=files,
json=json,
params=params,
follow_redirects: typing.Optional[bool] = None,
allow_redirects: typing.Optional[bool] = None,
timeout: typing.Union[
- httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+ httpx._types.TimeoutTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
) -> httpx.Response:
follow_redirects: typing.Optional[bool] = None,
allow_redirects: typing.Optional[bool] = None,
timeout: typing.Union[
- httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+ httpx._types.TimeoutTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
) -> httpx.Response:
follow_redirects: typing.Optional[bool] = None,
allow_redirects: typing.Optional[bool] = None,
timeout: typing.Union[
- httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+ httpx._types.TimeoutTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
) -> httpx.Response:
follow_redirects: typing.Optional[bool] = None,
allow_redirects: typing.Optional[bool] = None,
timeout: typing.Union[
- httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+ httpx._types.TimeoutTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
) -> httpx.Response:
return super().post(
url,
content=content,
- data=data, # type: ignore[arg-type]
+ data=data,
files=files,
json=json,
params=params,
follow_redirects: typing.Optional[bool] = None,
allow_redirects: typing.Optional[bool] = None,
timeout: typing.Union[
- httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+ httpx._types.TimeoutTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
) -> httpx.Response:
return super().put(
url,
content=content,
- data=data, # type: ignore[arg-type]
+ data=data,
files=files,
json=json,
params=params,
follow_redirects: typing.Optional[bool] = None,
allow_redirects: typing.Optional[bool] = None,
timeout: typing.Union[
- httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+ httpx._types.TimeoutTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
) -> httpx.Response:
return super().patch(
url,
content=content,
- data=data, # type: ignore[arg-type]
+ data=data,
files=files,
json=json,
params=params,
follow_redirects: typing.Optional[bool] = None,
allow_redirects: typing.Optional[bool] = None,
timeout: typing.Union[
- httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+ httpx._types.TimeoutTypes, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT,
extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
) -> httpx.Response:
import typing
+if typing.TYPE_CHECKING:
+ from starlette.requests import Request
+ from starlette.responses import Response
+ from starlette.websockets import WebSocket
+
AppType = typing.TypeVar("AppType")
Scope = typing.MutableMapping[str, typing.Any]
[AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]]
]
Lifespan = typing.Union[StatelessLifespan[AppType], StatefulLifespan[AppType]]
+
+HTTPExceptionHandler = typing.Callable[
+ ["Request", Exception], typing.Union["Response", typing.Awaitable["Response"]]
+]
+WebSocketExceptionHandler = typing.Callable[
+ ["WebSocket", Exception], typing.Awaitable[None]
+]
+ExceptionHandler = typing.Union[HTTPExceptionHandler, WebSocketExceptionHandler]
)
message = await self.receive()
self._raise_on_disconnect(message)
- return message["text"]
+ return typing.cast(str, message["text"])
async def receive_bytes(self) -> bytes:
if self.application_state != WebSocketState.CONNECTED:
)
message = await self.receive()
self._raise_on_disconnect(message)
- return message["bytes"]
+ return typing.cast(bytes, message["bytes"])
async def receive_json(self, mode: str = "text") -> typing.Any:
if mode not in {"text", "binary"}:
convertors.CONVERTOR_TYPES = convert_types
-class DateTimeConvertor(Convertor):
+class DateTimeConvertor(Convertor[datetime]):
regex = "[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}(.[0-9]+)?"
def convert(self, value: str) -> datetime:
import pytest
from starlette.applications import Starlette
-from starlette.formparsers import MultiPartException, UploadFile, _user_safe_decode
+from starlette.datastructures import UploadFile
+from starlette.formparsers import MultiPartException, _user_safe_decode
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Mount
-class ForceMultipartDict(dict):
+class ForceMultipartDict(typing.Dict[typing.Any, typing.Any]):
def __bool__(self):
return True
async def multi_items_app(scope, receive, send):
request = Request(scope, receive)
data = await request.form()
- output: typing.Dict[str, list] = {}
+ output: typing.Dict[str, typing.List[typing.Any]] = {}
for key, value in data.multi_items():
if key not in output:
output[key] = []
import anyio
import pytest
-from starlette.datastructures import Address
-from starlette.requests import ClientDisconnect, Request, State
+from starlette.datastructures import Address, State
+from starlette.requests import ClientDisconnect, Request
from starlette.responses import JSONResponse, PlainTextResponse, Response
from starlette.types import Message, Scope
import datetime as dt
import os
import time
+import typing
from http.cookies import SimpleCookie
import anyio
client = test_client_factory(app)
response = client.get("/")
- cookie: SimpleCookie = SimpleCookie(response.headers.get("set-cookie"))
+ cookie: "SimpleCookie[typing.Any]" = SimpleCookie(
+ response.headers.get("set-cookie")
+ )
assert cookie["mycookie"]["expires"] == "Thu, 22 Jan 2037 12:00:10 GMT"
pytest.param(lambda request: ..., "<lambda>", id="lambda"),
],
)
-def test_route_name(endpoint: typing.Callable, expected_name: str):
+def test_route_name(endpoint: typing.Callable[..., typing.Any], expected_name: str):
assert Route(path="/", endpoint=endpoint).name == expected_name