]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:bug: Fix dependency overrides in WebSockets (#1122)
authoramitlissack <amit@opentrons.com>
Mon, 30 Mar 2020 18:45:05 +0000 (14:45 -0400)
committerGitHub <noreply@github.com>
Mon, 30 Mar 2020 18:45:05 +0000 (20:45 +0200)
* add tests to test_ws_router to test dependencies and dependency overrides.

* supply dependency_overrides_provider to APIWebSocketRoute upon creation

fastapi/routing.py
tests/test_ws_router.py

index b90935e15ffe76ce3903f4b4bd66e842ab7a2d4c..1ec0b693c87b0165e760a0cc9b71f3558d02afac 100644 (file)
@@ -498,7 +498,12 @@ class APIRouter(routing.Router):
     def add_api_websocket_route(
         self, path: str, endpoint: Callable, name: str = None
     ) -> None:
-        route = APIWebSocketRoute(path, endpoint=endpoint, name=name)
+        route = APIWebSocketRoute(
+            path,
+            endpoint=endpoint,
+            name=name,
+            dependency_overrides_provider=self.dependency_overrides_provider,
+        )
         self.routes.append(route)
 
     def websocket(self, path: str, name: str = None) -> Callable:
index fd19e650a9ace44dec428573f6b018834b36368a..dd0456127222aeec84a74968df8a2702a1a54a12 100644 (file)
@@ -1,4 +1,4 @@
-from fastapi import APIRouter, FastAPI, WebSocket
+from fastapi import APIRouter, Depends, FastAPI, WebSocket
 from fastapi.testclient import TestClient
 
 router = APIRouter()
@@ -34,6 +34,19 @@ async def routerindex(websocket: WebSocket):
     await websocket.close()
 
 
+async def ws_dependency():
+    return "Socket Dependency"
+
+
+@router.websocket("/router-ws-depends/")
+async def router_ws_decorator_depends(
+    websocket: WebSocket, data=Depends(ws_dependency)
+):
+    await websocket.accept()
+    await websocket.send_text(data)
+    await websocket.close()
+
+
 app.include_router(router)
 app.include_router(prefix_router, prefix="/prefix")
 
@@ -64,3 +77,16 @@ def test_router2():
     with client.websocket_connect("/router2") as websocket:
         data = websocket.receive_text()
         assert data == "Hello, router!"
+
+
+def test_router_ws_depends():
+    client = TestClient(app)
+    with client.websocket_connect("/router-ws-depends/") as websocket:
+        assert websocket.receive_text() == "Socket Dependency"
+
+
+def test_router_ws_depends_with_override():
+    client = TestClient(app)
+    app.dependency_overrides[ws_dependency] = lambda: "Override"
+    with client.websocket_connect("/router-ws-depends/") as websocket:
+        assert websocket.receive_text() == "Override"