)
```
+You might also want to override how `WebSocketException` is handled:
+
+```python
+async def websocket_exception(websocket: WebSocket, exc: WebSocketException):
+ await websocket.close(code=1008)
+
+exception_handlers = {
+ WebSocketException: websocket_exception
+}
+```
+
## Errors and handled exceptions
It is important to differentiate between handled exceptions and errors.
You should only raise `HTTPException` inside routing or endpoints. Middleware
classes should instead just return appropriate responses directly.
+
+## WebSocketException
+
+You can use the `WebSocketException` class to raise errors inside of WebSocket endpoints.
+
+* `WebSocketException(code=1008, reason=None)`
+
+You can set any code valid as defined [in the specification](https://tools.ietf.org/html/rfc6455#section-7.4.1).
import typing
import warnings
-__all__ = ("HTTPException",)
+__all__ = ("HTTPException", "WebSocketException")
class HTTPException(Exception):
return f"{class_name}(status_code={self.status_code!r}, detail={self.detail!r})"
+class WebSocketException(Exception):
+ def __init__(self, code: int, reason: typing.Optional[str] = None) -> None:
+ self.code = code
+ self.reason = reason or ""
+
+ def __repr__(self) -> str:
+ class_name = self.__class__.__name__
+ return f"{class_name}(code={self.code!r}, reason={self.reason!r})"
+
+
__deprecated__ = "ExceptionMiddleware"
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
-from starlette.exceptions import HTTPException
+from starlette.exceptions import HTTPException, WebSocketException
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
+from starlette.websockets import WebSocket
class ExceptionMiddleware:
self._status_handlers: typing.Dict[int, typing.Callable] = {}
self._exception_handlers: typing.Dict[
typing.Type[Exception], typing.Callable
- ] = {HTTPException: self.http_exception}
+ ] = {
+ HTTPException: self.http_exception,
+ WebSocketException: self.websocket_exception,
+ }
if handlers is not None:
for key, value in handlers.items():
self.add_exception_handler(key, value)
return None
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- if scope["type"] != "http":
+ if scope["type"] not in ("http", "websocket"):
await self.app(scope, receive, send)
return
msg = "Caught handled exception, but response already started."
raise RuntimeError(msg) from exc
- request = Request(scope, receive=receive)
- if is_async_callable(handler):
- response = await handler(request, exc)
- else:
- response = await run_in_threadpool(handler, request, exc)
- await response(scope, receive, sender)
+ if scope["type"] == "http":
+ request = Request(scope, receive=receive)
+ if is_async_callable(handler):
+ response = await handler(request, exc)
+ else:
+ response = await run_in_threadpool(handler, request, exc)
+ await response(scope, receive, sender)
+ elif scope["type"] == "websocket":
+ websocket = WebSocket(scope, receive=receive, send=send)
+ if is_async_callable(handler):
+ await handler(websocket, exc)
+ else:
+ await run_in_threadpool(handler, websocket, exc)
def http_exception(self, request: Request, exc: HTTPException) -> Response:
if exc.status_code in {204, 304}:
return PlainTextResponse(
exc.detail, status_code=exc.status_code, headers=exc.headers
)
+
+ async def websocket_exception(
+ self, websocket: WebSocket, exc: WebSocketException
+ ) -> None:
+ await websocket.close(code=exc.code, reason=exc.reason)
import os
from contextlib import asynccontextmanager
+import anyio
import pytest
+from starlette import status
from starlette.applications import Starlette
from starlette.endpoints import HTTPEndpoint
-from starlette.exceptions import HTTPException
+from starlette.exceptions import HTTPException, WebSocketException
from starlette.middleware import Middleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.staticfiles import StaticFiles
+from starlette.websockets import WebSocket
async def error_500(request, exc):
await session.close()
+async def websocket_raise_websocket(websocket: WebSocket):
+ await websocket.accept()
+ raise WebSocketException(code=status.WS_1003_UNSUPPORTED_DATA)
+
+
+class CustomWSException(Exception):
+ pass
+
+
+async def websocket_raise_custom(websocket: WebSocket):
+ await websocket.accept()
+ raise CustomWSException()
+
+
+def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException):
+ anyio.from_thread.run(websocket.close, status.WS_1013_TRY_AGAIN_LATER)
+
+
users = Router(
routes=[
Route("/", endpoint=all_users_page),
500: error_500,
405: method_not_allowed,
HTTPException: http_exception,
+ CustomWSException: custom_ws_exception_handler,
}
middleware = [
Route("/class", endpoint=Homepage),
Route("/500", endpoint=runtime_error),
WebSocketRoute("/ws", endpoint=websocket_endpoint),
+ WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket),
+ WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom),
Mount("/users", app=users),
Host("{subdomain}.example.org", app=subdomain),
],
assert response.json() == {"detail": "Server Error"}
+def test_websocket_raise_websocket_exception(client):
+ with client.websocket_connect("/ws-raise-websocket") as session:
+ response = session.receive()
+ assert response == {
+ "type": "websocket.close",
+ "code": status.WS_1003_UNSUPPORTED_DATA,
+ "reason": "",
+ }
+
+
+def test_websocket_raise_custom_exception(client):
+ with client.websocket_connect("/ws-raise-custom") as session:
+ response = session.receive()
+ assert response == {
+ "type": "websocket.close",
+ "code": status.WS_1013_TRY_AGAIN_LATER,
+ "reason": "",
+ }
+
+
def test_middleware(test_client_factory):
client = test_client_factory(app, base_url="http://incorrecthost")
response = client.get("/func")
Route("/class", endpoint=Homepage),
Route("/500", endpoint=runtime_error, methods=["GET"]),
WebSocketRoute("/ws", endpoint=websocket_endpoint),
+ WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket),
+ WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom),
Mount(
"/users",
app=Router(
import pytest
-from starlette.exceptions import HTTPException
+from starlette.exceptions import HTTPException, WebSocketException
from starlette.middleware.exceptions import ExceptionMiddleware
from starlette.responses import PlainTextResponse
from starlette.routing import Route, Router, WebSocketRoute
assert response.text == ""
-def test_repr():
+def test_http_repr():
assert repr(HTTPException(404)) == (
"HTTPException(status_code=404, detail='Not Found')"
)
)
+def test_websocket_repr():
+ assert repr(WebSocketException(1008, reason="Policy Violation")) == (
+ "WebSocketException(code=1008, reason='Policy Violation')"
+ )
+
+ class CustomWebSocketException(WebSocketException):
+ pass
+
+ assert (
+ repr(CustomWebSocketException(1013, reason="Something custom"))
+ == "CustomWebSocketException(code=1013, reason='Something custom')"
+ )
+
+
def test_exception_middleware_deprecation() -> None:
# this test should be removed once the deprecation shim is removed
with pytest.warns(DeprecationWarning):