from __future__ import annotations
import contextlib
-import enum
import inspect
import io
import json
import math
-import queue
import sys
import typing
from concurrent.futures import Future
"""
-class _Eof(enum.Enum):
- EOF = enum.auto()
-
-
-EOF: typing.Final = _Eof.EOF
-Eof = typing.Literal[_Eof.EOF]
-
-
class WebSocketTestSession:
def __init__(
self,
self.scope = scope
self.accepted_subprotocol = None
self.portal_factory = portal_factory
- self._receive_queue: queue.Queue[Message] = queue.Queue()
- self._send_queue: queue.Queue[Message | Eof | BaseException] = queue.Queue()
self.extra_headers = None
def __enter__(self) -> WebSocketTestSession:
self.exit_stack = stack.pop_all()
return self
- def __exit__(self, *args: typing.Any) -> None:
- self.exit_stack.close()
-
- while True:
- message = self._send_queue.get()
- if message is EOF:
- break
- if isinstance(message, BaseException):
- raise message # pragma: no cover (defensive, should be impossible)
+ def __exit__(self, *args: typing.Any) -> bool | None:
+ return self.exit_stack.__exit__(*args)
async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None:
"""
The sub-thread in which the websocket session runs.
"""
- try:
- with anyio.CancelScope() as cs:
- task_status.started(cs)
- await self.app(self.scope, self._asgi_receive, self._asgi_send)
- except BaseException as exc:
- self._send_queue.put(exc)
- raise
- finally:
- self._send_queue.put(EOF) # TODO: use self._send_queue.shutdown() on 3.13+
-
- async def _asgi_receive(self) -> Message:
- while self._receive_queue.empty():
- self._queue_event = anyio.Event()
- await self._queue_event.wait()
- return self._receive_queue.get()
+ send_tx, send_rx = anyio.create_memory_object_stream[Message](math.inf)
+ receive_tx, receive_rx = anyio.create_memory_object_stream[Message](math.inf)
+ with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs:
+ self._receive_tx = receive_tx
+ self._send_rx = send_rx
+ task_status.started(cs)
+ await self.app(self.scope, receive_rx.receive, send_tx.send)
- async def _asgi_send(self, message: Message) -> None:
- self._send_queue.put(message)
+ # wait for cs.cancel to be called before closing streams
+ await anyio.sleep_forever()
def _raise_on_close(self, message: Message) -> None:
if message["type"] == "websocket.close":
raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body))
def send(self, message: Message) -> None:
- self._receive_queue.put(message)
- if hasattr(self, "_queue_event"):
- self.portal.start_task_soon(self._queue_event.set)
+ self.portal.call(self._receive_tx.send, message)
def send_text(self, data: str) -> None:
self.send({"type": "websocket.receive", "text": data})
self.send({"type": "websocket.disconnect", "code": code, "reason": reason})
def receive(self) -> Message:
- message = self._send_queue.get()
- assert message is not EOF
- if isinstance(message, BaseException):
- raise message
- return message
+ return self.portal.call(self._send_rx.receive)
def receive_text(self) -> str:
message = self.receive()