]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Convert ASGI 2.0 apps in tests to ASGI 3 (#1476)
authorAdrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>
Mon, 7 Feb 2022 14:58:30 +0000 (08:58 -0600)
committerGitHub <noreply@github.com>
Mon, 7 Feb 2022 14:58:30 +0000 (08:58 -0600)
tests/test_exceptions.py
tests/test_websockets.py

index 80307a521a2ab37a7238565fc79293606118ee97..50f677467411fcacdace572d62a33442c67bc41a 100644 (file)
@@ -99,11 +99,19 @@ def test_handled_exc_after_response(test_client_factory, client):
 
 
 def test_force_500_response(test_client_factory):
-    def app(scope):
+    # use a sentinal variable to make sure we actually
+    # make it into the endpoint and don't get a 500
+    # from an incorrect ASGI app signature or something
+    called = False
+
+    async def app(scope, receive, send):
+        nonlocal called
+        called = True
         raise RuntimeError()
 
     force_500_client = test_client_factory(app, raise_server_exceptions=False)
     response = force_500_client.get("/")
+    assert called
     assert response.status_code == 500
     assert response.text == ""
 
index e3a52762a0dc7747830f2fcf7808f66897dd53f5..f3970967eaebda35ac319637b25562c2fd6a7e9d 100644 (file)
@@ -2,18 +2,16 @@ import anyio
 import pytest
 
 from starlette import status
+from starlette.types import Receive, Scope, Send
 from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState
 
 
 def test_websocket_url(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.accept()
-            await websocket.send_json({"url": str(websocket.url)})
-            await websocket.close()
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.accept()
+        await websocket.send_json({"url": str(websocket.url)})
+        await websocket.close()
 
     client = test_client_factory(app)
     with client.websocket_connect("/123?a=abc") as websocket:
@@ -22,15 +20,12 @@ def test_websocket_url(test_client_factory):
 
 
 def test_websocket_binary_json(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.accept()
-            message = await websocket.receive_json(mode="binary")
-            await websocket.send_json(message, mode="binary")
-            await websocket.close()
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.accept()
+        message = await websocket.receive_json(mode="binary")
+        await websocket.send_json(message, mode="binary")
+        await websocket.close()
 
     client = test_client_factory(app)
     with client.websocket_connect("/123?a=abc") as websocket:
@@ -40,15 +35,12 @@ def test_websocket_binary_json(test_client_factory):
 
 
 def test_websocket_query_params(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            query_params = dict(websocket.query_params)
-            await websocket.accept()
-            await websocket.send_json({"params": query_params})
-            await websocket.close()
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        query_params = dict(websocket.query_params)
+        await websocket.accept()
+        await websocket.send_json({"params": query_params})
+        await websocket.close()
 
     client = test_client_factory(app)
     with client.websocket_connect("/?a=abc&b=456") as websocket:
@@ -57,15 +49,12 @@ def test_websocket_query_params(test_client_factory):
 
 
 def test_websocket_headers(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            headers = dict(websocket.headers)
-            await websocket.accept()
-            await websocket.send_json({"headers": headers})
-            await websocket.close()
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        headers = dict(websocket.headers)
+        await websocket.accept()
+        await websocket.send_json({"headers": headers})
+        await websocket.close()
 
     client = test_client_factory(app)
     with client.websocket_connect("/") as websocket:
@@ -83,14 +72,11 @@ def test_websocket_headers(test_client_factory):
 
 
 def test_websocket_port(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.accept()
-            await websocket.send_json({"port": websocket.url.port})
-            await websocket.close()
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.accept()
+        await websocket.send_json({"port": websocket.url.port})
+        await websocket.close()
 
     client = test_client_factory(app)
     with client.websocket_connect("ws://example.com:123/123?a=abc") as websocket:
@@ -99,15 +85,12 @@ def test_websocket_port(test_client_factory):
 
 
 def test_websocket_send_and_receive_text(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.accept()
-            data = await websocket.receive_text()
-            await websocket.send_text("Message was: " + data)
-            await websocket.close()
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.accept()
+        data = await websocket.receive_text()
+        await websocket.send_text("Message was: " + data)
+        await websocket.close()
 
     client = test_client_factory(app)
     with client.websocket_connect("/") as websocket:
@@ -117,15 +100,12 @@ def test_websocket_send_and_receive_text(test_client_factory):
 
 
 def test_websocket_send_and_receive_bytes(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.accept()
-            data = await websocket.receive_bytes()
-            await websocket.send_bytes(b"Message was: " + data)
-            await websocket.close()
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.accept()
+        data = await websocket.receive_bytes()
+        await websocket.send_bytes(b"Message was: " + data)
+        await websocket.close()
 
     client = test_client_factory(app)
     with client.websocket_connect("/") as websocket:
@@ -135,15 +115,12 @@ def test_websocket_send_and_receive_bytes(test_client_factory):
 
 
 def test_websocket_send_and_receive_json(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.accept()
-            data = await websocket.receive_json()
-            await websocket.send_json({"message": data})
-            await websocket.close()
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.accept()
+        data = await websocket.receive_json()
+        await websocket.send_json({"message": data})
+        await websocket.close()
 
     client = test_client_factory(app)
     with client.websocket_connect("/") as websocket:
@@ -153,14 +130,11 @@ def test_websocket_send_and_receive_json(test_client_factory):
 
 
 def test_websocket_iter_text(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.accept()
-            async for data in websocket.iter_text():
-                await websocket.send_text("Message was: " + data)
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.accept()
+        async for data in websocket.iter_text():
+            await websocket.send_text("Message was: " + data)
 
     client = test_client_factory(app)
     with client.websocket_connect("/") as websocket:
@@ -170,14 +144,11 @@ def test_websocket_iter_text(test_client_factory):
 
 
 def test_websocket_iter_bytes(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.accept()
-            async for data in websocket.iter_bytes():
-                await websocket.send_bytes(b"Message was: " + data)
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.accept()
+        async for data in websocket.iter_bytes():
+            await websocket.send_bytes(b"Message was: " + data)
 
     client = test_client_factory(app)
     with client.websocket_connect("/") as websocket:
@@ -187,14 +158,11 @@ def test_websocket_iter_bytes(test_client_factory):
 
 
 def test_websocket_iter_json(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.accept()
-            async for data in websocket.iter_json():
-                await websocket.send_json({"message": data})
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.accept()
+        async for data in websocket.iter_json():
+            await websocket.send_json({"message": data})
 
     client = test_client_factory(app)
     with client.websocket_connect("/") as websocket:
@@ -204,28 +172,25 @@ def test_websocket_iter_json(test_client_factory):
 
 
 def test_websocket_concurrency_pattern(test_client_factory):
-    def app(scope):
-        stream_send, stream_receive = anyio.create_memory_object_stream()
-
-        async def reader(websocket):
-            async with stream_send:
-                async for data in websocket.iter_json():
-                    await stream_send.send(data)
+    stream_send, stream_receive = anyio.create_memory_object_stream()
 
-        async def writer(websocket):
-            async with stream_receive:
-                async for message in stream_receive:
-                    await websocket.send_json(message)
+    async def reader(websocket):
+        async with stream_send:
+            async for data in websocket.iter_json():
+                await stream_send.send(data)
 
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.accept()
-            async with anyio.create_task_group() as task_group:
-                task_group.start_soon(reader, websocket)
-                await writer(websocket)
-            await websocket.close()
+    async def writer(websocket):
+        async with stream_receive:
+            async for message in stream_receive:
+                await websocket.send_json(message)
 
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.accept()
+        async with anyio.create_task_group() as task_group:
+            task_group.start_soon(reader, websocket)
+            await writer(websocket)
+        await websocket.close()
 
     client = test_client_factory(app)
     with client.websocket_connect("/") as websocket:
@@ -237,17 +202,14 @@ def test_websocket_concurrency_pattern(test_client_factory):
 def test_client_close(test_client_factory):
     close_code = None
 
-    def app(scope):
-        async def asgi(receive, send):
-            nonlocal close_code
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.accept()
-            try:
-                await websocket.receive_text()
-            except WebSocketDisconnect as exc:
-                close_code = exc.code
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        nonlocal close_code
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.accept()
+        try:
+            await websocket.receive_text()
+        except WebSocketDisconnect as exc:
+            close_code = exc.code
 
     client = test_client_factory(app)
     with client.websocket_connect("/") as websocket:
@@ -256,13 +218,10 @@ def test_client_close(test_client_factory):
 
 
 def test_application_close(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.accept()
-            await websocket.close(status.WS_1001_GOING_AWAY)
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.accept()
+        await websocket.close(status.WS_1001_GOING_AWAY)
 
     client = test_client_factory(app)
     with client.websocket_connect("/") as websocket:
@@ -272,12 +231,9 @@ def test_application_close(test_client_factory):
 
 
 def test_rejected_connection(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.close(status.WS_1001_GOING_AWAY)
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.close(status.WS_1001_GOING_AWAY)
 
     client = test_client_factory(app)
     with pytest.raises(WebSocketDisconnect) as exc:
@@ -287,14 +243,11 @@ def test_rejected_connection(test_client_factory):
 
 
 def test_subprotocol(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            assert websocket["subprotocols"] == ["soap", "wamp"]
-            await websocket.accept(subprotocol="wamp")
-            await websocket.close()
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        assert websocket["subprotocols"] == ["soap", "wamp"]
+        await websocket.accept(subprotocol="wamp")
+        await websocket.close()
 
     client = test_client_factory(app)
     with client.websocket_connect("/", subprotocols=["soap", "wamp"]) as websocket:
@@ -302,13 +255,10 @@ def test_subprotocol(test_client_factory):
 
 
 def test_additional_headers(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.accept(headers=[(b"additional", b"header")])
-            await websocket.close()
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.accept(headers=[(b"additional", b"header")])
+        await websocket.close()
 
     client = test_client_factory(app)
     with client.websocket_connect("/") as websocket:
@@ -316,13 +266,10 @@ def test_additional_headers(test_client_factory):
 
 
 def test_no_additional_headers(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.accept()
-            await websocket.close()
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.accept()
+        await websocket.close()
 
     client = test_client_factory(app)
     with client.websocket_connect("/") as websocket:
@@ -330,11 +277,8 @@ def test_no_additional_headers(test_client_factory):
 
 
 def test_websocket_exception(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            assert False
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        assert False
 
     client = test_client_factory(app)
     with pytest.raises(AssertionError):
@@ -343,14 +287,11 @@ def test_websocket_exception(test_client_factory):
 
 
 def test_duplicate_close(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.accept()
-            await websocket.close()
-            await websocket.close()
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.accept()
+        await websocket.close()
+        await websocket.close()
 
     client = test_client_factory(app)
     with pytest.raises(RuntimeError):
@@ -359,15 +300,12 @@ def test_duplicate_close(test_client_factory):
 
 
 def test_duplicate_disconnect(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.accept()
-            message = await websocket.receive()
-            assert message["type"] == "websocket.disconnect"
-            message = await websocket.receive()
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.accept()
+        message = await websocket.receive()
+        assert message["type"] == "websocket.disconnect"
+        message = await websocket.receive()
 
     client = test_client_factory(app)
     with pytest.raises(RuntimeError):
@@ -408,13 +346,10 @@ def test_websocket_scope_interface():
 
 
 def test_websocket_close_reason(test_client_factory) -> None:
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.accept()
-            await websocket.close(code=status.WS_1001_GOING_AWAY, reason="Going Away")
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.accept()
+        await websocket.close(code=status.WS_1001_GOING_AWAY, reason="Going Away")
 
     client = test_client_factory(app)
     with client.websocket_connect("/") as websocket:
@@ -425,13 +360,10 @@ def test_websocket_close_reason(test_client_factory) -> None:
 
 
 def test_send_json_invalid_mode(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.accept()
-            await websocket.send_json({}, mode="invalid")
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.accept()
+        await websocket.send_json({}, mode="invalid")
 
     client = test_client_factory(app)
     with pytest.raises(RuntimeError):
@@ -440,13 +372,10 @@ def test_send_json_invalid_mode(test_client_factory):
 
 
 def test_receive_json_invalid_mode(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.accept()
-            await websocket.receive_json(mode="invalid")
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.accept()
+        await websocket.receive_json(mode="invalid")
 
     client = test_client_factory(app)
     with pytest.raises(RuntimeError):
@@ -455,12 +384,9 @@ def test_receive_json_invalid_mode(test_client_factory):
 
 
 def test_receive_text_before_accept(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.receive_text()
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.receive_text()
 
     client = test_client_factory(app)
     with pytest.raises(RuntimeError):
@@ -469,12 +395,9 @@ def test_receive_text_before_accept(test_client_factory):
 
 
 def test_receive_bytes_before_accept(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.receive_bytes()
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.receive_bytes()
 
     client = test_client_factory(app)
     with pytest.raises(RuntimeError):
@@ -483,12 +406,9 @@ def test_receive_bytes_before_accept(test_client_factory):
 
 
 def test_receive_json_before_accept(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.receive_json()
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.receive_json()
 
     client = test_client_factory(app)
     with pytest.raises(RuntimeError):
@@ -497,12 +417,9 @@ def test_receive_json_before_accept(test_client_factory):
 
 
 def test_send_before_accept(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.send({"type": "websocket.send"})
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.send({"type": "websocket.send"})
 
     client = test_client_factory(app)
     with pytest.raises(RuntimeError):
@@ -511,13 +428,10 @@ def test_send_before_accept(test_client_factory):
 
 
 def test_send_wrong_message_type(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.send({"type": "websocket.accept"})
-            await websocket.send({"type": "websocket.accept"})
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.send({"type": "websocket.accept"})
+        await websocket.send({"type": "websocket.accept"})
 
     client = test_client_factory(app)
     with pytest.raises(RuntimeError):
@@ -526,14 +440,11 @@ def test_send_wrong_message_type(test_client_factory):
 
 
 def test_receive_before_accept(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.accept()
-            websocket.client_state = WebSocketState.CONNECTING
-            await websocket.receive()
-
-        return asgi
+    async def app(scope: Scope, receive: Receive, send: Send) -> None:
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.accept()
+        websocket.client_state = WebSocketState.CONNECTING
+        await websocket.receive()
 
     client = test_client_factory(app)
     with pytest.raises(RuntimeError):
@@ -542,13 +453,10 @@ def test_receive_before_accept(test_client_factory):
 
 
 def test_receive_wrong_message_type(test_client_factory):
-    def app(scope):
-        async def asgi(receive, send):
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.accept()
-            await websocket.receive()
-
-        return asgi
+    async def app(scope, receive, send):
+        websocket = WebSocket(scope, receive=receive, send=send)
+        await websocket.accept()
+        await websocket.receive()
 
     client = test_client_factory(app)
     with pytest.raises(RuntimeError):