]
dependencies = [
"anyio>=3.6.2,<5",
- "typing_extensions>=3.10.0; python_version < '3.10'",
+ "typing_extensions>=4.10.0; python_version < '3.13'",
]
[project.optional-dependencies]
if is_async_callable(handler):
response = await handler(conn, exc)
else:
- response = await run_in_threadpool(handler, conn, exc) # type: ignore
+ response = await run_in_threadpool(handler, conn, exc)
if response is not None:
await response(scope, receive, sender)
from starlette.types import Scope
-if sys.version_info >= (3, 10): # pragma: no cover
- from typing import TypeGuard
+if sys.version_info >= (3, 13): # pragma: no cover
+ from typing import TypeIs
else: # pragma: no cover
- from typing_extensions import TypeGuard
+ from typing_extensions import TypeIs
has_exceptiongroups = True
if sys.version_info < (3, 11): # pragma: no cover
@overload
-def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]: ...
+def is_async_callable(obj: AwaitableCallable[T]) -> TypeIs[AwaitableCallable[T]]: ...
@overload
-def is_async_callable(obj: Any) -> TypeGuard[AwaitableCallable[Any]]: ...
+def is_async_callable(obj: Any) -> TypeIs[AwaitableCallable[Any]]: ...
def is_async_callable(obj: Any) -> Any:
def build_middleware_stack(self) -> ASGIApp:
debug = self.debug
error_handler = None
- exception_handlers: dict[Any, Callable[[Request, Exception], Response]] = {}
+ exception_handlers: dict[Any, ExceptionHandler] = {}
for key, value in self.exception_handlers.items():
if key in (500, Exception):
import inspect
import sys
import traceback
-from typing import Any, Callable
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.requests import Request
from starlette.responses import HTMLResponse, PlainTextResponse, Response
-from starlette.types import ASGIApp, Message, Receive, Scope, Send
+from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
STYLES = """
p {
def __init__(
self,
app: ASGIApp,
- handler: Callable[[Request, Exception], Any] | None = None,
+ handler: ExceptionHandler | None = None,
debug: bool = False,
) -> None:
self.app = app
from __future__ import annotations
from collections.abc import Mapping
-from typing import Any, Callable
+from typing import Any
from starlette._exception_handler import (
ExceptionHandlers,
from starlette.exceptions import HTTPException, WebSocketException
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
-from starlette.types import ASGIApp, Receive, Scope, Send
+from starlette.types import ASGIApp, ExceptionHandler, Receive, Scope, Send
from starlette.websockets import WebSocket
def __init__(
self,
app: ASGIApp,
- handlers: Mapping[Any, Callable[[Request, Exception], Response]] | None = None,
+ handlers: Mapping[Any, ExceptionHandler] | None = None,
debug: bool = False,
) -> None:
self.app = app
def add_exception_handler(
self,
exc_class_or_status_code: int | type[Exception],
- handler: Callable[[Request, Exception], Response],
+ handler: ExceptionHandler,
) -> None:
if isinstance(exc_class_or_status_code, int):
self._status_handlers[exc_class_or_status_code] = handler
and returns an ASGI application.
"""
f: Callable[[Request], Awaitable[Response]] = (
- func if is_async_callable(func) else functools.partial(run_in_threadpool, func) # type:ignore
+ func if is_async_callable(func) else functools.partial(run_in_threadpool, func)
)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
# This should succeed because http_exception is async and won't use run_in_threadpool
response = client.get("/not_acceptable")
assert response.status_code == 406
+
+
+def test_handlers_annotations() -> None:
+ """Check that async exception handlers are accepted by type checkers.
+
+ We annotate the handlers' exceptions with plain `Exception` to avoid variance issues
+ when using other exception types.
+ """
+
+ async def async_catch_all_handler(request: Request, exc: Exception) -> JSONResponse:
+ raise NotImplementedError
+
+ def sync_catch_all_handler(request: Request, exc: Exception) -> JSONResponse:
+ raise NotImplementedError
+
+ ExceptionMiddleware(router, handlers={Exception: sync_catch_all_handler})
+ ExceptionMiddleware(router, handlers={Exception: async_catch_all_handler})