]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add reason to WebSocket closure (#1417)
authorAmin Alaee <mohammadamin.alaee@gmail.com>
Sat, 22 Jan 2022 16:11:32 +0000 (17:11 +0100)
committerGitHub <noreply@github.com>
Sat, 22 Jan 2022 16:11:32 +0000 (17:11 +0100)
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
docs/websockets.md
starlette/testclient.py
starlette/websockets.py
tests/test_websockets.py

index 43406aceda1f559c7f8e55e2f5b6c78c5223fb24..1128bce43f73231f7e6533daf057c7e503ee2036 100644 (file)
@@ -75,7 +75,7 @@ Use `websocket.receive_json(data, mode="binary")` to receive JSON over binary da
 
 ### Closing the connection
 
-* `await websocket.close(code=1000)`
+* `await websocket.close(code=1000, reason=None)`
 
 ### Sending and receiving messages
 
index 0b4bc78d1273f87fcc643c2d4064155b263863fe..c951767b46ea3c533e7d73b4e7e627d7f605ec0b 100644 (file)
@@ -352,7 +352,9 @@ class WebSocketTestSession:
 
     def _raise_on_close(self, message: Message) -> None:
         if message["type"] == "websocket.close":
-            raise WebSocketDisconnect(message.get("code", 1000))
+            raise WebSocketDisconnect(
+                message.get("code", 1000), message.get("reason", "")
+            )
 
     def send(self, message: Message) -> None:
         self._receive_queue.put(message)
index bf4cca83fa3c4998b1580110ad05ac0f893805cd..da7406047bdecf1c2ddadd1d7bd3f2f2320c51f0 100644 (file)
@@ -13,8 +13,9 @@ class WebSocketState(enum.Enum):
 
 
 class WebSocketDisconnect(Exception):
-    def __init__(self, code: int = 1000) -> None:
+    def __init__(self, code: int = 1000, reason: str = None) -> None:
         self.code = code
+        self.reason = reason or ""
 
 
 class WebSocket(HTTPConnection):
@@ -146,13 +147,18 @@ class WebSocket(HTTPConnection):
         else:
             await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})
 
-    async def close(self, code: int = 1000) -> None:
-        await self.send({"type": "websocket.close", "code": code})
+    async def close(self, code: int = 1000, reason: str = None) -> None:
+        await self.send(
+            {"type": "websocket.close", "code": code, "reason": reason or ""}
+        )
 
 
 class WebSocketClose:
-    def __init__(self, code: int = 1000) -> None:
+    def __init__(self, code: int = 1000, reason: str = None) -> None:
         self.code = code
+        self.reason = reason or ""
 
     async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
-        await send({"type": "websocket.close", "code": self.code})
+        await send(
+            {"type": "websocket.close", "code": self.code, "reason": self.reason}
+        )
index f3242d115eb8c45bc33ea24dd04b6c8ef974010e..b11685cbc3ba526c27435448dcc4d143c80e1be6 100644 (file)
@@ -405,3 +405,20 @@ def test_websocket_scope_interface():
     assert websocket == websocket
     assert websocket in {websocket}
     assert {websocket} == {websocket}
+
+
+def test_websocket_close_reason(test_client_factory) -> None:
+    def app(scope):
+        async def asgi(receive, send):
+            websocket = WebSocket(scope, receive=receive, send=send)
+            await websocket.accept()
+            await websocket.close(code=status.WS_1001_GOING_AWAY, reason="Going Away")
+
+        return asgi
+
+    client = test_client_factory(app)
+    with client.websocket_connect("/") as websocket:
+        with pytest.raises(WebSocketDisconnect) as exc:
+            websocket.receive_text()
+        assert exc.value.code == status.WS_1001_GOING_AWAY
+        assert exc.value.reason == "Going Away"