]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add type hints to `test_websockets.py` (#2494)
authorScirlat Danut <danut.scirlat@gmail.com>
Wed, 7 Feb 2024 19:56:21 +0000 (21:56 +0200)
committerGitHub <noreply@github.com>
Wed, 7 Feb 2024 19:56:21 +0000 (19:56 +0000)
Co-authored-by: Scirlat Danut <scirlatdanut@scirlats-mini.lan>
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
tests/test_websockets.py

index c8bfc02aa09f5ce344d34f65b9c96cb0711ad2fa..c4b6c16bdbbd012ac7cbaa1f6a77cb09dc638740 100644 (file)
@@ -11,8 +11,10 @@ from starlette.testclient import TestClient, WebSocketDenialResponse
 from starlette.types import Message, Receive, Scope, Send
 from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState
 
+TestClientFactory = Callable[..., TestClient]
 
-def test_websocket_url(test_client_factory: Callable[..., TestClient]):
+
+def test_websocket_url(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -25,7 +27,7 @@ def test_websocket_url(test_client_factory: Callable[..., TestClient]):
         assert data == {"url": "ws://testserver/123?a=abc"}
 
 
-def test_websocket_binary_json(test_client_factory: Callable[..., TestClient]):
+def test_websocket_binary_json(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -41,8 +43,8 @@ def test_websocket_binary_json(test_client_factory: Callable[..., TestClient]):
 
 
 def test_websocket_ensure_unicode_on_send_json(
-    test_client_factory: Callable[..., TestClient],
-):
+    test_client_factory: TestClientFactory,
+) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
 
@@ -58,7 +60,7 @@ def test_websocket_ensure_unicode_on_send_json(
         assert data == '{"test":"数据"}'
 
 
-def test_websocket_query_params(test_client_factory: Callable[..., TestClient]):
+def test_websocket_query_params(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         query_params = dict(websocket.query_params)
@@ -76,7 +78,7 @@ def test_websocket_query_params(test_client_factory: Callable[..., TestClient]):
     any(module in sys.modules for module in ("brotli", "brotlicffi")),
     reason='urllib3 includes "br" to the "accept-encoding" headers.',
 )
-def test_websocket_headers(test_client_factory: Callable[..., TestClient]):
+def test_websocket_headers(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         headers = dict(websocket.headers)
@@ -99,7 +101,7 @@ def test_websocket_headers(test_client_factory: Callable[..., TestClient]):
         assert data == {"headers": expected_headers}
 
 
-def test_websocket_port(test_client_factory: Callable[..., TestClient]):
+def test_websocket_port(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -113,8 +115,8 @@ def test_websocket_port(test_client_factory: Callable[..., TestClient]):
 
 
 def test_websocket_send_and_receive_text(
-    test_client_factory: Callable[..., TestClient],
-):
+    test_client_factory: TestClientFactory,
+) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -130,8 +132,8 @@ def test_websocket_send_and_receive_text(
 
 
 def test_websocket_send_and_receive_bytes(
-    test_client_factory: Callable[..., TestClient],
-):
+    test_client_factory: TestClientFactory,
+) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -147,8 +149,8 @@ def test_websocket_send_and_receive_bytes(
 
 
 def test_websocket_send_and_receive_json(
-    test_client_factory: Callable[..., TestClient],
-):
+    test_client_factory: TestClientFactory,
+) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -163,7 +165,7 @@ def test_websocket_send_and_receive_json(
         assert data == {"message": {"hello": "world"}}
 
 
-def test_websocket_iter_text(test_client_factory: Callable[..., TestClient]):
+def test_websocket_iter_text(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -177,7 +179,7 @@ def test_websocket_iter_text(test_client_factory: Callable[..., TestClient]):
         assert data == "Message was: Hello, world!"
 
 
-def test_websocket_iter_bytes(test_client_factory: Callable[..., TestClient]):
+def test_websocket_iter_bytes(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -191,7 +193,7 @@ def test_websocket_iter_bytes(test_client_factory: Callable[..., TestClient]):
         assert data == b"Message was: Hello, world!"
 
 
-def test_websocket_iter_json(test_client_factory: Callable[..., TestClient]):
+def test_websocket_iter_json(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -205,17 +207,17 @@ def test_websocket_iter_json(test_client_factory: Callable[..., TestClient]):
         assert data == {"message": {"hello": "world"}}
 
 
-def test_websocket_concurrency_pattern(test_client_factory: Callable[..., TestClient]):
+def test_websocket_concurrency_pattern(test_client_factory: TestClientFactory) -> None:
     stream_send: ObjectSendStream[MutableMapping[str, Any]]
     stream_receive: ObjectReceiveStream[MutableMapping[str, Any]]
     stream_send, stream_receive = anyio.create_memory_object_stream()
 
-    async def reader(websocket: WebSocket):
+    async def reader(websocket: WebSocket) -> None:
         async with stream_send:
             async for data in websocket.iter_json():
                 await stream_send.send(data)
 
-    async def writer(websocket: WebSocket):
+    async def writer(websocket: WebSocket) -> None:
         async with stream_receive:
             async for message in stream_receive:
                 await websocket.send_json(message)
@@ -235,7 +237,7 @@ def test_websocket_concurrency_pattern(test_client_factory: Callable[..., TestCl
         assert data == {"hello": "world"}
 
 
-def test_client_close(test_client_factory: Callable[..., TestClient]):
+def test_client_close(test_client_factory: TestClientFactory) -> None:
     close_code = None
     close_reason = None
 
@@ -257,7 +259,7 @@ def test_client_close(test_client_factory: Callable[..., TestClient]):
 
 
 @pytest.mark.anyio
-async def test_client_disconnect_on_send():
+async def test_client_disconnect_on_send() -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -278,7 +280,7 @@ async def test_client_disconnect_on_send():
     assert ctx.value.code == status.WS_1006_ABNORMAL_CLOSURE
 
 
-def test_application_close(test_client_factory: Callable[..., TestClient]):
+def test_application_close(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -291,7 +293,7 @@ def test_application_close(test_client_factory: Callable[..., TestClient]):
         assert exc.value.code == status.WS_1001_GOING_AWAY
 
 
-def test_rejected_connection(test_client_factory: Callable[..., TestClient]):
+def test_rejected_connection(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         msg = await websocket.receive()
@@ -305,7 +307,7 @@ def test_rejected_connection(test_client_factory: Callable[..., TestClient]):
     assert exc.value.code == status.WS_1001_GOING_AWAY
 
 
-def test_send_denial_response(test_client_factory: Callable[..., TestClient]):
+def test_send_denial_response(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         msg = await websocket.receive()
@@ -321,7 +323,7 @@ def test_send_denial_response(test_client_factory: Callable[..., TestClient]):
     assert exc.value.content == b"foo"
 
 
-def test_send_response_multi(test_client_factory: Callable[..., TestClient]):
+def test_send_response_multi(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         msg = await websocket.receive()
@@ -356,7 +358,7 @@ def test_send_response_multi(test_client_factory: Callable[..., TestClient]):
     assert exc.value.headers["foo"] == "bar"
 
 
-def test_send_response_unsupported(test_client_factory: Callable[..., TestClient]):
+def test_send_response_unsupported(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         del scope["extensions"]["websocket.http.response"]
         websocket = WebSocket(scope, receive=receive, send=send)
@@ -377,7 +379,7 @@ def test_send_response_unsupported(test_client_factory: Callable[..., TestClient
     assert exc.value.code == status.WS_1000_NORMAL_CLOSURE
 
 
-def test_send_response_duplicate_start(test_client_factory: Callable[..., TestClient]):
+def test_send_response_duplicate_start(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         msg = await websocket.receive()
@@ -410,7 +412,7 @@ def test_send_response_duplicate_start(test_client_factory: Callable[..., TestCl
             pass  # pragma: no cover
 
 
-def test_subprotocol(test_client_factory: Callable[..., TestClient]):
+def test_subprotocol(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         assert websocket["subprotocols"] == ["soap", "wamp"]
@@ -422,7 +424,7 @@ def test_subprotocol(test_client_factory: Callable[..., TestClient]):
         assert websocket.accepted_subprotocol == "wamp"
 
 
-def test_additional_headers(test_client_factory: Callable[..., TestClient]):
+def test_additional_headers(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept(headers=[(b"additional", b"header")])
@@ -433,7 +435,7 @@ def test_additional_headers(test_client_factory: Callable[..., TestClient]):
         assert websocket.extra_headers == [(b"additional", b"header")]
 
 
-def test_no_additional_headers(test_client_factory: Callable[..., TestClient]):
+def test_no_additional_headers(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -444,7 +446,7 @@ def test_no_additional_headers(test_client_factory: Callable[..., TestClient]):
         assert websocket.extra_headers == []
 
 
-def test_websocket_exception(test_client_factory: Callable[..., TestClient]):
+def test_websocket_exception(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         assert False
 
@@ -454,7 +456,7 @@ def test_websocket_exception(test_client_factory: Callable[..., TestClient]):
             pass  # pragma: no cover
 
 
-def test_duplicate_close(test_client_factory: Callable[..., TestClient]):
+def test_duplicate_close(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -467,7 +469,7 @@ def test_duplicate_close(test_client_factory: Callable[..., TestClient]):
             pass  # pragma: no cover
 
 
-def test_duplicate_disconnect(test_client_factory: Callable[..., TestClient]):
+def test_duplicate_disconnect(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -481,7 +483,7 @@ def test_duplicate_disconnect(test_client_factory: Callable[..., TestClient]):
             websocket.close()
 
 
-def test_websocket_scope_interface():
+def test_websocket_scope_interface() -> None:
     """
     A WebSocket can be instantiated with a scope, and presents a `Mapping`
     interface.
@@ -513,7 +515,7 @@ def test_websocket_scope_interface():
     assert {websocket} == {websocket}
 
 
-def test_websocket_close_reason(test_client_factory: Callable[..., TestClient]) -> None:
+def test_websocket_close_reason(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -527,7 +529,7 @@ def test_websocket_close_reason(test_client_factory: Callable[..., TestClient])
         assert exc.value.reason == "Going Away"
 
 
-def test_send_json_invalid_mode(test_client_factory: Callable[..., TestClient]):
+def test_send_json_invalid_mode(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -539,7 +541,7 @@ def test_send_json_invalid_mode(test_client_factory: Callable[..., TestClient]):
             pass  # pragma: no cover
 
 
-def test_receive_json_invalid_mode(test_client_factory: Callable[..., TestClient]):
+def test_receive_json_invalid_mode(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -551,7 +553,7 @@ def test_receive_json_invalid_mode(test_client_factory: Callable[..., TestClient
             pass  # pragma: nocover
 
 
-def test_receive_text_before_accept(test_client_factory: Callable[..., TestClient]):
+def test_receive_text_before_accept(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.receive_text()
@@ -562,7 +564,7 @@ def test_receive_text_before_accept(test_client_factory: Callable[..., TestClien
             pass  # pragma: nocover
 
 
-def test_receive_bytes_before_accept(test_client_factory: Callable[..., TestClient]):
+def test_receive_bytes_before_accept(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.receive_bytes()
@@ -573,7 +575,7 @@ def test_receive_bytes_before_accept(test_client_factory: Callable[..., TestClie
             pass  # pragma: nocover
 
 
-def test_receive_json_before_accept(test_client_factory: Callable[..., TestClient]):
+def test_receive_json_before_accept(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.receive_json()
@@ -584,7 +586,7 @@ def test_receive_json_before_accept(test_client_factory: Callable[..., TestClien
             pass  # pragma: no cover
 
 
-def test_send_before_accept(test_client_factory: Callable[..., TestClient]):
+def test_send_before_accept(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.send({"type": "websocket.send"})
@@ -595,7 +597,7 @@ def test_send_before_accept(test_client_factory: Callable[..., TestClient]):
             pass  # pragma: nocover
 
 
-def test_send_wrong_message_type(test_client_factory: Callable[..., TestClient]):
+def test_send_wrong_message_type(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.send({"type": "websocket.accept"})
@@ -607,7 +609,7 @@ def test_send_wrong_message_type(test_client_factory: Callable[..., TestClient])
             pass  # pragma: no cover
 
 
-def test_receive_before_accept(test_client_factory: Callable[..., TestClient]):
+def test_receive_before_accept(test_client_factory: TestClientFactory) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -620,8 +622,8 @@ def test_receive_before_accept(test_client_factory: Callable[..., TestClient]):
             websocket.send({"type": "websocket.send"})
 
 
-def test_receive_wrong_message_type(test_client_factory: Callable[..., TestClient]):
-    async def app(scope: Scope, receive: Receive, send: Send):
+def test_receive_wrong_message_type(test_client_factory: TestClientFactory) -> None:
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
         await websocket.receive()