]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Allow to raise `HTTPException` before `websocket.accept()` (#2725)
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Tue, 15 Oct 2024 07:50:49 +0000 (09:50 +0200)
committerGitHub <noreply@github.com>
Tue, 15 Oct 2024 07:50:49 +0000 (09:50 +0200)
* Allow to raise `HTTPException` before `websocket.accept()`

* move <<

* Add documentation

docs/exceptions.md
starlette/_exception_handler.py
starlette/testclient.py
tests/test_applications.py
tests/test_exceptions.py

index f97f1af8937bb19c8304734e1b498797e56963fe..2a351e7711ee6ef60f2b469fba47ee2d92a56ed8 100644 (file)
@@ -115,14 +115,30 @@ In order to deal with this behaviour correctly, the middleware stack of a
 
 ## HTTPException
 
-The `HTTPException` class provides a base class that you can use for any
-handled exceptions. The `ExceptionMiddleware` implementation defaults to
-returning plain-text HTTP responses for any `HTTPException`.
+The `HTTPException` class provides a base class that you can use for any handled exceptions.
+The `ExceptionMiddleware` implementation defaults to returning plain-text HTTP responses for any `HTTPException`.
 
 * `HTTPException(status_code, detail=None, headers=None)`
 
-You should only raise `HTTPException` inside routing or endpoints. Middleware
-classes should instead just return appropriate responses directly.
+You should only raise `HTTPException` inside routing or endpoints.
+Middleware classes should instead just return appropriate responses directly.
+
+You can use an `HTTPException` on a WebSocket endpoint in case it's raised before `websocket.accept()`.
+The connection is not upgraded to a WebSocket connection, and the proper HTTP response is returned.
+
+```python
+from starlette.applications import Starlette
+from starlette.exceptions import HTTPException
+from starlette.routing import WebSocketRoute
+from starlette.websockets import WebSocket
+
+
+async def websocket_endpoint(websocket: WebSocket):
+    raise HTTPException(status_code=400, detail="Bad request")
+
+
+app = Starlette(routes=[WebSocketRoute("/ws", websocket_endpoint)])
+```
 
 ## WebSocketException
 
index 4fbc86394d1cea6c1d8d2af9e7edd1afe24f4f3b..baf6e2f50e8e3e518ab1901e4a6a5c117246aba3 100644 (file)
@@ -6,16 +6,7 @@ from starlette._utils import is_async_callable
 from starlette.concurrency import run_in_threadpool
 from starlette.exceptions import HTTPException
 from starlette.requests import Request
-from starlette.types import (
-    ASGIApp,
-    ExceptionHandler,
-    HTTPExceptionHandler,
-    Message,
-    Receive,
-    Scope,
-    Send,
-    WebSocketExceptionHandler,
-)
+from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
 from starlette.websockets import WebSocket
 
 ExceptionHandlers = typing.Dict[typing.Any, ExceptionHandler]
@@ -62,24 +53,13 @@ def wrap_app_handling_exceptions(app: ASGIApp, conn: Request | WebSocket) -> ASG
                 raise exc
 
             if response_started:
-                msg = "Caught handled exception, but response already started."
-                raise RuntimeError(msg) from exc
-
-            if scope["type"] == "http":
-                nonlocal conn
-                handler = typing.cast(HTTPExceptionHandler, handler)
-                conn = typing.cast(Request, conn)
-                if is_async_callable(handler):
-                    response = await handler(conn, exc)
-                else:
-                    response = await run_in_threadpool(handler, conn, exc)
+                raise RuntimeError("Caught handled exception, but response already started.") from exc
+
+            if is_async_callable(handler):
+                response = await handler(conn, exc)
+            else:
+                response = await run_in_threadpool(handler, conn, exc)  # type: ignore
+            if response is not None:
                 await response(scope, receive, sender)
-            elif scope["type"] == "websocket":
-                handler = typing.cast(WebSocketExceptionHandler, handler)
-                conn = typing.cast(WebSocket, conn)
-                if is_async_callable(handler):
-                    await handler(conn, exc)
-                else:
-                    await run_in_threadpool(handler, conn, exc)
 
     return wrapped_app
index 1a2d101a01af937c01af8db1b4e8e6bf2e2cda58..5143c4c57f18fb7365a798362c4dff1ca9f83e6e 100644 (file)
@@ -178,11 +178,7 @@ class WebSocketTestSession:
                 body.append(message["body"])
                 if not message.get("more_body", False):
                     break
-            raise WebSocketDenialResponse(
-                status_code=status_code,
-                headers=headers,
-                content=b"".join(body),
-            )
+            raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body))
 
     def send(self, message: Message) -> None:
         self._receive_queue.put(message)
index 86c713c38a64be87a8841b029cf9d0d7c83b66b2..0560444383ab81fb33d20554b00cfc965ae0800a 100644 (file)
@@ -3,7 +3,7 @@ from contextlib import asynccontextmanager
 from pathlib import Path
 from typing import AsyncGenerator, AsyncIterator, Generator
 
-import anyio
+import anyio.from_thread
 import pytest
 
 from starlette import status
@@ -17,7 +17,7 @@ from starlette.requests import Request
 from starlette.responses import JSONResponse, PlainTextResponse
 from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
 from starlette.staticfiles import StaticFiles
-from starlette.testclient import TestClient
+from starlette.testclient import TestClient, WebSocketDenialResponse
 from starlette.types import ASGIApp, Receive, Scope, Send
 from starlette.websockets import WebSocket
 from tests.types import TestClientFactory
@@ -71,11 +71,15 @@ async def websocket_endpoint(session: WebSocket) -> None:
     await session.close()
 
 
-async def websocket_raise_websocket(websocket: WebSocket) -> None:
+async def websocket_raise_websocket_exception(websocket: WebSocket) -> None:
     await websocket.accept()
     raise WebSocketException(code=status.WS_1003_UNSUPPORTED_DATA)
 
 
+async def websocket_raise_http_exception(websocket: WebSocket) -> None:
+    raise HTTPException(status_code=401, detail="Unauthorized")
+
+
 class CustomWSException(Exception):
     pass
 
@@ -118,7 +122,8 @@ app = Starlette(
         Route("/class", endpoint=Homepage),
         Route("/500", endpoint=runtime_error),
         WebSocketRoute("/ws", endpoint=websocket_endpoint),
-        WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket),
+        WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket_exception),
+        WebSocketRoute("/ws-raise-http", endpoint=websocket_raise_http_exception),
         WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom),
         Mount("/users", app=users),
         Host("{subdomain}.example.org", app=subdomain),
@@ -219,6 +224,14 @@ def test_websocket_raise_websocket_exception(client: TestClient) -> None:
         }
 
 
+def test_websocket_raise_http_exception(client: TestClient) -> None:
+    with pytest.raises(WebSocketDenialResponse) as exc:
+        with client.websocket_connect("/ws-raise-http"):
+            pass  # pragma: no cover
+    assert exc.value.status_code == 401
+    assert exc.value.content == b'{"detail":"Unauthorized"}'
+
+
 def test_websocket_raise_custom_exception(client: TestClient) -> None:
     with client.websocket_connect("/ws-raise-custom") as session:
         response = session.receive()
@@ -243,7 +256,8 @@ def test_routes() -> None:
         Route("/class", endpoint=Homepage),
         Route("/500", endpoint=runtime_error, methods=["GET"]),
         WebSocketRoute("/ws", endpoint=websocket_endpoint),
-        WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket),
+        WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket_exception),
+        WebSocketRoute("/ws-raise-http", endpoint=websocket_raise_http_exception),
         WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom),
         Mount(
             "/users",
index eebee3759f92f5530fbc8589007da88f5c668cbe..ffc8883b3c721a1f7f27819fa7d758a63ce92cba 100644 (file)
@@ -63,11 +63,7 @@ router = Router(
         Route("/with_headers", endpoint=with_headers),
         Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse()),
         WebSocketRoute("/runtime_error", endpoint=raise_runtime_error),
-        Route(
-            "/consume_body_in_endpoint_and_handler",
-            endpoint=read_body_and_raise_exc,
-            methods=["POST"],
-        ),
+        Route("/consume_body_in_endpoint_and_handler", endpoint=read_body_and_raise_exc, methods=["POST"]),
     ]
 )
 
@@ -114,13 +110,10 @@ def test_websockets_should_raise(client: TestClient) -> None:
             pass  # pragma: no cover
 
 
-def test_handled_exc_after_response(
-    test_client_factory: TestClientFactory,
-    client: TestClient,
-) -> None:
+def test_handled_exc_after_response(test_client_factory: TestClientFactory, client: TestClient) -> None:
     # A 406 HttpException is raised *after* the response has already been sent.
     # The exception middleware should raise a RuntimeError.
-    with pytest.raises(RuntimeError):
+    with pytest.raises(RuntimeError, match="Caught handled exception, but response already started."):
         client.get("/handled_exc_after_response")
 
     # If `raise_server_exceptions=False` then the test client will still allow
@@ -132,7 +125,7 @@ def test_handled_exc_after_response(
 
 
 def test_force_500_response(test_client_factory: TestClientFactory) -> None:
-    # use a sentinal variable to make sure we actually
+    # use a sentinel variable to make sure we actually
     # make it into the endpoint and don't get a 500
     # from an incorrect ASGI app signature or something
     called = False