From: Marcelo Trylesinski Date: Fri, 12 Jan 2024 10:06:48 +0000 (+0100) Subject: Add type hint on `test_websockets.py` (#2411) X-Git-Tag: 0.36.0~9 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=8d7630bda107ebbc58a4130b605ee6481067100e;p=thirdparty%2Fstarlette.git Add type hint on `test_websockets.py` (#2411) * Add type hint on `test_websockets.py` * Add type ignore on mock_receive --- diff --git a/tests/test_websockets.py b/tests/test_websockets.py index 41ab82e1..283dcfc7 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -7,11 +7,11 @@ from anyio.abc import ObjectReceiveStream, ObjectSendStream from starlette import status from starlette.testclient import TestClient -from starlette.types import Receive, Scope, Send +from starlette.types import Message, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState -def test_websocket_url(test_client_factory): +def test_websocket_url(test_client_factory: Callable[..., TestClient]): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -24,7 +24,7 @@ def test_websocket_url(test_client_factory): assert data == {"url": "ws://testserver/123?a=abc"} -def test_websocket_binary_json(test_client_factory): +def test_websocket_binary_json(test_client_factory: Callable[..., TestClient]): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -57,7 +57,7 @@ def test_websocket_ensure_unicode_on_send_json( assert data == '{"test":"数据"}' -def test_websocket_query_params(test_client_factory): +def test_websocket_query_params(test_client_factory: Callable[..., TestClient]): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) query_params = dict(websocket.query_params) @@ -75,7 +75,7 @@ def test_websocket_query_params(test_client_factory): any(module in sys.modules for module in ("brotli", "brotlicffi")), reason='urllib3 includes "br" to the "accept-encoding" headers.', ) -def test_websocket_headers(test_client_factory): +def test_websocket_headers(test_client_factory: Callable[..., TestClient]): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) headers = dict(websocket.headers) @@ -98,7 +98,7 @@ def test_websocket_headers(test_client_factory): assert data == {"headers": expected_headers} -def test_websocket_port(test_client_factory): +def test_websocket_port(test_client_factory: Callable[..., TestClient]): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -111,7 +111,9 @@ def test_websocket_port(test_client_factory): assert data == {"port": 123} -def test_websocket_send_and_receive_text(test_client_factory): +def test_websocket_send_and_receive_text( + test_client_factory: Callable[..., TestClient], +): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -126,7 +128,9 @@ def test_websocket_send_and_receive_text(test_client_factory): assert data == "Message was: Hello, world!" -def test_websocket_send_and_receive_bytes(test_client_factory): +def test_websocket_send_and_receive_bytes( + test_client_factory: Callable[..., TestClient], +): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -141,7 +145,9 @@ def test_websocket_send_and_receive_bytes(test_client_factory): assert data == b"Message was: Hello, world!" -def test_websocket_send_and_receive_json(test_client_factory): +def test_websocket_send_and_receive_json( + test_client_factory: Callable[..., TestClient], +): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -156,7 +162,7 @@ def test_websocket_send_and_receive_json(test_client_factory): assert data == {"message": {"hello": "world"}} -def test_websocket_iter_text(test_client_factory): +def test_websocket_iter_text(test_client_factory: Callable[..., TestClient]): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -170,7 +176,7 @@ def test_websocket_iter_text(test_client_factory): assert data == "Message was: Hello, world!" -def test_websocket_iter_bytes(test_client_factory): +def test_websocket_iter_bytes(test_client_factory: Callable[..., TestClient]): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -184,7 +190,7 @@ def test_websocket_iter_bytes(test_client_factory): assert data == b"Message was: Hello, world!" -def test_websocket_iter_json(test_client_factory): +def test_websocket_iter_json(test_client_factory: Callable[..., TestClient]): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -198,17 +204,17 @@ def test_websocket_iter_json(test_client_factory): assert data == {"message": {"hello": "world"}} -def test_websocket_concurrency_pattern(test_client_factory): +def test_websocket_concurrency_pattern(test_client_factory: Callable[..., TestClient]): stream_send: ObjectSendStream[MutableMapping[str, Any]] stream_receive: ObjectReceiveStream[MutableMapping[str, Any]] stream_send, stream_receive = anyio.create_memory_object_stream() - async def reader(websocket): + async def reader(websocket: WebSocket): async with stream_send: async for data in websocket.iter_json(): await stream_send.send(data) - async def writer(websocket): + async def writer(websocket: WebSocket): async with stream_receive: async for message in stream_receive: await websocket.send_json(message) @@ -249,7 +255,7 @@ def test_client_close(test_client_factory: Callable[..., TestClient]): assert close_reason == "Going Away" -def test_application_close(test_client_factory): +def test_application_close(test_client_factory: Callable[..., TestClient]): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -262,7 +268,7 @@ def test_application_close(test_client_factory): assert exc.value.code == status.WS_1001_GOING_AWAY -def test_rejected_connection(test_client_factory): +def test_rejected_connection(test_client_factory: Callable[..., TestClient]): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.close(status.WS_1001_GOING_AWAY) @@ -270,11 +276,11 @@ def test_rejected_connection(test_client_factory): client = test_client_factory(app) with pytest.raises(WebSocketDisconnect) as exc: with client.websocket_connect("/"): - pass # pragma: nocover + pass # pragma: no cover assert exc.value.code == status.WS_1001_GOING_AWAY -def test_subprotocol(test_client_factory): +def test_subprotocol(test_client_factory: Callable[..., TestClient]): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) assert websocket["subprotocols"] == ["soap", "wamp"] @@ -286,7 +292,7 @@ def test_subprotocol(test_client_factory): assert websocket.accepted_subprotocol == "wamp" -def test_additional_headers(test_client_factory): +def test_additional_headers(test_client_factory: Callable[..., TestClient]): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept(headers=[(b"additional", b"header")]) @@ -297,7 +303,7 @@ def test_additional_headers(test_client_factory): assert websocket.extra_headers == [(b"additional", b"header")] -def test_no_additional_headers(test_client_factory): +def test_no_additional_headers(test_client_factory: Callable[..., TestClient]): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -308,17 +314,17 @@ def test_no_additional_headers(test_client_factory): assert websocket.extra_headers == [] -def test_websocket_exception(test_client_factory): +def test_websocket_exception(test_client_factory: Callable[..., TestClient]): async def app(scope: Scope, receive: Receive, send: Send) -> None: assert False client = test_client_factory(app) with pytest.raises(AssertionError): with client.websocket_connect("/123?a=abc"): - pass # pragma: nocover + pass # pragma: no cover -def test_duplicate_close(test_client_factory): +def test_duplicate_close(test_client_factory: Callable[..., TestClient]): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -328,10 +334,10 @@ def test_duplicate_close(test_client_factory): client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): - pass # pragma: nocover + pass # pragma: no cover -def test_duplicate_disconnect(test_client_factory): +def test_duplicate_disconnect(test_client_factory: Callable[..., TestClient]): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -351,11 +357,11 @@ def test_websocket_scope_interface(): interface. """ - async def mock_receive(): - pass # pragma: no cover + async def mock_receive() -> Message: # type: ignore + ... # pragma: no cover - async def mock_send(message): - pass # pragma: no cover + async def mock_send(message: Message) -> None: + ... # pragma: no cover websocket = WebSocket( {"type": "websocket", "path": "/abc/", "headers": []}, @@ -377,7 +383,7 @@ def test_websocket_scope_interface(): assert {websocket} == {websocket} -def test_websocket_close_reason(test_client_factory) -> None: +def test_websocket_close_reason(test_client_factory: Callable[..., TestClient]) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -391,7 +397,7 @@ def test_websocket_close_reason(test_client_factory) -> None: assert exc.value.reason == "Going Away" -def test_send_json_invalid_mode(test_client_factory): +def test_send_json_invalid_mode(test_client_factory: Callable[..., TestClient]): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -400,10 +406,10 @@ def test_send_json_invalid_mode(test_client_factory): client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): - pass # pragma: nocover + pass # pragma: no cover -def test_receive_json_invalid_mode(test_client_factory): +def test_receive_json_invalid_mode(test_client_factory: Callable[..., TestClient]): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -415,7 +421,7 @@ def test_receive_json_invalid_mode(test_client_factory): pass # pragma: nocover -def test_receive_text_before_accept(test_client_factory): +def test_receive_text_before_accept(test_client_factory: Callable[..., TestClient]): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.receive_text() @@ -426,7 +432,7 @@ def test_receive_text_before_accept(test_client_factory): pass # pragma: nocover -def test_receive_bytes_before_accept(test_client_factory): +def test_receive_bytes_before_accept(test_client_factory: Callable[..., TestClient]): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.receive_bytes() @@ -437,7 +443,7 @@ def test_receive_bytes_before_accept(test_client_factory): pass # pragma: nocover -def test_receive_json_before_accept(test_client_factory): +def test_receive_json_before_accept(test_client_factory: Callable[..., TestClient]): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.receive_json() @@ -445,10 +451,10 @@ def test_receive_json_before_accept(test_client_factory): client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): - pass # pragma: nocover + pass # pragma: no cover -def test_send_before_accept(test_client_factory): +def test_send_before_accept(test_client_factory: Callable[..., TestClient]): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.send({"type": "websocket.send"}) @@ -459,7 +465,7 @@ def test_send_before_accept(test_client_factory): pass # pragma: nocover -def test_send_wrong_message_type(test_client_factory): +def test_send_wrong_message_type(test_client_factory: Callable[..., TestClient]): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.send({"type": "websocket.accept"}) @@ -468,10 +474,10 @@ def test_send_wrong_message_type(test_client_factory): client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): - pass # pragma: nocover + pass # pragma: no cover -def test_receive_before_accept(test_client_factory): +def test_receive_before_accept(test_client_factory: Callable[..., TestClient]): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -484,8 +490,8 @@ def test_receive_before_accept(test_client_factory): websocket.send({"type": "websocket.send"}) -def test_receive_wrong_message_type(test_client_factory): - async def app(scope, receive, send): +def test_receive_wrong_message_type(test_client_factory: Callable[..., TestClient]): + async def app(scope: Scope, receive: Receive, send: Send): websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.receive()