from starlette import status
from starlette.testclient import TestClient
-from starlette.types import Receive, Scope, Send
+from starlette.types import Message, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState
-def test_websocket_url(test_client_factory):
+def test_websocket_url(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
assert data == {"url": "ws://testserver/123?a=abc"}
-def test_websocket_binary_json(test_client_factory):
+def test_websocket_binary_json(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
assert data == '{"test":"数据"}'
-def test_websocket_query_params(test_client_factory):
+def test_websocket_query_params(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
query_params = dict(websocket.query_params)
any(module in sys.modules for module in ("brotli", "brotlicffi")),
reason='urllib3 includes "br" to the "accept-encoding" headers.',
)
-def test_websocket_headers(test_client_factory):
+def test_websocket_headers(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
headers = dict(websocket.headers)
assert data == {"headers": expected_headers}
-def test_websocket_port(test_client_factory):
+def test_websocket_port(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
assert data == {"port": 123}
-def test_websocket_send_and_receive_text(test_client_factory):
+def test_websocket_send_and_receive_text(
+ test_client_factory: Callable[..., TestClient],
+):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
assert data == "Message was: Hello, world!"
-def test_websocket_send_and_receive_bytes(test_client_factory):
+def test_websocket_send_and_receive_bytes(
+ test_client_factory: Callable[..., TestClient],
+):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
assert data == b"Message was: Hello, world!"
-def test_websocket_send_and_receive_json(test_client_factory):
+def test_websocket_send_and_receive_json(
+ test_client_factory: Callable[..., TestClient],
+):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
assert data == {"message": {"hello": "world"}}
-def test_websocket_iter_text(test_client_factory):
+def test_websocket_iter_text(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
assert data == "Message was: Hello, world!"
-def test_websocket_iter_bytes(test_client_factory):
+def test_websocket_iter_bytes(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
assert data == b"Message was: Hello, world!"
-def test_websocket_iter_json(test_client_factory):
+def test_websocket_iter_json(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
assert data == {"message": {"hello": "world"}}
-def test_websocket_concurrency_pattern(test_client_factory):
+def test_websocket_concurrency_pattern(test_client_factory: Callable[..., TestClient]):
stream_send: ObjectSendStream[MutableMapping[str, Any]]
stream_receive: ObjectReceiveStream[MutableMapping[str, Any]]
stream_send, stream_receive = anyio.create_memory_object_stream()
- async def reader(websocket):
+ async def reader(websocket: WebSocket):
async with stream_send:
async for data in websocket.iter_json():
await stream_send.send(data)
- async def writer(websocket):
+ async def writer(websocket: WebSocket):
async with stream_receive:
async for message in stream_receive:
await websocket.send_json(message)
assert close_reason == "Going Away"
-def test_application_close(test_client_factory):
+def test_application_close(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
assert exc.value.code == status.WS_1001_GOING_AWAY
-def test_rejected_connection(test_client_factory):
+def test_rejected_connection(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.close(status.WS_1001_GOING_AWAY)
client = test_client_factory(app)
with pytest.raises(WebSocketDisconnect) as exc:
with client.websocket_connect("/"):
- pass # pragma: nocover
+ pass # pragma: no cover
assert exc.value.code == status.WS_1001_GOING_AWAY
-def test_subprotocol(test_client_factory):
+def test_subprotocol(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
assert websocket["subprotocols"] == ["soap", "wamp"]
assert websocket.accepted_subprotocol == "wamp"
-def test_additional_headers(test_client_factory):
+def test_additional_headers(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept(headers=[(b"additional", b"header")])
assert websocket.extra_headers == [(b"additional", b"header")]
-def test_no_additional_headers(test_client_factory):
+def test_no_additional_headers(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
assert websocket.extra_headers == []
-def test_websocket_exception(test_client_factory):
+def test_websocket_exception(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
assert False
client = test_client_factory(app)
with pytest.raises(AssertionError):
with client.websocket_connect("/123?a=abc"):
- pass # pragma: nocover
+ pass # pragma: no cover
-def test_duplicate_close(test_client_factory):
+def test_duplicate_close(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
client = test_client_factory(app)
with pytest.raises(RuntimeError):
with client.websocket_connect("/"):
- pass # pragma: nocover
+ pass # pragma: no cover
-def test_duplicate_disconnect(test_client_factory):
+def test_duplicate_disconnect(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
interface.
"""
- async def mock_receive():
- pass # pragma: no cover
+ async def mock_receive() -> Message: # type: ignore
+ ... # pragma: no cover
- async def mock_send(message):
- pass # pragma: no cover
+ async def mock_send(message: Message) -> None:
+ ... # pragma: no cover
websocket = WebSocket(
{"type": "websocket", "path": "/abc/", "headers": []},
assert {websocket} == {websocket}
-def test_websocket_close_reason(test_client_factory) -> None:
+def test_websocket_close_reason(test_client_factory: Callable[..., TestClient]) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
assert exc.value.reason == "Going Away"
-def test_send_json_invalid_mode(test_client_factory):
+def test_send_json_invalid_mode(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
client = test_client_factory(app)
with pytest.raises(RuntimeError):
with client.websocket_connect("/"):
- pass # pragma: nocover
+ pass # pragma: no cover
-def test_receive_json_invalid_mode(test_client_factory):
+def test_receive_json_invalid_mode(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
pass # pragma: nocover
-def test_receive_text_before_accept(test_client_factory):
+def test_receive_text_before_accept(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.receive_text()
pass # pragma: nocover
-def test_receive_bytes_before_accept(test_client_factory):
+def test_receive_bytes_before_accept(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.receive_bytes()
pass # pragma: nocover
-def test_receive_json_before_accept(test_client_factory):
+def test_receive_json_before_accept(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.receive_json()
client = test_client_factory(app)
with pytest.raises(RuntimeError):
with client.websocket_connect("/"):
- pass # pragma: nocover
+ pass # pragma: no cover
-def test_send_before_accept(test_client_factory):
+def test_send_before_accept(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.send({"type": "websocket.send"})
pass # pragma: nocover
-def test_send_wrong_message_type(test_client_factory):
+def test_send_wrong_message_type(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.send({"type": "websocket.accept"})
client = test_client_factory(app)
with pytest.raises(RuntimeError):
with client.websocket_connect("/"):
- pass # pragma: nocover
+ pass # pragma: no cover
-def test_receive_before_accept(test_client_factory):
+def test_receive_before_accept(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
websocket.send({"type": "websocket.send"})
-def test_receive_wrong_message_type(test_client_factory):
- async def app(scope, receive, send):
+def test_receive_wrong_message_type(test_client_factory: Callable[..., TestClient]):
+ async def app(scope: Scope, receive: Receive, send: Send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
await websocket.receive()