self.session = session
+class WebSocketDenialResponse( # type: ignore[misc]
+ httpx.Response,
+ WebSocketDisconnect,
+):
+ """
+ A special case of `WebSocketDisconnect`, raised in the `TestClient` if the
+ `WebSocket` is closed before being accepted with a `send_denial_response()`.
+ """
+
+
class WebSocketTestSession:
def __init__(
self,
def _raise_on_close(self, message: Message) -> None:
if message["type"] == "websocket.close":
raise WebSocketDisconnect(
- message.get("code", 1000), message.get("reason", "")
+ code=message.get("code", 1000), reason=message.get("reason", "")
+ )
+ elif message["type"] == "websocket.http.response.start":
+ status_code: int = message["status"]
+ headers: list[tuple[bytes, bytes]] = message["headers"]
+ body: list[bytes] = []
+ while True:
+ message = self.receive()
+ assert message["type"] == "websocket.http.response.body"
+ body.append(message["body"])
+ if not message.get("more_body", False):
+ break
+ raise WebSocketDenialResponse(
+ status_code=status_code,
+ headers=headers,
+ content=b"".join(body),
)
def send(self, message: Message) -> None:
"server": [host, port],
"subprotocols": subprotocols,
"state": self.app_state.copy(),
+ "extensions": {"websocket.http.response": {}},
}
session = WebSocketTestSession(self.app, scope, self.portal_factory)
raise _Upgrade(session)
import typing
from starlette.requests import HTTPConnection
+from starlette.responses import Response
from starlette.types import Message, Receive, Scope, Send
CONNECTING = 0
CONNECTED = 1
DISCONNECTED = 2
+ RESPONSE = 3
class WebSocketDisconnect(Exception):
"""
if self.application_state == WebSocketState.CONNECTING:
message_type = message["type"]
- if message_type not in {"websocket.accept", "websocket.close"}:
+ if message_type not in {
+ "websocket.accept",
+ "websocket.close",
+ "websocket.http.response.start",
+ }:
raise RuntimeError(
- 'Expected ASGI message "websocket.accept" or '
- f'"websocket.close", but got {message_type!r}'
+ 'Expected ASGI message "websocket.accept",'
+ '"websocket.close" or "websocket.http.response.start",'
+ f"but got {message_type!r}"
)
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
+ elif message_type == "websocket.http.response.start":
+ self.application_state = WebSocketState.RESPONSE
else:
self.application_state = WebSocketState.CONNECTED
await self._send(message)
except IOError:
self.application_state = WebSocketState.DISCONNECTED
raise WebSocketDisconnect(code=1006)
+ elif self.application_state == WebSocketState.RESPONSE:
+ message_type = message["type"]
+ if message_type != "websocket.http.response.body":
+ raise RuntimeError(
+ 'Expected ASGI message "websocket.http.response.body", '
+ f"but got {message_type!r}"
+ )
+ if not message.get("more_body", False):
+ self.application_state = WebSocketState.DISCONNECTED
+ await self._send(message)
else:
raise RuntimeError('Cannot call "send" once a close message has been sent.')
{"type": "websocket.close", "code": code, "reason": reason or ""}
)
+ async def send_denial_response(self, response: Response) -> None:
+ if "websocket.http.response" in self.scope.get("extensions", {}):
+ await response(self.scope, self.receive, self.send)
+ else:
+ raise RuntimeError(
+ "The server doesn't support the Websocket Denial Response extension."
+ )
+
class WebSocketClose:
def __init__(self, code: int = 1000, reason: str | None = None) -> None:
from anyio.abc import ObjectReceiveStream, ObjectSendStream
from starlette import status
-from starlette.testclient import TestClient
+from starlette.responses import Response
+from starlette.testclient import TestClient, WebSocketDenialResponse
from starlette.types import Message, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState
def test_rejected_connection(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
+ msg = await websocket.receive()
+ assert msg == {"type": "websocket.connect"}
await websocket.close(status.WS_1001_GOING_AWAY)
client = test_client_factory(app)
assert exc.value.code == status.WS_1001_GOING_AWAY
+def test_send_denial_response(test_client_factory: Callable[..., TestClient]):
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ msg = await websocket.receive()
+ assert msg == {"type": "websocket.connect"}
+ response = Response(status_code=404, content="foo")
+ await websocket.send_denial_response(response)
+
+ client = test_client_factory(app)
+ with pytest.raises(WebSocketDenialResponse) as exc:
+ with client.websocket_connect("/"):
+ pass # pragma: no cover
+ assert exc.value.status_code == 404
+ assert exc.value.content == b"foo"
+
+
+def test_send_response_multi(test_client_factory: Callable[..., TestClient]):
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ msg = await websocket.receive()
+ assert msg == {"type": "websocket.connect"}
+ await websocket.send(
+ {
+ "type": "websocket.http.response.start",
+ "status": 404,
+ "headers": [(b"content-type", b"text/plain"), (b"foo", b"bar")],
+ }
+ )
+ await websocket.send(
+ {
+ "type": "websocket.http.response.body",
+ "body": b"hard",
+ "more_body": True,
+ }
+ )
+ await websocket.send(
+ {
+ "type": "websocket.http.response.body",
+ "body": b"body",
+ }
+ )
+
+ client = test_client_factory(app)
+ with pytest.raises(WebSocketDenialResponse) as exc:
+ with client.websocket_connect("/"):
+ pass # pragma: no cover
+ assert exc.value.status_code == 404
+ assert exc.value.content == b"hardbody"
+ assert exc.value.headers["foo"] == "bar"
+
+
+def test_send_response_unsupported(test_client_factory: Callable[..., TestClient]):
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ del scope["extensions"]["websocket.http.response"]
+ websocket = WebSocket(scope, receive=receive, send=send)
+ msg = await websocket.receive()
+ assert msg == {"type": "websocket.connect"}
+ response = Response(status_code=404, content="foo")
+ with pytest.raises(
+ RuntimeError,
+ match="The server doesn't support the Websocket Denial Response extension.",
+ ):
+ await websocket.send_denial_response(response)
+ await websocket.close()
+
+ client = test_client_factory(app)
+ with pytest.raises(WebSocketDisconnect) as exc:
+ with client.websocket_connect("/"):
+ pass # pragma: no cover
+ assert exc.value.code == status.WS_1000_NORMAL_CLOSURE
+
+
+def test_send_response_duplicate_start(test_client_factory: Callable[..., TestClient]):
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ msg = await websocket.receive()
+ assert msg == {"type": "websocket.connect"}
+ response = Response(status_code=404, content="foo")
+ await websocket.send(
+ {
+ "type": "websocket.http.response.start",
+ "status": response.status_code,
+ "headers": response.raw_headers,
+ }
+ )
+ await websocket.send(
+ {
+ "type": "websocket.http.response.start",
+ "status": response.status_code,
+ "headers": response.raw_headers,
+ }
+ )
+
+ client = test_client_factory(app)
+ with pytest.raises(
+ RuntimeError,
+ match=(
+ 'Expected ASGI message "websocket.http.response.body", but got '
+ "'websocket.http.response.start'"
+ ),
+ ):
+ with client.websocket_connect("/"):
+ pass # pragma: no cover
+
+
def test_subprotocol(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)