]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
✨ Add support for `dependencies` in WebSocket routes (#4534)
authorPaulo Costa <me@paulo.costa.nom.br>
Sun, 11 Jun 2023 20:35:39 +0000 (17:35 -0300)
committerGitHub <noreply@github.com>
Sun, 11 Jun 2023 20:35:39 +0000 (20:35 +0000)
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
fastapi/applications.py
fastapi/routing.py
tests/test_ws_dependencies.py [new file with mode: 0644]

index d5ea1d72af6b6cafcf3a7e2d221d0a887583e248..298aca921d70c265cc9de95916002ac4b1955abf 100644 (file)
@@ -401,15 +401,34 @@ class FastAPI(Starlette):
         return decorator
 
     def add_api_websocket_route(
-        self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None
+        self,
+        path: str,
+        endpoint: Callable[..., Any],
+        name: Optional[str] = None,
+        *,
+        dependencies: Optional[Sequence[Depends]] = None,
     ) -> None:
-        self.router.add_api_websocket_route(path, endpoint, name=name)
+        self.router.add_api_websocket_route(
+            path,
+            endpoint,
+            name=name,
+            dependencies=dependencies,
+        )
 
     def websocket(
-        self, path: str, name: Optional[str] = None
+        self,
+        path: str,
+        name: Optional[str] = None,
+        *,
+        dependencies: Optional[Sequence[Depends]] = None,
     ) -> Callable[[DecoratedCallable], DecoratedCallable]:
         def decorator(func: DecoratedCallable) -> DecoratedCallable:
-            self.add_api_websocket_route(path, func, name=name)
+            self.add_api_websocket_route(
+                path,
+                func,
+                name=name,
+                dependencies=dependencies,
+            )
             return func
 
         return decorator
index 7f1936f7f9ee627e5eeeb02199693c6a04d31016..af628f32d7d13f1643245b9b38ad2d1080b349c8 100644 (file)
@@ -296,13 +296,21 @@ class APIWebSocketRoute(routing.WebSocketRoute):
         endpoint: Callable[..., Any],
         *,
         name: Optional[str] = None,
+        dependencies: Optional[Sequence[params.Depends]] = None,
         dependency_overrides_provider: Optional[Any] = None,
     ) -> None:
         self.path = path
         self.endpoint = endpoint
         self.name = get_name(endpoint) if name is None else name
+        self.dependencies = list(dependencies or [])
         self.path_regex, self.path_format, self.param_convertors = compile_path(path)
         self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
+        for depends in self.dependencies[::-1]:
+            self.dependant.dependencies.insert(
+                0,
+                get_parameterless_sub_dependant(depends=depends, path=self.path_format),
+            )
+
         self.app = websocket_session(
             get_websocket_app(
                 dependant=self.dependant,
@@ -416,10 +424,7 @@ class APIRoute(routing.Route):
         else:
             self.response_field = None  # type: ignore
             self.secure_cloned_response_field = None
-        if dependencies:
-            self.dependencies = list(dependencies)
-        else:
-            self.dependencies = []
+        self.dependencies = list(dependencies or [])
         self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
         # if a "form feed" character (page break) is found in the description text,
         # truncate description text to the content preceding the first "form feed"
@@ -514,7 +519,7 @@ class APIRouter(routing.Router):
             ), "A path prefix must not end with '/', as the routes will start with '/'"
         self.prefix = prefix
         self.tags: List[Union[str, Enum]] = tags or []
-        self.dependencies = list(dependencies or []) or []
+        self.dependencies = list(dependencies or [])
         self.deprecated = deprecated
         self.include_in_schema = include_in_schema
         self.responses = responses or {}
@@ -688,21 +693,37 @@ class APIRouter(routing.Router):
         return decorator
 
     def add_api_websocket_route(
-        self, path: str, endpoint: Callable[..., Any], name: Optional[str] = None
+        self,
+        path: str,
+        endpoint: Callable[..., Any],
+        name: Optional[str] = None,
+        *,
+        dependencies: Optional[Sequence[params.Depends]] = None,
     ) -> None:
+        current_dependencies = self.dependencies.copy()
+        if dependencies:
+            current_dependencies.extend(dependencies)
+
         route = APIWebSocketRoute(
             self.prefix + path,
             endpoint=endpoint,
             name=name,
+            dependencies=current_dependencies,
             dependency_overrides_provider=self.dependency_overrides_provider,
         )
         self.routes.append(route)
 
     def websocket(
-        self, path: str, name: Optional[str] = None
+        self,
+        path: str,
+        name: Optional[str] = None,
+        *,
+        dependencies: Optional[Sequence[params.Depends]] = None,
     ) -> Callable[[DecoratedCallable], DecoratedCallable]:
         def decorator(func: DecoratedCallable) -> DecoratedCallable:
-            self.add_api_websocket_route(path, func, name=name)
+            self.add_api_websocket_route(
+                path, func, name=name, dependencies=dependencies
+            )
             return func
 
         return decorator
@@ -817,8 +838,16 @@ class APIRouter(routing.Router):
                     name=route.name,
                 )
             elif isinstance(route, APIWebSocketRoute):
+                current_dependencies = []
+                if dependencies:
+                    current_dependencies.extend(dependencies)
+                if route.dependencies:
+                    current_dependencies.extend(route.dependencies)
                 self.add_api_websocket_route(
-                    prefix + route.path, route.endpoint, name=route.name
+                    prefix + route.path,
+                    route.endpoint,
+                    dependencies=current_dependencies,
+                    name=route.name,
                 )
             elif isinstance(route, routing.WebSocketRoute):
                 self.add_websocket_route(
diff --git a/tests/test_ws_dependencies.py b/tests/test_ws_dependencies.py
new file mode 100644 (file)
index 0000000..ccb1c4b
--- /dev/null
@@ -0,0 +1,73 @@
+import json
+from typing import List
+
+from fastapi import APIRouter, Depends, FastAPI, WebSocket
+from fastapi.testclient import TestClient
+from typing_extensions import Annotated
+
+
+def dependency_list() -> List[str]:
+    return []
+
+
+DepList = Annotated[List[str], Depends(dependency_list)]
+
+
+def create_dependency(name: str):
+    def fun(deps: DepList):
+        deps.append(name)
+
+    return Depends(fun)
+
+
+router = APIRouter(dependencies=[create_dependency("router")])
+prefix_router = APIRouter(dependencies=[create_dependency("prefix_router")])
+app = FastAPI(dependencies=[create_dependency("app")])
+
+
+@app.websocket("/", dependencies=[create_dependency("index")])
+async def index(websocket: WebSocket, deps: DepList):
+    await websocket.accept()
+    await websocket.send_text(json.dumps(deps))
+    await websocket.close()
+
+
+@router.websocket("/router", dependencies=[create_dependency("routerindex")])
+async def routerindex(websocket: WebSocket, deps: DepList):
+    await websocket.accept()
+    await websocket.send_text(json.dumps(deps))
+    await websocket.close()
+
+
+@prefix_router.websocket("/", dependencies=[create_dependency("routerprefixindex")])
+async def routerprefixindex(websocket: WebSocket, deps: DepList):
+    await websocket.accept()
+    await websocket.send_text(json.dumps(deps))
+    await websocket.close()
+
+
+app.include_router(router, dependencies=[create_dependency("router2")])
+app.include_router(
+    prefix_router, prefix="/prefix", dependencies=[create_dependency("prefix_router2")]
+)
+
+
+def test_index():
+    client = TestClient(app)
+    with client.websocket_connect("/") as websocket:
+        data = json.loads(websocket.receive_text())
+        assert data == ["app", "index"]
+
+
+def test_routerindex():
+    client = TestClient(app)
+    with client.websocket_connect("/router") as websocket:
+        data = json.loads(websocket.receive_text())
+        assert data == ["app", "router2", "router", "routerindex"]
+
+
+def test_routerprefixindex():
+    client = TestClient(app)
+    with client.websocket_connect("/prefix/") as websocket:
+        data = json.loads(websocket.receive_text())
+        assert data == ["app", "prefix_router2", "prefix_router", "routerprefixindex"]