]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
test websocket-exception-on-http
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Sun, 13 Oct 2024 15:19:17 +0000 (17:19 +0200)
committerMarcelo Trylesinski <marcelotryle@gmail.com>
Sun, 13 Oct 2024 15:19:17 +0000 (17:19 +0200)
main.py [new file with mode: 0644]
starlette/_exception_handler.py
starlette/middleware/exceptions.py
tests/test_applications.py
tests/test_exceptions.py

diff --git a/main.py b/main.py
new file mode 100644 (file)
index 0000000..d5899a8
--- /dev/null
+++ b/main.py
@@ -0,0 +1,13 @@
+from starlette.applications import Starlette
+from starlette.exceptions import HTTPException, WebSocketException
+from starlette.routing import WebSocketRoute
+from starlette.websockets import WebSocket
+
+
+async def websocket_endpoint(websocket: WebSocket):
+    await websocket.accept()
+    raise WebSocketException(code=1001)
+    raise HTTPException(400)
+
+
+app = Starlette(routes=[WebSocketRoute("/ws", websocket_endpoint)])
index 4fbc86394d1cea6c1d8d2af9e7edd1afe24f4f3b..72438408d0305e8491c6751a42af29b80341dfed 100644 (file)
@@ -1,21 +1,13 @@
 from __future__ import annotations
 
 import typing
+import warnings
 
 from starlette._utils import is_async_callable
 from starlette.concurrency import run_in_threadpool
-from starlette.exceptions import HTTPException
+from starlette.exceptions import HTTPException, WebSocketException
 from starlette.requests import Request
-from starlette.types import (
-    ASGIApp,
-    ExceptionHandler,
-    HTTPExceptionHandler,
-    Message,
-    Receive,
-    Scope,
-    Send,
-    WebSocketExceptionHandler,
-)
+from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
 from starlette.websockets import WebSocket
 
 ExceptionHandlers = typing.Dict[typing.Any, ExceptionHandler]
@@ -30,28 +22,33 @@ def _lookup_exception_handler(exc_handlers: ExceptionHandlers, exc: Exception) -
 
 
 def wrap_app_handling_exceptions(app: ASGIApp, conn: Request | WebSocket) -> ASGIApp:
-    exception_handlers: ExceptionHandlers
-    status_handlers: StatusHandlers
-    try:
-        exception_handlers, status_handlers = conn.scope["starlette.exception_handlers"]
-    except KeyError:
-        exception_handlers, status_handlers = {}, {}
-
     async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None:
         response_started = False
+        websocket_accepted = False
+        print("websocket_accepted", websocket_accepted)
 
         async def sender(message: Message) -> None:
             nonlocal response_started
+            nonlocal websocket_accepted
 
-            if message["type"] == "http.response.start":
+            if message["type"] in ("http.response.start", "websocket.http.response.start"):
                 response_started = True
+            elif message["type"] == "websocket.accept":
+                websocket_accepted = True
+            print(f"message: {message}")
             await send(message)
 
         try:
             await app(scope, receive, sender)
         except Exception as exc:
-            handler = None
+            exception_handlers: ExceptionHandlers
+            status_handlers: StatusHandlers
+            try:
+                exception_handlers, status_handlers = scope["starlette.exception_handlers"]
+            except KeyError:
+                exception_handlers, status_handlers = {}, {}
 
+            handler = None
             if isinstance(exc, HTTPException):
                 handler = status_handlers.get(exc.status_code)
 
@@ -62,24 +59,17 @@ def wrap_app_handling_exceptions(app: ASGIApp, conn: Request | WebSocket) -> ASG
                 raise exc
 
             if response_started:
-                msg = "Caught handled exception, but response already started."
-                raise RuntimeError(msg) from exc
-
-            if scope["type"] == "http":
-                nonlocal conn
-                handler = typing.cast(HTTPExceptionHandler, handler)
-                conn = typing.cast(Request, conn)
-                if is_async_callable(handler):
-                    response = await handler(conn, exc)
-                else:
-                    response = await run_in_threadpool(handler, conn, exc)
+                raise RuntimeError("Caught handled exception, but response already started.") from exc
+
+            print("run before the conditional", websocket_accepted)
+            if not websocket_accepted and isinstance(exc, WebSocketException):
+                warnings.warn("WebSocketException used before the websocket connection was accepted.", UserWarning)
+
+            if is_async_callable(handler):
+                response = await handler(conn, exc)
+            else:
+                response = await run_in_threadpool(handler, conn, exc)  # type: ignore
+            if response is not None:
                 await response(scope, receive, sender)
-            elif scope["type"] == "websocket":
-                handler = typing.cast(WebSocketExceptionHandler, handler)
-                conn = typing.cast(WebSocket, conn)
-                if is_async_callable(handler):
-                    await handler(conn, exc)
-                else:
-                    await run_in_threadpool(handler, conn, exc)
 
     return wrapped_app
index d708929e3bae38a055885aa9dd4573b25c87aa3e..139884a38a901bb8e747626b01a327216d14b600 100644 (file)
@@ -2,11 +2,7 @@ from __future__ import annotations
 
 import typing
 
-from starlette._exception_handler import (
-    ExceptionHandlers,
-    StatusHandlers,
-    wrap_app_handling_exceptions,
-)
+from starlette._exception_handler import ExceptionHandlers, StatusHandlers, wrap_app_handling_exceptions
 from starlette.exceptions import HTTPException, WebSocketException
 from starlette.requests import Request
 from starlette.responses import PlainTextResponse, Response
@@ -45,13 +41,9 @@ class ExceptionMiddleware:
 
     async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
         if scope["type"] not in ("http", "websocket"):
-            await self.app(scope, receive, send)
-            return
+            return await self.app(scope, receive, send)
 
-        scope["starlette.exception_handlers"] = (
-            self._exception_handlers,
-            self._status_handlers,
-        )
+        scope["starlette.exception_handlers"] = (self._exception_handlers, self._status_handlers)
 
         conn: Request | WebSocket
         if scope["type"] == "http":
index 86c713c38a64be87a8841b029cf9d0d7c83b66b2..afdcecd3d0b051847c2b23210ab6e29d0ec191af 100644 (file)
@@ -3,7 +3,7 @@ from contextlib import asynccontextmanager
 from pathlib import Path
 from typing import AsyncGenerator, AsyncIterator, Generator
 
-import anyio
+import anyio.from_thread
 import pytest
 
 from starlette import status
@@ -212,21 +212,13 @@ def test_500(test_client_factory: TestClientFactory) -> None:
 def test_websocket_raise_websocket_exception(client: TestClient) -> None:
     with client.websocket_connect("/ws-raise-websocket") as session:
         response = session.receive()
-        assert response == {
-            "type": "websocket.close",
-            "code": status.WS_1003_UNSUPPORTED_DATA,
-            "reason": "",
-        }
+        assert response == {"type": "websocket.close", "code": status.WS_1003_UNSUPPORTED_DATA, "reason": ""}
 
 
 def test_websocket_raise_custom_exception(client: TestClient) -> None:
     with client.websocket_connect("/ws-raise-custom") as session:
         response = session.receive()
-        assert response == {
-            "type": "websocket.close",
-            "code": status.WS_1013_TRY_AGAIN_LATER,
-            "reason": "",
-        }
+        assert response == {"type": "websocket.close", "code": status.WS_1013_TRY_AGAIN_LATER, "reason": ""}
 
 
 def test_middleware(test_client_factory: TestClientFactory) -> None:
@@ -254,10 +246,7 @@ def test_routes() -> None:
                 ]
             ),
         ),
-        Host(
-            "{subdomain}.example.org",
-            app=Router(routes=[Route("/", endpoint=custom_subdomain)]),
-        ),
+        Host("{subdomain}.example.org", app=Router(routes=[Route("/", endpoint=custom_subdomain)])),
     ]
 
 
@@ -266,11 +255,7 @@ def test_app_mount(tmpdir: Path, test_client_factory: TestClientFactory) -> None
     with open(path, "w") as file:
         file.write("<file content>")
 
-    app = Starlette(
-        routes=[
-            Mount("/static", StaticFiles(directory=tmpdir)),
-        ]
-    )
+    app = Starlette(routes=[Mount("/static", StaticFiles(directory=tmpdir))])
 
     client = test_client_factory(app)
 
@@ -287,11 +272,7 @@ def test_app_debug(test_client_factory: TestClientFactory) -> None:
     async def homepage(request: Request) -> None:
         raise RuntimeError()
 
-    app = Starlette(
-        routes=[
-            Route("/", homepage),
-        ],
-    )
+    app = Starlette(routes=[Route("/", homepage)])
     app.debug = True
 
     client = test_client_factory(app, raise_server_exceptions=False)
@@ -305,11 +286,7 @@ def test_app_add_route(test_client_factory: TestClientFactory) -> None:
     async def homepage(request: Request) -> PlainTextResponse:
         return PlainTextResponse("Hello, World!")
 
-    app = Starlette(
-        routes=[
-            Route("/", endpoint=homepage),
-        ]
-    )
+    app = Starlette(routes=[Route("/", endpoint=homepage)])
 
     client = test_client_factory(app)
     response = client.get("/")
@@ -323,11 +300,7 @@ def test_app_add_websocket_route(test_client_factory: TestClientFactory) -> None
         await session.send_text("Hello, world!")
         await session.close()
 
-    app = Starlette(
-        routes=[
-            WebSocketRoute("/ws", endpoint=websocket_endpoint),
-        ]
-    )
+    app = Starlette(routes=[WebSocketRoute("/ws", endpoint=websocket_endpoint)])
     client = test_client_factory(app)
 
     with client.websocket_connect("/ws") as session:
@@ -348,10 +321,7 @@ def test_app_add_event_handler(test_client_factory: TestClientFactory) -> None:
         cleanup_complete = True
 
     with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
-        app = Starlette(
-            on_startup=[run_startup],
-            on_shutdown=[run_cleanup],
-        )
+        app = Starlette(on_startup=[run_startup], on_shutdown=[run_cleanup])
 
     assert not startup_complete
     assert not cleanup_complete
index eebee3759f92f5530fbc8589007da88f5c668cbe..92470988d426520b6b7f32b2a1af14790abe8917 100644 (file)
@@ -63,11 +63,7 @@ router = Router(
         Route("/with_headers", endpoint=with_headers),
         Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse()),
         WebSocketRoute("/runtime_error", endpoint=raise_runtime_error),
-        Route(
-            "/consume_body_in_endpoint_and_handler",
-            endpoint=read_body_and_raise_exc,
-            methods=["POST"],
-        ),
+        Route("/consume_body_in_endpoint_and_handler", endpoint=read_body_and_raise_exc, methods=["POST"]),
     ]
 )
 
@@ -114,10 +110,7 @@ def test_websockets_should_raise(client: TestClient) -> None:
             pass  # pragma: no cover
 
 
-def test_handled_exc_after_response(
-    test_client_factory: TestClientFactory,
-    client: TestClient,
-) -> None:
+def test_handled_exc_after_response(test_client_factory: TestClientFactory, client: TestClient) -> None:
     # A 406 HttpException is raised *after* the response has already been sent.
     # The exception middleware should raise a RuntimeError.
     with pytest.raises(RuntimeError):
@@ -132,7 +125,7 @@ def test_handled_exc_after_response(
 
 
 def test_force_500_response(test_client_factory: TestClientFactory) -> None:
-    # use a sentinal variable to make sure we actually
+    # use a sentinel variable to make sure we actually
     # make it into the endpoint and don't get a 500
     # from an incorrect ASGI app signature or something
     called = False