From: Amin Alaee Date: Fri, 4 Feb 2022 10:51:04 +0000 (+0100) Subject: Replace WebSocket assertions with RuntimeError (#1472) X-Git-Tag: 0.19.0~41 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=424351cb231c67798a65c091b0b7d42790f5e444;p=thirdparty%2Fstarlette.git Replace WebSocket assertions with RuntimeError (#1472) Co-authored-by: Tom Christie Co-authored-by: Marcelo Trylesinski --- diff --git a/starlette/websockets.py b/starlette/websockets.py index da740604..03ed1997 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -34,13 +34,21 @@ class WebSocket(HTTPConnection): if self.client_state == WebSocketState.CONNECTING: message = await self._receive() message_type = message["type"] - assert message_type == "websocket.connect" + if message_type != "websocket.connect": + raise RuntimeError( + 'Expected ASGI message "websocket.connect", ' + f"but got {message_type!r}" + ) self.client_state = WebSocketState.CONNECTED return message elif self.client_state == WebSocketState.CONNECTED: message = await self._receive() message_type = message["type"] - assert message_type in {"websocket.receive", "websocket.disconnect"} + if message_type not in {"websocket.receive", "websocket.disconnect"}: + raise RuntimeError( + 'Expected ASGI message "websocket.receive" or ' + f'"websocket.disconnect", but got {message_type!r}' + ) if message_type == "websocket.disconnect": self.client_state = WebSocketState.DISCONNECTED return message @@ -55,7 +63,11 @@ class WebSocket(HTTPConnection): """ if self.application_state == WebSocketState.CONNECTING: message_type = message["type"] - assert message_type in {"websocket.accept", "websocket.close"} + if message_type not in {"websocket.accept", "websocket.close"}: + raise RuntimeError( + 'Expected ASGI message "websocket.connect", ' + f"but got {message_type!r}" + ) if message_type == "websocket.close": self.application_state = WebSocketState.DISCONNECTED else: @@ -63,7 +75,11 @@ class WebSocket(HTTPConnection): await self._send(message) elif self.application_state == WebSocketState.CONNECTED: message_type = message["type"] - assert message_type in {"websocket.send", "websocket.close"} + if message_type not in {"websocket.send", "websocket.close"}: + raise RuntimeError( + 'Expected ASGI message "websocket.send" or "websocket.close", ' + f"but got {message_type!r}" + ) if message_type == "websocket.close": self.application_state = WebSocketState.DISCONNECTED await self._send(message) @@ -89,20 +105,30 @@ class WebSocket(HTTPConnection): raise WebSocketDisconnect(message["code"]) async def receive_text(self) -> str: - assert self.application_state == WebSocketState.CONNECTED + if self.application_state != WebSocketState.CONNECTED: + raise RuntimeError( + 'WebSocket is not connected. Need to call "accept" first.' + ) message = await self.receive() self._raise_on_disconnect(message) return message["text"] async def receive_bytes(self) -> bytes: - assert self.application_state == WebSocketState.CONNECTED + if self.application_state != WebSocketState.CONNECTED: + raise RuntimeError( + 'WebSocket is not connected. Need to call "accept" first.' + ) message = await self.receive() self._raise_on_disconnect(message) return message["bytes"] async def receive_json(self, mode: str = "text") -> typing.Any: - assert mode in ["text", "binary"] - assert self.application_state == WebSocketState.CONNECTED + if mode not in {"text", "binary"}: + raise RuntimeError('The "mode" argument should be "text" or "binary".') + if self.application_state != WebSocketState.CONNECTED: + raise RuntimeError( + 'WebSocket is not connected. Need to call "accept" first.' + ) message = await self.receive() self._raise_on_disconnect(message) @@ -140,7 +166,8 @@ class WebSocket(HTTPConnection): await self.send({"type": "websocket.send", "bytes": data}) async def send_json(self, data: typing.Any, mode: str = "text") -> None: - assert mode in ["text", "binary"] + if mode not in {"text", "binary"}: + raise RuntimeError('The "mode" argument should be "text" or "binary".') text = json.dumps(data) if mode == "text": await self.send({"type": "websocket.send", "text": text}) diff --git a/tests/test_websockets.py b/tests/test_websockets.py index b11685cb..e3a52762 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -2,7 +2,7 @@ import anyio import pytest from starlette import status -from starlette.websockets import WebSocket, WebSocketDisconnect +from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState def test_websocket_url(test_client_factory): @@ -422,3 +422,135 @@ def test_websocket_close_reason(test_client_factory) -> None: websocket.receive_text() assert exc.value.code == status.WS_1001_GOING_AWAY assert exc.value.reason == "Going Away" + + +def test_send_json_invalid_mode(test_client_factory): + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + await websocket.send_json({}, mode="invalid") + + return asgi + + client = test_client_factory(app) + with pytest.raises(RuntimeError): + with client.websocket_connect("/"): + pass # pragma: nocover + + +def test_receive_json_invalid_mode(test_client_factory): + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + await websocket.receive_json(mode="invalid") + + return asgi + + client = test_client_factory(app) + with pytest.raises(RuntimeError): + with client.websocket_connect("/"): + pass # pragma: nocover + + +def test_receive_text_before_accept(test_client_factory): + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.receive_text() + + return asgi + + client = test_client_factory(app) + with pytest.raises(RuntimeError): + with client.websocket_connect("/"): + pass # pragma: nocover + + +def test_receive_bytes_before_accept(test_client_factory): + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.receive_bytes() + + return asgi + + client = test_client_factory(app) + with pytest.raises(RuntimeError): + with client.websocket_connect("/"): + pass # pragma: nocover + + +def test_receive_json_before_accept(test_client_factory): + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.receive_json() + + return asgi + + client = test_client_factory(app) + with pytest.raises(RuntimeError): + with client.websocket_connect("/"): + pass # pragma: nocover + + +def test_send_before_accept(test_client_factory): + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.send({"type": "websocket.send"}) + + return asgi + + client = test_client_factory(app) + with pytest.raises(RuntimeError): + with client.websocket_connect("/"): + pass # pragma: nocover + + +def test_send_wrong_message_type(test_client_factory): + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.send({"type": "websocket.accept"}) + await websocket.send({"type": "websocket.accept"}) + + return asgi + + client = test_client_factory(app) + with pytest.raises(RuntimeError): + with client.websocket_connect("/"): + pass # pragma: nocover + + +def test_receive_before_accept(test_client_factory): + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + websocket.client_state = WebSocketState.CONNECTING + await websocket.receive() + + return asgi + + client = test_client_factory(app) + with pytest.raises(RuntimeError): + with client.websocket_connect("/") as websocket: + websocket.send({"type": "websocket.send"}) + + +def test_receive_wrong_message_type(test_client_factory): + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + await websocket.receive() + + return asgi + + client = test_client_factory(app) + with pytest.raises(RuntimeError): + with client.websocket_connect("/") as websocket: + websocket.send({"type": "websocket.connect"})