]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add `headers` parameter to `HTTPException` (#1435)
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Wed, 26 Jan 2022 15:06:15 +0000 (16:06 +0100)
committerGitHub <noreply@github.com>
Wed, 26 Jan 2022 15:06:15 +0000 (16:06 +0100)
* Add `headers` parameter to `HTTPException`

* Update exceptions docs

docs/exceptions.md
starlette/exceptions.py
tests/test_exceptions.py

index d083a2da792d734b59806b9adb055849ef0e118a..bf460d2296110ee2008f36fd13e2deb3ca78b055 100644 (file)
@@ -4,6 +4,8 @@ how you return responses when errors or handled exceptions occur.
 
 ```python
 from starlette.applications import Starlette
+from starlette.exceptions import HTTPException
+from starlette.requests import Request
 from starlette.responses import HTMLResponse
 
 
@@ -11,10 +13,10 @@ HTML_404_PAGE = ...
 HTML_500_PAGE = ...
 
 
-async def not_found(request, exc):
+async def not_found(request: Request, exc: HTTPException):
     return HTMLResponse(content=HTML_404_PAGE, status_code=exc.status_code)
 
-async def server_error(request, exc):
+async def server_error(request: Request, exc: HTTPException):
     return HTMLResponse(content=HTML_500_PAGE, status_code=exc.status_code)
 
 
@@ -40,7 +42,7 @@ In particular you might want to override how the built-in `HTTPException` class
 is handled. For example, to use JSON style responses:
 
 ```python
-async def http_exception(request, exc):
+async def http_exception(request: Request, exc: HTTPException):
     return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
 
 exception_handlers = {
@@ -48,6 +50,18 @@ exception_handlers = {
 }
 ```
 
+The `HTTPException` is also equipped with the `headers` argument. Which allows the propagation
+of the headers to the response class:
+
+```python
+async def http_exception(request: Request, exc: HTTPException):
+    return JSONResponse(
+        {"detail": exc.detail},
+        status_code=exc.status_code,
+        headers=exc.headers
+    )
+```
+
 ## Errors and handled exceptions
 
 It is important to differentiate between handled exceptions and errors.
@@ -76,7 +90,7 @@ The `HTTPException` class provides a base class that you can use for any
 handled exceptions. The `ExceptionMiddleware` implementation defaults to
 returning plain-text HTTP responses for any `HTTPException`.
 
-* `HTTPException(status_code, detail=None)`
+* `HTTPException(status_code, detail=None, headers=None)`
 
 You should only raise `HTTPException` inside routing or endpoints. Middleware
 classes should instead just return appropriate responses directly.
index fcb0776cc76fc25cde5d1331c3dd02b61ab27a41..8f28b6e2df88c87e4cd5beed1b11812dddf022c9 100644 (file)
@@ -9,11 +9,14 @@ from starlette.types import ASGIApp, Message, Receive, Scope, Send
 
 
 class HTTPException(Exception):
-    def __init__(self, status_code: int, detail: str = None) -> None:
+    def __init__(
+        self, status_code: int, detail: str = None, headers: dict = None
+    ) -> None:
         if detail is None:
             detail = http.HTTPStatus(status_code).phrase
         self.status_code = status_code
         self.detail = detail
+        self.headers = headers
 
     def __repr__(self) -> str:
         class_name = self.__class__.__name__
@@ -99,5 +102,7 @@ class ExceptionMiddleware:
 
     def http_exception(self, request: Request, exc: HTTPException) -> Response:
         if exc.status_code in {204, 304}:
-            return Response(status_code=exc.status_code)
-        return PlainTextResponse(exc.detail, status_code=exc.status_code)
+            return Response(status_code=exc.status_code, headers=exc.headers)
+        return PlainTextResponse(
+            exc.detail, status_code=exc.status_code, headers=exc.headers
+        )
index cef03359f1fdf8d7e2bf081032d539f8a7a9ee27..80307a521a2ab37a7238565fc79293606118ee97 100644 (file)
@@ -21,6 +21,10 @@ def not_modified(request):
     raise HTTPException(status_code=304)
 
 
+def with_headers(request):
+    raise HTTPException(status_code=200, headers={"x-potato": "always"})
+
+
 class HandledExcAfterResponse:
     async def __call__(self, scope, receive, send):
         response = PlainTextResponse("OK", status_code=200)
@@ -34,6 +38,7 @@ router = Router(
         Route("/not_acceptable", endpoint=not_acceptable),
         Route("/no_content", endpoint=no_content),
         Route("/not_modified", endpoint=not_modified),
+        Route("/with_headers", endpoint=with_headers),
         Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse()),
         WebSocketRoute("/runtime_error", endpoint=raise_runtime_error),
     ]
@@ -67,6 +72,12 @@ def test_not_modified(client):
     assert response.text == ""
 
 
+def test_with_headers(client):
+    response = client.get("/with_headers")
+    assert response.status_code == 200
+    assert response.headers["x-potato"] == "always"
+
+
 def test_websockets_should_raise(client):
     with pytest.raises(RuntimeError):
         with client.websocket_connect("/runtime_error"):