From: amitlissack Date: Mon, 30 Mar 2020 18:45:05 +0000 (-0400) Subject: :bug: Fix dependency overrides in WebSockets (#1122) X-Git-Tag: 0.53.2~4 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=02441ff0313d5b471b662293244c53e712f1243f;p=thirdparty%2Ffastapi%2Ffastapi.git :bug: Fix dependency overrides in WebSockets (#1122) * add tests to test_ws_router to test dependencies and dependency overrides. * supply dependency_overrides_provider to APIWebSocketRoute upon creation --- diff --git a/fastapi/routing.py b/fastapi/routing.py index b90935e15f..1ec0b693c8 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -498,7 +498,12 @@ class APIRouter(routing.Router): def add_api_websocket_route( self, path: str, endpoint: Callable, name: str = None ) -> None: - route = APIWebSocketRoute(path, endpoint=endpoint, name=name) + route = APIWebSocketRoute( + path, + endpoint=endpoint, + name=name, + dependency_overrides_provider=self.dependency_overrides_provider, + ) self.routes.append(route) def websocket(self, path: str, name: str = None) -> Callable: diff --git a/tests/test_ws_router.py b/tests/test_ws_router.py index fd19e650a9..dd04561272 100644 --- a/tests/test_ws_router.py +++ b/tests/test_ws_router.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, FastAPI, WebSocket +from fastapi import APIRouter, Depends, FastAPI, WebSocket from fastapi.testclient import TestClient router = APIRouter() @@ -34,6 +34,19 @@ async def routerindex(websocket: WebSocket): await websocket.close() +async def ws_dependency(): + return "Socket Dependency" + + +@router.websocket("/router-ws-depends/") +async def router_ws_decorator_depends( + websocket: WebSocket, data=Depends(ws_dependency) +): + await websocket.accept() + await websocket.send_text(data) + await websocket.close() + + app.include_router(router) app.include_router(prefix_router, prefix="/prefix") @@ -64,3 +77,16 @@ def test_router2(): with client.websocket_connect("/router2") as websocket: data = websocket.receive_text() assert data == "Hello, router!" + + +def test_router_ws_depends(): + client = TestClient(app) + with client.websocket_connect("/router-ws-depends/") as websocket: + assert websocket.receive_text() == "Socket Dependency" + + +def test_router_ws_depends_with_override(): + client = TestClient(app) + app.dependency_overrides[ws_dependency] = lambda: "Override" + with client.websocket_connect("/router-ws-depends/") as websocket: + assert websocket.receive_text() == "Override"