]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
use a pair of memory object streams instead of two queues (#2829)
authorThomas Grainger <tagrain@gmail.com>
Sun, 29 Dec 2024 12:13:31 +0000 (12:13 +0000)
committerGitHub <noreply@github.com>
Sun, 29 Dec 2024 12:13:31 +0000 (12:13 +0000)
starlette/testclient.py

index a14f646d4a0190c08eb38f733a83353acbba0bd7..9a0abbd7b506d2d13644715c6438306576f90763 100644 (file)
@@ -1,12 +1,10 @@
 from __future__ import annotations
 
 import contextlib
-import enum
 import inspect
 import io
 import json
 import math
-import queue
 import sys
 import typing
 from concurrent.futures import Future
@@ -85,14 +83,6 @@ class WebSocketDenialResponse(  # type: ignore[misc]
     """
 
 
-class _Eof(enum.Enum):
-    EOF = enum.auto()
-
-
-EOF: typing.Final = _Eof.EOF
-Eof = typing.Literal[_Eof.EOF]
-
-
 class WebSocketTestSession:
     def __init__(
         self,
@@ -104,8 +94,6 @@ class WebSocketTestSession:
         self.scope = scope
         self.accepted_subprotocol = None
         self.portal_factory = portal_factory
-        self._receive_queue: queue.Queue[Message] = queue.Queue()
-        self._send_queue: queue.Queue[Message | Eof | BaseException] = queue.Queue()
         self.extra_headers = None
 
     def __enter__(self) -> WebSocketTestSession:
@@ -123,38 +111,23 @@ class WebSocketTestSession:
             self.exit_stack = stack.pop_all()
             return self
 
-    def __exit__(self, *args: typing.Any) -> None:
-        self.exit_stack.close()
-
-        while True:
-            message = self._send_queue.get()
-            if message is EOF:
-                break
-            if isinstance(message, BaseException):
-                raise message  # pragma: no cover (defensive, should be impossible)
+    def __exit__(self, *args: typing.Any) -> bool | None:
+        return self.exit_stack.__exit__(*args)
 
     async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None:
         """
         The sub-thread in which the websocket session runs.
         """
-        try:
-            with anyio.CancelScope() as cs:
-                task_status.started(cs)
-                await self.app(self.scope, self._asgi_receive, self._asgi_send)
-        except BaseException as exc:
-            self._send_queue.put(exc)
-            raise
-        finally:
-            self._send_queue.put(EOF)  # TODO: use self._send_queue.shutdown() on 3.13+
-
-    async def _asgi_receive(self) -> Message:
-        while self._receive_queue.empty():
-            self._queue_event = anyio.Event()
-            await self._queue_event.wait()
-        return self._receive_queue.get()
+        send_tx, send_rx = anyio.create_memory_object_stream[Message](math.inf)
+        receive_tx, receive_rx = anyio.create_memory_object_stream[Message](math.inf)
+        with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs:
+            self._receive_tx = receive_tx
+            self._send_rx = send_rx
+            task_status.started(cs)
+            await self.app(self.scope, receive_rx.receive, send_tx.send)
 
-    async def _asgi_send(self, message: Message) -> None:
-        self._send_queue.put(message)
+            # wait for cs.cancel to be called before closing streams
+            await anyio.sleep_forever()
 
     def _raise_on_close(self, message: Message) -> None:
         if message["type"] == "websocket.close":
@@ -172,9 +145,7 @@ class WebSocketTestSession:
             raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body))
 
     def send(self, message: Message) -> None:
-        self._receive_queue.put(message)
-        if hasattr(self, "_queue_event"):
-            self.portal.start_task_soon(self._queue_event.set)
+        self.portal.call(self._receive_tx.send, message)
 
     def send_text(self, data: str) -> None:
         self.send({"type": "websocket.receive", "text": data})
@@ -193,11 +164,7 @@ class WebSocketTestSession:
         self.send({"type": "websocket.disconnect", "code": code, "reason": reason})
 
     def receive(self) -> Message:
-        message = self._send_queue.get()
-        assert message is not EOF
-        if isinstance(message, BaseException):
-            raise message
-        return message
+        return self.portal.call(self._send_rx.receive)
 
     def receive_text(self) -> str:
         message = self.receive()