]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add type hint on `test_websockets.py` (#2411)
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Fri, 12 Jan 2024 10:06:48 +0000 (11:06 +0100)
committerGitHub <noreply@github.com>
Fri, 12 Jan 2024 10:06:48 +0000 (03:06 -0700)
* Add type hint on `test_websockets.py`

* Add type ignore on mock_receive

tests/test_websockets.py

index 41ab82e18001bf579aec6cf366f670e46842590c..283dcfc78aa964b7eb7f6ea95b2af5cdcd415732 100644 (file)
@@ -7,11 +7,11 @@ from anyio.abc import ObjectReceiveStream, ObjectSendStream
 
 from starlette import status
 from starlette.testclient import TestClient
-from starlette.types import Receive, Scope, Send
+from starlette.types import Message, Receive, Scope, Send
 from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState
 
 
-def test_websocket_url(test_client_factory):
+def test_websocket_url(test_client_factory: Callable[..., TestClient]):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -24,7 +24,7 @@ def test_websocket_url(test_client_factory):
         assert data == {"url": "ws://testserver/123?a=abc"}
 
 
-def test_websocket_binary_json(test_client_factory):
+def test_websocket_binary_json(test_client_factory: Callable[..., TestClient]):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -57,7 +57,7 @@ def test_websocket_ensure_unicode_on_send_json(
         assert data == '{"test":"数据"}'
 
 
-def test_websocket_query_params(test_client_factory):
+def test_websocket_query_params(test_client_factory: Callable[..., TestClient]):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         query_params = dict(websocket.query_params)
@@ -75,7 +75,7 @@ def test_websocket_query_params(test_client_factory):
     any(module in sys.modules for module in ("brotli", "brotlicffi")),
     reason='urllib3 includes "br" to the "accept-encoding" headers.',
 )
-def test_websocket_headers(test_client_factory):
+def test_websocket_headers(test_client_factory: Callable[..., TestClient]):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         headers = dict(websocket.headers)
@@ -98,7 +98,7 @@ def test_websocket_headers(test_client_factory):
         assert data == {"headers": expected_headers}
 
 
-def test_websocket_port(test_client_factory):
+def test_websocket_port(test_client_factory: Callable[..., TestClient]):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -111,7 +111,9 @@ def test_websocket_port(test_client_factory):
         assert data == {"port": 123}
 
 
-def test_websocket_send_and_receive_text(test_client_factory):
+def test_websocket_send_and_receive_text(
+    test_client_factory: Callable[..., TestClient],
+):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -126,7 +128,9 @@ def test_websocket_send_and_receive_text(test_client_factory):
         assert data == "Message was: Hello, world!"
 
 
-def test_websocket_send_and_receive_bytes(test_client_factory):
+def test_websocket_send_and_receive_bytes(
+    test_client_factory: Callable[..., TestClient],
+):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -141,7 +145,9 @@ def test_websocket_send_and_receive_bytes(test_client_factory):
         assert data == b"Message was: Hello, world!"
 
 
-def test_websocket_send_and_receive_json(test_client_factory):
+def test_websocket_send_and_receive_json(
+    test_client_factory: Callable[..., TestClient],
+):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -156,7 +162,7 @@ def test_websocket_send_and_receive_json(test_client_factory):
         assert data == {"message": {"hello": "world"}}
 
 
-def test_websocket_iter_text(test_client_factory):
+def test_websocket_iter_text(test_client_factory: Callable[..., TestClient]):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -170,7 +176,7 @@ def test_websocket_iter_text(test_client_factory):
         assert data == "Message was: Hello, world!"
 
 
-def test_websocket_iter_bytes(test_client_factory):
+def test_websocket_iter_bytes(test_client_factory: Callable[..., TestClient]):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -184,7 +190,7 @@ def test_websocket_iter_bytes(test_client_factory):
         assert data == b"Message was: Hello, world!"
 
 
-def test_websocket_iter_json(test_client_factory):
+def test_websocket_iter_json(test_client_factory: Callable[..., TestClient]):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -198,17 +204,17 @@ def test_websocket_iter_json(test_client_factory):
         assert data == {"message": {"hello": "world"}}
 
 
-def test_websocket_concurrency_pattern(test_client_factory):
+def test_websocket_concurrency_pattern(test_client_factory: Callable[..., TestClient]):
     stream_send: ObjectSendStream[MutableMapping[str, Any]]
     stream_receive: ObjectReceiveStream[MutableMapping[str, Any]]
     stream_send, stream_receive = anyio.create_memory_object_stream()
 
-    async def reader(websocket):
+    async def reader(websocket: WebSocket):
         async with stream_send:
             async for data in websocket.iter_json():
                 await stream_send.send(data)
 
-    async def writer(websocket):
+    async def writer(websocket: WebSocket):
         async with stream_receive:
             async for message in stream_receive:
                 await websocket.send_json(message)
@@ -249,7 +255,7 @@ def test_client_close(test_client_factory: Callable[..., TestClient]):
     assert close_reason == "Going Away"
 
 
-def test_application_close(test_client_factory):
+def test_application_close(test_client_factory: Callable[..., TestClient]):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -262,7 +268,7 @@ def test_application_close(test_client_factory):
         assert exc.value.code == status.WS_1001_GOING_AWAY
 
 
-def test_rejected_connection(test_client_factory):
+def test_rejected_connection(test_client_factory: Callable[..., TestClient]):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.close(status.WS_1001_GOING_AWAY)
@@ -270,11 +276,11 @@ def test_rejected_connection(test_client_factory):
     client = test_client_factory(app)
     with pytest.raises(WebSocketDisconnect) as exc:
         with client.websocket_connect("/"):
-            pass  # pragma: nocover
+            pass  # pragma: no cover
     assert exc.value.code == status.WS_1001_GOING_AWAY
 
 
-def test_subprotocol(test_client_factory):
+def test_subprotocol(test_client_factory: Callable[..., TestClient]):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         assert websocket["subprotocols"] == ["soap", "wamp"]
@@ -286,7 +292,7 @@ def test_subprotocol(test_client_factory):
         assert websocket.accepted_subprotocol == "wamp"
 
 
-def test_additional_headers(test_client_factory):
+def test_additional_headers(test_client_factory: Callable[..., TestClient]):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept(headers=[(b"additional", b"header")])
@@ -297,7 +303,7 @@ def test_additional_headers(test_client_factory):
         assert websocket.extra_headers == [(b"additional", b"header")]
 
 
-def test_no_additional_headers(test_client_factory):
+def test_no_additional_headers(test_client_factory: Callable[..., TestClient]):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -308,17 +314,17 @@ def test_no_additional_headers(test_client_factory):
         assert websocket.extra_headers == []
 
 
-def test_websocket_exception(test_client_factory):
+def test_websocket_exception(test_client_factory: Callable[..., TestClient]):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         assert False
 
     client = test_client_factory(app)
     with pytest.raises(AssertionError):
         with client.websocket_connect("/123?a=abc"):
-            pass  # pragma: nocover
+            pass  # pragma: no cover
 
 
-def test_duplicate_close(test_client_factory):
+def test_duplicate_close(test_client_factory: Callable[..., TestClient]):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -328,10 +334,10 @@ def test_duplicate_close(test_client_factory):
     client = test_client_factory(app)
     with pytest.raises(RuntimeError):
         with client.websocket_connect("/"):
-            pass  # pragma: nocover
+            pass  # pragma: no cover
 
 
-def test_duplicate_disconnect(test_client_factory):
+def test_duplicate_disconnect(test_client_factory: Callable[..., TestClient]):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -351,11 +357,11 @@ def test_websocket_scope_interface():
     interface.
     """
 
-    async def mock_receive():
-        pass  # pragma: no cover
+    async def mock_receive() -> Message:  # type: ignore
+        ...  # pragma: no cover
 
-    async def mock_send(message):
-        pass  # pragma: no cover
+    async def mock_send(message: Message) -> None:
+        ...  # pragma: no cover
 
     websocket = WebSocket(
         {"type": "websocket", "path": "/abc/", "headers": []},
@@ -377,7 +383,7 @@ def test_websocket_scope_interface():
     assert {websocket} == {websocket}
 
 
-def test_websocket_close_reason(test_client_factory) -> None:
+def test_websocket_close_reason(test_client_factory: Callable[..., TestClient]) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -391,7 +397,7 @@ def test_websocket_close_reason(test_client_factory) -> None:
         assert exc.value.reason == "Going Away"
 
 
-def test_send_json_invalid_mode(test_client_factory):
+def test_send_json_invalid_mode(test_client_factory: Callable[..., TestClient]):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -400,10 +406,10 @@ def test_send_json_invalid_mode(test_client_factory):
     client = test_client_factory(app)
     with pytest.raises(RuntimeError):
         with client.websocket_connect("/"):
-            pass  # pragma: nocover
+            pass  # pragma: no cover
 
 
-def test_receive_json_invalid_mode(test_client_factory):
+def test_receive_json_invalid_mode(test_client_factory: Callable[..., TestClient]):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -415,7 +421,7 @@ def test_receive_json_invalid_mode(test_client_factory):
             pass  # pragma: nocover
 
 
-def test_receive_text_before_accept(test_client_factory):
+def test_receive_text_before_accept(test_client_factory: Callable[..., TestClient]):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.receive_text()
@@ -426,7 +432,7 @@ def test_receive_text_before_accept(test_client_factory):
             pass  # pragma: nocover
 
 
-def test_receive_bytes_before_accept(test_client_factory):
+def test_receive_bytes_before_accept(test_client_factory: Callable[..., TestClient]):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.receive_bytes()
@@ -437,7 +443,7 @@ def test_receive_bytes_before_accept(test_client_factory):
             pass  # pragma: nocover
 
 
-def test_receive_json_before_accept(test_client_factory):
+def test_receive_json_before_accept(test_client_factory: Callable[..., TestClient]):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.receive_json()
@@ -445,10 +451,10 @@ def test_receive_json_before_accept(test_client_factory):
     client = test_client_factory(app)
     with pytest.raises(RuntimeError):
         with client.websocket_connect("/"):
-            pass  # pragma: nocover
+            pass  # pragma: no cover
 
 
-def test_send_before_accept(test_client_factory):
+def test_send_before_accept(test_client_factory: Callable[..., TestClient]):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.send({"type": "websocket.send"})
@@ -459,7 +465,7 @@ def test_send_before_accept(test_client_factory):
             pass  # pragma: nocover
 
 
-def test_send_wrong_message_type(test_client_factory):
+def test_send_wrong_message_type(test_client_factory: Callable[..., TestClient]):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.send({"type": "websocket.accept"})
@@ -468,10 +474,10 @@ def test_send_wrong_message_type(test_client_factory):
     client = test_client_factory(app)
     with pytest.raises(RuntimeError):
         with client.websocket_connect("/"):
-            pass  # pragma: nocover
+            pass  # pragma: no cover
 
 
-def test_receive_before_accept(test_client_factory):
+def test_receive_before_accept(test_client_factory: Callable[..., TestClient]):
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
@@ -484,8 +490,8 @@ def test_receive_before_accept(test_client_factory):
             websocket.send({"type": "websocket.send"})
 
 
-def test_receive_wrong_message_type(test_client_factory):
-    async def app(scope, receive, send):
+def test_receive_wrong_message_type(test_client_factory: Callable[..., TestClient]):
+    async def app(scope: Scope, receive: Receive, send: Send):
         websocket = WebSocket(scope, receive=receive, send=send)
         await websocket.accept()
         await websocket.receive()