From: Marcelo Trylesinski Date: Sun, 13 Oct 2024 15:19:17 +0000 (+0200) Subject: test X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=refs%2Fheads%2Fwebsocket-exception-on-http;p=thirdparty%2Fstarlette.git test --- diff --git a/main.py b/main.py new file mode 100644 index 00000000..d5899a89 --- /dev/null +++ b/main.py @@ -0,0 +1,13 @@ +from starlette.applications import Starlette +from starlette.exceptions import HTTPException, WebSocketException +from starlette.routing import WebSocketRoute +from starlette.websockets import WebSocket + + +async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + raise WebSocketException(code=1001) + raise HTTPException(400) + + +app = Starlette(routes=[WebSocketRoute("/ws", websocket_endpoint)]) diff --git a/starlette/_exception_handler.py b/starlette/_exception_handler.py index 4fbc8639..72438408 100644 --- a/starlette/_exception_handler.py +++ b/starlette/_exception_handler.py @@ -1,21 +1,13 @@ from __future__ import annotations import typing +import warnings from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool -from starlette.exceptions import HTTPException +from starlette.exceptions import HTTPException, WebSocketException from starlette.requests import Request -from starlette.types import ( - ASGIApp, - ExceptionHandler, - HTTPExceptionHandler, - Message, - Receive, - Scope, - Send, - WebSocketExceptionHandler, -) +from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send from starlette.websockets import WebSocket ExceptionHandlers = typing.Dict[typing.Any, ExceptionHandler] @@ -30,28 +22,33 @@ def _lookup_exception_handler(exc_handlers: ExceptionHandlers, exc: Exception) - def wrap_app_handling_exceptions(app: ASGIApp, conn: Request | WebSocket) -> ASGIApp: - exception_handlers: ExceptionHandlers - status_handlers: StatusHandlers - try: - exception_handlers, status_handlers = conn.scope["starlette.exception_handlers"] - except KeyError: - exception_handlers, status_handlers = {}, {} - async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None: response_started = False + websocket_accepted = False + print("websocket_accepted", websocket_accepted) async def sender(message: Message) -> None: nonlocal response_started + nonlocal websocket_accepted - if message["type"] == "http.response.start": + if message["type"] in ("http.response.start", "websocket.http.response.start"): response_started = True + elif message["type"] == "websocket.accept": + websocket_accepted = True + print(f"message: {message}") await send(message) try: await app(scope, receive, sender) except Exception as exc: - handler = None + exception_handlers: ExceptionHandlers + status_handlers: StatusHandlers + try: + exception_handlers, status_handlers = scope["starlette.exception_handlers"] + except KeyError: + exception_handlers, status_handlers = {}, {} + handler = None if isinstance(exc, HTTPException): handler = status_handlers.get(exc.status_code) @@ -62,24 +59,17 @@ def wrap_app_handling_exceptions(app: ASGIApp, conn: Request | WebSocket) -> ASG raise exc if response_started: - msg = "Caught handled exception, but response already started." - raise RuntimeError(msg) from exc - - if scope["type"] == "http": - nonlocal conn - handler = typing.cast(HTTPExceptionHandler, handler) - conn = typing.cast(Request, conn) - if is_async_callable(handler): - response = await handler(conn, exc) - else: - response = await run_in_threadpool(handler, conn, exc) + raise RuntimeError("Caught handled exception, but response already started.") from exc + + print("run before the conditional", websocket_accepted) + if not websocket_accepted and isinstance(exc, WebSocketException): + warnings.warn("WebSocketException used before the websocket connection was accepted.", UserWarning) + + if is_async_callable(handler): + response = await handler(conn, exc) + else: + response = await run_in_threadpool(handler, conn, exc) # type: ignore + if response is not None: await response(scope, receive, sender) - elif scope["type"] == "websocket": - handler = typing.cast(WebSocketExceptionHandler, handler) - conn = typing.cast(WebSocket, conn) - if is_async_callable(handler): - await handler(conn, exc) - else: - await run_in_threadpool(handler, conn, exc) return wrapped_app diff --git a/starlette/middleware/exceptions.py b/starlette/middleware/exceptions.py index d708929e..139884a3 100644 --- a/starlette/middleware/exceptions.py +++ b/starlette/middleware/exceptions.py @@ -2,11 +2,7 @@ from __future__ import annotations import typing -from starlette._exception_handler import ( - ExceptionHandlers, - StatusHandlers, - wrap_app_handling_exceptions, -) +from starlette._exception_handler import ExceptionHandlers, StatusHandlers, wrap_app_handling_exceptions from starlette.exceptions import HTTPException, WebSocketException from starlette.requests import Request from starlette.responses import PlainTextResponse, Response @@ -45,13 +41,9 @@ class ExceptionMiddleware: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] not in ("http", "websocket"): - await self.app(scope, receive, send) - return + return await self.app(scope, receive, send) - scope["starlette.exception_handlers"] = ( - self._exception_handlers, - self._status_handlers, - ) + scope["starlette.exception_handlers"] = (self._exception_handlers, self._status_handlers) conn: Request | WebSocket if scope["type"] == "http": diff --git a/tests/test_applications.py b/tests/test_applications.py index 86c713c3..afdcecd3 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -3,7 +3,7 @@ from contextlib import asynccontextmanager from pathlib import Path from typing import AsyncGenerator, AsyncIterator, Generator -import anyio +import anyio.from_thread import pytest from starlette import status @@ -212,21 +212,13 @@ def test_500(test_client_factory: TestClientFactory) -> None: def test_websocket_raise_websocket_exception(client: TestClient) -> None: with client.websocket_connect("/ws-raise-websocket") as session: response = session.receive() - assert response == { - "type": "websocket.close", - "code": status.WS_1003_UNSUPPORTED_DATA, - "reason": "", - } + assert response == {"type": "websocket.close", "code": status.WS_1003_UNSUPPORTED_DATA, "reason": ""} def test_websocket_raise_custom_exception(client: TestClient) -> None: with client.websocket_connect("/ws-raise-custom") as session: response = session.receive() - assert response == { - "type": "websocket.close", - "code": status.WS_1013_TRY_AGAIN_LATER, - "reason": "", - } + assert response == {"type": "websocket.close", "code": status.WS_1013_TRY_AGAIN_LATER, "reason": ""} def test_middleware(test_client_factory: TestClientFactory) -> None: @@ -254,10 +246,7 @@ def test_routes() -> None: ] ), ), - Host( - "{subdomain}.example.org", - app=Router(routes=[Route("/", endpoint=custom_subdomain)]), - ), + Host("{subdomain}.example.org", app=Router(routes=[Route("/", endpoint=custom_subdomain)])), ] @@ -266,11 +255,7 @@ def test_app_mount(tmpdir: Path, test_client_factory: TestClientFactory) -> None with open(path, "w") as file: file.write("") - app = Starlette( - routes=[ - Mount("/static", StaticFiles(directory=tmpdir)), - ] - ) + app = Starlette(routes=[Mount("/static", StaticFiles(directory=tmpdir))]) client = test_client_factory(app) @@ -287,11 +272,7 @@ def test_app_debug(test_client_factory: TestClientFactory) -> None: async def homepage(request: Request) -> None: raise RuntimeError() - app = Starlette( - routes=[ - Route("/", homepage), - ], - ) + app = Starlette(routes=[Route("/", homepage)]) app.debug = True client = test_client_factory(app, raise_server_exceptions=False) @@ -305,11 +286,7 @@ def test_app_add_route(test_client_factory: TestClientFactory) -> None: async def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Hello, World!") - app = Starlette( - routes=[ - Route("/", endpoint=homepage), - ] - ) + app = Starlette(routes=[Route("/", endpoint=homepage)]) client = test_client_factory(app) response = client.get("/") @@ -323,11 +300,7 @@ def test_app_add_websocket_route(test_client_factory: TestClientFactory) -> None await session.send_text("Hello, world!") await session.close() - app = Starlette( - routes=[ - WebSocketRoute("/ws", endpoint=websocket_endpoint), - ] - ) + app = Starlette(routes=[WebSocketRoute("/ws", endpoint=websocket_endpoint)]) client = test_client_factory(app) with client.websocket_connect("/ws") as session: @@ -348,10 +321,7 @@ def test_app_add_event_handler(test_client_factory: TestClientFactory) -> None: cleanup_complete = True with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"): - app = Starlette( - on_startup=[run_startup], - on_shutdown=[run_cleanup], - ) + app = Starlette(on_startup=[run_startup], on_shutdown=[run_cleanup]) assert not startup_complete assert not cleanup_complete diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index eebee375..92470988 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -63,11 +63,7 @@ router = Router( Route("/with_headers", endpoint=with_headers), Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse()), WebSocketRoute("/runtime_error", endpoint=raise_runtime_error), - Route( - "/consume_body_in_endpoint_and_handler", - endpoint=read_body_and_raise_exc, - methods=["POST"], - ), + Route("/consume_body_in_endpoint_and_handler", endpoint=read_body_and_raise_exc, methods=["POST"]), ] ) @@ -114,10 +110,7 @@ def test_websockets_should_raise(client: TestClient) -> None: pass # pragma: no cover -def test_handled_exc_after_response( - test_client_factory: TestClientFactory, - client: TestClient, -) -> None: +def test_handled_exc_after_response(test_client_factory: TestClientFactory, client: TestClient) -> None: # A 406 HttpException is raised *after* the response has already been sent. # The exception middleware should raise a RuntimeError. with pytest.raises(RuntimeError): @@ -132,7 +125,7 @@ def test_handled_exc_after_response( def test_force_500_response(test_client_factory: TestClientFactory) -> None: - # use a sentinal variable to make sure we actually + # use a sentinel variable to make sure we actually # make it into the endpoint and don't get a 500 # from an incorrect ASGI app signature or something called = False