]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
collect errors more reliably from websocket test client (#2814)
authorThomas Grainger <tagrain@gmail.com>
Sun, 29 Dec 2024 11:00:04 +0000 (11:00 +0000)
committerGitHub <noreply@github.com>
Sun, 29 Dec 2024 11:00:04 +0000 (12:00 +0100)
starlette/testclient.py
tests/test_testclient.py

index 8e908d36f7cc7ec7f7d6d1dc913fb8cd1c1f4053..a14f646d4a0190c08eb38f733a83353acbba0bd7 100644 (file)
@@ -1,6 +1,7 @@
 from __future__ import annotations
 
 import contextlib
+import enum
 import inspect
 import io
 import json
@@ -9,7 +10,6 @@ import queue
 import sys
 import typing
 from concurrent.futures import Future
-from functools import cached_property
 from types import GeneratorType
 from urllib.parse import unquote, urljoin
 
@@ -85,6 +85,14 @@ class WebSocketDenialResponse(  # type: ignore[misc]
     """
 
 
+class _Eof(enum.Enum):
+    EOF = enum.auto()
+
+
+EOF: typing.Final = _Eof.EOF
+Eof = typing.Literal[_Eof.EOF]
+
+
 class WebSocketTestSession:
     def __init__(
         self,
@@ -97,63 +105,47 @@ class WebSocketTestSession:
         self.accepted_subprotocol = None
         self.portal_factory = portal_factory
         self._receive_queue: queue.Queue[Message] = queue.Queue()
-        self._send_queue: queue.Queue[Message | BaseException] = queue.Queue()
+        self._send_queue: queue.Queue[Message | Eof | BaseException] = queue.Queue()
         self.extra_headers = None
 
     def __enter__(self) -> WebSocketTestSession:
-        self.exit_stack = contextlib.ExitStack()
-        self.portal = self.exit_stack.enter_context(self.portal_factory())
-
-        try:
-            _: Future[None] = self.portal.start_task_soon(self._run)
+        with contextlib.ExitStack() as stack:
+            self.portal = portal = stack.enter_context(self.portal_factory())
+            fut, cs = portal.start_task(self._run)
+            stack.callback(fut.result)
+            stack.callback(portal.call, cs.cancel)
             self.send({"type": "websocket.connect"})
             message = self.receive()
             self._raise_on_close(message)
-        except Exception:
-            self.exit_stack.close()
-            raise
-        self.accepted_subprotocol = message.get("subprotocol", None)
-        self.extra_headers = message.get("headers", None)
-        return self
-
-    @cached_property
-    def should_close(self) -> anyio.Event:
-        return anyio.Event()
-
-    async def _notify_close(self) -> None:
-        self.should_close.set()
+            self.accepted_subprotocol = message.get("subprotocol", None)
+            self.extra_headers = message.get("headers", None)
+            stack.callback(self.close, 1000)
+            self.exit_stack = stack.pop_all()
+            return self
 
     def __exit__(self, *args: typing.Any) -> None:
-        try:
-            self.close(1000)
-        finally:
-            self.portal.start_task_soon(self._notify_close)
-            self.exit_stack.close()
-        while not self._send_queue.empty():
+        self.exit_stack.close()
+
+        while True:
             message = self._send_queue.get()
+            if message is EOF:
+                break
             if isinstance(message, BaseException):
-                raise message
+                raise message  # pragma: no cover (defensive, should be impossible)
 
-    async def _run(self) -> None:
+    async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None:
         """
         The sub-thread in which the websocket session runs.
         """
-
-        async def run_app(tg: anyio.abc.TaskGroup) -> None:
-            try:
+        try:
+            with anyio.CancelScope() as cs:
+                task_status.started(cs)
                 await self.app(self.scope, self._asgi_receive, self._asgi_send)
-            except anyio.get_cancelled_exc_class():
-                ...
-            except BaseException as exc:
-                self._send_queue.put(exc)
-                raise
-            finally:
-                tg.cancel_scope.cancel()
-
-        async with anyio.create_task_group() as tg:
-            tg.start_soon(run_app, tg)
-            await self.should_close.wait()
-            tg.cancel_scope.cancel()
+        except BaseException as exc:
+            self._send_queue.put(exc)
+            raise
+        finally:
+            self._send_queue.put(EOF)  # TODO: use self._send_queue.shutdown() on 3.13+
 
     async def _asgi_receive(self) -> Message:
         while self._receive_queue.empty():
@@ -202,6 +194,7 @@ class WebSocketTestSession:
 
     def receive(self) -> Message:
         message = self._send_queue.get()
+        assert message is not EOF
         if isinstance(message, BaseException):
             raise message
         return message
index 58ab6f6f23123e80703c7135c4412f5b331df198..478dbca4692cfb97f7757d746bd81ed0581a747c 100644 (file)
@@ -255,19 +255,25 @@ def test_websocket_blocking_receive(test_client_factory: TestClientFactory) -> N
 
 
 def test_websocket_not_block_on_close(test_client_factory: TestClientFactory) -> None:
+    cancelled = False
+
     def app(scope: Scope) -> ASGIInstance:
         async def asgi(receive: Receive, send: Send) -> None:
-            websocket = WebSocket(scope, receive=receive, send=send)
-            await websocket.accept()
-            while True:
-                await anyio.sleep(0.1)
+            nonlocal cancelled
+            try:
+                websocket = WebSocket(scope, receive=receive, send=send)
+                await websocket.accept()
+                await anyio.sleep_forever()
+            except anyio.get_cancelled_exc_class():
+                cancelled = True
+                raise
 
         return asgi
 
     client = test_client_factory(app)  # type: ignore
-    with client.websocket_connect("/") as websocket:
+    with client.websocket_connect("/"):
         ...
-    assert websocket.should_close.is_set()
+    assert cancelled
 
 
 def test_client(test_client_factory: TestClientFactory) -> None: