]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Make error handler run always (#761)
authorAviram Hassan <aviramyhassan@gmail.com>
Tue, 1 Feb 2022 14:08:24 +0000 (16:08 +0200)
committerGitHub <noreply@github.com>
Tue, 1 Feb 2022 14:08:24 +0000 (15:08 +0100)
* Error handler call always

* Add tests

* Add docs

* Only run response callable if response didn't start

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
docs/exceptions.md
starlette/middleware/errors.py
tests/middleware/test_errors.py

index bf460d2296110ee2008f36fd13e2deb3ca78b055..9818a20455daa2e39d452d74dc382cc3b6a135bc 100644 (file)
@@ -75,6 +75,24 @@ should bubble through the entire middleware stack as exceptions. Any error
 logging middleware should ensure that it re-raises the exception all the
 way up to the server.
 
+In practical terms, the error handled used is `exception_handler[500]` or `exception_handler[Exception]`.
+Both keys `500` and `Exception` can be used. See below:
+
+```python
+async def handle_error(request: Request, exc: HTTPException):
+    # Perform some logic
+    return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
+
+exception_handlers = {
+    Exception: handle_error  # or "500: handle_error"
+}
+```
+
+It's important to notice that in case a [`BackgroundTask`](https://www.starlette.io/background/) raises an exception,
+it will be handled by the `handle_error` function, but at that point, the response was already sent. In other words,
+the response created by `handle_error` will be discarded. In case the error happens before the response was sent, then
+it will use the response object - in the above example, the returned `JSONResponse`.
+
 In order to deal with this behaviour correctly, the middleware stack of a
 `Starlette` application is configured like this:
 
index 30f5570ca88f2900f04dd1d4845eef70ef4489c7..474c9afc0b62609b70b4fff378c7330b2dc09881 100644 (file)
@@ -158,21 +158,21 @@ class ServerErrorMiddleware:
         try:
             await self.app(scope, receive, _send)
         except Exception as exc:
-            if not response_started:
-                request = Request(scope)
-                if self.debug:
-                    # In debug mode, return traceback responses.
-                    response = self.debug_response(request, exc)
-                elif self.handler is None:
-                    # Use our default 500 error handler.
-                    response = self.error_response(request, exc)
+            request = Request(scope)
+            if self.debug:
+                # In debug mode, return traceback responses.
+                response = self.debug_response(request, exc)
+            elif self.handler is None:
+                # Use our default 500 error handler.
+                response = self.error_response(request, exc)
+            else:
+                # Use an installed 500 error handler.
+                if asyncio.iscoroutinefunction(self.handler):
+                    response = await self.handler(request, exc)
                 else:
-                    # Use an installed 500 error handler.
-                    if asyncio.iscoroutinefunction(self.handler):
-                        response = await self.handler(request, exc)
-                    else:
-                        response = await run_in_threadpool(self.handler, request, exc)
+                    response = await run_in_threadpool(self.handler, request, exc)
 
+            if not response_started:
                 await response(scope, receive, send)
 
             # We always continue to raise the exception.
index 2c926a9b2d6605b1aa0b465a1f0f03bd89ef3574..392c2ba16a6d8074314aa4267c6f4071047b1366 100644 (file)
@@ -1,7 +1,10 @@
 import pytest
 
+from starlette.applications import Starlette
+from starlette.background import BackgroundTask
 from starlette.middleware.errors import ServerErrorMiddleware
 from starlette.responses import JSONResponse, Response
+from starlette.routing import Route
 
 
 def test_handler(test_client_factory):
@@ -68,3 +71,28 @@ def test_debug_not_http(test_client_factory):
         client = test_client_factory(app)
         with client.websocket_connect("/"):
             pass  # pragma: nocover
+
+
+def test_background_task(test_client_factory):
+    accessed_error_handler = False
+
+    def error_handler(request, exc):
+        nonlocal accessed_error_handler
+        accessed_error_handler = True
+
+    def raise_exception():
+        raise Exception("Something went wrong")
+
+    async def endpoint(request):
+        task = BackgroundTask(raise_exception)
+        return Response(status_code=204, background=task)
+
+    app = Starlette(
+        routes=[Route("/", endpoint=endpoint)],
+        exception_handlers={Exception: error_handler},
+    )
+
+    client = test_client_factory(app, raise_server_exceptions=False)
+    response = client.get("/")
+    assert response.status_code == 204
+    assert accessed_error_handler