]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Cancel `WebSocketTestSession` on close (#2427)
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Sat, 20 Jan 2024 13:59:47 +0000 (14:59 +0100)
committerGitHub <noreply@github.com>
Sat, 20 Jan 2024 13:59:47 +0000 (06:59 -0700)
* Cancel `WebSocketTestSession` on close

* Undo some noise

* Fix test

* Undo pyproject

* Undo anyio bump

* Undo changes on test_authentication

* Always call cancel scope

starlette/testclient.py
tests/test_testclient.py

index 149fbb07ac7c0b8b89745593b1cf58907294e6c4..2cccb15d13a9609a9dbbf058cce420cf676b513f 100644 (file)
@@ -1,9 +1,12 @@
+from __future__ import annotations
+
 import contextlib
 import inspect
 import io
 import json
 import math
 import queue
+import sys
 import typing
 import warnings
 from concurrent.futures import Future
@@ -11,6 +14,7 @@ from types import GeneratorType
 from urllib.parse import unquote, urljoin
 
 import anyio
+import anyio.abc
 import anyio.from_thread
 from anyio.abc import ObjectReceiveStream, ObjectSendStream
 from anyio.streams.stapled import StapledObjectStream
@@ -19,6 +23,11 @@ from starlette._utils import is_async_callable
 from starlette.types import ASGIApp, Message, Receive, Scope, Send
 from starlette.websockets import WebSocketDisconnect
 
+if sys.version_info >= (3, 10):  # pragma: no cover
+    from typing import TypeGuard
+else:  # pragma: no cover
+    from typing_extensions import TypeGuard
+
 try:
     import httpx
 except ModuleNotFoundError:  # pragma: no cover
@@ -39,7 +48,7 @@ ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
 _RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str]]]
 
 
-def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool:
+def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> TypeGuard[ASGI3App]:
     if inspect.isclass(app):
         return hasattr(app, "__await__")
     return is_async_callable(app)
@@ -64,7 +73,7 @@ class _AsyncBackend(typing.TypedDict):
 
 
 class _Upgrade(Exception):
-    def __init__(self, session: "WebSocketTestSession") -> None:
+    def __init__(self, session: WebSocketTestSession) -> None:
         self.session = session
 
 
@@ -79,16 +88,17 @@ class WebSocketTestSession:
         self.scope = scope
         self.accepted_subprotocol = None
         self.portal_factory = portal_factory
-        self._receive_queue: "queue.Queue[Message]" = queue.Queue()
-        self._send_queue: "queue.Queue[Message | BaseException]" = queue.Queue()
+        self._receive_queue: queue.Queue[Message] = queue.Queue()
+        self._send_queue: queue.Queue[Message | BaseException] = queue.Queue()
         self.extra_headers = None
 
-    def __enter__(self) -> "WebSocketTestSession":
+    def __enter__(self) -> WebSocketTestSession:
         self.exit_stack = contextlib.ExitStack()
         self.portal = self.exit_stack.enter_context(self.portal_factory())
+        self.should_close = anyio.Event()
 
         try:
-            _: "Future[None]" = self.portal.start_task_soon(self._run)
+            _: Future[None] = self.portal.start_task_soon(self._run)
             self.send({"type": "websocket.connect"})
             message = self.receive()
             self._raise_on_close(message)
@@ -99,10 +109,14 @@ class WebSocketTestSession:
         self.extra_headers = message.get("headers", None)
         return self
 
+    async def _notify_close(self) -> None:
+        self.should_close.set()
+
     def __exit__(self, *args: typing.Any) -> None:
         try:
             self.close(1000)
         finally:
+            self.portal.start_task_soon(self._notify_close)
             self.exit_stack.close()
         while not self._send_queue.empty():
             message = self._send_queue.get()
@@ -113,14 +127,22 @@ class WebSocketTestSession:
         """
         The sub-thread in which the websocket session runs.
         """
-        scope = self.scope
-        receive = self._asgi_receive
-        send = self._asgi_send
-        try:
-            await self.app(scope, receive, send)
-        except BaseException as exc:
-            self._send_queue.put(exc)
-            raise
+
+        async def run_app(tg: anyio.abc.TaskGroup) -> None:
+            try:
+                await self.app(self.scope, self._asgi_receive, self._asgi_send)
+            except anyio.get_cancelled_exc_class():
+                ...
+            except BaseException as exc:
+                self._send_queue.put(exc)
+                raise
+            finally:
+                tg.cancel_scope.cancel()
+
+        async with anyio.create_task_group() as tg:
+            tg.start_soon(run_app, tg)
+            await self.should_close.wait()
+            tg.cancel_scope.cancel()
 
     async def _asgi_receive(self) -> Message:
         while self._receive_queue.empty():
@@ -153,7 +175,7 @@ class WebSocketTestSession:
         else:
             self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
 
-    def close(self, code: int = 1000, reason: typing.Union[str, None] = None) -> None:
+    def close(self, code: int = 1000, reason: str | None = None) -> None:
         self.send({"type": "websocket.disconnect", "code": code, "reason": reason})
 
     def receive(self) -> Message:
@@ -172,8 +194,9 @@ class WebSocketTestSession:
         self._raise_on_close(message)
         return typing.cast(bytes, message["bytes"])
 
-    def receive_json(self, mode: str = "text") -> typing.Any:
-        assert mode in ["text", "binary"]
+    def receive_json(
+        self, mode: typing.Literal["text", "binary"] = "text"
+    ) -> typing.Any:
         message = self.receive()
         self._raise_on_close(message)
         if mode == "text":
@@ -191,7 +214,7 @@ class _TestClientTransport(httpx.BaseTransport):
         raise_server_exceptions: bool = True,
         root_path: str = "",
         *,
-        app_state: typing.Dict[str, typing.Any],
+        app_state: dict[str, typing.Any],
     ) -> None:
         self.app = app
         self.raise_server_exceptions = raise_server_exceptions
@@ -217,7 +240,7 @@ class _TestClientTransport(httpx.BaseTransport):
 
         # Include the 'host' header.
         if "host" in request.headers:
-            headers: typing.List[typing.Tuple[bytes, bytes]] = []
+            headers: list[tuple[bytes, bytes]] = []
         elif port == default_port:  # pragma: no cover
             headers = [(b"host", host.encode())]
         else:  # pragma: no cover
@@ -229,7 +252,7 @@ class _TestClientTransport(httpx.BaseTransport):
             for key, value in request.headers.multi_items()
         ]
 
-        scope: typing.Dict[str, typing.Any]
+        scope: dict[str, typing.Any]
 
         if scheme in {"ws", "wss"}:
             subprotocol = request.headers.get("sec-websocket-protocol", None)
@@ -272,7 +295,7 @@ class _TestClientTransport(httpx.BaseTransport):
         request_complete = False
         response_started = False
         response_complete: anyio.Event
-        raw_kwargs: typing.Dict[str, typing.Any] = {"stream": io.BytesIO()}
+        raw_kwargs: dict[str, typing.Any] = {"stream": io.BytesIO()}
         template = None
         context = None
 
@@ -363,8 +386,8 @@ class _TestClientTransport(httpx.BaseTransport):
 
 class TestClient(httpx.Client):
     __test__ = False
-    task: "Future[None]"
-    portal: typing.Optional[anyio.abc.BlockingPortal] = None
+    task: Future[None]
+    portal: anyio.abc.BlockingPortal | None = None
 
     def __init__(
         self,
@@ -372,17 +395,16 @@ class TestClient(httpx.Client):
         base_url: str = "http://testserver",
         raise_server_exceptions: bool = True,
         root_path: str = "",
-        backend: str = "asyncio",
-        backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None,
-        cookies: httpx._types.CookieTypes = None,
-        headers: typing.Dict[str, str] = None,
+        backend: typing.Literal["asyncio", "trio"] = "asyncio",
+        backend_options: typing.Dict[str, typing.Any] | None = None,
+        cookies: httpx._types.CookieTypes | None = None,
+        headers: typing.Dict[str, str] | None = None,
         follow_redirects: bool = True,
     ) -> None:
         self.async_backend = _AsyncBackend(
             backend=backend, backend_options=backend_options or {}
         )
         if _is_asgi3(app):
-            app = typing.cast(ASGI3App, app)
             asgi_app = app
         else:
             app = typing.cast(ASGI2App, app)  # type: ignore[assignment]
@@ -419,13 +441,11 @@ class TestClient(httpx.Client):
                 yield portal
 
     def _choose_redirect_arg(
-        self,
-        follow_redirects: typing.Optional[bool],
-        allow_redirects: typing.Optional[bool],
-    ) -> typing.Union[bool, httpx._client.UseClientDefault]:
-        redirect: typing.Union[
-            bool, httpx._client.UseClientDefault
-        ] = httpx._client.USE_CLIENT_DEFAULT
+        self, follow_redirects: bool | None, allow_redirects: bool | None
+    ) -> bool | httpx._client.UseClientDefault:
+        redirect: bool | httpx._client.UseClientDefault = (
+            httpx._client.USE_CLIENT_DEFAULT
+        )
         if allow_redirects is not None:
             message = (
                 "The `allow_redirects` argument is deprecated. "
@@ -709,7 +729,10 @@ class TestClient(httpx.Client):
         )
 
     def websocket_connect(
-        self, url: str, subprotocols: typing.Sequence[str] = None, **kwargs: typing.Any
+        self,
+        url: str,
+        subprotocols: typing.Sequence[str] | None = None,
+        **kwargs: typing.Any,
     ) -> "WebSocketTestSession":
         url = urljoin("ws://testserver", url)
         headers = kwargs.get("headers", {})
index 38c6f0872a652c66f458a25c11192020cfba336e..eeccf25a7f029fb96665ccde41c93cac4afaf8e5 100644 (file)
@@ -5,6 +5,7 @@ from contextlib import asynccontextmanager
 from typing import Callable
 
 import anyio
+import anyio.lowlevel
 import pytest
 import sniffio
 import trio.lowlevel
@@ -15,18 +16,15 @@ from starlette.requests import Request
 from starlette.responses import JSONResponse, RedirectResponse, Response
 from starlette.routing import Route
 from starlette.testclient import TestClient
+from starlette.types import ASGIApp, Receive, Scope, Send
 from starlette.websockets import WebSocket, WebSocketDisconnect
 
 
-def mock_service_endpoint(request):
+def mock_service_endpoint(request: Request):
     return JSONResponse({"mock": "example"})
 
 
-mock_service = Starlette(
-    routes=[
-        Route("/", endpoint=mock_service_endpoint),
-    ]
-)
+mock_service = Starlette(routes=[Route("/", endpoint=mock_service_endpoint)])
 
 
 def current_task():
@@ -48,7 +46,7 @@ def startup():
     raise RuntimeError()
 
 
-def test_use_testclient_in_endpoint(test_client_factory):
+def test_use_testclient_in_endpoint(test_client_factory: Callable[..., TestClient]):
     """
     We should be able to use the test client within applications.
 
@@ -56,7 +54,7 @@ def test_use_testclient_in_endpoint(test_client_factory):
     during tests or in development.
     """
 
-    def homepage(request):
+    def homepage(request: Request):
         client = test_client_factory(mock_service)
         response = client.get("/")
         return JSONResponse(response.json())
@@ -87,7 +85,9 @@ def test_testclient_headers_behavior():
     assert client.headers.get("Authentication") == "Bearer 123"
 
 
-def test_use_testclient_as_contextmanager(test_client_factory, anyio_backend_name):
+def test_use_testclient_as_contextmanager(
+    test_client_factory: Callable[..., TestClient], anyio_backend_name: str
+):
     """
     This test asserts a number of properties that are important for an
     app level task_group
@@ -109,17 +109,17 @@ def test_use_testclient_as_contextmanager(test_client_factory, anyio_backend_nam
     shutdown_loop = None
 
     @asynccontextmanager
-    async def lifespan_context(app):
+    async def lifespan_context(app: Starlette):
         nonlocal startup_task, startup_loop, shutdown_task, shutdown_loop
 
         startup_task = current_task()
         startup_loop = get_identity()
-        async with anyio.create_task_group() as app.task_group:
+        async with anyio.create_task_group():
             yield
         shutdown_task = current_task()
         shutdown_loop = get_identity()
 
-    async def loop_id(request):
+    async def loop_id(request: Request):
         return JSONResponse(get_identity())
 
     app = Starlette(
@@ -165,7 +165,7 @@ def test_use_testclient_as_contextmanager(test_client_factory, anyio_backend_nam
     assert first_task is not startup_task
 
 
-def test_error_on_startup(test_client_factory):
+def test_error_on_startup(test_client_factory: Callable[..., TestClient]):
     with pytest.deprecated_call(
         match="The on_startup and on_shutdown parameters are deprecated"
     ):
@@ -176,15 +176,15 @@ def test_error_on_startup(test_client_factory):
             pass  # pragma: no cover
 
 
-def test_exception_in_middleware(test_client_factory):
+def test_exception_in_middleware(test_client_factory: Callable[..., TestClient]):
     class MiddlewareException(Exception):
         pass
 
     class BrokenMiddleware:
-        def __init__(self, app):
+        def __init__(self, app: ASGIApp):
             self.app = app
 
-        async def __call__(self, scope, receive, send):
+        async def __call__(self, scope: Scope, receive: Receive, send: Send):
             raise MiddlewareException()
 
     broken_middleware = Starlette(middleware=[Middleware(BrokenMiddleware)])
@@ -194,9 +194,9 @@ def test_exception_in_middleware(test_client_factory):
             pass  # pragma: no cover
 
 
-def test_testclient_asgi2(test_client_factory):
-    def app(scope):
-        async def inner(receive, send):
+def test_testclient_asgi2(test_client_factory: Callable[..., TestClient]):
+    def app(scope: Scope):
+        async def inner(receive: Receive, send: Send):
             await send(
                 {
                     "type": "http.response.start",
@@ -213,8 +213,8 @@ def test_testclient_asgi2(test_client_factory):
     assert response.text == "Hello, world!"
 
 
-def test_testclient_asgi3(test_client_factory):
-    async def app(scope, receive, send):
+def test_testclient_asgi3(test_client_factory: Callable[..., TestClient]):
+    async def app(scope: Scope, receive: Receive, send: Send):
         await send(
             {
                 "type": "http.response.start",
@@ -229,12 +229,12 @@ def test_testclient_asgi3(test_client_factory):
     assert response.text == "Hello, world!"
 
 
-def test_websocket_blocking_receive(test_client_factory):
-    def app(scope):
-        async def respond(websocket):
+def test_websocket_blocking_receive(test_client_factory: Callable[..., TestClient]):
+    def app(scope: Scope):
+        async def respond(websocket: WebSocket):
             await websocket.send_json({"message": "test"})
 
-        async def asgi(receive, send):
+        async def asgi(receive: Receive, send: Send):
             websocket = WebSocket(scope, receive=receive, send=send)
             await websocket.accept()
             async with anyio.create_task_group() as task_group:
@@ -254,9 +254,25 @@ def test_websocket_blocking_receive(test_client_factory):
         assert data == {"message": "test"}
 
 
+def test_websocket_not_block_on_close(test_client_factory: Callable[..., TestClient]):
+    def app(scope: Scope):
+        async def asgi(receive: Receive, send: Send):
+            websocket = WebSocket(scope, receive=receive, send=send)
+            await websocket.accept()
+            while True:
+                await anyio.sleep(0.1)
+
+        return asgi
+
+    client = test_client_factory(app)
+    with client.websocket_connect("/") as websocket:
+        ...
+    assert websocket.should_close.is_set()
+
+
 @pytest.mark.parametrize("param", ("2020-07-14T00:00:00+00:00", "España", "voilà"))
-def test_query_params(test_client_factory, param: str):
-    def homepage(request):
+def test_query_params(test_client_factory: Callable[..., TestClient], param: str):
+    def homepage(request: Request):
         return Response(request.query_params["param"])
 
     app = Starlette(routes=[Route("/", endpoint=homepage)])
@@ -284,7 +300,9 @@ def test_query_params(test_client_factory, param: str):
         ("example.com", False),
     ],
 )
-def test_domain_restricted_cookies(test_client_factory, domain, ok):
+def test_domain_restricted_cookies(
+    test_client_factory: Callable[..., TestClient], domain: str, ok: bool
+):
     """
     Test that test client discards domain restricted cookies which do not match the
     base_url of the testclient (`http://testserver` by default).
@@ -294,7 +312,7 @@ def test_domain_restricted_cookies(test_client_factory, domain, ok):
     in accordance with RFC 2965.
     """
 
-    async def app(scope, receive, send):
+    async def app(scope: Scope, receive: Receive, send: Send):
         response = Response("Hello, world!", media_type="text/plain")
         response.set_cookie(
             "mycookie",
@@ -310,8 +328,8 @@ def test_domain_restricted_cookies(test_client_factory, domain, ok):
     assert cookie_set == ok
 
 
-def test_forward_follow_redirects(test_client_factory):
-    async def app(scope, receive, send):
+def test_forward_follow_redirects(test_client_factory: Callable[..., TestClient]):
+    async def app(scope: Scope, receive: Receive, send: Send):
         if "/ok" in scope["path"]:
             response = Response("ok")
         else:
@@ -323,8 +341,8 @@ def test_forward_follow_redirects(test_client_factory):
     assert response.status_code == 200
 
 
-def test_forward_nofollow_redirects(test_client_factory):
-    async def app(scope, receive, send):
+def test_forward_nofollow_redirects(test_client_factory: Callable[..., TestClient]):
+    async def app(scope: Scope, receive: Receive, send: Send):
         response = RedirectResponse("/ok")
         await response(scope, receive, send)