from __future__ import annotations
import contextlib
+import enum
import inspect
import io
import json
import sys
import typing
from concurrent.futures import Future
-from functools import cached_property
from types import GeneratorType
from urllib.parse import unquote, urljoin
"""
+class _Eof(enum.Enum):
+ EOF = enum.auto()
+
+
+EOF: typing.Final = _Eof.EOF
+Eof = typing.Literal[_Eof.EOF]
+
+
class WebSocketTestSession:
def __init__(
self,
self.accepted_subprotocol = None
self.portal_factory = portal_factory
self._receive_queue: queue.Queue[Message] = queue.Queue()
- self._send_queue: queue.Queue[Message | BaseException] = queue.Queue()
+ self._send_queue: queue.Queue[Message | Eof | BaseException] = queue.Queue()
self.extra_headers = None
def __enter__(self) -> WebSocketTestSession:
- self.exit_stack = contextlib.ExitStack()
- self.portal = self.exit_stack.enter_context(self.portal_factory())
-
- try:
- _: Future[None] = self.portal.start_task_soon(self._run)
+ with contextlib.ExitStack() as stack:
+ self.portal = portal = stack.enter_context(self.portal_factory())
+ fut, cs = portal.start_task(self._run)
+ stack.callback(fut.result)
+ stack.callback(portal.call, cs.cancel)
self.send({"type": "websocket.connect"})
message = self.receive()
self._raise_on_close(message)
- except Exception:
- self.exit_stack.close()
- raise
- self.accepted_subprotocol = message.get("subprotocol", None)
- self.extra_headers = message.get("headers", None)
- return self
-
- @cached_property
- def should_close(self) -> anyio.Event:
- return anyio.Event()
-
- async def _notify_close(self) -> None:
- self.should_close.set()
+ self.accepted_subprotocol = message.get("subprotocol", None)
+ self.extra_headers = message.get("headers", None)
+ stack.callback(self.close, 1000)
+ self.exit_stack = stack.pop_all()
+ return self
def __exit__(self, *args: typing.Any) -> None:
- try:
- self.close(1000)
- finally:
- self.portal.start_task_soon(self._notify_close)
- self.exit_stack.close()
- while not self._send_queue.empty():
+ self.exit_stack.close()
+
+ while True:
message = self._send_queue.get()
+ if message is EOF:
+ break
if isinstance(message, BaseException):
- raise message
+ raise message # pragma: no cover (defensive, should be impossible)
- async def _run(self) -> None:
+ async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None:
"""
The sub-thread in which the websocket session runs.
"""
-
- async def run_app(tg: anyio.abc.TaskGroup) -> None:
- try:
+ try:
+ with anyio.CancelScope() as cs:
+ task_status.started(cs)
await self.app(self.scope, self._asgi_receive, self._asgi_send)
- except anyio.get_cancelled_exc_class():
- ...
- except BaseException as exc:
- self._send_queue.put(exc)
- raise
- finally:
- tg.cancel_scope.cancel()
-
- async with anyio.create_task_group() as tg:
- tg.start_soon(run_app, tg)
- await self.should_close.wait()
- tg.cancel_scope.cancel()
+ except BaseException as exc:
+ self._send_queue.put(exc)
+ raise
+ finally:
+ self._send_queue.put(EOF) # TODO: use self._send_queue.shutdown() on 3.13+
async def _asgi_receive(self) -> Message:
while self._receive_queue.empty():
def receive(self) -> Message:
message = self._send_queue.get()
+ assert message is not EOF
if isinstance(message, BaseException):
raise message
return message
def test_websocket_not_block_on_close(test_client_factory: TestClientFactory) -> None:
+ cancelled = False
+
def app(scope: Scope) -> ASGIInstance:
async def asgi(receive: Receive, send: Send) -> None:
- websocket = WebSocket(scope, receive=receive, send=send)
- await websocket.accept()
- while True:
- await anyio.sleep(0.1)
+ nonlocal cancelled
+ try:
+ websocket = WebSocket(scope, receive=receive, send=send)
+ await websocket.accept()
+ await anyio.sleep_forever()
+ except anyio.get_cancelled_exc_class():
+ cancelled = True
+ raise
return asgi
client = test_client_factory(app) # type: ignore
- with client.websocket_connect("/") as websocket:
+ with client.websocket_connect("/"):
...
- assert websocket.should_close.is_set()
+ assert cancelled
def test_client(test_client_factory: TestClientFactory) -> None: