From: Marcelo Trylesinski Date: Sat, 20 Jan 2024 13:59:47 +0000 (+0100) Subject: Cancel `WebSocketTestSession` on close (#2427) X-Git-Tag: 0.36.0~4 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=3ae161ed8f94a1e4871c4aa11581868065b175b5;p=thirdparty%2Fstarlette.git Cancel `WebSocketTestSession` on close (#2427) * Cancel `WebSocketTestSession` on close * Undo some noise * Fix test * Undo pyproject * Undo anyio bump * Undo changes on test_authentication * Always call cancel scope --- diff --git a/starlette/testclient.py b/starlette/testclient.py index 149fbb07..2cccb15d 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import contextlib import inspect import io import json import math import queue +import sys import typing import warnings from concurrent.futures import Future @@ -11,6 +14,7 @@ from types import GeneratorType from urllib.parse import unquote, urljoin import anyio +import anyio.abc import anyio.from_thread from anyio.abc import ObjectReceiveStream, ObjectSendStream from anyio.streams.stapled import StapledObjectStream @@ -19,6 +23,11 @@ from starlette._utils import is_async_callable from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocketDisconnect +if sys.version_info >= (3, 10): # pragma: no cover + from typing import TypeGuard +else: # pragma: no cover + from typing_extensions import TypeGuard + try: import httpx except ModuleNotFoundError: # pragma: no cover @@ -39,7 +48,7 @@ ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] _RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str]]] -def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool: +def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> TypeGuard[ASGI3App]: if inspect.isclass(app): return hasattr(app, "__await__") return is_async_callable(app) @@ -64,7 +73,7 @@ class _AsyncBackend(typing.TypedDict): class _Upgrade(Exception): - def __init__(self, session: "WebSocketTestSession") -> None: + def __init__(self, session: WebSocketTestSession) -> None: self.session = session @@ -79,16 +88,17 @@ class WebSocketTestSession: self.scope = scope self.accepted_subprotocol = None self.portal_factory = portal_factory - self._receive_queue: "queue.Queue[Message]" = queue.Queue() - self._send_queue: "queue.Queue[Message | BaseException]" = queue.Queue() + self._receive_queue: queue.Queue[Message] = queue.Queue() + self._send_queue: queue.Queue[Message | BaseException] = queue.Queue() self.extra_headers = None - def __enter__(self) -> "WebSocketTestSession": + def __enter__(self) -> WebSocketTestSession: self.exit_stack = contextlib.ExitStack() self.portal = self.exit_stack.enter_context(self.portal_factory()) + self.should_close = anyio.Event() try: - _: "Future[None]" = self.portal.start_task_soon(self._run) + _: Future[None] = self.portal.start_task_soon(self._run) self.send({"type": "websocket.connect"}) message = self.receive() self._raise_on_close(message) @@ -99,10 +109,14 @@ class WebSocketTestSession: self.extra_headers = message.get("headers", None) return self + async def _notify_close(self) -> None: + self.should_close.set() + def __exit__(self, *args: typing.Any) -> None: try: self.close(1000) finally: + self.portal.start_task_soon(self._notify_close) self.exit_stack.close() while not self._send_queue.empty(): message = self._send_queue.get() @@ -113,14 +127,22 @@ class WebSocketTestSession: """ The sub-thread in which the websocket session runs. """ - scope = self.scope - receive = self._asgi_receive - send = self._asgi_send - try: - await self.app(scope, receive, send) - except BaseException as exc: - self._send_queue.put(exc) - raise + + async def run_app(tg: anyio.abc.TaskGroup) -> None: + try: + await self.app(self.scope, self._asgi_receive, self._asgi_send) + except anyio.get_cancelled_exc_class(): + ... + except BaseException as exc: + self._send_queue.put(exc) + raise + finally: + tg.cancel_scope.cancel() + + async with anyio.create_task_group() as tg: + tg.start_soon(run_app, tg) + await self.should_close.wait() + tg.cancel_scope.cancel() async def _asgi_receive(self) -> Message: while self._receive_queue.empty(): @@ -153,7 +175,7 @@ class WebSocketTestSession: else: self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")}) - def close(self, code: int = 1000, reason: typing.Union[str, None] = None) -> None: + def close(self, code: int = 1000, reason: str | None = None) -> None: self.send({"type": "websocket.disconnect", "code": code, "reason": reason}) def receive(self) -> Message: @@ -172,8 +194,9 @@ class WebSocketTestSession: self._raise_on_close(message) return typing.cast(bytes, message["bytes"]) - def receive_json(self, mode: str = "text") -> typing.Any: - assert mode in ["text", "binary"] + def receive_json( + self, mode: typing.Literal["text", "binary"] = "text" + ) -> typing.Any: message = self.receive() self._raise_on_close(message) if mode == "text": @@ -191,7 +214,7 @@ class _TestClientTransport(httpx.BaseTransport): raise_server_exceptions: bool = True, root_path: str = "", *, - app_state: typing.Dict[str, typing.Any], + app_state: dict[str, typing.Any], ) -> None: self.app = app self.raise_server_exceptions = raise_server_exceptions @@ -217,7 +240,7 @@ class _TestClientTransport(httpx.BaseTransport): # Include the 'host' header. if "host" in request.headers: - headers: typing.List[typing.Tuple[bytes, bytes]] = [] + headers: list[tuple[bytes, bytes]] = [] elif port == default_port: # pragma: no cover headers = [(b"host", host.encode())] else: # pragma: no cover @@ -229,7 +252,7 @@ class _TestClientTransport(httpx.BaseTransport): for key, value in request.headers.multi_items() ] - scope: typing.Dict[str, typing.Any] + scope: dict[str, typing.Any] if scheme in {"ws", "wss"}: subprotocol = request.headers.get("sec-websocket-protocol", None) @@ -272,7 +295,7 @@ class _TestClientTransport(httpx.BaseTransport): request_complete = False response_started = False response_complete: anyio.Event - raw_kwargs: typing.Dict[str, typing.Any] = {"stream": io.BytesIO()} + raw_kwargs: dict[str, typing.Any] = {"stream": io.BytesIO()} template = None context = None @@ -363,8 +386,8 @@ class _TestClientTransport(httpx.BaseTransport): class TestClient(httpx.Client): __test__ = False - task: "Future[None]" - portal: typing.Optional[anyio.abc.BlockingPortal] = None + task: Future[None] + portal: anyio.abc.BlockingPortal | None = None def __init__( self, @@ -372,17 +395,16 @@ class TestClient(httpx.Client): base_url: str = "http://testserver", raise_server_exceptions: bool = True, root_path: str = "", - backend: str = "asyncio", - backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None, - cookies: httpx._types.CookieTypes = None, - headers: typing.Dict[str, str] = None, + backend: typing.Literal["asyncio", "trio"] = "asyncio", + backend_options: typing.Dict[str, typing.Any] | None = None, + cookies: httpx._types.CookieTypes | None = None, + headers: typing.Dict[str, str] | None = None, follow_redirects: bool = True, ) -> None: self.async_backend = _AsyncBackend( backend=backend, backend_options=backend_options or {} ) if _is_asgi3(app): - app = typing.cast(ASGI3App, app) asgi_app = app else: app = typing.cast(ASGI2App, app) # type: ignore[assignment] @@ -419,13 +441,11 @@ class TestClient(httpx.Client): yield portal def _choose_redirect_arg( - self, - follow_redirects: typing.Optional[bool], - allow_redirects: typing.Optional[bool], - ) -> typing.Union[bool, httpx._client.UseClientDefault]: - redirect: typing.Union[ - bool, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT + self, follow_redirects: bool | None, allow_redirects: bool | None + ) -> bool | httpx._client.UseClientDefault: + redirect: bool | httpx._client.UseClientDefault = ( + httpx._client.USE_CLIENT_DEFAULT + ) if allow_redirects is not None: message = ( "The `allow_redirects` argument is deprecated. " @@ -709,7 +729,10 @@ class TestClient(httpx.Client): ) def websocket_connect( - self, url: str, subprotocols: typing.Sequence[str] = None, **kwargs: typing.Any + self, + url: str, + subprotocols: typing.Sequence[str] | None = None, + **kwargs: typing.Any, ) -> "WebSocketTestSession": url = urljoin("ws://testserver", url) headers = kwargs.get("headers", {}) diff --git a/tests/test_testclient.py b/tests/test_testclient.py index 38c6f087..eeccf25a 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -5,6 +5,7 @@ from contextlib import asynccontextmanager from typing import Callable import anyio +import anyio.lowlevel import pytest import sniffio import trio.lowlevel @@ -15,18 +16,15 @@ from starlette.requests import Request from starlette.responses import JSONResponse, RedirectResponse, Response from starlette.routing import Route from starlette.testclient import TestClient +from starlette.types import ASGIApp, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketDisconnect -def mock_service_endpoint(request): +def mock_service_endpoint(request: Request): return JSONResponse({"mock": "example"}) -mock_service = Starlette( - routes=[ - Route("/", endpoint=mock_service_endpoint), - ] -) +mock_service = Starlette(routes=[Route("/", endpoint=mock_service_endpoint)]) def current_task(): @@ -48,7 +46,7 @@ def startup(): raise RuntimeError() -def test_use_testclient_in_endpoint(test_client_factory): +def test_use_testclient_in_endpoint(test_client_factory: Callable[..., TestClient]): """ We should be able to use the test client within applications. @@ -56,7 +54,7 @@ def test_use_testclient_in_endpoint(test_client_factory): during tests or in development. """ - def homepage(request): + def homepage(request: Request): client = test_client_factory(mock_service) response = client.get("/") return JSONResponse(response.json()) @@ -87,7 +85,9 @@ def test_testclient_headers_behavior(): assert client.headers.get("Authentication") == "Bearer 123" -def test_use_testclient_as_contextmanager(test_client_factory, anyio_backend_name): +def test_use_testclient_as_contextmanager( + test_client_factory: Callable[..., TestClient], anyio_backend_name: str +): """ This test asserts a number of properties that are important for an app level task_group @@ -109,17 +109,17 @@ def test_use_testclient_as_contextmanager(test_client_factory, anyio_backend_nam shutdown_loop = None @asynccontextmanager - async def lifespan_context(app): + async def lifespan_context(app: Starlette): nonlocal startup_task, startup_loop, shutdown_task, shutdown_loop startup_task = current_task() startup_loop = get_identity() - async with anyio.create_task_group() as app.task_group: + async with anyio.create_task_group(): yield shutdown_task = current_task() shutdown_loop = get_identity() - async def loop_id(request): + async def loop_id(request: Request): return JSONResponse(get_identity()) app = Starlette( @@ -165,7 +165,7 @@ def test_use_testclient_as_contextmanager(test_client_factory, anyio_backend_nam assert first_task is not startup_task -def test_error_on_startup(test_client_factory): +def test_error_on_startup(test_client_factory: Callable[..., TestClient]): with pytest.deprecated_call( match="The on_startup and on_shutdown parameters are deprecated" ): @@ -176,15 +176,15 @@ def test_error_on_startup(test_client_factory): pass # pragma: no cover -def test_exception_in_middleware(test_client_factory): +def test_exception_in_middleware(test_client_factory: Callable[..., TestClient]): class MiddlewareException(Exception): pass class BrokenMiddleware: - def __init__(self, app): + def __init__(self, app: ASGIApp): self.app = app - async def __call__(self, scope, receive, send): + async def __call__(self, scope: Scope, receive: Receive, send: Send): raise MiddlewareException() broken_middleware = Starlette(middleware=[Middleware(BrokenMiddleware)]) @@ -194,9 +194,9 @@ def test_exception_in_middleware(test_client_factory): pass # pragma: no cover -def test_testclient_asgi2(test_client_factory): - def app(scope): - async def inner(receive, send): +def test_testclient_asgi2(test_client_factory: Callable[..., TestClient]): + def app(scope: Scope): + async def inner(receive: Receive, send: Send): await send( { "type": "http.response.start", @@ -213,8 +213,8 @@ def test_testclient_asgi2(test_client_factory): assert response.text == "Hello, world!" -def test_testclient_asgi3(test_client_factory): - async def app(scope, receive, send): +def test_testclient_asgi3(test_client_factory: Callable[..., TestClient]): + async def app(scope: Scope, receive: Receive, send: Send): await send( { "type": "http.response.start", @@ -229,12 +229,12 @@ def test_testclient_asgi3(test_client_factory): assert response.text == "Hello, world!" -def test_websocket_blocking_receive(test_client_factory): - def app(scope): - async def respond(websocket): +def test_websocket_blocking_receive(test_client_factory: Callable[..., TestClient]): + def app(scope: Scope): + async def respond(websocket: WebSocket): await websocket.send_json({"message": "test"}) - async def asgi(receive, send): + async def asgi(receive: Receive, send: Send): websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() async with anyio.create_task_group() as task_group: @@ -254,9 +254,25 @@ def test_websocket_blocking_receive(test_client_factory): assert data == {"message": "test"} +def test_websocket_not_block_on_close(test_client_factory: Callable[..., TestClient]): + def app(scope: Scope): + async def asgi(receive: Receive, send: Send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + while True: + await anyio.sleep(0.1) + + return asgi + + client = test_client_factory(app) + with client.websocket_connect("/") as websocket: + ... + assert websocket.should_close.is_set() + + @pytest.mark.parametrize("param", ("2020-07-14T00:00:00+00:00", "España", "voilà")) -def test_query_params(test_client_factory, param: str): - def homepage(request): +def test_query_params(test_client_factory: Callable[..., TestClient], param: str): + def homepage(request: Request): return Response(request.query_params["param"]) app = Starlette(routes=[Route("/", endpoint=homepage)]) @@ -284,7 +300,9 @@ def test_query_params(test_client_factory, param: str): ("example.com", False), ], ) -def test_domain_restricted_cookies(test_client_factory, domain, ok): +def test_domain_restricted_cookies( + test_client_factory: Callable[..., TestClient], domain: str, ok: bool +): """ Test that test client discards domain restricted cookies which do not match the base_url of the testclient (`http://testserver` by default). @@ -294,7 +312,7 @@ def test_domain_restricted_cookies(test_client_factory, domain, ok): in accordance with RFC 2965. """ - async def app(scope, receive, send): + async def app(scope: Scope, receive: Receive, send: Send): response = Response("Hello, world!", media_type="text/plain") response.set_cookie( "mycookie", @@ -310,8 +328,8 @@ def test_domain_restricted_cookies(test_client_factory, domain, ok): assert cookie_set == ok -def test_forward_follow_redirects(test_client_factory): - async def app(scope, receive, send): +def test_forward_follow_redirects(test_client_factory: Callable[..., TestClient]): + async def app(scope: Scope, receive: Receive, send: Send): if "/ok" in scope["path"]: response = Response("ok") else: @@ -323,8 +341,8 @@ def test_forward_follow_redirects(test_client_factory): assert response.status_code == 200 -def test_forward_nofollow_redirects(test_client_factory): - async def app(scope, receive, send): +def test_forward_nofollow_redirects(test_client_factory: Callable[..., TestClient]): + async def app(scope: Scope, receive: Receive, send: Send): response = RedirectResponse("/ok") await response(scope, receive, send)