]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:sparkles: Add support for WebSockets with dependencies, params, etc #166 (#178)
authorJames Kaplan <jekirl@users.noreply.github.com>
Fri, 24 May 2019 16:41:41 +0000 (09:41 -0700)
committerSebastián Ramírez <tiangolo@gmail.com>
Fri, 24 May 2019 16:41:41 +0000 (20:41 +0400)
12 files changed:
docs/src/websockets/__init__.py [new file with mode: 0644]
docs/src/websockets/tutorial001.py
docs/src/websockets/tutorial002.py [new file with mode: 0644]
docs/tutorial/websockets.md
fastapi/applications.py
fastapi/dependencies/models.py
fastapi/dependencies/utils.py
fastapi/routing.py
tests/test_tutorial/test_websockets/__init__.py [new file with mode: 0644]
tests/test_tutorial/test_websockets/test_tutorial001.py [new file with mode: 0644]
tests/test_tutorial/test_websockets/test_tutorial002.py [new file with mode: 0644]
tests/test_ws_router.py

diff --git a/docs/src/websockets/__init__.py b/docs/src/websockets/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
index 2713550d36d44ee765ad7f9d509bcde66ea4ad34..3adfd49c10ba051fca1f194f943bd3421cc21001 100644 (file)
@@ -44,10 +44,9 @@ async def get():
     return HTMLResponse(html)
 
 
-@app.websocket_route("/ws")
+@app.websocket("/ws")
 async def websocket_endpoint(websocket: WebSocket):
     await websocket.accept()
     while True:
         data = await websocket.receive_text()
         await websocket.send_text(f"Message text was: {data}")
-    await websocket.close()
diff --git a/docs/src/websockets/tutorial002.py b/docs/src/websockets/tutorial002.py
new file mode 100644 (file)
index 0000000..f57b927
--- /dev/null
@@ -0,0 +1,78 @@
+from fastapi import Cookie, Depends, FastAPI, Header
+from starlette.responses import HTMLResponse
+from starlette.status import WS_1008_POLICY_VIOLATION
+from starlette.websockets import WebSocket
+
+app = FastAPI()
+
+html = """
+<!DOCTYPE html>
+<html>
+    <head>
+        <title>Chat</title>
+    </head>
+    <body>
+        <h1>WebSocket Chat</h1>
+        <form action="" onsubmit="sendMessage(event)">
+            <label>Item ID: <input type="text" id="itemId" autocomplete="off" value="foo"/></label>
+            <button onclick="connect(event)">Connect</button>
+            <br>
+            <label>Message: <input type="text" id="messageText" autocomplete="off"/></label>
+            <button>Send</button>
+        </form>
+        <ul id='messages'>
+        </ul>
+        <script>
+        var ws = null;
+            function connect(event) {
+                var input = document.getElementById("itemId")
+                ws = new WebSocket("ws://localhost:8000/items/" + input.value + "/ws");
+                ws.onmessage = function(event) {
+                    var messages = document.getElementById('messages')
+                    var message = document.createElement('li')
+                    var content = document.createTextNode(event.data)
+                    message.appendChild(content)
+                    messages.appendChild(message)
+                };
+            }
+            function sendMessage(event) {
+                var input = document.getElementById("messageText")
+                ws.send(input.value)
+                input.value = ''
+                event.preventDefault()
+            }
+        </script>
+    </body>
+</html>
+"""
+
+
+@app.get("/")
+async def get():
+    return HTMLResponse(html)
+
+
+async def get_cookie_or_client(
+    websocket: WebSocket, session: str = Cookie(None), x_client: str = Header(None)
+):
+    if session is None and x_client is None:
+        await websocket.close(code=WS_1008_POLICY_VIOLATION)
+    return session or x_client
+
+
+@app.websocket("/items/{item_id}/ws")
+async def websocket_endpoint(
+    websocket: WebSocket,
+    item_id: int,
+    q: str = None,
+    cookie_or_client: str = Depends(get_cookie_or_client),
+):
+    await websocket.accept()
+    while True:
+        data = await websocket.receive_text()
+        await websocket.send_text(
+            f"Session Cookie or X-Client Header value is: {cookie_or_client}"
+        )
+        if q is not None:
+            await websocket.send_text(f"Query parameter q is: {q}")
+        await websocket.send_text(f"Message text was: {data}, for item ID: {item_id}")
index 9bdb39a32edaea1c7b37e18ae3c1456f0b775068..16bba8ee3da990ca0abff109090665036f5f91a7 100644 (file)
@@ -27,9 +27,9 @@ But it's the simplest way to focus on the server-side of WebSockets and have a w
 {!./src/websockets/tutorial001.py!}
 ```
 
-## Create a `websocket_route`
+## Create a `websocket`
 
-In your **FastAPI** application, create a `websocket_route`:
+In your **FastAPI** application, create a `websocket`:
 
 ```Python hl_lines="3 47 48"
 {!./src/websockets/tutorial001.py!}
@@ -38,15 +38,6 @@ In your **FastAPI** application, create a `websocket_route`:
 !!! tip
     In this example we are importing `WebSocket` from `starlette.websockets` to use it in the type declaration in the WebSocket route function.
 
-    That is not required, but it's recommended as it will provide you completion and checks inside the function.
-
-
-!!! info
-    This `websocket_route` we are using comes directly from <a href="https://www.starlette.io/applications/" target="_blank">Starlette</a>. 
-    
-    That's why the naming convention is not the same as with other API path operations (`get`, `post`, etc).
-
-
 ## Await for messages and send messages
 
 In your WebSocket route you can `await` for messages and send messages.
@@ -57,6 +48,32 @@ In your WebSocket route you can `await` for messages and send messages.
 
 You can receive and send binary, text, and JSON data.
 
+## Using `Depends` and others
+
+In WebSocket endpoints you can import from `fastapi` and use:
+
+* `Depends`
+* `Security`
+* `Cookie`
+* `Header`
+* `Path`
+* `Query`
+
+They work the same way as for other FastAPI endpoints/*path operations*:
+
+```Python hl_lines="55 56 57 58 59 60 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78"
+{!./src/websockets/tutorial002.py!}
+```
+
+!!! info
+    In a WebSocket it doesn't really make sense to raise an `HTTPException`. So it's better to close the WebSocket connection directly.
+
+    You can use a closing code from the <a href="https://tools.ietf.org/html/rfc6455#section-7.4.1" target="_blank">valid codes defined in the specification</a>.
+
+    In the future, there will be a `WebSocketException` that you will be able to `raise` from anywhere, and add exception handlers for it. It depends on the <a href="https://github.com/encode/starlette/pull/527" target="_blank">PR #527</a> in Starlette.
+
+## More info
+
 To learn more about the options, check Starlette's documentation for:
 
 * <a href="https://www.starlette.io/applications/" target="_blank">Applications (`websocket_route`)</a>.
index dd5633dd28074bbdb85041a90361f796aad34487..7041e91d6edc8e135926bb75f662b3ba832e9049 100644 (file)
@@ -203,6 +203,18 @@ class FastAPI(Starlette):
 
         return decorator
 
+    def add_api_websocket_route(
+        self, path: str, endpoint: Callable, name: str = None
+    ) -> None:
+        self.router.add_api_websocket_route(path, endpoint, name=name)
+
+    def websocket(self, path: str, name: str = None) -> Callable:
+        def decorator(func: Callable) -> Callable:
+            self.add_api_websocket_route(path, func, name=name)
+            return func
+
+        return decorator
+
     def include_router(
         self,
         router: routing.APIRouter,
index 8bba5e369c71cae52bf308f01edf5290bb0c139e..67eb094e82f3cbb12a7925e7c4a8f9994ab2eba3 100644 (file)
@@ -26,6 +26,7 @@ class Dependant:
         name: str = None,
         call: Callable = None,
         request_param_name: str = None,
+        websocket_param_name: str = None,
         background_tasks_param_name: str = None,
         security_scopes_param_name: str = None,
         security_scopes: List[str] = None,
@@ -38,6 +39,7 @@ class Dependant:
         self.dependencies = dependencies or []
         self.security_requirements = security_schemes or []
         self.request_param_name = request_param_name
+        self.websocket_param_name = websocket_param_name
         self.background_tasks_param_name = background_tasks_param_name
         self.security_scopes = security_scopes
         self.security_scopes_param_name = security_scopes_param_name
index 0530fd2093cc5aeb8e079b4cd30c345774a09e14..194187f28ca06898cd828abdfb93e47bba22cdcf 100644 (file)
@@ -33,6 +33,7 @@ from starlette.background import BackgroundTasks
 from starlette.concurrency import run_in_threadpool
 from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
 from starlette.requests import Request
+from starlette.websockets import WebSocket
 
 param_supported_types = (
     str,
@@ -184,6 +185,8 @@ def get_dependant(
             )
         elif lenient_issubclass(param.annotation, Request):
             dependant.request_param_name = param_name
+        elif lenient_issubclass(param.annotation, WebSocket):
+            dependant.websocket_param_name = param_name
         elif lenient_issubclass(param.annotation, BackgroundTasks):
             dependant.background_tasks_param_name = param_name
         elif lenient_issubclass(param.annotation, SecurityScopes):
@@ -279,7 +282,7 @@ def is_coroutine_callable(call: Callable) -> bool:
 
 async def solve_dependencies(
     *,
-    request: Request,
+    request: Union[Request, WebSocket],
     dependant: Dependant,
     body: Dict[str, Any] = None,
     background_tasks: BackgroundTasks = None,
@@ -326,8 +329,10 @@ async def solve_dependencies(
         )
         values.update(body_values)
         errors.extend(body_errors)
-    if dependant.request_param_name:
+    if dependant.request_param_name and isinstance(request, Request):
         values[dependant.request_param_name] = request
+    elif dependant.websocket_param_name and isinstance(request, WebSocket):
+        values[dependant.websocket_param_name] = request
     if dependant.background_tasks_param_name:
         if background_tasks is None:
             background_tasks = BackgroundTasks()
index ef8d9bed212e3702aff887da02d10949e508f78b..c902bb2add42cc927eaae54632f90b093551e4dd 100644 (file)
@@ -1,6 +1,7 @@
 import asyncio
 import inspect
 import logging
+import re
 from typing import Any, Callable, Dict, List, Optional, Type, Union
 
 from fastapi import params
@@ -21,8 +22,14 @@ from starlette.concurrency import run_in_threadpool
 from starlette.exceptions import HTTPException
 from starlette.requests import Request
 from starlette.responses import JSONResponse, Response
-from starlette.routing import compile_path, get_name, request_response
-from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
+from starlette.routing import (
+    compile_path,
+    get_name,
+    request_response,
+    websocket_session,
+)
+from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, WS_1008_POLICY_VIOLATION
+from starlette.websockets import WebSocket
 
 
 def serialize_response(*, field: Field = None, response: Response) -> Any:
@@ -97,6 +104,35 @@ def get_app(
     return app
 
 
+def get_websocket_app(dependant: Dependant) -> Callable:
+    async def app(websocket: WebSocket) -> None:
+        values, errors, _ = await solve_dependencies(
+            request=websocket, dependant=dependant
+        )
+        if errors:
+            await websocket.close(code=WS_1008_POLICY_VIOLATION)
+            errors_out = ValidationError(errors)
+            raise HTTPException(
+                status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=errors_out.errors()
+            )
+        assert dependant.call is not None, "dependant.call must me a function"
+        await dependant.call(**values)
+
+    return app
+
+
+class APIWebSocketRoute(routing.WebSocketRoute):
+    def __init__(self, path: str, endpoint: Callable, *, name: str = None) -> None:
+        self.path = path
+        self.endpoint = endpoint
+        self.name = get_name(endpoint) if name is None else name
+        self.dependant = get_dependant(path=path, call=self.endpoint)
+        self.app = websocket_session(get_websocket_app(dependant=self.dependant))
+        regex = "^" + path + "$"
+        regex = re.sub("{([a-zA-Z_][a-zA-Z0-9_]*)}", r"(?P<\1>[^/]+)", regex)
+        self.path_regex, self.path_format, self.param_convertors = compile_path(path)
+
+
 class APIRoute(routing.Route):
     def __init__(
         self,
@@ -281,6 +317,19 @@ class APIRouter(routing.Router):
 
         return decorator
 
+    def add_api_websocket_route(
+        self, path: str, endpoint: Callable, name: str = None
+    ) -> None:
+        route = APIWebSocketRoute(path, endpoint=endpoint, name=name)
+        self.routes.append(route)
+
+    def websocket(self, path: str, name: str = None) -> Callable:
+        def decorator(func: Callable) -> Callable:
+            self.add_api_websocket_route(path, func, name=name)
+            return func
+
+        return decorator
+
     def include_router(
         self,
         router: "APIRouter",
@@ -326,6 +375,10 @@ class APIRouter(routing.Router):
                     include_in_schema=route.include_in_schema,
                     name=route.name,
                 )
+            elif isinstance(route, APIWebSocketRoute):
+                self.add_api_websocket_route(
+                    prefix + route.path, route.endpoint, name=route.name
+                )
             elif isinstance(route, routing.WebSocketRoute):
                 self.add_websocket_route(
                     prefix + route.path, route.endpoint, name=route.name
diff --git a/tests/test_tutorial/test_websockets/__init__.py b/tests/test_tutorial/test_websockets/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/tests/test_tutorial/test_websockets/test_tutorial001.py b/tests/test_tutorial/test_websockets/test_tutorial001.py
new file mode 100644 (file)
index 0000000..e886140
--- /dev/null
@@ -0,0 +1,25 @@
+import pytest
+from starlette.testclient import TestClient
+from starlette.websockets import WebSocketDisconnect
+from websockets.tutorial001 import app
+
+client = TestClient(app)
+
+
+def test_main():
+    response = client.get("/")
+    assert response.status_code == 200
+    assert b"<!DOCTYPE html>" in response.content
+
+
+def test_websocket():
+    with pytest.raises(WebSocketDisconnect):
+        with client.websocket_connect("/ws") as websocket:
+            message = "Message one"
+            websocket.send_text(message)
+            data = websocket.receive_text()
+            assert data == f"Message text was: {message}"
+            message = "Message two"
+            websocket.send_text(message)
+            data = websocket.receive_text()
+            assert data == f"Message text was: {message}"
diff --git a/tests/test_tutorial/test_websockets/test_tutorial002.py b/tests/test_tutorial/test_websockets/test_tutorial002.py
new file mode 100644 (file)
index 0000000..063f83c
--- /dev/null
@@ -0,0 +1,83 @@
+import pytest
+from starlette.testclient import TestClient
+from starlette.websockets import WebSocketDisconnect
+from websockets.tutorial002 import app
+
+client = TestClient(app)
+
+
+def test_main():
+    response = client.get("/")
+    assert response.status_code == 200
+    assert b"<!DOCTYPE html>" in response.content
+
+
+def test_websocket_with_cookie():
+    with pytest.raises(WebSocketDisconnect):
+        with client.websocket_connect(
+            "/items/1/ws", cookies={"session": "fakesession"}
+        ) as websocket:
+            message = "Message one"
+            websocket.send_text(message)
+            data = websocket.receive_text()
+            assert data == "Session Cookie or X-Client Header value is: fakesession"
+            data = websocket.receive_text()
+            assert data == f"Message text was: {message}, for item ID: 1"
+            message = "Message two"
+            websocket.send_text(message)
+            data = websocket.receive_text()
+            assert data == "Session Cookie or X-Client Header value is: fakesession"
+            data = websocket.receive_text()
+            assert data == f"Message text was: {message}, for item ID: 1"
+
+
+def test_websocket_with_header():
+    with pytest.raises(WebSocketDisconnect):
+        with client.websocket_connect(
+            "/items/2/ws", headers={"X-Client": "xmen"}
+        ) as websocket:
+            message = "Message one"
+            websocket.send_text(message)
+            data = websocket.receive_text()
+            assert data == "Session Cookie or X-Client Header value is: xmen"
+            data = websocket.receive_text()
+            assert data == f"Message text was: {message}, for item ID: 2"
+            message = "Message two"
+            websocket.send_text(message)
+            data = websocket.receive_text()
+            assert data == "Session Cookie or X-Client Header value is: xmen"
+            data = websocket.receive_text()
+            assert data == f"Message text was: {message}, for item ID: 2"
+
+
+def test_websocket_with_header_and_query():
+    with pytest.raises(WebSocketDisconnect):
+        with client.websocket_connect(
+            "/items/2/ws?q=baz", headers={"X-Client": "xmen"}
+        ) as websocket:
+            message = "Message one"
+            websocket.send_text(message)
+            data = websocket.receive_text()
+            assert data == "Session Cookie or X-Client Header value is: xmen"
+            data = websocket.receive_text()
+            assert data == "Query parameter q is: baz"
+            data = websocket.receive_text()
+            assert data == f"Message text was: {message}, for item ID: 2"
+            message = "Message two"
+            websocket.send_text(message)
+            data = websocket.receive_text()
+            assert data == "Session Cookie or X-Client Header value is: xmen"
+            data = websocket.receive_text()
+            assert data == "Query parameter q is: baz"
+            data = websocket.receive_text()
+            assert data == f"Message text was: {message}, for item ID: 2"
+
+
+def test_websocket_no_credentials():
+    with pytest.raises(WebSocketDisconnect):
+        client.websocket_connect("/items/2/ws")
+
+
+def test_websocket_invalid_data():
+    with pytest.raises(WebSocketDisconnect):
+        client.websocket_connect("/items/foo/ws", headers={"X-Client": "xmen"})
index d3c69ca1f866f08e30026b6492dd50e6a46296f1..019d7cbc7cc1f262a6c68699050bd4e7b2a59c96 100644 (file)
@@ -28,6 +28,13 @@ async def routerprefixindex(websocket: WebSocket):
     await websocket.close()
 
 
+@router.websocket("/router2")
+async def routerindex(websocket: WebSocket):
+    await websocket.accept()
+    await websocket.send_text("Hello, router!")
+    await websocket.close()
+
+
 app.include_router(router)
 app.include_router(prefix_router, prefix="/prefix")
 
@@ -51,3 +58,10 @@ def test_prefix_router():
     with client.websocket_connect("/prefix/") as websocket:
         data = websocket.receive_text()
         assert data == "Hello, router with prefix!"
+
+
+def test_router2():
+    client = TestClient(app)
+    with client.websocket_connect("/router2") as websocket:
+        data = websocket.receive_text()
+        assert data == "Hello, router!"