from fastapi.exception_handlers import (
http_exception_handler,
request_validation_exception_handler,
+ websocket_request_validation_exception_handler,
)
-from fastapi.exceptions import RequestValidationError
+from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
from fastapi.logger import logger
from fastapi.middleware.asyncexitstack import AsyncExitStackMiddleware
from fastapi.openapi.docs import (
self.exception_handlers.setdefault(
RequestValidationError, request_validation_exception_handler
)
+ self.exception_handlers.setdefault(
+ WebSocketRequestValidationError,
+ # Starlette still has incorrect type specification for the handlers
+ websocket_request_validation_exception_handler, # type: ignore
+ )
self.user_middleware: List[Middleware] = (
[] if middleware is None else list(middleware)
from fastapi.encoders import jsonable_encoder
-from fastapi.exceptions import RequestValidationError
+from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
from fastapi.utils import is_body_allowed_for_status_code
+from fastapi.websockets import WebSocket
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
-from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
+from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, WS_1008_POLICY_VIOLATION
async def http_exception_handler(request: Request, exc: HTTPException) -> Response:
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
content={"detail": jsonable_encoder(exc.errors())},
)
+
+
+async def websocket_request_validation_exception_handler(
+ websocket: WebSocket, exc: WebSocketRequestValidationError
+) -> None:
+ await websocket.close(
+ code=WS_1008_POLICY_VIOLATION, reason=jsonable_encoder(exc.errors())
+ )
-from fastapi import APIRouter, Depends, FastAPI, WebSocket
+import functools
+
+import pytest
+from fastapi import (
+ APIRouter,
+ Depends,
+ FastAPI,
+ Header,
+ WebSocket,
+ WebSocketDisconnect,
+ status,
+)
+from fastapi.middleware import Middleware
from fastapi.testclient import TestClient
router = APIRouter()
await websocket.close()
-app.include_router(router)
-app.include_router(prefix_router, prefix="/prefix")
-app.include_router(native_prefix_route)
+async def ws_dependency_err():
+ raise NotImplementedError()
+
+
+@router.websocket("/depends-err/")
+async def router_ws_depends_err(websocket: WebSocket, data=Depends(ws_dependency_err)):
+ pass # pragma: no cover
+
+
+async def ws_dependency_validate(x_missing: str = Header()):
+ pass # pragma: no cover
+
+
+@router.websocket("/depends-validate/")
+async def router_ws_depends_validate(
+ websocket: WebSocket, data=Depends(ws_dependency_validate)
+):
+ pass # pragma: no cover
+
+
+class CustomError(Exception):
+ pass
+
+
+@router.websocket("/custom_error/")
+async def router_ws_custom_error(websocket: WebSocket):
+ raise CustomError()
+
+
+def make_app(app=None, **kwargs):
+ app = app or FastAPI(**kwargs)
+ app.include_router(router)
+ app.include_router(prefix_router, prefix="/prefix")
+ app.include_router(native_prefix_route)
+ return app
+
+
+app = make_app(app)
def test_app():
assert data == "path/to/file"
data = websocket.receive_text()
assert data == "a_query_param"
+
+
+def test_wrong_uri():
+ """
+ Verify that a websocket connection to a non-existent endpoing returns in a shutdown
+ """
+ client = TestClient(app)
+ with pytest.raises(WebSocketDisconnect) as e:
+ with client.websocket_connect("/no-router/"):
+ pass # pragma: no cover
+ assert e.value.code == status.WS_1000_NORMAL_CLOSURE
+
+
+def websocket_middleware(middleware_func):
+ """
+ Helper to create a Starlette pure websocket middleware
+ """
+
+ def middleware_constructor(app):
+ @functools.wraps(app)
+ async def wrapped_app(scope, receive, send):
+ if scope["type"] != "websocket":
+ return await app(scope, receive, send) # pragma: no cover
+
+ async def call_next():
+ return await app(scope, receive, send)
+
+ websocket = WebSocket(scope, receive=receive, send=send)
+ return await middleware_func(websocket, call_next)
+
+ return wrapped_app
+
+ return middleware_constructor
+
+
+def test_depend_validation():
+ """
+ Verify that a validation in a dependency invokes the correct exception handler
+ """
+ caught = []
+
+ @websocket_middleware
+ async def catcher(websocket, call_next):
+ try:
+ return await call_next()
+ except Exception as e: # pragma: no cover
+ caught.append(e)
+ raise
+
+ myapp = make_app(middleware=[Middleware(catcher)])
+
+ client = TestClient(myapp)
+ with pytest.raises(WebSocketDisconnect) as e:
+ with client.websocket_connect("/depends-validate/"):
+ pass # pragma: no cover
+ # the validation error does produce a close message
+ assert e.value.code == status.WS_1008_POLICY_VIOLATION
+ # and no error is leaked
+ assert caught == []
+
+
+def test_depend_err_middleware():
+ """
+ Verify that it is possible to write custom WebSocket middleware to catch errors
+ """
+
+ @websocket_middleware
+ async def errorhandler(websocket: WebSocket, call_next):
+ try:
+ return await call_next()
+ except Exception as e:
+ await websocket.close(code=status.WS_1006_ABNORMAL_CLOSURE, reason=repr(e))
+
+ myapp = make_app(middleware=[Middleware(errorhandler)])
+ client = TestClient(myapp)
+ with pytest.raises(WebSocketDisconnect) as e:
+ with client.websocket_connect("/depends-err/"):
+ pass # pragma: no cover
+ assert e.value.code == status.WS_1006_ABNORMAL_CLOSURE
+ assert "NotImplementedError" in e.value.reason
+
+
+def test_depend_err_handler():
+ """
+ Verify that it is possible to write custom WebSocket middleware to catch errors
+ """
+
+ async def custom_handler(websocket: WebSocket, exc: CustomError) -> None:
+ await websocket.close(1002, "foo")
+
+ myapp = make_app(exception_handlers={CustomError: custom_handler})
+ client = TestClient(myapp)
+ with pytest.raises(WebSocketDisconnect) as e:
+ with client.websocket_connect("/custom_error/"):
+ pass # pragma: no cover
+ assert e.value.code == 1002
+ assert "foo" in e.value.reason