]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
✨ Add exception handler for `WebSocketRequestValidationError` (which also allows...
authorKristján Valur Jónsson <sweskman@gmail.com>
Sun, 11 Jun 2023 19:08:14 +0000 (19:08 +0000)
committerGitHub <noreply@github.com>
Sun, 11 Jun 2023 19:08:14 +0000 (21:08 +0200)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
fastapi/applications.py
fastapi/exception_handlers.py
fastapi/routing.py
tests/test_ws_router.py

index 8b3a74d3c833d7fb7d91c15080afea550b157e8a..d5ea1d72af6b6cafcf3a7e2d221d0a887583e248 100644 (file)
@@ -19,8 +19,9 @@ from fastapi.encoders import DictIntStrAny, SetIntStr
 from fastapi.exception_handlers import (
     http_exception_handler,
     request_validation_exception_handler,
+    websocket_request_validation_exception_handler,
 )
-from fastapi.exceptions import RequestValidationError
+from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
 from fastapi.logger import logger
 from fastapi.middleware.asyncexitstack import AsyncExitStackMiddleware
 from fastapi.openapi.docs import (
@@ -145,6 +146,11 @@ class FastAPI(Starlette):
         self.exception_handlers.setdefault(
             RequestValidationError, request_validation_exception_handler
         )
+        self.exception_handlers.setdefault(
+            WebSocketRequestValidationError,
+            # Starlette still has incorrect type specification for the handlers
+            websocket_request_validation_exception_handler,  # type: ignore
+        )
 
         self.user_middleware: List[Middleware] = (
             [] if middleware is None else list(middleware)
index 4d7ea5ec2e44bab45b444881175403685a5876fc..6c2ba7fedf9337260824b62987e65301e4fed129 100644 (file)
@@ -1,10 +1,11 @@
 from fastapi.encoders import jsonable_encoder
-from fastapi.exceptions import RequestValidationError
+from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
 from fastapi.utils import is_body_allowed_for_status_code
+from fastapi.websockets import WebSocket
 from starlette.exceptions import HTTPException
 from starlette.requests import Request
 from starlette.responses import JSONResponse, Response
-from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
+from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, WS_1008_POLICY_VIOLATION
 
 
 async def http_exception_handler(request: Request, exc: HTTPException) -> Response:
@@ -23,3 +24,11 @@ async def request_validation_exception_handler(
         status_code=HTTP_422_UNPROCESSABLE_ENTITY,
         content={"detail": jsonable_encoder(exc.errors())},
     )
+
+
+async def websocket_request_validation_exception_handler(
+    websocket: WebSocket, exc: WebSocketRequestValidationError
+) -> None:
+    await websocket.close(
+        code=WS_1008_POLICY_VIOLATION, reason=jsonable_encoder(exc.errors())
+    )
index 06c71bffadd983e94801cbb1b7d040af42f8e603..7f1936f7f9ee627e5eeeb02199693c6a04d31016 100644 (file)
@@ -56,7 +56,6 @@ from starlette.routing import (
     request_response,
     websocket_session,
 )
-from starlette.status import WS_1008_POLICY_VIOLATION
 from starlette.types import ASGIApp, Lifespan, Scope
 from starlette.websockets import WebSocket
 
@@ -283,7 +282,6 @@ def get_websocket_app(
         )
         values, errors, _, _2, _3 = solved_result
         if errors:
-            await websocket.close(code=WS_1008_POLICY_VIOLATION)
             raise WebSocketRequestValidationError(errors)
         assert dependant.call is not None, "dependant.call must be a function"
         await dependant.call(**values)
index c312821e969064a53410cf93358dd04d1cd5d6f8..240a42bb0c97e6c4c545410d3c965867251eb1e4 100644 (file)
@@ -1,4 +1,16 @@
-from fastapi import APIRouter, Depends, FastAPI, WebSocket
+import functools
+
+import pytest
+from fastapi import (
+    APIRouter,
+    Depends,
+    FastAPI,
+    Header,
+    WebSocket,
+    WebSocketDisconnect,
+    status,
+)
+from fastapi.middleware import Middleware
 from fastapi.testclient import TestClient
 
 router = APIRouter()
@@ -63,9 +75,44 @@ async def router_native_prefix_ws(websocket: WebSocket):
     await websocket.close()
 
 
-app.include_router(router)
-app.include_router(prefix_router, prefix="/prefix")
-app.include_router(native_prefix_route)
+async def ws_dependency_err():
+    raise NotImplementedError()
+
+
+@router.websocket("/depends-err/")
+async def router_ws_depends_err(websocket: WebSocket, data=Depends(ws_dependency_err)):
+    pass  # pragma: no cover
+
+
+async def ws_dependency_validate(x_missing: str = Header()):
+    pass  # pragma: no cover
+
+
+@router.websocket("/depends-validate/")
+async def router_ws_depends_validate(
+    websocket: WebSocket, data=Depends(ws_dependency_validate)
+):
+    pass  # pragma: no cover
+
+
+class CustomError(Exception):
+    pass
+
+
+@router.websocket("/custom_error/")
+async def router_ws_custom_error(websocket: WebSocket):
+    raise CustomError()
+
+
+def make_app(app=None, **kwargs):
+    app = app or FastAPI(**kwargs)
+    app.include_router(router)
+    app.include_router(prefix_router, prefix="/prefix")
+    app.include_router(native_prefix_route)
+    return app
+
+
+app = make_app(app)
 
 
 def test_app():
@@ -125,3 +172,100 @@ def test_router_with_params():
         assert data == "path/to/file"
         data = websocket.receive_text()
         assert data == "a_query_param"
+
+
+def test_wrong_uri():
+    """
+    Verify that a websocket connection to a non-existent endpoing returns in a shutdown
+    """
+    client = TestClient(app)
+    with pytest.raises(WebSocketDisconnect) as e:
+        with client.websocket_connect("/no-router/"):
+            pass  # pragma: no cover
+    assert e.value.code == status.WS_1000_NORMAL_CLOSURE
+
+
+def websocket_middleware(middleware_func):
+    """
+    Helper to create a Starlette pure websocket middleware
+    """
+
+    def middleware_constructor(app):
+        @functools.wraps(app)
+        async def wrapped_app(scope, receive, send):
+            if scope["type"] != "websocket":
+                return await app(scope, receive, send)  # pragma: no cover
+
+            async def call_next():
+                return await app(scope, receive, send)
+
+            websocket = WebSocket(scope, receive=receive, send=send)
+            return await middleware_func(websocket, call_next)
+
+        return wrapped_app
+
+    return middleware_constructor
+
+
+def test_depend_validation():
+    """
+    Verify that a validation in a dependency invokes the correct exception handler
+    """
+    caught = []
+
+    @websocket_middleware
+    async def catcher(websocket, call_next):
+        try:
+            return await call_next()
+        except Exception as e:  # pragma: no cover
+            caught.append(e)
+            raise
+
+    myapp = make_app(middleware=[Middleware(catcher)])
+
+    client = TestClient(myapp)
+    with pytest.raises(WebSocketDisconnect) as e:
+        with client.websocket_connect("/depends-validate/"):
+            pass  # pragma: no cover
+    # the validation error does produce a close message
+    assert e.value.code == status.WS_1008_POLICY_VIOLATION
+    # and no error is leaked
+    assert caught == []
+
+
+def test_depend_err_middleware():
+    """
+    Verify that it is possible to write custom WebSocket middleware to catch errors
+    """
+
+    @websocket_middleware
+    async def errorhandler(websocket: WebSocket, call_next):
+        try:
+            return await call_next()
+        except Exception as e:
+            await websocket.close(code=status.WS_1006_ABNORMAL_CLOSURE, reason=repr(e))
+
+    myapp = make_app(middleware=[Middleware(errorhandler)])
+    client = TestClient(myapp)
+    with pytest.raises(WebSocketDisconnect) as e:
+        with client.websocket_connect("/depends-err/"):
+            pass  # pragma: no cover
+    assert e.value.code == status.WS_1006_ABNORMAL_CLOSURE
+    assert "NotImplementedError" in e.value.reason
+
+
+def test_depend_err_handler():
+    """
+    Verify that it is possible to write custom WebSocket middleware to catch errors
+    """
+
+    async def custom_handler(websocket: WebSocket, exc: CustomError) -> None:
+        await websocket.close(1002, "foo")
+
+    myapp = make_app(exception_handlers={CustomError: custom_handler})
+    client = TestClient(myapp)
+    with pytest.raises(WebSocketDisconnect) as e:
+        with client.websocket_connect("/custom_error/"):
+            pass  # pragma: no cover
+    assert e.value.code == 1002
+    assert "foo" in e.value.reason