import pytest
from starlette import status
+from starlette.types import Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState
def test_websocket_url(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.accept()
- await websocket.send_json({"url": str(websocket.url)})
- await websocket.close()
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.accept()
+ await websocket.send_json({"url": str(websocket.url)})
+ await websocket.close()
client = test_client_factory(app)
with client.websocket_connect("/123?a=abc") as websocket:
def test_websocket_binary_json(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.accept()
- message = await websocket.receive_json(mode="binary")
- await websocket.send_json(message, mode="binary")
- await websocket.close()
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.accept()
+ message = await websocket.receive_json(mode="binary")
+ await websocket.send_json(message, mode="binary")
+ await websocket.close()
client = test_client_factory(app)
with client.websocket_connect("/123?a=abc") as websocket:
def test_websocket_query_params(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- query_params = dict(websocket.query_params)
- await websocket.accept()
- await websocket.send_json({"params": query_params})
- await websocket.close()
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ query_params = dict(websocket.query_params)
+ await websocket.accept()
+ await websocket.send_json({"params": query_params})
+ await websocket.close()
client = test_client_factory(app)
with client.websocket_connect("/?a=abc&b=456") as websocket:
def test_websocket_headers(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- headers = dict(websocket.headers)
- await websocket.accept()
- await websocket.send_json({"headers": headers})
- await websocket.close()
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ headers = dict(websocket.headers)
+ await websocket.accept()
+ await websocket.send_json({"headers": headers})
+ await websocket.close()
client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
def test_websocket_port(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.accept()
- await websocket.send_json({"port": websocket.url.port})
- await websocket.close()
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.accept()
+ await websocket.send_json({"port": websocket.url.port})
+ await websocket.close()
client = test_client_factory(app)
with client.websocket_connect("ws://example.com:123/123?a=abc") as websocket:
def test_websocket_send_and_receive_text(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.accept()
- data = await websocket.receive_text()
- await websocket.send_text("Message was: " + data)
- await websocket.close()
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.accept()
+ data = await websocket.receive_text()
+ await websocket.send_text("Message was: " + data)
+ await websocket.close()
client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
def test_websocket_send_and_receive_bytes(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.accept()
- data = await websocket.receive_bytes()
- await websocket.send_bytes(b"Message was: " + data)
- await websocket.close()
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.accept()
+ data = await websocket.receive_bytes()
+ await websocket.send_bytes(b"Message was: " + data)
+ await websocket.close()
client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
def test_websocket_send_and_receive_json(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.accept()
- data = await websocket.receive_json()
- await websocket.send_json({"message": data})
- await websocket.close()
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.accept()
+ data = await websocket.receive_json()
+ await websocket.send_json({"message": data})
+ await websocket.close()
client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
def test_websocket_iter_text(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.accept()
- async for data in websocket.iter_text():
- await websocket.send_text("Message was: " + data)
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.accept()
+ async for data in websocket.iter_text():
+ await websocket.send_text("Message was: " + data)
client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
def test_websocket_iter_bytes(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.accept()
- async for data in websocket.iter_bytes():
- await websocket.send_bytes(b"Message was: " + data)
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.accept()
+ async for data in websocket.iter_bytes():
+ await websocket.send_bytes(b"Message was: " + data)
client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
def test_websocket_iter_json(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.accept()
- async for data in websocket.iter_json():
- await websocket.send_json({"message": data})
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.accept()
+ async for data in websocket.iter_json():
+ await websocket.send_json({"message": data})
client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
def test_websocket_concurrency_pattern(test_client_factory):
- def app(scope):
- stream_send, stream_receive = anyio.create_memory_object_stream()
-
- async def reader(websocket):
- async with stream_send:
- async for data in websocket.iter_json():
- await stream_send.send(data)
+ stream_send, stream_receive = anyio.create_memory_object_stream()
- async def writer(websocket):
- async with stream_receive:
- async for message in stream_receive:
- await websocket.send_json(message)
+ async def reader(websocket):
+ async with stream_send:
+ async for data in websocket.iter_json():
+ await stream_send.send(data)
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.accept()
- async with anyio.create_task_group() as task_group:
- task_group.start_soon(reader, websocket)
- await writer(websocket)
- await websocket.close()
+ async def writer(websocket):
+ async with stream_receive:
+ async for message in stream_receive:
+ await websocket.send_json(message)
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.accept()
+ async with anyio.create_task_group() as task_group:
+ task_group.start_soon(reader, websocket)
+ await writer(websocket)
+ await websocket.close()
client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
def test_client_close(test_client_factory):
close_code = None
- def app(scope):
- async def asgi(receive, send):
- nonlocal close_code
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.accept()
- try:
- await websocket.receive_text()
- except WebSocketDisconnect as exc:
- close_code = exc.code
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ nonlocal close_code
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.accept()
+ try:
+ await websocket.receive_text()
+ except WebSocketDisconnect as exc:
+ close_code = exc.code
client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
def test_application_close(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.accept()
- await websocket.close(status.WS_1001_GOING_AWAY)
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.accept()
+ await websocket.close(status.WS_1001_GOING_AWAY)
client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
def test_rejected_connection(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.close(status.WS_1001_GOING_AWAY)
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.close(status.WS_1001_GOING_AWAY)
client = test_client_factory(app)
with pytest.raises(WebSocketDisconnect) as exc:
def test_subprotocol(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- assert websocket["subprotocols"] == ["soap", "wamp"]
- await websocket.accept(subprotocol="wamp")
- await websocket.close()
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ assert websocket["subprotocols"] == ["soap", "wamp"]
+ await websocket.accept(subprotocol="wamp")
+ await websocket.close()
client = test_client_factory(app)
with client.websocket_connect("/", subprotocols=["soap", "wamp"]) as websocket:
def test_additional_headers(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.accept(headers=[(b"additional", b"header")])
- await websocket.close()
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.accept(headers=[(b"additional", b"header")])
+ await websocket.close()
client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
def test_no_additional_headers(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.accept()
- await websocket.close()
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.accept()
+ await websocket.close()
client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
def test_websocket_exception(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- assert False
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ assert False
client = test_client_factory(app)
with pytest.raises(AssertionError):
def test_duplicate_close(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.accept()
- await websocket.close()
- await websocket.close()
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.accept()
+ await websocket.close()
+ await websocket.close()
client = test_client_factory(app)
with pytest.raises(RuntimeError):
def test_duplicate_disconnect(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.accept()
- message = await websocket.receive()
- assert message["type"] == "websocket.disconnect"
- message = await websocket.receive()
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.accept()
+ message = await websocket.receive()
+ assert message["type"] == "websocket.disconnect"
+ message = await websocket.receive()
client = test_client_factory(app)
with pytest.raises(RuntimeError):
def test_websocket_close_reason(test_client_factory) -> None:
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.accept()
- await websocket.close(code=status.WS_1001_GOING_AWAY, reason="Going Away")
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.accept()
+ await websocket.close(code=status.WS_1001_GOING_AWAY, reason="Going Away")
client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
def test_send_json_invalid_mode(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.accept()
- await websocket.send_json({}, mode="invalid")
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.accept()
+ await websocket.send_json({}, mode="invalid")
client = test_client_factory(app)
with pytest.raises(RuntimeError):
def test_receive_json_invalid_mode(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.accept()
- await websocket.receive_json(mode="invalid")
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.accept()
+ await websocket.receive_json(mode="invalid")
client = test_client_factory(app)
with pytest.raises(RuntimeError):
def test_receive_text_before_accept(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.receive_text()
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.receive_text()
client = test_client_factory(app)
with pytest.raises(RuntimeError):
def test_receive_bytes_before_accept(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.receive_bytes()
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.receive_bytes()
client = test_client_factory(app)
with pytest.raises(RuntimeError):
def test_receive_json_before_accept(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.receive_json()
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.receive_json()
client = test_client_factory(app)
with pytest.raises(RuntimeError):
def test_send_before_accept(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.send({"type": "websocket.send"})
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.send({"type": "websocket.send"})
client = test_client_factory(app)
with pytest.raises(RuntimeError):
def test_send_wrong_message_type(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.send({"type": "websocket.accept"})
- await websocket.send({"type": "websocket.accept"})
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.send({"type": "websocket.accept"})
+ await websocket.send({"type": "websocket.accept"})
client = test_client_factory(app)
with pytest.raises(RuntimeError):
def test_receive_before_accept(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.accept()
- websocket.client_state = WebSocketState.CONNECTING
- await websocket.receive()
-
- return asgi
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.accept()
+ websocket.client_state = WebSocketState.CONNECTING
+ await websocket.receive()
client = test_client_factory(app)
with pytest.raises(RuntimeError):
def test_receive_wrong_message_type(test_client_factory):
- def app(scope):
- async def asgi(receive, send):
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.accept()
- await websocket.receive()
-
- return asgi
+ async def app(scope, receive, send):
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.accept()
+ await websocket.receive()
client = test_client_factory(app)
with pytest.raises(RuntimeError):