From: Marcelo Trylesinski Date: Tue, 15 Oct 2024 07:50:49 +0000 (+0200) Subject: Allow to raise `HTTPException` before `websocket.accept()` (#2725) X-Git-Tag: 0.41.0~1 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=99b6938363c2e5b05a53ef5558c511fed61fbfb4;p=thirdparty%2Fstarlette.git Allow to raise `HTTPException` before `websocket.accept()` (#2725) * Allow to raise `HTTPException` before `websocket.accept()` * move << * Add documentation --- diff --git a/docs/exceptions.md b/docs/exceptions.md index f97f1af8..2a351e77 100644 --- a/docs/exceptions.md +++ b/docs/exceptions.md @@ -115,14 +115,30 @@ In order to deal with this behaviour correctly, the middleware stack of a ## HTTPException -The `HTTPException` class provides a base class that you can use for any -handled exceptions. The `ExceptionMiddleware` implementation defaults to -returning plain-text HTTP responses for any `HTTPException`. +The `HTTPException` class provides a base class that you can use for any handled exceptions. +The `ExceptionMiddleware` implementation defaults to returning plain-text HTTP responses for any `HTTPException`. * `HTTPException(status_code, detail=None, headers=None)` -You should only raise `HTTPException` inside routing or endpoints. Middleware -classes should instead just return appropriate responses directly. +You should only raise `HTTPException` inside routing or endpoints. +Middleware classes should instead just return appropriate responses directly. + +You can use an `HTTPException` on a WebSocket endpoint in case it's raised before `websocket.accept()`. +The connection is not upgraded to a WebSocket connection, and the proper HTTP response is returned. + +```python +from starlette.applications import Starlette +from starlette.exceptions import HTTPException +from starlette.routing import WebSocketRoute +from starlette.websockets import WebSocket + + +async def websocket_endpoint(websocket: WebSocket): + raise HTTPException(status_code=400, detail="Bad request") + + +app = Starlette(routes=[WebSocketRoute("/ws", websocket_endpoint)]) +``` ## WebSocketException diff --git a/starlette/_exception_handler.py b/starlette/_exception_handler.py index 4fbc8639..baf6e2f5 100644 --- a/starlette/_exception_handler.py +++ b/starlette/_exception_handler.py @@ -6,16 +6,7 @@ from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException from starlette.requests import Request -from starlette.types import ( - ASGIApp, - ExceptionHandler, - HTTPExceptionHandler, - Message, - Receive, - Scope, - Send, - WebSocketExceptionHandler, -) +from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send from starlette.websockets import WebSocket ExceptionHandlers = typing.Dict[typing.Any, ExceptionHandler] @@ -62,24 +53,13 @@ def wrap_app_handling_exceptions(app: ASGIApp, conn: Request | WebSocket) -> ASG raise exc if response_started: - msg = "Caught handled exception, but response already started." - raise RuntimeError(msg) from exc - - if scope["type"] == "http": - nonlocal conn - handler = typing.cast(HTTPExceptionHandler, handler) - conn = typing.cast(Request, conn) - if is_async_callable(handler): - response = await handler(conn, exc) - else: - response = await run_in_threadpool(handler, conn, exc) + raise RuntimeError("Caught handled exception, but response already started.") from exc + + if is_async_callable(handler): + response = await handler(conn, exc) + else: + response = await run_in_threadpool(handler, conn, exc) # type: ignore + if response is not None: await response(scope, receive, sender) - elif scope["type"] == "websocket": - handler = typing.cast(WebSocketExceptionHandler, handler) - conn = typing.cast(WebSocket, conn) - if is_async_callable(handler): - await handler(conn, exc) - else: - await run_in_threadpool(handler, conn, exc) return wrapped_app diff --git a/starlette/testclient.py b/starlette/testclient.py index 1a2d101a..5143c4c5 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -178,11 +178,7 @@ class WebSocketTestSession: body.append(message["body"]) if not message.get("more_body", False): break - raise WebSocketDenialResponse( - status_code=status_code, - headers=headers, - content=b"".join(body), - ) + raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body)) def send(self, message: Message) -> None: self._receive_queue.put(message) diff --git a/tests/test_applications.py b/tests/test_applications.py index 86c713c3..05604443 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -3,7 +3,7 @@ from contextlib import asynccontextmanager from pathlib import Path from typing import AsyncGenerator, AsyncIterator, Generator -import anyio +import anyio.from_thread import pytest from starlette import status @@ -17,7 +17,7 @@ from starlette.requests import Request from starlette.responses import JSONResponse, PlainTextResponse from starlette.routing import Host, Mount, Route, Router, WebSocketRoute from starlette.staticfiles import StaticFiles -from starlette.testclient import TestClient +from starlette.testclient import TestClient, WebSocketDenialResponse from starlette.types import ASGIApp, Receive, Scope, Send from starlette.websockets import WebSocket from tests.types import TestClientFactory @@ -71,11 +71,15 @@ async def websocket_endpoint(session: WebSocket) -> None: await session.close() -async def websocket_raise_websocket(websocket: WebSocket) -> None: +async def websocket_raise_websocket_exception(websocket: WebSocket) -> None: await websocket.accept() raise WebSocketException(code=status.WS_1003_UNSUPPORTED_DATA) +async def websocket_raise_http_exception(websocket: WebSocket) -> None: + raise HTTPException(status_code=401, detail="Unauthorized") + + class CustomWSException(Exception): pass @@ -118,7 +122,8 @@ app = Starlette( Route("/class", endpoint=Homepage), Route("/500", endpoint=runtime_error), WebSocketRoute("/ws", endpoint=websocket_endpoint), - WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket), + WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket_exception), + WebSocketRoute("/ws-raise-http", endpoint=websocket_raise_http_exception), WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom), Mount("/users", app=users), Host("{subdomain}.example.org", app=subdomain), @@ -219,6 +224,14 @@ def test_websocket_raise_websocket_exception(client: TestClient) -> None: } +def test_websocket_raise_http_exception(client: TestClient) -> None: + with pytest.raises(WebSocketDenialResponse) as exc: + with client.websocket_connect("/ws-raise-http"): + pass # pragma: no cover + assert exc.value.status_code == 401 + assert exc.value.content == b'{"detail":"Unauthorized"}' + + def test_websocket_raise_custom_exception(client: TestClient) -> None: with client.websocket_connect("/ws-raise-custom") as session: response = session.receive() @@ -243,7 +256,8 @@ def test_routes() -> None: Route("/class", endpoint=Homepage), Route("/500", endpoint=runtime_error, methods=["GET"]), WebSocketRoute("/ws", endpoint=websocket_endpoint), - WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket), + WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket_exception), + WebSocketRoute("/ws-raise-http", endpoint=websocket_raise_http_exception), WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom), Mount( "/users", diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index eebee375..ffc8883b 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -63,11 +63,7 @@ router = Router( Route("/with_headers", endpoint=with_headers), Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse()), WebSocketRoute("/runtime_error", endpoint=raise_runtime_error), - Route( - "/consume_body_in_endpoint_and_handler", - endpoint=read_body_and_raise_exc, - methods=["POST"], - ), + Route("/consume_body_in_endpoint_and_handler", endpoint=read_body_and_raise_exc, methods=["POST"]), ] ) @@ -114,13 +110,10 @@ def test_websockets_should_raise(client: TestClient) -> None: pass # pragma: no cover -def test_handled_exc_after_response( - test_client_factory: TestClientFactory, - client: TestClient, -) -> None: +def test_handled_exc_after_response(test_client_factory: TestClientFactory, client: TestClient) -> None: # A 406 HttpException is raised *after* the response has already been sent. # The exception middleware should raise a RuntimeError. - with pytest.raises(RuntimeError): + with pytest.raises(RuntimeError, match="Caught handled exception, but response already started."): client.get("/handled_exc_after_response") # If `raise_server_exceptions=False` then the test client will still allow @@ -132,7 +125,7 @@ def test_handled_exc_after_response( def test_force_500_response(test_client_factory: TestClientFactory) -> None: - # use a sentinal variable to make sure we actually + # use a sentinel variable to make sure we actually # make it into the endpoint and don't get a 500 # from an incorrect ASGI app signature or something called = False