From: Kristján Valur Jónsson Date: Sun, 4 Feb 2024 20:16:10 +0000 (+0000) Subject: Support the WebSocket Denial Response ASGI extension (#2041) X-Git-Tag: 0.37.0~3 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=93e74a4d2f171bc48c66133d04b071c28ea63562;p=thirdparty%2Fstarlette.git Support the WebSocket Denial Response ASGI extension (#2041) * supply asgi_extensions to TestClient * Add WebSocket.send_response() * Add response support for WebSocket testclient * fix test for filesystem line-endings * lintint * support websocket.http.response extension by default * Improve coverate * Apply suggestions from code review Co-authored-by: Marcelo Trylesinski * Undo unrelated change * fix incorrect error message * Update starlette/websockets.py Co-authored-by: Marcelo Trylesinski * formatting * Re-introduce close-code and close-reason to WebSocketReject * Make sure the "websocket.connect" message is received in tests * Deliver a websocket.disconnect message to the app even if it closes/rejects itself. * Add test for filling out missing `websocket.disconnect` code * Add rejection headers. Expand tests. * Fix types, headers in message are `bytes` tuples. * Minimal WebSocket Denial Response implementation * Revert "Minimal WebSocket Denial Response implementation" This reverts commit 7af10ddcfa5423c18953cf5d1317cb5aa30a014c. * Rename to send_denial_response and update documentation * Remove the app_disconnect_msg. This can be added later in a separate PR * Remove status code 1005 from this PR * Assume that the application has tested for the extension before sending websocket.http.response.start * Rename WebSocketReject to WebSocketDenialResponse * Remove code and status from WebSocketDenialResponse. Just send a regular WebSocketDisconnect even when connection is rejected with close() * Raise an exception if attempting to send a http response and server does not support it. * WebSocketDenialClose and WebSocketDenialResponse These are both instances of WebSocketDenial. * Update starlette/testclient.py Co-authored-by: Marcelo Trylesinski * Revert "WebSocketDenialClose and WebSocketDenialResponse" This reverts commit 71b76e3f1c87064fe8458ff9d4ad0b242cbf15e7. * Rename parameters, member variables * Use httpx.Response as the base for WebSocketDenialResponse. * Apply suggestions from code review Co-authored-by: Marcelo Trylesinski * Update sanity check message * Remove un-needed function * Expand error message test regex * Add type hings to test methods * Add doc string to test. * Fix mypy complaining about mismatching parent methods. * nitpick & remove test * Simplify the documentation * Update starlette/testclient.py * Update starlette/testclient.py * Remove an unnecessary test * there is no special "close because of rejection" in the testclient anymore. --------- Co-authored-by: Marcelo Trylesinski --- diff --git a/docs/websockets.md b/docs/websockets.md index 6cf4e08e..92c53cc0 100644 --- a/docs/websockets.md +++ b/docs/websockets.md @@ -109,3 +109,17 @@ correctly updated. * `await websocket.send(message)` * `await websocket.receive()` + +### Send Denial Response + +If you call `websocket.close()` before calling `websocket.accept()` then +the server will automatically send a HTTP 403 error to the client. + +If you want to send a different error response, you can use the +`websocket.send_denial_response()` method. This will send the response +and then close the connection. + +* `await websocket.send_denial_response(response)` + +This requires the ASGI server to support the WebSocket Denial Response +extension. If it is not supported a `RuntimeError` will be raised. diff --git a/starlette/responses.py b/starlette/responses.py index e613e98b..15292f0e 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -148,14 +148,15 @@ class Response: ) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + prefix = "websocket." if scope["type"] == "websocket" else "" await send( { - "type": "http.response.start", + "type": prefix + "http.response.start", "status": self.status_code, "headers": self.raw_headers, } ) - await send({"type": "http.response.body", "body": self.body}) + await send({"type": prefix + "http.response.body", "body": self.body}) if self.background is not None: await self.background() diff --git a/starlette/testclient.py b/starlette/testclient.py index 9e3ece4d..90eb53e3 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -78,6 +78,16 @@ class _Upgrade(Exception): self.session = session +class WebSocketDenialResponse( # type: ignore[misc] + httpx.Response, + WebSocketDisconnect, +): + """ + A special case of `WebSocketDisconnect`, raised in the `TestClient` if the + `WebSocket` is closed before being accepted with a `send_denial_response()`. + """ + + class WebSocketTestSession: def __init__( self, @@ -159,7 +169,22 @@ class WebSocketTestSession: def _raise_on_close(self, message: Message) -> None: if message["type"] == "websocket.close": raise WebSocketDisconnect( - message.get("code", 1000), message.get("reason", "") + code=message.get("code", 1000), reason=message.get("reason", "") + ) + elif message["type"] == "websocket.http.response.start": + status_code: int = message["status"] + headers: list[tuple[bytes, bytes]] = message["headers"] + body: list[bytes] = [] + while True: + message = self.receive() + assert message["type"] == "websocket.http.response.body" + body.append(message["body"]) + if not message.get("more_body", False): + break + raise WebSocketDenialResponse( + status_code=status_code, + headers=headers, + content=b"".join(body), ) def send(self, message: Message) -> None: @@ -277,6 +302,7 @@ class _TestClientTransport(httpx.BaseTransport): "server": [host, port], "subprotocols": subprotocols, "state": self.app_state.copy(), + "extensions": {"websocket.http.response": {}}, } session = WebSocketTestSession(self.app, scope, self.portal_factory) raise _Upgrade(session) diff --git a/starlette/websockets.py b/starlette/websockets.py index 850fbf11..955063fa 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -5,6 +5,7 @@ import json import typing from starlette.requests import HTTPConnection +from starlette.responses import Response from starlette.types import Message, Receive, Scope, Send @@ -12,6 +13,7 @@ class WebSocketState(enum.Enum): CONNECTING = 0 CONNECTED = 1 DISCONNECTED = 2 + RESPONSE = 3 class WebSocketDisconnect(Exception): @@ -65,13 +67,20 @@ class WebSocket(HTTPConnection): """ if self.application_state == WebSocketState.CONNECTING: message_type = message["type"] - if message_type not in {"websocket.accept", "websocket.close"}: + if message_type not in { + "websocket.accept", + "websocket.close", + "websocket.http.response.start", + }: raise RuntimeError( - 'Expected ASGI message "websocket.accept" or ' - f'"websocket.close", but got {message_type!r}' + 'Expected ASGI message "websocket.accept",' + '"websocket.close" or "websocket.http.response.start",' + f"but got {message_type!r}" ) if message_type == "websocket.close": self.application_state = WebSocketState.DISCONNECTED + elif message_type == "websocket.http.response.start": + self.application_state = WebSocketState.RESPONSE else: self.application_state = WebSocketState.CONNECTED await self._send(message) @@ -89,6 +98,16 @@ class WebSocket(HTTPConnection): except IOError: self.application_state = WebSocketState.DISCONNECTED raise WebSocketDisconnect(code=1006) + elif self.application_state == WebSocketState.RESPONSE: + message_type = message["type"] + if message_type != "websocket.http.response.body": + raise RuntimeError( + 'Expected ASGI message "websocket.http.response.body", ' + f"but got {message_type!r}" + ) + if not message.get("more_body", False): + self.application_state = WebSocketState.DISCONNECTED + await self._send(message) else: raise RuntimeError('Cannot call "send" once a close message has been sent.') @@ -185,6 +204,14 @@ class WebSocket(HTTPConnection): {"type": "websocket.close", "code": code, "reason": reason or ""} ) + async def send_denial_response(self, response: Response) -> None: + if "websocket.http.response" in self.scope.get("extensions", {}): + await response(self.scope, self.receive, self.send) + else: + raise RuntimeError( + "The server doesn't support the Websocket Denial Response extension." + ) + class WebSocketClose: def __init__(self, code: int = 1000, reason: str | None = None) -> None: diff --git a/tests/test_websockets.py b/tests/test_websockets.py index 24747740..c8bfc02a 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -6,7 +6,8 @@ import pytest from anyio.abc import ObjectReceiveStream, ObjectSendStream from starlette import status -from starlette.testclient import TestClient +from starlette.responses import Response +from starlette.testclient import TestClient, WebSocketDenialResponse from starlette.types import Message, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState @@ -293,6 +294,8 @@ def test_application_close(test_client_factory: Callable[..., TestClient]): def test_rejected_connection(test_client_factory: Callable[..., TestClient]): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) + msg = await websocket.receive() + assert msg == {"type": "websocket.connect"} await websocket.close(status.WS_1001_GOING_AWAY) client = test_client_factory(app) @@ -302,6 +305,111 @@ def test_rejected_connection(test_client_factory: Callable[..., TestClient]): assert exc.value.code == status.WS_1001_GOING_AWAY +def test_send_denial_response(test_client_factory: Callable[..., TestClient]): + async def app(scope: Scope, receive: Receive, send: Send) -> None: + websocket = WebSocket(scope, receive=receive, send=send) + msg = await websocket.receive() + assert msg == {"type": "websocket.connect"} + response = Response(status_code=404, content="foo") + await websocket.send_denial_response(response) + + client = test_client_factory(app) + with pytest.raises(WebSocketDenialResponse) as exc: + with client.websocket_connect("/"): + pass # pragma: no cover + assert exc.value.status_code == 404 + assert exc.value.content == b"foo" + + +def test_send_response_multi(test_client_factory: Callable[..., TestClient]): + async def app(scope: Scope, receive: Receive, send: Send) -> None: + websocket = WebSocket(scope, receive=receive, send=send) + msg = await websocket.receive() + assert msg == {"type": "websocket.connect"} + await websocket.send( + { + "type": "websocket.http.response.start", + "status": 404, + "headers": [(b"content-type", b"text/plain"), (b"foo", b"bar")], + } + ) + await websocket.send( + { + "type": "websocket.http.response.body", + "body": b"hard", + "more_body": True, + } + ) + await websocket.send( + { + "type": "websocket.http.response.body", + "body": b"body", + } + ) + + client = test_client_factory(app) + with pytest.raises(WebSocketDenialResponse) as exc: + with client.websocket_connect("/"): + pass # pragma: no cover + assert exc.value.status_code == 404 + assert exc.value.content == b"hardbody" + assert exc.value.headers["foo"] == "bar" + + +def test_send_response_unsupported(test_client_factory: Callable[..., TestClient]): + async def app(scope: Scope, receive: Receive, send: Send) -> None: + del scope["extensions"]["websocket.http.response"] + websocket = WebSocket(scope, receive=receive, send=send) + msg = await websocket.receive() + assert msg == {"type": "websocket.connect"} + response = Response(status_code=404, content="foo") + with pytest.raises( + RuntimeError, + match="The server doesn't support the Websocket Denial Response extension.", + ): + await websocket.send_denial_response(response) + await websocket.close() + + client = test_client_factory(app) + with pytest.raises(WebSocketDisconnect) as exc: + with client.websocket_connect("/"): + pass # pragma: no cover + assert exc.value.code == status.WS_1000_NORMAL_CLOSURE + + +def test_send_response_duplicate_start(test_client_factory: Callable[..., TestClient]): + async def app(scope: Scope, receive: Receive, send: Send) -> None: + websocket = WebSocket(scope, receive=receive, send=send) + msg = await websocket.receive() + assert msg == {"type": "websocket.connect"} + response = Response(status_code=404, content="foo") + await websocket.send( + { + "type": "websocket.http.response.start", + "status": response.status_code, + "headers": response.raw_headers, + } + ) + await websocket.send( + { + "type": "websocket.http.response.start", + "status": response.status_code, + "headers": response.raw_headers, + } + ) + + client = test_client_factory(app) + with pytest.raises( + RuntimeError, + match=( + 'Expected ASGI message "websocket.http.response.body", but got ' + "'websocket.http.response.start'" + ), + ): + with client.websocket_connect("/"): + pass # pragma: no cover + + def test_subprotocol(test_client_factory: Callable[..., TestClient]): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send)