return HTMLResponse(html)
-@app.websocket_route("/ws")
+@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
while True:
data = await websocket.receive_text()
await websocket.send_text(f"Message text was: {data}")
- await websocket.close()
--- /dev/null
+from fastapi import Cookie, Depends, FastAPI, Header
+from starlette.responses import HTMLResponse
+from starlette.status import WS_1008_POLICY_VIOLATION
+from starlette.websockets import WebSocket
+
+app = FastAPI()
+
+html = """
+<!DOCTYPE html>
+<html>
+ <head>
+ <title>Chat</title>
+ </head>
+ <body>
+ <h1>WebSocket Chat</h1>
+ <form action="" onsubmit="sendMessage(event)">
+ <label>Item ID: <input type="text" id="itemId" autocomplete="off" value="foo"/></label>
+ <button onclick="connect(event)">Connect</button>
+ <br>
+ <label>Message: <input type="text" id="messageText" autocomplete="off"/></label>
+ <button>Send</button>
+ </form>
+ <ul id='messages'>
+ </ul>
+ <script>
+ var ws = null;
+ function connect(event) {
+ var input = document.getElementById("itemId")
+ ws = new WebSocket("ws://localhost:8000/items/" + input.value + "/ws");
+ ws.onmessage = function(event) {
+ var messages = document.getElementById('messages')
+ var message = document.createElement('li')
+ var content = document.createTextNode(event.data)
+ message.appendChild(content)
+ messages.appendChild(message)
+ };
+ }
+ function sendMessage(event) {
+ var input = document.getElementById("messageText")
+ ws.send(input.value)
+ input.value = ''
+ event.preventDefault()
+ }
+ </script>
+ </body>
+</html>
+"""
+
+
+@app.get("/")
+async def get():
+ return HTMLResponse(html)
+
+
+async def get_cookie_or_client(
+ websocket: WebSocket, session: str = Cookie(None), x_client: str = Header(None)
+):
+ if session is None and x_client is None:
+ await websocket.close(code=WS_1008_POLICY_VIOLATION)
+ return session or x_client
+
+
+@app.websocket("/items/{item_id}/ws")
+async def websocket_endpoint(
+ websocket: WebSocket,
+ item_id: int,
+ q: str = None,
+ cookie_or_client: str = Depends(get_cookie_or_client),
+):
+ await websocket.accept()
+ while True:
+ data = await websocket.receive_text()
+ await websocket.send_text(
+ f"Session Cookie or X-Client Header value is: {cookie_or_client}"
+ )
+ if q is not None:
+ await websocket.send_text(f"Query parameter q is: {q}")
+ await websocket.send_text(f"Message text was: {data}, for item ID: {item_id}")
{!./src/websockets/tutorial001.py!}
```
-## Create a `websocket_route`
+## Create a `websocket`
-In your **FastAPI** application, create a `websocket_route`:
+In your **FastAPI** application, create a `websocket`:
```Python hl_lines="3 47 48"
{!./src/websockets/tutorial001.py!}
!!! tip
In this example we are importing `WebSocket` from `starlette.websockets` to use it in the type declaration in the WebSocket route function.
- That is not required, but it's recommended as it will provide you completion and checks inside the function.
-
-
-!!! info
- This `websocket_route` we are using comes directly from <a href="https://www.starlette.io/applications/" target="_blank">Starlette</a>.
-
- That's why the naming convention is not the same as with other API path operations (`get`, `post`, etc).
-
-
## Await for messages and send messages
In your WebSocket route you can `await` for messages and send messages.
You can receive and send binary, text, and JSON data.
+## Using `Depends` and others
+
+In WebSocket endpoints you can import from `fastapi` and use:
+
+* `Depends`
+* `Security`
+* `Cookie`
+* `Header`
+* `Path`
+* `Query`
+
+They work the same way as for other FastAPI endpoints/*path operations*:
+
+```Python hl_lines="55 56 57 58 59 60 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78"
+{!./src/websockets/tutorial002.py!}
+```
+
+!!! info
+ In a WebSocket it doesn't really make sense to raise an `HTTPException`. So it's better to close the WebSocket connection directly.
+
+ You can use a closing code from the <a href="https://tools.ietf.org/html/rfc6455#section-7.4.1" target="_blank">valid codes defined in the specification</a>.
+
+ In the future, there will be a `WebSocketException` that you will be able to `raise` from anywhere, and add exception handlers for it. It depends on the <a href="https://github.com/encode/starlette/pull/527" target="_blank">PR #527</a> in Starlette.
+
+## More info
+
To learn more about the options, check Starlette's documentation for:
* <a href="https://www.starlette.io/applications/" target="_blank">Applications (`websocket_route`)</a>.
return decorator
+ def add_api_websocket_route(
+ self, path: str, endpoint: Callable, name: str = None
+ ) -> None:
+ self.router.add_api_websocket_route(path, endpoint, name=name)
+
+ def websocket(self, path: str, name: str = None) -> Callable:
+ def decorator(func: Callable) -> Callable:
+ self.add_api_websocket_route(path, func, name=name)
+ return func
+
+ return decorator
+
def include_router(
self,
router: routing.APIRouter,
name: str = None,
call: Callable = None,
request_param_name: str = None,
+ websocket_param_name: str = None,
background_tasks_param_name: str = None,
security_scopes_param_name: str = None,
security_scopes: List[str] = None,
self.dependencies = dependencies or []
self.security_requirements = security_schemes or []
self.request_param_name = request_param_name
+ self.websocket_param_name = websocket_param_name
self.background_tasks_param_name = background_tasks_param_name
self.security_scopes = security_scopes
self.security_scopes_param_name = security_scopes_param_name
from starlette.concurrency import run_in_threadpool
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
from starlette.requests import Request
+from starlette.websockets import WebSocket
param_supported_types = (
str,
)
elif lenient_issubclass(param.annotation, Request):
dependant.request_param_name = param_name
+ elif lenient_issubclass(param.annotation, WebSocket):
+ dependant.websocket_param_name = param_name
elif lenient_issubclass(param.annotation, BackgroundTasks):
dependant.background_tasks_param_name = param_name
elif lenient_issubclass(param.annotation, SecurityScopes):
async def solve_dependencies(
*,
- request: Request,
+ request: Union[Request, WebSocket],
dependant: Dependant,
body: Dict[str, Any] = None,
background_tasks: BackgroundTasks = None,
)
values.update(body_values)
errors.extend(body_errors)
- if dependant.request_param_name:
+ if dependant.request_param_name and isinstance(request, Request):
values[dependant.request_param_name] = request
+ elif dependant.websocket_param_name and isinstance(request, WebSocket):
+ values[dependant.websocket_param_name] = request
if dependant.background_tasks_param_name:
if background_tasks is None:
background_tasks = BackgroundTasks()
import asyncio
import inspect
import logging
+import re
from typing import Any, Callable, Dict, List, Optional, Type, Union
from fastapi import params
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
-from starlette.routing import compile_path, get_name, request_response
-from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
+from starlette.routing import (
+ compile_path,
+ get_name,
+ request_response,
+ websocket_session,
+)
+from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, WS_1008_POLICY_VIOLATION
+from starlette.websockets import WebSocket
def serialize_response(*, field: Field = None, response: Response) -> Any:
return app
+def get_websocket_app(dependant: Dependant) -> Callable:
+ async def app(websocket: WebSocket) -> None:
+ values, errors, _ = await solve_dependencies(
+ request=websocket, dependant=dependant
+ )
+ if errors:
+ await websocket.close(code=WS_1008_POLICY_VIOLATION)
+ errors_out = ValidationError(errors)
+ raise HTTPException(
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=errors_out.errors()
+ )
+ assert dependant.call is not None, "dependant.call must me a function"
+ await dependant.call(**values)
+
+ return app
+
+
+class APIWebSocketRoute(routing.WebSocketRoute):
+ def __init__(self, path: str, endpoint: Callable, *, name: str = None) -> None:
+ self.path = path
+ self.endpoint = endpoint
+ self.name = get_name(endpoint) if name is None else name
+ self.dependant = get_dependant(path=path, call=self.endpoint)
+ self.app = websocket_session(get_websocket_app(dependant=self.dependant))
+ regex = "^" + path + "$"
+ regex = re.sub("{([a-zA-Z_][a-zA-Z0-9_]*)}", r"(?P<\1>[^/]+)", regex)
+ self.path_regex, self.path_format, self.param_convertors = compile_path(path)
+
+
class APIRoute(routing.Route):
def __init__(
self,
return decorator
+ def add_api_websocket_route(
+ self, path: str, endpoint: Callable, name: str = None
+ ) -> None:
+ route = APIWebSocketRoute(path, endpoint=endpoint, name=name)
+ self.routes.append(route)
+
+ def websocket(self, path: str, name: str = None) -> Callable:
+ def decorator(func: Callable) -> Callable:
+ self.add_api_websocket_route(path, func, name=name)
+ return func
+
+ return decorator
+
def include_router(
self,
router: "APIRouter",
include_in_schema=route.include_in_schema,
name=route.name,
)
+ elif isinstance(route, APIWebSocketRoute):
+ self.add_api_websocket_route(
+ prefix + route.path, route.endpoint, name=route.name
+ )
elif isinstance(route, routing.WebSocketRoute):
self.add_websocket_route(
prefix + route.path, route.endpoint, name=route.name
--- /dev/null
+import pytest
+from starlette.testclient import TestClient
+from starlette.websockets import WebSocketDisconnect
+from websockets.tutorial001 import app
+
+client = TestClient(app)
+
+
+def test_main():
+ response = client.get("/")
+ assert response.status_code == 200
+ assert b"<!DOCTYPE html>" in response.content
+
+
+def test_websocket():
+ with pytest.raises(WebSocketDisconnect):
+ with client.websocket_connect("/ws") as websocket:
+ message = "Message one"
+ websocket.send_text(message)
+ data = websocket.receive_text()
+ assert data == f"Message text was: {message}"
+ message = "Message two"
+ websocket.send_text(message)
+ data = websocket.receive_text()
+ assert data == f"Message text was: {message}"
--- /dev/null
+import pytest
+from starlette.testclient import TestClient
+from starlette.websockets import WebSocketDisconnect
+from websockets.tutorial002 import app
+
+client = TestClient(app)
+
+
+def test_main():
+ response = client.get("/")
+ assert response.status_code == 200
+ assert b"<!DOCTYPE html>" in response.content
+
+
+def test_websocket_with_cookie():
+ with pytest.raises(WebSocketDisconnect):
+ with client.websocket_connect(
+ "/items/1/ws", cookies={"session": "fakesession"}
+ ) as websocket:
+ message = "Message one"
+ websocket.send_text(message)
+ data = websocket.receive_text()
+ assert data == "Session Cookie or X-Client Header value is: fakesession"
+ data = websocket.receive_text()
+ assert data == f"Message text was: {message}, for item ID: 1"
+ message = "Message two"
+ websocket.send_text(message)
+ data = websocket.receive_text()
+ assert data == "Session Cookie or X-Client Header value is: fakesession"
+ data = websocket.receive_text()
+ assert data == f"Message text was: {message}, for item ID: 1"
+
+
+def test_websocket_with_header():
+ with pytest.raises(WebSocketDisconnect):
+ with client.websocket_connect(
+ "/items/2/ws", headers={"X-Client": "xmen"}
+ ) as websocket:
+ message = "Message one"
+ websocket.send_text(message)
+ data = websocket.receive_text()
+ assert data == "Session Cookie or X-Client Header value is: xmen"
+ data = websocket.receive_text()
+ assert data == f"Message text was: {message}, for item ID: 2"
+ message = "Message two"
+ websocket.send_text(message)
+ data = websocket.receive_text()
+ assert data == "Session Cookie or X-Client Header value is: xmen"
+ data = websocket.receive_text()
+ assert data == f"Message text was: {message}, for item ID: 2"
+
+
+def test_websocket_with_header_and_query():
+ with pytest.raises(WebSocketDisconnect):
+ with client.websocket_connect(
+ "/items/2/ws?q=baz", headers={"X-Client": "xmen"}
+ ) as websocket:
+ message = "Message one"
+ websocket.send_text(message)
+ data = websocket.receive_text()
+ assert data == "Session Cookie or X-Client Header value is: xmen"
+ data = websocket.receive_text()
+ assert data == "Query parameter q is: baz"
+ data = websocket.receive_text()
+ assert data == f"Message text was: {message}, for item ID: 2"
+ message = "Message two"
+ websocket.send_text(message)
+ data = websocket.receive_text()
+ assert data == "Session Cookie or X-Client Header value is: xmen"
+ data = websocket.receive_text()
+ assert data == "Query parameter q is: baz"
+ data = websocket.receive_text()
+ assert data == f"Message text was: {message}, for item ID: 2"
+
+
+def test_websocket_no_credentials():
+ with pytest.raises(WebSocketDisconnect):
+ client.websocket_connect("/items/2/ws")
+
+
+def test_websocket_invalid_data():
+ with pytest.raises(WebSocketDisconnect):
+ client.websocket_connect("/items/foo/ws", headers={"X-Client": "xmen"})
await websocket.close()
+@router.websocket("/router2")
+async def routerindex(websocket: WebSocket):
+ await websocket.accept()
+ await websocket.send_text("Hello, router!")
+ await websocket.close()
+
+
app.include_router(router)
app.include_router(prefix_router, prefix="/prefix")
with client.websocket_connect("/prefix/") as websocket:
data = websocket.receive_text()
assert data == "Hello, router with prefix!"
+
+
+def test_router2():
+ client = TestClient(app)
+ with client.websocket_connect("/router2") as websocket:
+ data = websocket.receive_text()
+ assert data == "Hello, router!"