]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Allow async exception handlers to type-check (#2949)
authorJonathan Berthias <jvberthias@gmail.com>
Fri, 13 Jun 2025 12:39:24 +0000 (14:39 +0200)
committerGitHub <noreply@github.com>
Fri, 13 Jun 2025 12:39:24 +0000 (14:39 +0200)
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
pyproject.toml
starlette/_exception_handler.py
starlette/_utils.py
starlette/applications.py
starlette/middleware/errors.py
starlette/middleware/exceptions.py
starlette/routing.py
tests/test_exceptions.py

index e86f49d123967a9bff7d92e4026b9484f78e087c..56c9900f452595dfe3df6866292a532ccc74526d 100644 (file)
@@ -27,7 +27,7 @@ classifiers = [
 ]
 dependencies = [
     "anyio>=3.6.2,<5",
-    "typing_extensions>=3.10.0; python_version < '3.10'",
+    "typing_extensions>=4.10.0; python_version < '3.13'",
 ]
 
 [project.optional-dependencies]
index bcb96c9f20f4e6d0a993c3fd93ae4b9f58f02c8e..40761b68c98cb40a3d24ff1e6067f3a653800b94 100644 (file)
@@ -58,7 +58,7 @@ def wrap_app_handling_exceptions(app: ASGIApp, conn: Request | WebSocket) -> ASG
             if is_async_callable(handler):
                 response = await handler(conn, exc)
             else:
-                response = await run_in_threadpool(handler, conn, exc)  # type: ignore
+                response = await run_in_threadpool(handler, conn, exc)
             if response is not None:
                 await response(scope, receive, sender)
 
index a35ca82d9cfcf891fd67b39f11b3c0e2f00fc8c9..5ac985de1dd1c49f0cdaf931de15da102162ebea 100644 (file)
@@ -9,10 +9,10 @@ from typing import Any, Callable, Generic, Protocol, TypeVar, overload
 
 from starlette.types import Scope
 
-if sys.version_info >= (3, 10):  # pragma: no cover
-    from typing import TypeGuard
+if sys.version_info >= (3, 13):  # pragma: no cover
+    from typing import TypeIs
 else:  # pragma: no cover
-    from typing_extensions import TypeGuard
+    from typing_extensions import TypeIs
 
 has_exceptiongroups = True
 if sys.version_info < (3, 11):  # pragma: no cover
@@ -26,11 +26,11 @@ AwaitableCallable = Callable[..., Awaitable[T]]
 
 
 @overload
-def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]: ...
+def is_async_callable(obj: AwaitableCallable[T]) -> TypeIs[AwaitableCallable[T]]: ...
 
 
 @overload
-def is_async_callable(obj: Any) -> TypeGuard[AwaitableCallable[Any]]: ...
+def is_async_callable(obj: Any) -> TypeIs[AwaitableCallable[Any]]: ...
 
 
 def is_async_callable(obj: Any) -> Any:
index 32f6e560a425e4acb9db351ef3f87a4f761190e5..62c0eb165eef282b6603f6ef17839214a8c531d8 100644 (file)
@@ -80,7 +80,7 @@ class Starlette:
     def build_middleware_stack(self) -> ASGIApp:
         debug = self.debug
         error_handler = None
-        exception_handlers: dict[Any, Callable[[Request, Exception], Response]] = {}
+        exception_handlers: dict[Any, ExceptionHandler] = {}
 
         for key, value in self.exception_handlers.items():
             if key in (500, Exception):
index 60b96a5c736504f52e85d2985933e850bab0e471..bbc10728af8f809c12e39d5016585922f66665cc 100644 (file)
@@ -4,13 +4,12 @@ import html
 import inspect
 import sys
 import traceback
-from typing import Any, Callable
 
 from starlette._utils import is_async_callable
 from starlette.concurrency import run_in_threadpool
 from starlette.requests import Request
 from starlette.responses import HTMLResponse, PlainTextResponse, Response
-from starlette.types import ASGIApp, Message, Receive, Scope, Send
+from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
 
 STYLES = """
 p {
@@ -140,7 +139,7 @@ class ServerErrorMiddleware:
     def __init__(
         self,
         app: ASGIApp,
-        handler: Callable[[Request, Exception], Any] | None = None,
+        handler: ExceptionHandler | None = None,
         debug: bool = False,
     ) -> None:
         self.app = app
index 864c2238d9a7532b670b76ab2ce3229cafdcaeb3..6c98558f4db50007dd3ddee1892454a5a73f6abd 100644 (file)
@@ -1,7 +1,7 @@
 from __future__ import annotations
 
 from collections.abc import Mapping
-from typing import Any, Callable
+from typing import Any
 
 from starlette._exception_handler import (
     ExceptionHandlers,
@@ -11,7 +11,7 @@ from starlette._exception_handler import (
 from starlette.exceptions import HTTPException, WebSocketException
 from starlette.requests import Request
 from starlette.responses import PlainTextResponse, Response
-from starlette.types import ASGIApp, Receive, Scope, Send
+from starlette.types import ASGIApp, ExceptionHandler, Receive, Scope, Send
 from starlette.websockets import WebSocket
 
 
@@ -19,7 +19,7 @@ class ExceptionMiddleware:
     def __init__(
         self,
         app: ASGIApp,
-        handlers: Mapping[Any, Callable[[Request, Exception], Response]] | None = None,
+        handlers: Mapping[Any, ExceptionHandler] | None = None,
         debug: bool = False,
     ) -> None:
         self.app = app
@@ -36,7 +36,7 @@ class ExceptionMiddleware:
     def add_exception_handler(
         self,
         exc_class_or_status_code: int | type[Exception],
-        handler: Callable[[Request, Exception], Response],
+        handler: ExceptionHandler,
     ) -> None:
         if isinstance(exc_class_or_status_code, int):
             self._status_handlers[exc_class_or_status_code] = handler
index db7921e1f154214852d272da765055bdac997d45..96c2df929fcfc6431a716f43de76b4d029737ccb 100644 (file)
@@ -65,7 +65,7 @@ def request_response(
     and returns an ASGI application.
     """
     f: Callable[[Request], Awaitable[Response]] = (
-        func if is_async_callable(func) else functools.partial(run_in_threadpool, func)  # type:ignore
+        func if is_async_callable(func) else functools.partial(run_in_threadpool, func)
     )
 
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
index 3b48f81444e87caf8231e1f342b909a4f1d6f746..d554ac045741f004d5a1bcc1376dcc015bebbb37 100644 (file)
@@ -205,3 +205,20 @@ def test_http_exception_does_not_use_threadpool(client: TestClient, monkeypatch:
     # This should succeed because http_exception is async and won't use run_in_threadpool
     response = client.get("/not_acceptable")
     assert response.status_code == 406
+
+
+def test_handlers_annotations() -> None:
+    """Check that async exception handlers are accepted by type checkers.
+
+    We annotate the handlers' exceptions with plain `Exception` to avoid variance issues
+    when using other exception types.
+    """
+
+    async def async_catch_all_handler(request: Request, exc: Exception) -> JSONResponse:
+        raise NotImplementedError
+
+    def sync_catch_all_handler(request: Request, exc: Exception) -> JSONResponse:
+        raise NotImplementedError
+
+    ExceptionMiddleware(router, handlers={Exception: sync_catch_all_handler})
+    ExceptionMiddleware(router, handlers={Exception: async_catch_all_handler})