]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Raise `WebSocketDisconnect` when `WebSocket.send()` excepts `IOError` (#2425)
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Sat, 20 Jan 2024 15:08:22 +0000 (16:08 +0100)
committerGitHub <noreply@github.com>
Sat, 20 Jan 2024 15:08:22 +0000 (08:08 -0700)
* Raise `WebSocketDisconnect` when `WebSocket.send()` excepts `IOError`

* Restrict the IOError

starlette/websockets.py
tests/test_websockets.py

index a34bc13391c9bf0f8546a6d186481f2958779c4b..084d9309443f0fbf7958fd95bd2dc78e84bdd75d 100644 (file)
@@ -82,7 +82,11 @@ class WebSocket(HTTPConnection):
                 )
             if message_type == "websocket.close":
                 self.application_state = WebSocketState.DISCONNECTED
-            await self._send(message)
+            try:
+                await self._send(message)
+            except IOError:
+                self.application_state = WebSocketState.DISCONNECTED
+                raise WebSocketDisconnect(code=1006)
         else:
             raise RuntimeError('Cannot call "send" once a close message has been sent.')
 
index 283dcfc78aa964b7eb7f6ea95b2af5cdcd415732..247477404c682fa73023e8594c4c8c5da87c90a6 100644 (file)
@@ -255,6 +255,28 @@ def test_client_close(test_client_factory: Callable[..., TestClient]):
     assert close_reason == "Going Away"
 
 
+@pytest.mark.anyio
+async def test_client_disconnect_on_send():
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.accept()
+        await websocket.send_text("Hello, world!")
+
+    async def receive() -> Message:
+        return {"type": "websocket.connect"}
+
+    async def send(message: Message) -> None:
+        if message["type"] == "websocket.accept":
+            return
+        # Simulate the exception the server would send to the application when the
+        # client disconnects.
+        raise IOError
+
+    with pytest.raises(WebSocketDisconnect) as ctx:
+        await app({"type": "websocket", "path": "/"}, receive, send)
+    assert ctx.value.code == status.WS_1006_ABNORMAL_CLOSURE
+
+
 def test_application_close(test_client_factory: Callable[..., TestClient]):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)