from starlette.types import Message, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState
+TestClientFactory = Callable[..., TestClient]
-def test_websocket_url(test_client_factory: Callable[..., TestClient]):
+
+def test_websocket_url(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
assert data == {"url": "ws://testserver/123?a=abc"}
-def test_websocket_binary_json(test_client_factory: Callable[..., TestClient]):
+def test_websocket_binary_json(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
def test_websocket_ensure_unicode_on_send_json(
- test_client_factory: Callable[..., TestClient],
-):
+ test_client_factory: TestClientFactory,
+) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
assert data == '{"test":"数据"}'
-def test_websocket_query_params(test_client_factory: Callable[..., TestClient]):
+def test_websocket_query_params(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
query_params = dict(websocket.query_params)
any(module in sys.modules for module in ("brotli", "brotlicffi")),
reason='urllib3 includes "br" to the "accept-encoding" headers.',
)
-def test_websocket_headers(test_client_factory: Callable[..., TestClient]):
+def test_websocket_headers(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
headers = dict(websocket.headers)
assert data == {"headers": expected_headers}
-def test_websocket_port(test_client_factory: Callable[..., TestClient]):
+def test_websocket_port(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
def test_websocket_send_and_receive_text(
- test_client_factory: Callable[..., TestClient],
-):
+ test_client_factory: TestClientFactory,
+) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
def test_websocket_send_and_receive_bytes(
- test_client_factory: Callable[..., TestClient],
-):
+ test_client_factory: TestClientFactory,
+) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
def test_websocket_send_and_receive_json(
- test_client_factory: Callable[..., TestClient],
-):
+ test_client_factory: TestClientFactory,
+) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
assert data == {"message": {"hello": "world"}}
-def test_websocket_iter_text(test_client_factory: Callable[..., TestClient]):
+def test_websocket_iter_text(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
assert data == "Message was: Hello, world!"
-def test_websocket_iter_bytes(test_client_factory: Callable[..., TestClient]):
+def test_websocket_iter_bytes(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
assert data == b"Message was: Hello, world!"
-def test_websocket_iter_json(test_client_factory: Callable[..., TestClient]):
+def test_websocket_iter_json(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
assert data == {"message": {"hello": "world"}}
-def test_websocket_concurrency_pattern(test_client_factory: Callable[..., TestClient]):
+def test_websocket_concurrency_pattern(test_client_factory: TestClientFactory) -> None:
stream_send: ObjectSendStream[MutableMapping[str, Any]]
stream_receive: ObjectReceiveStream[MutableMapping[str, Any]]
stream_send, stream_receive = anyio.create_memory_object_stream()
- async def reader(websocket: WebSocket):
+ async def reader(websocket: WebSocket) -> None:
async with stream_send:
async for data in websocket.iter_json():
await stream_send.send(data)
- async def writer(websocket: WebSocket):
+ async def writer(websocket: WebSocket) -> None:
async with stream_receive:
async for message in stream_receive:
await websocket.send_json(message)
assert data == {"hello": "world"}
-def test_client_close(test_client_factory: Callable[..., TestClient]):
+def test_client_close(test_client_factory: TestClientFactory) -> None:
close_code = None
close_reason = None
@pytest.mark.anyio
-async def test_client_disconnect_on_send():
+async def test_client_disconnect_on_send() -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
assert ctx.value.code == status.WS_1006_ABNORMAL_CLOSURE
-def test_application_close(test_client_factory: Callable[..., TestClient]):
+def test_application_close(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
assert exc.value.code == status.WS_1001_GOING_AWAY
-def test_rejected_connection(test_client_factory: Callable[..., TestClient]):
+def test_rejected_connection(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
msg = await websocket.receive()
assert exc.value.code == status.WS_1001_GOING_AWAY
-def test_send_denial_response(test_client_factory: Callable[..., TestClient]):
+def test_send_denial_response(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
msg = await websocket.receive()
assert exc.value.content == b"foo"
-def test_send_response_multi(test_client_factory: Callable[..., TestClient]):
+def test_send_response_multi(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
msg = await websocket.receive()
assert exc.value.headers["foo"] == "bar"
-def test_send_response_unsupported(test_client_factory: Callable[..., TestClient]):
+def test_send_response_unsupported(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
del scope["extensions"]["websocket.http.response"]
websocket = WebSocket(scope, receive=receive, send=send)
assert exc.value.code == status.WS_1000_NORMAL_CLOSURE
-def test_send_response_duplicate_start(test_client_factory: Callable[..., TestClient]):
+def test_send_response_duplicate_start(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
msg = await websocket.receive()
pass # pragma: no cover
-def test_subprotocol(test_client_factory: Callable[..., TestClient]):
+def test_subprotocol(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
assert websocket["subprotocols"] == ["soap", "wamp"]
assert websocket.accepted_subprotocol == "wamp"
-def test_additional_headers(test_client_factory: Callable[..., TestClient]):
+def test_additional_headers(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept(headers=[(b"additional", b"header")])
assert websocket.extra_headers == [(b"additional", b"header")]
-def test_no_additional_headers(test_client_factory: Callable[..., TestClient]):
+def test_no_additional_headers(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
assert websocket.extra_headers == []
-def test_websocket_exception(test_client_factory: Callable[..., TestClient]):
+def test_websocket_exception(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
assert False
pass # pragma: no cover
-def test_duplicate_close(test_client_factory: Callable[..., TestClient]):
+def test_duplicate_close(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
pass # pragma: no cover
-def test_duplicate_disconnect(test_client_factory: Callable[..., TestClient]):
+def test_duplicate_disconnect(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
websocket.close()
-def test_websocket_scope_interface():
+def test_websocket_scope_interface() -> None:
"""
A WebSocket can be instantiated with a scope, and presents a `Mapping`
interface.
assert {websocket} == {websocket}
-def test_websocket_close_reason(test_client_factory: Callable[..., TestClient]) -> None:
+def test_websocket_close_reason(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
assert exc.value.reason == "Going Away"
-def test_send_json_invalid_mode(test_client_factory: Callable[..., TestClient]):
+def test_send_json_invalid_mode(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
pass # pragma: no cover
-def test_receive_json_invalid_mode(test_client_factory: Callable[..., TestClient]):
+def test_receive_json_invalid_mode(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
pass # pragma: nocover
-def test_receive_text_before_accept(test_client_factory: Callable[..., TestClient]):
+def test_receive_text_before_accept(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.receive_text()
pass # pragma: nocover
-def test_receive_bytes_before_accept(test_client_factory: Callable[..., TestClient]):
+def test_receive_bytes_before_accept(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.receive_bytes()
pass # pragma: nocover
-def test_receive_json_before_accept(test_client_factory: Callable[..., TestClient]):
+def test_receive_json_before_accept(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.receive_json()
pass # pragma: no cover
-def test_send_before_accept(test_client_factory: Callable[..., TestClient]):
+def test_send_before_accept(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.send({"type": "websocket.send"})
pass # pragma: nocover
-def test_send_wrong_message_type(test_client_factory: Callable[..., TestClient]):
+def test_send_wrong_message_type(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.send({"type": "websocket.accept"})
pass # pragma: no cover
-def test_receive_before_accept(test_client_factory: Callable[..., TestClient]):
+def test_receive_before_accept(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
websocket.send({"type": "websocket.send"})
-def test_receive_wrong_message_type(test_client_factory: Callable[..., TestClient]):
- async def app(scope: Scope, receive: Receive, send: Send):
+def test_receive_wrong_message_type(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
await websocket.receive()