]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Replace WebSocket assertions with RuntimeError (#1472)
authorAmin Alaee <mohammadamin.alaee@gmail.com>
Fri, 4 Feb 2022 10:51:04 +0000 (11:51 +0100)
committerGitHub <noreply@github.com>
Fri, 4 Feb 2022 10:51:04 +0000 (11:51 +0100)
Co-authored-by: Tom Christie <tom@tomchristie.com>
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
starlette/websockets.py
tests/test_websockets.py

index da7406047bdecf1c2ddadd1d7bd3f2f2320c51f0..03ed19972aada03a2accd8eff70bc631184aa409 100644 (file)
@@ -34,13 +34,21 @@ class WebSocket(HTTPConnection):
         if self.client_state == WebSocketState.CONNECTING:
             message = await self._receive()
             message_type = message["type"]
-            assert message_type == "websocket.connect"
+            if message_type != "websocket.connect":
+                raise RuntimeError(
+                    'Expected ASGI message "websocket.connect", '
+                    f"but got {message_type!r}"
+                )
             self.client_state = WebSocketState.CONNECTED
             return message
         elif self.client_state == WebSocketState.CONNECTED:
             message = await self._receive()
             message_type = message["type"]
-            assert message_type in {"websocket.receive", "websocket.disconnect"}
+            if message_type not in {"websocket.receive", "websocket.disconnect"}:
+                raise RuntimeError(
+                    'Expected ASGI message "websocket.receive" or '
+                    f'"websocket.disconnect", but got {message_type!r}'
+                )
             if message_type == "websocket.disconnect":
                 self.client_state = WebSocketState.DISCONNECTED
             return message
@@ -55,7 +63,11 @@ class WebSocket(HTTPConnection):
         """
         if self.application_state == WebSocketState.CONNECTING:
             message_type = message["type"]
-            assert message_type in {"websocket.accept", "websocket.close"}
+            if message_type not in {"websocket.accept", "websocket.close"}:
+                raise RuntimeError(
+                    'Expected ASGI message "websocket.connect", '
+                    f"but got {message_type!r}"
+                )
             if message_type == "websocket.close":
                 self.application_state = WebSocketState.DISCONNECTED
             else:
@@ -63,7 +75,11 @@ class WebSocket(HTTPConnection):
             await self._send(message)
         elif self.application_state == WebSocketState.CONNECTED:
             message_type = message["type"]
-            assert message_type in {"websocket.send", "websocket.close"}
+            if message_type not in {"websocket.send", "websocket.close"}:
+                raise RuntimeError(
+                    'Expected ASGI message "websocket.send" or "websocket.close", '
+                    f"but got {message_type!r}"
+                )
             if message_type == "websocket.close":
                 self.application_state = WebSocketState.DISCONNECTED
             await self._send(message)
@@ -89,20 +105,30 @@ class WebSocket(HTTPConnection):
             raise WebSocketDisconnect(message["code"])
 
     async def receive_text(self) -> str:
-        assert self.application_state == WebSocketState.CONNECTED
+        if self.application_state != WebSocketState.CONNECTED:
+            raise RuntimeError(
+                'WebSocket is not connected. Need to call "accept" first.'
+            )
         message = await self.receive()
         self._raise_on_disconnect(message)
         return message["text"]
 
     async def receive_bytes(self) -> bytes:
-        assert self.application_state == WebSocketState.CONNECTED
+        if self.application_state != WebSocketState.CONNECTED:
+            raise RuntimeError(
+                'WebSocket is not connected. Need to call "accept" first.'
+            )
         message = await self.receive()
         self._raise_on_disconnect(message)
         return message["bytes"]
 
     async def receive_json(self, mode: str = "text") -> typing.Any:
-        assert mode in ["text", "binary"]
-        assert self.application_state == WebSocketState.CONNECTED
+        if mode not in {"text", "binary"}:
+            raise RuntimeError('The "mode" argument should be "text" or "binary".')
+        if self.application_state != WebSocketState.CONNECTED:
+            raise RuntimeError(
+                'WebSocket is not connected. Need to call "accept" first.'
+            )
         message = await self.receive()
         self._raise_on_disconnect(message)
 
@@ -140,7 +166,8 @@ class WebSocket(HTTPConnection):
         await self.send({"type": "websocket.send", "bytes": data})
 
     async def send_json(self, data: typing.Any, mode: str = "text") -> None:
-        assert mode in ["text", "binary"]
+        if mode not in {"text", "binary"}:
+            raise RuntimeError('The "mode" argument should be "text" or "binary".')
         text = json.dumps(data)
         if mode == "text":
             await self.send({"type": "websocket.send", "text": text})
index b11685cbc3ba526c27435448dcc4d143c80e1be6..e3a52762a0dc7747830f2fcf7808f66897dd53f5 100644 (file)
@@ -2,7 +2,7 @@ import anyio
 import pytest
 
 from starlette import status
-from starlette.websockets import WebSocket, WebSocketDisconnect
+from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState
 
 
 def test_websocket_url(test_client_factory):
@@ -422,3 +422,135 @@ def test_websocket_close_reason(test_client_factory) -> None:
             websocket.receive_text()
         assert exc.value.code == status.WS_1001_GOING_AWAY
         assert exc.value.reason == "Going Away"
+
+
+def test_send_json_invalid_mode(test_client_factory):
+    def app(scope):
+        async def asgi(receive, send):
+            websocket = WebSocket(scope, receive=receive, send=send)
+            await websocket.accept()
+            await websocket.send_json({}, mode="invalid")
+
+        return asgi
+
+    client = test_client_factory(app)
+    with pytest.raises(RuntimeError):
+        with client.websocket_connect("/"):
+            pass  # pragma: nocover
+
+
+def test_receive_json_invalid_mode(test_client_factory):
+    def app(scope):
+        async def asgi(receive, send):
+            websocket = WebSocket(scope, receive=receive, send=send)
+            await websocket.accept()
+            await websocket.receive_json(mode="invalid")
+
+        return asgi
+
+    client = test_client_factory(app)
+    with pytest.raises(RuntimeError):
+        with client.websocket_connect("/"):
+            pass  # pragma: nocover
+
+
+def test_receive_text_before_accept(test_client_factory):
+    def app(scope):
+        async def asgi(receive, send):
+            websocket = WebSocket(scope, receive=receive, send=send)
+            await websocket.receive_text()
+
+        return asgi
+
+    client = test_client_factory(app)
+    with pytest.raises(RuntimeError):
+        with client.websocket_connect("/"):
+            pass  # pragma: nocover
+
+
+def test_receive_bytes_before_accept(test_client_factory):
+    def app(scope):
+        async def asgi(receive, send):
+            websocket = WebSocket(scope, receive=receive, send=send)
+            await websocket.receive_bytes()
+
+        return asgi
+
+    client = test_client_factory(app)
+    with pytest.raises(RuntimeError):
+        with client.websocket_connect("/"):
+            pass  # pragma: nocover
+
+
+def test_receive_json_before_accept(test_client_factory):
+    def app(scope):
+        async def asgi(receive, send):
+            websocket = WebSocket(scope, receive=receive, send=send)
+            await websocket.receive_json()
+
+        return asgi
+
+    client = test_client_factory(app)
+    with pytest.raises(RuntimeError):
+        with client.websocket_connect("/"):
+            pass  # pragma: nocover
+
+
+def test_send_before_accept(test_client_factory):
+    def app(scope):
+        async def asgi(receive, send):
+            websocket = WebSocket(scope, receive=receive, send=send)
+            await websocket.send({"type": "websocket.send"})
+
+        return asgi
+
+    client = test_client_factory(app)
+    with pytest.raises(RuntimeError):
+        with client.websocket_connect("/"):
+            pass  # pragma: nocover
+
+
+def test_send_wrong_message_type(test_client_factory):
+    def app(scope):
+        async def asgi(receive, send):
+            websocket = WebSocket(scope, receive=receive, send=send)
+            await websocket.send({"type": "websocket.accept"})
+            await websocket.send({"type": "websocket.accept"})
+
+        return asgi
+
+    client = test_client_factory(app)
+    with pytest.raises(RuntimeError):
+        with client.websocket_connect("/"):
+            pass  # pragma: nocover
+
+
+def test_receive_before_accept(test_client_factory):
+    def app(scope):
+        async def asgi(receive, send):
+            websocket = WebSocket(scope, receive=receive, send=send)
+            await websocket.accept()
+            websocket.client_state = WebSocketState.CONNECTING
+            await websocket.receive()
+
+        return asgi
+
+    client = test_client_factory(app)
+    with pytest.raises(RuntimeError):
+        with client.websocket_connect("/") as websocket:
+            websocket.send({"type": "websocket.send"})
+
+
+def test_receive_wrong_message_type(test_client_factory):
+    def app(scope):
+        async def asgi(receive, send):
+            websocket = WebSocket(scope, receive=receive, send=send)
+            await websocket.accept()
+            await websocket.receive()
+
+        return asgi
+
+    client = test_client_factory(app)
+    with pytest.raises(RuntimeError):
+        with client.websocket_connect("/") as websocket:
+            websocket.send({"type": "websocket.connect"})