]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Support the WebSocket Denial Response ASGI extension (#2041)
authorKristján Valur Jónsson <sweskman@gmail.com>
Sun, 4 Feb 2024 20:16:10 +0000 (20:16 +0000)
committerGitHub <noreply@github.com>
Sun, 4 Feb 2024 20:16:10 +0000 (20:16 +0000)
* supply asgi_extensions to TestClient

* Add WebSocket.send_response()

* Add response support for WebSocket testclient

* fix test for filesystem line-endings

* lintint

* support websocket.http.response extension by default

* Improve coverate

* Apply suggestions from code review

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
* Undo unrelated change

* fix incorrect error message

* Update starlette/websockets.py

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
* formatting

* Re-introduce close-code and close-reason to WebSocketReject

* Make sure the "websocket.connect" message is received in tests

* Deliver a websocket.disconnect message to the app even if it closes/rejects itself.

* Add test for filling out missing `websocket.disconnect` code

* Add rejection headers.  Expand tests.

* Fix types, headers in message are `bytes` tuples.

* Minimal WebSocket Denial Response implementation

* Revert "Minimal WebSocket Denial Response implementation"

This reverts commit 7af10ddcfa5423c18953cf5d1317cb5aa30a014c.

* Rename to send_denial_response and update documentation

* Remove the app_disconnect_msg.  This can be added later in a separate PR

* Remove status code 1005 from this PR

* Assume that the application has tested for the extension before sending websocket.http.response.start

* Rename WebSocketReject to WebSocketDenialResponse

* Remove code and status from WebSocketDenialResponse.
Just send a regular WebSocketDisconnect even when connection is rejected with close()

* Raise an exception if attempting to send a http response and server does not support it.

* WebSocketDenialClose and WebSocketDenialResponse
These are both instances of WebSocketDenial.

* Update starlette/testclient.py

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
* Revert "WebSocketDenialClose and WebSocketDenialResponse"

This reverts commit 71b76e3f1c87064fe8458ff9d4ad0b242cbf15e7.

* Rename parameters, member variables

* Use httpx.Response as the base for WebSocketDenialResponse.

* Apply suggestions from code review

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
* Update sanity check message

* Remove un-needed function

* Expand error message test regex

* Add type hings to test methods

* Add doc string to test.

* Fix mypy complaining about mismatching parent methods.

* nitpick & remove test

* Simplify the documentation

* Update starlette/testclient.py

* Update starlette/testclient.py

* Remove an unnecessary test

* there is no special "close because of rejection" in the testclient anymore.

---------

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
docs/websockets.md
starlette/responses.py
starlette/testclient.py
starlette/websockets.py
tests/test_websockets.py

index 6cf4e08e0ed7f9b7e0a1399e38247c110871a111..92c53cc0d401d5c811b31f0976aac52c8d35b3c3 100644 (file)
@@ -109,3 +109,17 @@ correctly updated.
 
 * `await websocket.send(message)`
 * `await websocket.receive()`
+
+### Send Denial Response
+
+If you call `websocket.close()` before calling `websocket.accept()` then
+the server will automatically send a HTTP 403 error to the client.
+
+If you want to send a different error response, you can use the
+`websocket.send_denial_response()` method. This will send the response
+and then close the connection.
+
+* `await websocket.send_denial_response(response)`
+
+This requires the ASGI server to support the WebSocket Denial Response
+extension. If it is not supported a `RuntimeError` will be raised.
index e613e98b223a45e94dd7e807a695a591f3e72e5a..15292f0e7075dd67093587bda3062e160f76517e 100644 (file)
@@ -148,14 +148,15 @@ class Response:
         )
 
     async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+        prefix = "websocket." if scope["type"] == "websocket" else ""
         await send(
             {
-                "type": "http.response.start",
+                "type": prefix + "http.response.start",
                 "status": self.status_code,
                 "headers": self.raw_headers,
             }
         )
-        await send({"type": "http.response.body", "body": self.body})
+        await send({"type": prefix + "http.response.body", "body": self.body})
 
         if self.background is not None:
             await self.background()
index 9e3ece4d9dbe80b20812dfe9dcb6b9eaffa84801..90eb53e3d43ec4ffd2a3fa5e84b1705dc29ece0a 100644 (file)
@@ -78,6 +78,16 @@ class _Upgrade(Exception):
         self.session = session
 
 
+class WebSocketDenialResponse(  # type: ignore[misc]
+    httpx.Response,
+    WebSocketDisconnect,
+):
+    """
+    A special case of `WebSocketDisconnect`, raised in the `TestClient` if the
+    `WebSocket` is closed before being accepted with a `send_denial_response()`.
+    """
+
+
 class WebSocketTestSession:
     def __init__(
         self,
@@ -159,7 +169,22 @@ class WebSocketTestSession:
     def _raise_on_close(self, message: Message) -> None:
         if message["type"] == "websocket.close":
             raise WebSocketDisconnect(
-                message.get("code", 1000), message.get("reason", "")
+                code=message.get("code", 1000), reason=message.get("reason", "")
+            )
+        elif message["type"] == "websocket.http.response.start":
+            status_code: int = message["status"]
+            headers: list[tuple[bytes, bytes]] = message["headers"]
+            body: list[bytes] = []
+            while True:
+                message = self.receive()
+                assert message["type"] == "websocket.http.response.body"
+                body.append(message["body"])
+                if not message.get("more_body", False):
+                    break
+            raise WebSocketDenialResponse(
+                status_code=status_code,
+                headers=headers,
+                content=b"".join(body),
             )
 
     def send(self, message: Message) -> None:
@@ -277,6 +302,7 @@ class _TestClientTransport(httpx.BaseTransport):
                 "server": [host, port],
                 "subprotocols": subprotocols,
                 "state": self.app_state.copy(),
+                "extensions": {"websocket.http.response": {}},
             }
             session = WebSocketTestSession(self.app, scope, self.portal_factory)
             raise _Upgrade(session)
index 850fbf1154033ec58de2775dfdd5d1f0f87e18c4..955063fa179d5bb859e81f256a15540d35a6d43c 100644 (file)
@@ -5,6 +5,7 @@ import json
 import typing
 
 from starlette.requests import HTTPConnection
+from starlette.responses import Response
 from starlette.types import Message, Receive, Scope, Send
 
 
@@ -12,6 +13,7 @@ class WebSocketState(enum.Enum):
     CONNECTING = 0
     CONNECTED = 1
     DISCONNECTED = 2
+    RESPONSE = 3
 
 
 class WebSocketDisconnect(Exception):
@@ -65,13 +67,20 @@ class WebSocket(HTTPConnection):
         """
         if self.application_state == WebSocketState.CONNECTING:
             message_type = message["type"]
-            if message_type not in {"websocket.accept", "websocket.close"}:
+            if message_type not in {
+                "websocket.accept",
+                "websocket.close",
+                "websocket.http.response.start",
+            }:
                 raise RuntimeError(
-                    'Expected ASGI message "websocket.accept" or '
-                    f'"websocket.close", but got {message_type!r}'
+                    'Expected ASGI message "websocket.accept",'
+                    '"websocket.close" or "websocket.http.response.start",'
+                    f"but got {message_type!r}"
                 )
             if message_type == "websocket.close":
                 self.application_state = WebSocketState.DISCONNECTED
+            elif message_type == "websocket.http.response.start":
+                self.application_state = WebSocketState.RESPONSE
             else:
                 self.application_state = WebSocketState.CONNECTED
             await self._send(message)
@@ -89,6 +98,16 @@ class WebSocket(HTTPConnection):
             except IOError:
                 self.application_state = WebSocketState.DISCONNECTED
                 raise WebSocketDisconnect(code=1006)
+        elif self.application_state == WebSocketState.RESPONSE:
+            message_type = message["type"]
+            if message_type != "websocket.http.response.body":
+                raise RuntimeError(
+                    'Expected ASGI message "websocket.http.response.body", '
+                    f"but got {message_type!r}"
+                )
+            if not message.get("more_body", False):
+                self.application_state = WebSocketState.DISCONNECTED
+            await self._send(message)
         else:
             raise RuntimeError('Cannot call "send" once a close message has been sent.')
 
@@ -185,6 +204,14 @@ class WebSocket(HTTPConnection):
             {"type": "websocket.close", "code": code, "reason": reason or ""}
         )
 
+    async def send_denial_response(self, response: Response) -> None:
+        if "websocket.http.response" in self.scope.get("extensions", {}):
+            await response(self.scope, self.receive, self.send)
+        else:
+            raise RuntimeError(
+                "The server doesn't support the Websocket Denial Response extension."
+            )
+
 
 class WebSocketClose:
     def __init__(self, code: int = 1000, reason: str | None = None) -> None:
index 247477404c682fa73023e8594c4c8c5da87c90a6..c8bfc02aa09f5ce344d34f65b9c96cb0711ad2fa 100644 (file)
@@ -6,7 +6,8 @@ import pytest
 from anyio.abc import ObjectReceiveStream, ObjectSendStream
 
 from starlette import status
-from starlette.testclient import TestClient
+from starlette.responses import Response
+from starlette.testclient import TestClient, WebSocketDenialResponse
 from starlette.types import Message, Receive, Scope, Send
 from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState
 
@@ -293,6 +294,8 @@ def test_application_close(test_client_factory: Callable[..., TestClient]):
 def test_rejected_connection(test_client_factory: Callable[..., TestClient]):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
+        msg = await websocket.receive()
+        assert msg == {"type": "websocket.connect"}
         await websocket.close(status.WS_1001_GOING_AWAY)
 
     client = test_client_factory(app)
@@ -302,6 +305,111 @@ def test_rejected_connection(test_client_factory: Callable[..., TestClient]):
     assert exc.value.code == status.WS_1001_GOING_AWAY
 
 
+def test_send_denial_response(test_client_factory: Callable[..., TestClient]):
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        msg = await websocket.receive()
+        assert msg == {"type": "websocket.connect"}
+        response = Response(status_code=404, content="foo")
+        await websocket.send_denial_response(response)
+
+    client = test_client_factory(app)
+    with pytest.raises(WebSocketDenialResponse) as exc:
+        with client.websocket_connect("/"):
+            pass  # pragma: no cover
+    assert exc.value.status_code == 404
+    assert exc.value.content == b"foo"
+
+
+def test_send_response_multi(test_client_factory: Callable[..., TestClient]):
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        msg = await websocket.receive()
+        assert msg == {"type": "websocket.connect"}
+        await websocket.send(
+            {
+                "type": "websocket.http.response.start",
+                "status": 404,
+                "headers": [(b"content-type", b"text/plain"), (b"foo", b"bar")],
+            }
+        )
+        await websocket.send(
+            {
+                "type": "websocket.http.response.body",
+                "body": b"hard",
+                "more_body": True,
+            }
+        )
+        await websocket.send(
+            {
+                "type": "websocket.http.response.body",
+                "body": b"body",
+            }
+        )
+
+    client = test_client_factory(app)
+    with pytest.raises(WebSocketDenialResponse) as exc:
+        with client.websocket_connect("/"):
+            pass  # pragma: no cover
+    assert exc.value.status_code == 404
+    assert exc.value.content == b"hardbody"
+    assert exc.value.headers["foo"] == "bar"
+
+
+def test_send_response_unsupported(test_client_factory: Callable[..., TestClient]):
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        del scope["extensions"]["websocket.http.response"]
+        websocket = WebSocket(scope, receive=receive, send=send)
+        msg = await websocket.receive()
+        assert msg == {"type": "websocket.connect"}
+        response = Response(status_code=404, content="foo")
+        with pytest.raises(
+            RuntimeError,
+            match="The server doesn't support the Websocket Denial Response extension.",
+        ):
+            await websocket.send_denial_response(response)
+        await websocket.close()
+
+    client = test_client_factory(app)
+    with pytest.raises(WebSocketDisconnect) as exc:
+        with client.websocket_connect("/"):
+            pass  # pragma: no cover
+    assert exc.value.code == status.WS_1000_NORMAL_CLOSURE
+
+
+def test_send_response_duplicate_start(test_client_factory: Callable[..., TestClient]):
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        msg = await websocket.receive()
+        assert msg == {"type": "websocket.connect"}
+        response = Response(status_code=404, content="foo")
+        await websocket.send(
+            {
+                "type": "websocket.http.response.start",
+                "status": response.status_code,
+                "headers": response.raw_headers,
+            }
+        )
+        await websocket.send(
+            {
+                "type": "websocket.http.response.start",
+                "status": response.status_code,
+                "headers": response.raw_headers,
+            }
+        )
+
+    client = test_client_factory(app)
+    with pytest.raises(
+        RuntimeError,
+        match=(
+            'Expected ASGI message "websocket.http.response.body", but got '
+            "'websocket.http.response.start'"
+        ),
+    ):
+        with client.websocket_connect("/"):
+            pass  # pragma: no cover
+
+
 def test_subprotocol(test_client_factory: Callable[..., TestClient]):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)