From: Scirlat Danut Date: Wed, 7 Feb 2024 19:56:21 +0000 (+0200) Subject: Add type hints to `test_websockets.py` (#2494) X-Git-Tag: 0.37.1~7 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=179d934b5537039d60d0866300ad6815a74c7812;p=thirdparty%2Fstarlette.git Add type hints to `test_websockets.py` (#2494) Co-authored-by: Scirlat Danut Co-authored-by: Marcelo Trylesinski --- diff --git a/tests/test_websockets.py b/tests/test_websockets.py index c8bfc02a..c4b6c16b 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -11,8 +11,10 @@ from starlette.testclient import TestClient, WebSocketDenialResponse from starlette.types import Message, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState +TestClientFactory = Callable[..., TestClient] -def test_websocket_url(test_client_factory: Callable[..., TestClient]): + +def test_websocket_url(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -25,7 +27,7 @@ def test_websocket_url(test_client_factory: Callable[..., TestClient]): assert data == {"url": "ws://testserver/123?a=abc"} -def test_websocket_binary_json(test_client_factory: Callable[..., TestClient]): +def test_websocket_binary_json(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -41,8 +43,8 @@ def test_websocket_binary_json(test_client_factory: Callable[..., TestClient]): def test_websocket_ensure_unicode_on_send_json( - test_client_factory: Callable[..., TestClient], -): + test_client_factory: TestClientFactory, +) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) @@ -58,7 +60,7 @@ def test_websocket_ensure_unicode_on_send_json( assert data == '{"test":"数据"}' -def test_websocket_query_params(test_client_factory: Callable[..., TestClient]): +def test_websocket_query_params(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) query_params = dict(websocket.query_params) @@ -76,7 +78,7 @@ def test_websocket_query_params(test_client_factory: Callable[..., TestClient]): any(module in sys.modules for module in ("brotli", "brotlicffi")), reason='urllib3 includes "br" to the "accept-encoding" headers.', ) -def test_websocket_headers(test_client_factory: Callable[..., TestClient]): +def test_websocket_headers(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) headers = dict(websocket.headers) @@ -99,7 +101,7 @@ def test_websocket_headers(test_client_factory: Callable[..., TestClient]): assert data == {"headers": expected_headers} -def test_websocket_port(test_client_factory: Callable[..., TestClient]): +def test_websocket_port(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -113,8 +115,8 @@ def test_websocket_port(test_client_factory: Callable[..., TestClient]): def test_websocket_send_and_receive_text( - test_client_factory: Callable[..., TestClient], -): + test_client_factory: TestClientFactory, +) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -130,8 +132,8 @@ def test_websocket_send_and_receive_text( def test_websocket_send_and_receive_bytes( - test_client_factory: Callable[..., TestClient], -): + test_client_factory: TestClientFactory, +) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -147,8 +149,8 @@ def test_websocket_send_and_receive_bytes( def test_websocket_send_and_receive_json( - test_client_factory: Callable[..., TestClient], -): + test_client_factory: TestClientFactory, +) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -163,7 +165,7 @@ def test_websocket_send_and_receive_json( assert data == {"message": {"hello": "world"}} -def test_websocket_iter_text(test_client_factory: Callable[..., TestClient]): +def test_websocket_iter_text(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -177,7 +179,7 @@ def test_websocket_iter_text(test_client_factory: Callable[..., TestClient]): assert data == "Message was: Hello, world!" -def test_websocket_iter_bytes(test_client_factory: Callable[..., TestClient]): +def test_websocket_iter_bytes(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -191,7 +193,7 @@ def test_websocket_iter_bytes(test_client_factory: Callable[..., TestClient]): assert data == b"Message was: Hello, world!" -def test_websocket_iter_json(test_client_factory: Callable[..., TestClient]): +def test_websocket_iter_json(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -205,17 +207,17 @@ def test_websocket_iter_json(test_client_factory: Callable[..., TestClient]): assert data == {"message": {"hello": "world"}} -def test_websocket_concurrency_pattern(test_client_factory: Callable[..., TestClient]): +def test_websocket_concurrency_pattern(test_client_factory: TestClientFactory) -> None: stream_send: ObjectSendStream[MutableMapping[str, Any]] stream_receive: ObjectReceiveStream[MutableMapping[str, Any]] stream_send, stream_receive = anyio.create_memory_object_stream() - async def reader(websocket: WebSocket): + async def reader(websocket: WebSocket) -> None: async with stream_send: async for data in websocket.iter_json(): await stream_send.send(data) - async def writer(websocket: WebSocket): + async def writer(websocket: WebSocket) -> None: async with stream_receive: async for message in stream_receive: await websocket.send_json(message) @@ -235,7 +237,7 @@ def test_websocket_concurrency_pattern(test_client_factory: Callable[..., TestCl assert data == {"hello": "world"} -def test_client_close(test_client_factory: Callable[..., TestClient]): +def test_client_close(test_client_factory: TestClientFactory) -> None: close_code = None close_reason = None @@ -257,7 +259,7 @@ def test_client_close(test_client_factory: Callable[..., TestClient]): @pytest.mark.anyio -async def test_client_disconnect_on_send(): +async def test_client_disconnect_on_send() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -278,7 +280,7 @@ async def test_client_disconnect_on_send(): assert ctx.value.code == status.WS_1006_ABNORMAL_CLOSURE -def test_application_close(test_client_factory: Callable[..., TestClient]): +def test_application_close(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -291,7 +293,7 @@ def test_application_close(test_client_factory: Callable[..., TestClient]): assert exc.value.code == status.WS_1001_GOING_AWAY -def test_rejected_connection(test_client_factory: Callable[..., TestClient]): +def test_rejected_connection(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) msg = await websocket.receive() @@ -305,7 +307,7 @@ def test_rejected_connection(test_client_factory: Callable[..., TestClient]): assert exc.value.code == status.WS_1001_GOING_AWAY -def test_send_denial_response(test_client_factory: Callable[..., TestClient]): +def test_send_denial_response(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) msg = await websocket.receive() @@ -321,7 +323,7 @@ def test_send_denial_response(test_client_factory: Callable[..., TestClient]): assert exc.value.content == b"foo" -def test_send_response_multi(test_client_factory: Callable[..., TestClient]): +def test_send_response_multi(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) msg = await websocket.receive() @@ -356,7 +358,7 @@ def test_send_response_multi(test_client_factory: Callable[..., TestClient]): assert exc.value.headers["foo"] == "bar" -def test_send_response_unsupported(test_client_factory: Callable[..., TestClient]): +def test_send_response_unsupported(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: del scope["extensions"]["websocket.http.response"] websocket = WebSocket(scope, receive=receive, send=send) @@ -377,7 +379,7 @@ def test_send_response_unsupported(test_client_factory: Callable[..., TestClient assert exc.value.code == status.WS_1000_NORMAL_CLOSURE -def test_send_response_duplicate_start(test_client_factory: Callable[..., TestClient]): +def test_send_response_duplicate_start(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) msg = await websocket.receive() @@ -410,7 +412,7 @@ def test_send_response_duplicate_start(test_client_factory: Callable[..., TestCl pass # pragma: no cover -def test_subprotocol(test_client_factory: Callable[..., TestClient]): +def test_subprotocol(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) assert websocket["subprotocols"] == ["soap", "wamp"] @@ -422,7 +424,7 @@ def test_subprotocol(test_client_factory: Callable[..., TestClient]): assert websocket.accepted_subprotocol == "wamp" -def test_additional_headers(test_client_factory: Callable[..., TestClient]): +def test_additional_headers(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept(headers=[(b"additional", b"header")]) @@ -433,7 +435,7 @@ def test_additional_headers(test_client_factory: Callable[..., TestClient]): assert websocket.extra_headers == [(b"additional", b"header")] -def test_no_additional_headers(test_client_factory: Callable[..., TestClient]): +def test_no_additional_headers(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -444,7 +446,7 @@ def test_no_additional_headers(test_client_factory: Callable[..., TestClient]): assert websocket.extra_headers == [] -def test_websocket_exception(test_client_factory: Callable[..., TestClient]): +def test_websocket_exception(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: assert False @@ -454,7 +456,7 @@ def test_websocket_exception(test_client_factory: Callable[..., TestClient]): pass # pragma: no cover -def test_duplicate_close(test_client_factory: Callable[..., TestClient]): +def test_duplicate_close(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -467,7 +469,7 @@ def test_duplicate_close(test_client_factory: Callable[..., TestClient]): pass # pragma: no cover -def test_duplicate_disconnect(test_client_factory: Callable[..., TestClient]): +def test_duplicate_disconnect(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -481,7 +483,7 @@ def test_duplicate_disconnect(test_client_factory: Callable[..., TestClient]): websocket.close() -def test_websocket_scope_interface(): +def test_websocket_scope_interface() -> None: """ A WebSocket can be instantiated with a scope, and presents a `Mapping` interface. @@ -513,7 +515,7 @@ def test_websocket_scope_interface(): assert {websocket} == {websocket} -def test_websocket_close_reason(test_client_factory: Callable[..., TestClient]) -> None: +def test_websocket_close_reason(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -527,7 +529,7 @@ def test_websocket_close_reason(test_client_factory: Callable[..., TestClient]) assert exc.value.reason == "Going Away" -def test_send_json_invalid_mode(test_client_factory: Callable[..., TestClient]): +def test_send_json_invalid_mode(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -539,7 +541,7 @@ def test_send_json_invalid_mode(test_client_factory: Callable[..., TestClient]): pass # pragma: no cover -def test_receive_json_invalid_mode(test_client_factory: Callable[..., TestClient]): +def test_receive_json_invalid_mode(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -551,7 +553,7 @@ def test_receive_json_invalid_mode(test_client_factory: Callable[..., TestClient pass # pragma: nocover -def test_receive_text_before_accept(test_client_factory: Callable[..., TestClient]): +def test_receive_text_before_accept(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.receive_text() @@ -562,7 +564,7 @@ def test_receive_text_before_accept(test_client_factory: Callable[..., TestClien pass # pragma: nocover -def test_receive_bytes_before_accept(test_client_factory: Callable[..., TestClient]): +def test_receive_bytes_before_accept(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.receive_bytes() @@ -573,7 +575,7 @@ def test_receive_bytes_before_accept(test_client_factory: Callable[..., TestClie pass # pragma: nocover -def test_receive_json_before_accept(test_client_factory: Callable[..., TestClient]): +def test_receive_json_before_accept(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.receive_json() @@ -584,7 +586,7 @@ def test_receive_json_before_accept(test_client_factory: Callable[..., TestClien pass # pragma: no cover -def test_send_before_accept(test_client_factory: Callable[..., TestClient]): +def test_send_before_accept(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.send({"type": "websocket.send"}) @@ -595,7 +597,7 @@ def test_send_before_accept(test_client_factory: Callable[..., TestClient]): pass # pragma: nocover -def test_send_wrong_message_type(test_client_factory: Callable[..., TestClient]): +def test_send_wrong_message_type(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.send({"type": "websocket.accept"}) @@ -607,7 +609,7 @@ def test_send_wrong_message_type(test_client_factory: Callable[..., TestClient]) pass # pragma: no cover -def test_receive_before_accept(test_client_factory: Callable[..., TestClient]): +def test_receive_before_accept(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() @@ -620,8 +622,8 @@ def test_receive_before_accept(test_client_factory: Callable[..., TestClient]): websocket.send({"type": "websocket.send"}) -def test_receive_wrong_message_type(test_client_factory: Callable[..., TestClient]): - async def app(scope: Scope, receive: Receive, send: Send): +def test_receive_wrong_message_type(test_client_factory: TestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.receive()