if self.client_state == WebSocketState.CONNECTING:
message = await self._receive()
message_type = message["type"]
- assert message_type == "websocket.connect"
+ if message_type != "websocket.connect":
+ raise RuntimeError(
+ 'Expected ASGI message "websocket.connect", '
+ f"but got {message_type!r}"
+ )
self.client_state = WebSocketState.CONNECTED
return message
elif self.client_state == WebSocketState.CONNECTED:
message = await self._receive()
message_type = message["type"]
- assert message_type in {"websocket.receive", "websocket.disconnect"}
+ if message_type not in {"websocket.receive", "websocket.disconnect"}:
+ raise RuntimeError(
+ 'Expected ASGI message "websocket.receive" or '
+ f'"websocket.disconnect", but got {message_type!r}'
+ )
if message_type == "websocket.disconnect":
self.client_state = WebSocketState.DISCONNECTED
return message
"""
if self.application_state == WebSocketState.CONNECTING:
message_type = message["type"]
- assert message_type in {"websocket.accept", "websocket.close"}
+ if message_type not in {"websocket.accept", "websocket.close"}:
+ raise RuntimeError(
+ 'Expected ASGI message "websocket.connect", '
+ f"but got {message_type!r}"
+ )
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
else:
await self._send(message)
elif self.application_state == WebSocketState.CONNECTED:
message_type = message["type"]
- assert message_type in {"websocket.send", "websocket.close"}
+ if message_type not in {"websocket.send", "websocket.close"}:
+ raise RuntimeError(
+ 'Expected ASGI message "websocket.send" or "websocket.close", '
+ f"but got {message_type!r}"
+ )
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
await self._send(message)
raise WebSocketDisconnect(message["code"])
async def receive_text(self) -> str:
- assert self.application_state == WebSocketState.CONNECTED
+ if self.application_state != WebSocketState.CONNECTED:
+ raise RuntimeError(
+ 'WebSocket is not connected. Need to call "accept" first.'
+ )
message = await self.receive()
self._raise_on_disconnect(message)
return message["text"]
async def receive_bytes(self) -> bytes:
- assert self.application_state == WebSocketState.CONNECTED
+ if self.application_state != WebSocketState.CONNECTED:
+ raise RuntimeError(
+ 'WebSocket is not connected. Need to call "accept" first.'
+ )
message = await self.receive()
self._raise_on_disconnect(message)
return message["bytes"]
async def receive_json(self, mode: str = "text") -> typing.Any:
- assert mode in ["text", "binary"]
- assert self.application_state == WebSocketState.CONNECTED
+ if mode not in {"text", "binary"}:
+ raise RuntimeError('The "mode" argument should be "text" or "binary".')
+ if self.application_state != WebSocketState.CONNECTED:
+ raise RuntimeError(
+ 'WebSocket is not connected. Need to call "accept" first.'
+ )
message = await self.receive()
self._raise_on_disconnect(message)
await self.send({"type": "websocket.send", "bytes": data})
async def send_json(self, data: typing.Any, mode: str = "text") -> None:
- assert mode in ["text", "binary"]
+ if mode not in {"text", "binary"}:
+ raise RuntimeError('The "mode" argument should be "text" or "binary".')
text = json.dumps(data)
if mode == "text":
await self.send({"type": "websocket.send", "text": text})
import pytest
from starlette import status
-from starlette.websockets import WebSocket, WebSocketDisconnect
+from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState
def test_websocket_url(test_client_factory):
websocket.receive_text()
assert exc.value.code == status.WS_1001_GOING_AWAY
assert exc.value.reason == "Going Away"
+
+
+def test_send_json_invalid_mode(test_client_factory):
+ def app(scope):
+ async def asgi(receive, send):
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.accept()
+ await websocket.send_json({}, mode="invalid")
+
+ return asgi
+
+ client = test_client_factory(app)
+ with pytest.raises(RuntimeError):
+ with client.websocket_connect("/"):
+ pass # pragma: nocover
+
+
+def test_receive_json_invalid_mode(test_client_factory):
+ def app(scope):
+ async def asgi(receive, send):
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.accept()
+ await websocket.receive_json(mode="invalid")
+
+ return asgi
+
+ client = test_client_factory(app)
+ with pytest.raises(RuntimeError):
+ with client.websocket_connect("/"):
+ pass # pragma: nocover
+
+
+def test_receive_text_before_accept(test_client_factory):
+ def app(scope):
+ async def asgi(receive, send):
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.receive_text()
+
+ return asgi
+
+ client = test_client_factory(app)
+ with pytest.raises(RuntimeError):
+ with client.websocket_connect("/"):
+ pass # pragma: nocover
+
+
+def test_receive_bytes_before_accept(test_client_factory):
+ def app(scope):
+ async def asgi(receive, send):
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.receive_bytes()
+
+ return asgi
+
+ client = test_client_factory(app)
+ with pytest.raises(RuntimeError):
+ with client.websocket_connect("/"):
+ pass # pragma: nocover
+
+
+def test_receive_json_before_accept(test_client_factory):
+ def app(scope):
+ async def asgi(receive, send):
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.receive_json()
+
+ return asgi
+
+ client = test_client_factory(app)
+ with pytest.raises(RuntimeError):
+ with client.websocket_connect("/"):
+ pass # pragma: nocover
+
+
+def test_send_before_accept(test_client_factory):
+ def app(scope):
+ async def asgi(receive, send):
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.send({"type": "websocket.send"})
+
+ return asgi
+
+ client = test_client_factory(app)
+ with pytest.raises(RuntimeError):
+ with client.websocket_connect("/"):
+ pass # pragma: nocover
+
+
+def test_send_wrong_message_type(test_client_factory):
+ def app(scope):
+ async def asgi(receive, send):
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.send({"type": "websocket.accept"})
+ await websocket.send({"type": "websocket.accept"})
+
+ return asgi
+
+ client = test_client_factory(app)
+ with pytest.raises(RuntimeError):
+ with client.websocket_connect("/"):
+ pass # pragma: nocover
+
+
+def test_receive_before_accept(test_client_factory):
+ def app(scope):
+ async def asgi(receive, send):
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.accept()
+ websocket.client_state = WebSocketState.CONNECTING
+ await websocket.receive()
+
+ return asgi
+
+ client = test_client_factory(app)
+ with pytest.raises(RuntimeError):
+ with client.websocket_connect("/") as websocket:
+ websocket.send({"type": "websocket.send"})
+
+
+def test_receive_wrong_message_type(test_client_factory):
+ def app(scope):
+ async def asgi(receive, send):
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.accept()
+ await websocket.receive()
+
+ return asgi
+
+ client = test_client_factory(app)
+ with pytest.raises(RuntimeError):
+ with client.websocket_connect("/") as websocket:
+ websocket.send({"type": "websocket.connect"})