]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add `WebSocketException` and support for WS handlers (#1263)
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Mon, 5 Sep 2022 12:31:59 +0000 (14:31 +0200)
committerGitHub <noreply@github.com>
Mon, 5 Sep 2022 12:31:59 +0000 (14:31 +0200)
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
Co-authored-by: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>
Co-authored-by: Tom Christie <tom@tomchristie.com>
docs/exceptions.md
starlette/exceptions.py
starlette/middleware/exceptions.py
tests/test_applications.py
tests/test_exceptions.py

index 9818a20455daa2e39d452d74dc382cc3b6a135bc..f97f1af8937bb19c8304734e1b498797e56963fe 100644 (file)
@@ -62,6 +62,17 @@ async def http_exception(request: Request, exc: HTTPException):
     )
 ```
 
+You might also want to override how `WebSocketException` is handled:
+
+```python
+async def websocket_exception(websocket: WebSocket, exc: WebSocketException):
+    await websocket.close(code=1008)
+
+exception_handlers = {
+    WebSocketException: websocket_exception
+}
+```
+
 ## Errors and handled exceptions
 
 It is important to differentiate between handled exceptions and errors.
@@ -112,3 +123,11 @@ returning plain-text HTTP responses for any `HTTPException`.
 
 You should only raise `HTTPException` inside routing or endpoints. Middleware
 classes should instead just return appropriate responses directly.
+
+## WebSocketException
+
+You can use the `WebSocketException` class to raise errors inside of WebSocket endpoints.
+
+* `WebSocketException(code=1008, reason=None)`
+
+You can set any code valid as defined [in the specification](https://tools.ietf.org/html/rfc6455#section-7.4.1).
index 2b5acddb53df956a20b21325a45c16ebd653f505..87da735915de4752dc40346b11d2b25394a5bb77 100644 (file)
@@ -2,7 +2,7 @@ import http
 import typing
 import warnings
 
-__all__ = ("HTTPException",)
+__all__ = ("HTTPException", "WebSocketException")
 
 
 class HTTPException(Exception):
@@ -23,6 +23,16 @@ class HTTPException(Exception):
         return f"{class_name}(status_code={self.status_code!r}, detail={self.detail!r})"
 
 
+class WebSocketException(Exception):
+    def __init__(self, code: int, reason: typing.Optional[str] = None) -> None:
+        self.code = code
+        self.reason = reason or ""
+
+    def __repr__(self) -> str:
+        class_name = self.__class__.__name__
+        return f"{class_name}(code={self.code!r}, reason={self.reason!r})"
+
+
 __deprecated__ = "ExceptionMiddleware"
 
 
index 42fd41ae2fd528aeffd675bdc062483f8e0a0e28..cd729417043e2267ba7e5efaa345d66531388070 100644 (file)
@@ -2,10 +2,11 @@ import typing
 
 from starlette._utils import is_async_callable
 from starlette.concurrency import run_in_threadpool
-from starlette.exceptions import HTTPException
+from starlette.exceptions import HTTPException, WebSocketException
 from starlette.requests import Request
 from starlette.responses import PlainTextResponse, Response
 from starlette.types import ASGIApp, Message, Receive, Scope, Send
+from starlette.websockets import WebSocket
 
 
 class ExceptionMiddleware:
@@ -22,7 +23,10 @@ class ExceptionMiddleware:
         self._status_handlers: typing.Dict[int, typing.Callable] = {}
         self._exception_handlers: typing.Dict[
             typing.Type[Exception], typing.Callable
-        ] = {HTTPException: self.http_exception}
+        ] = {
+            HTTPException: self.http_exception,
+            WebSocketException: self.websocket_exception,
+        }
         if handlers is not None:
             for key, value in handlers.items():
                 self.add_exception_handler(key, value)
@@ -47,7 +51,7 @@ class ExceptionMiddleware:
         return None
 
     async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
-        if scope["type"] != "http":
+        if scope["type"] not in ("http", "websocket"):
             await self.app(scope, receive, send)
             return
 
@@ -78,12 +82,19 @@ class ExceptionMiddleware:
                 msg = "Caught handled exception, but response already started."
                 raise RuntimeError(msg) from exc
 
-            request = Request(scope, receive=receive)
-            if is_async_callable(handler):
-                response = await handler(request, exc)
-            else:
-                response = await run_in_threadpool(handler, request, exc)
-            await response(scope, receive, sender)
+            if scope["type"] == "http":
+                request = Request(scope, receive=receive)
+                if is_async_callable(handler):
+                    response = await handler(request, exc)
+                else:
+                    response = await run_in_threadpool(handler, request, exc)
+                await response(scope, receive, sender)
+            elif scope["type"] == "websocket":
+                websocket = WebSocket(scope, receive=receive, send=send)
+                if is_async_callable(handler):
+                    await handler(websocket, exc)
+                else:
+                    await run_in_threadpool(handler, websocket, exc)
 
     def http_exception(self, request: Request, exc: HTTPException) -> Response:
         if exc.status_code in {204, 304}:
@@ -91,3 +102,8 @@ class ExceptionMiddleware:
         return PlainTextResponse(
             exc.detail, status_code=exc.status_code, headers=exc.headers
         )
+
+    async def websocket_exception(
+        self, websocket: WebSocket, exc: WebSocketException
+    ) -> None:
+        await websocket.close(code=exc.code, reason=exc.reason)
index 0d0ede571b2b86e8939e04200fd37b6f327f9eb0..2cee601b0a979693edd8c44868ed21cb42ae1f86 100644 (file)
@@ -1,16 +1,19 @@
 import os
 from contextlib import asynccontextmanager
 
+import anyio
 import pytest
 
+from starlette import status
 from starlette.applications import Starlette
 from starlette.endpoints import HTTPEndpoint
-from starlette.exceptions import HTTPException
+from starlette.exceptions import HTTPException, WebSocketException
 from starlette.middleware import Middleware
 from starlette.middleware.trustedhost import TrustedHostMiddleware
 from starlette.responses import JSONResponse, PlainTextResponse
 from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
 from starlette.staticfiles import StaticFiles
+from starlette.websockets import WebSocket
 
 
 async def error_500(request, exc):
@@ -61,6 +64,24 @@ async def websocket_endpoint(session):
     await session.close()
 
 
+async def websocket_raise_websocket(websocket: WebSocket):
+    await websocket.accept()
+    raise WebSocketException(code=status.WS_1003_UNSUPPORTED_DATA)
+
+
+class CustomWSException(Exception):
+    pass
+
+
+async def websocket_raise_custom(websocket: WebSocket):
+    await websocket.accept()
+    raise CustomWSException()
+
+
+def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException):
+    anyio.from_thread.run(websocket.close, status.WS_1013_TRY_AGAIN_LATER)
+
+
 users = Router(
     routes=[
         Route("/", endpoint=all_users_page),
@@ -78,6 +99,7 @@ exception_handlers = {
     500: error_500,
     405: method_not_allowed,
     HTTPException: http_exception,
+    CustomWSException: custom_ws_exception_handler,
 }
 
 middleware = [
@@ -91,6 +113,8 @@ app = Starlette(
         Route("/class", endpoint=Homepage),
         Route("/500", endpoint=runtime_error),
         WebSocketRoute("/ws", endpoint=websocket_endpoint),
+        WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket),
+        WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom),
         Mount("/users", app=users),
         Host("{subdomain}.example.org", app=subdomain),
     ],
@@ -180,6 +204,26 @@ def test_500(test_client_factory):
     assert response.json() == {"detail": "Server Error"}
 
 
+def test_websocket_raise_websocket_exception(client):
+    with client.websocket_connect("/ws-raise-websocket") as session:
+        response = session.receive()
+        assert response == {
+            "type": "websocket.close",
+            "code": status.WS_1003_UNSUPPORTED_DATA,
+            "reason": "",
+        }
+
+
+def test_websocket_raise_custom_exception(client):
+    with client.websocket_connect("/ws-raise-custom") as session:
+        response = session.receive()
+        assert response == {
+            "type": "websocket.close",
+            "code": status.WS_1013_TRY_AGAIN_LATER,
+            "reason": "",
+        }
+
+
 def test_middleware(test_client_factory):
     client = test_client_factory(app, base_url="http://incorrecthost")
     response = client.get("/func")
@@ -194,6 +238,8 @@ def test_routes():
         Route("/class", endpoint=Homepage),
         Route("/500", endpoint=runtime_error, methods=["GET"]),
         WebSocketRoute("/ws", endpoint=websocket_endpoint),
+        WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket),
+        WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom),
         Mount(
             "/users",
             app=Router(
index 9acd421540e6f48ba69c06752f3b3ffbba54d196..05583a430b602f1fb8089cf38a368103c2a4d511 100644 (file)
@@ -2,7 +2,7 @@ import warnings
 
 import pytest
 
-from starlette.exceptions import HTTPException
+from starlette.exceptions import HTTPException, WebSocketException
 from starlette.middleware.exceptions import ExceptionMiddleware
 from starlette.responses import PlainTextResponse
 from starlette.routing import Route, Router, WebSocketRoute
@@ -119,7 +119,7 @@ def test_force_500_response(test_client_factory):
     assert response.text == ""
 
 
-def test_repr():
+def test_http_repr():
     assert repr(HTTPException(404)) == (
         "HTTPException(status_code=404, detail='Not Found')"
     )
@@ -135,6 +135,20 @@ def test_repr():
     )
 
 
+def test_websocket_repr():
+    assert repr(WebSocketException(1008, reason="Policy Violation")) == (
+        "WebSocketException(code=1008, reason='Policy Violation')"
+    )
+
+    class CustomWebSocketException(WebSocketException):
+        pass
+
+    assert (
+        repr(CustomWebSocketException(1013, reason="Something custom"))
+        == "CustomWebSocketException(code=1013, reason='Something custom')"
+    )
+
+
 def test_exception_middleware_deprecation() -> None:
     # this test should be removed once the deprecation shim is removed
     with pytest.warns(DeprecationWarning):