From: Amin Alaee Date: Sat, 22 Jan 2022 16:11:32 +0000 (+0100) Subject: Add reason to WebSocket closure (#1417) X-Git-Tag: 0.18.0~1 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=34d9f0f10f2e4706551acf340eb1eaf4d8c87b49;p=thirdparty%2Fstarlette.git Add reason to WebSocket closure (#1417) Co-authored-by: Marcelo Trylesinski --- diff --git a/docs/websockets.md b/docs/websockets.md index 43406ace..1128bce4 100644 --- a/docs/websockets.md +++ b/docs/websockets.md @@ -75,7 +75,7 @@ Use `websocket.receive_json(data, mode="binary")` to receive JSON over binary da ### Closing the connection -* `await websocket.close(code=1000)` +* `await websocket.close(code=1000, reason=None)` ### Sending and receiving messages diff --git a/starlette/testclient.py b/starlette/testclient.py index 0b4bc78d..c951767b 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -352,7 +352,9 @@ class WebSocketTestSession: def _raise_on_close(self, message: Message) -> None: if message["type"] == "websocket.close": - raise WebSocketDisconnect(message.get("code", 1000)) + raise WebSocketDisconnect( + message.get("code", 1000), message.get("reason", "") + ) def send(self, message: Message) -> None: self._receive_queue.put(message) diff --git a/starlette/websockets.py b/starlette/websockets.py index bf4cca83..da740604 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -13,8 +13,9 @@ class WebSocketState(enum.Enum): class WebSocketDisconnect(Exception): - def __init__(self, code: int = 1000) -> None: + def __init__(self, code: int = 1000, reason: str = None) -> None: self.code = code + self.reason = reason or "" class WebSocket(HTTPConnection): @@ -146,13 +147,18 @@ class WebSocket(HTTPConnection): else: await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")}) - async def close(self, code: int = 1000) -> None: - await self.send({"type": "websocket.close", "code": code}) + async def close(self, code: int = 1000, reason: str = None) -> None: + await self.send( + {"type": "websocket.close", "code": code, "reason": reason or ""} + ) class WebSocketClose: - def __init__(self, code: int = 1000) -> None: + def __init__(self, code: int = 1000, reason: str = None) -> None: self.code = code + self.reason = reason or "" async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - await send({"type": "websocket.close", "code": self.code}) + await send( + {"type": "websocket.close", "code": self.code, "reason": self.reason} + ) diff --git a/tests/test_websockets.py b/tests/test_websockets.py index f3242d11..b11685cb 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -405,3 +405,20 @@ def test_websocket_scope_interface(): assert websocket == websocket assert websocket in {websocket} assert {websocket} == {websocket} + + +def test_websocket_close_reason(test_client_factory) -> None: + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + await websocket.close(code=status.WS_1001_GOING_AWAY, reason="Going Away") + + return asgi + + client = test_client_factory(app) + with client.websocket_connect("/") as websocket: + with pytest.raises(WebSocketDisconnect) as exc: + websocket.receive_text() + assert exc.value.code == status.WS_1001_GOING_AWAY + assert exc.value.reason == "Going Away"