]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
ensure TestClient requests run in the same EventLoop as lifespan (#1213)
authorThomas Grainger <tagrain@gmail.com>
Sat, 3 Jul 2021 16:39:25 +0000 (17:39 +0100)
committerGitHub <noreply@github.com>
Sat, 3 Jul 2021 16:39:25 +0000 (17:39 +0100)
* ensure TestClient requests run in the same EventLoop as lifespan

* for lifespan task verification, use native task identity rather than anyio.abc.TaskInfo equality

https://github.com/agronholm/anyio/issues/324

* remove redundant pragma: no cover

* it's now a loop_id not a threading.ident

* replace Protocol with plain Callable TypeAlias

* use lifespan_context to actually open a task group

trio should complain if used incorrectly here.

* assign self.portal once, schedule reset immediately after assignment

* inline apps into their tests

* make task/loop trackers nonlocals

starlette/testclient.py
tests/test_testclient.py

index 33bb410d02b09f773ac124ed8ff4522083f70257..7aa59fb9e672629c3d53a7e8726852e0b2e8604d 100644 (file)
@@ -12,7 +12,7 @@ import typing
 from concurrent.futures import Future
 from urllib.parse import unquote, urljoin, urlsplit
 
-import anyio
+import anyio.abc
 import requests
 from anyio.streams.stapled import StapledObjectStream
 
@@ -24,6 +24,12 @@ if sys.version_info >= (3, 8):  # pragma: no cover
 else:  # pragma: no cover
     from typing_extensions import TypedDict
 
+
+_PortalFactoryType = typing.Callable[
+    [], typing.ContextManager[anyio.abc.BlockingPortal]
+]
+
+
 # Annotations for `Session.request()`
 Cookies = typing.Union[
     typing.MutableMapping[str, str], requests.cookies.RequestsCookieJar
@@ -106,14 +112,14 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
     def __init__(
         self,
         app: ASGI3App,
-        async_backend: _AsyncBackend,
+        portal_factory: _PortalFactoryType,
         raise_server_exceptions: bool = True,
         root_path: str = "",
     ) -> None:
         self.app = app
         self.raise_server_exceptions = raise_server_exceptions
         self.root_path = root_path
-        self.async_backend = async_backend
+        self.portal_factory = portal_factory
 
     def send(
         self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any
@@ -162,7 +168,7 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
                 "server": [host, port],
                 "subprotocols": subprotocols,
             }
-            session = WebSocketTestSession(self.app, scope, self.async_backend)
+            session = WebSocketTestSession(self.app, scope, self.portal_factory)
             raise _Upgrade(session)
 
         scope = {
@@ -252,7 +258,7 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
                 context = message["context"]
 
         try:
-            with anyio.start_blocking_portal(**self.async_backend) as portal:
+            with self.portal_factory() as portal:
                 response_complete = portal.call(anyio.Event)
                 portal.call(self.app, scope, receive, send)
         except BaseException as exc:
@@ -285,20 +291,18 @@ class WebSocketTestSession:
         self,
         app: ASGI3App,
         scope: Scope,
-        async_backend: _AsyncBackend,
+        portal_factory: _PortalFactoryType,
     ) -> None:
         self.app = app
         self.scope = scope
         self.accepted_subprotocol = None
-        self.async_backend = async_backend
+        self.portal_factory = portal_factory
         self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue()
         self._send_queue: "queue.Queue[typing.Any]" = queue.Queue()
 
     def __enter__(self) -> "WebSocketTestSession":
         self.exit_stack = contextlib.ExitStack()
-        self.portal = self.exit_stack.enter_context(
-            anyio.start_blocking_portal(**self.async_backend)
-        )
+        self.portal = self.exit_stack.enter_context(self.portal_factory())
 
         try:
             _: "Future[None]" = self.portal.start_task_soon(self._run)
@@ -396,6 +400,7 @@ class WebSocketTestSession:
 class TestClient(requests.Session):
     __test__ = False  # For pytest to not discover this up.
     task: "Future[None]"
+    portal: typing.Optional[anyio.abc.BlockingPortal] = None
 
     def __init__(
         self,
@@ -418,7 +423,7 @@ class TestClient(requests.Session):
             asgi_app = _WrapASGI2(app)  #  type: ignore
         adapter = _ASGIAdapter(
             asgi_app,
-            self.async_backend,
+            portal_factory=self._portal_factory,
             raise_server_exceptions=raise_server_exceptions,
             root_path=root_path,
         )
@@ -430,6 +435,16 @@ class TestClient(requests.Session):
         self.app = asgi_app
         self.base_url = base_url
 
+    @contextlib.contextmanager
+    def _portal_factory(
+        self,
+    ) -> typing.Generator[anyio.abc.BlockingPortal, None, None]:
+        if self.portal is not None:
+            yield self.portal
+        else:
+            with anyio.start_blocking_portal(**self.async_backend) as portal:
+                yield portal
+
     def request(  # type: ignore
         self,
         method: str,
@@ -490,29 +505,34 @@ class TestClient(requests.Session):
         return session
 
     def __enter__(self) -> "TestClient":
-        self.exit_stack = contextlib.ExitStack()
-        self.portal = self.exit_stack.enter_context(
-            anyio.start_blocking_portal(**self.async_backend)
-        )
-        self.stream_send = StapledObjectStream(
-            *anyio.create_memory_object_stream(math.inf)
-        )
-        self.stream_receive = StapledObjectStream(
-            *anyio.create_memory_object_stream(math.inf)
-        )
-        try:
-            self.task = self.portal.start_task_soon(self.lifespan)
-            self.portal.call(self.wait_startup)
-        except Exception:
-            self.exit_stack.close()
-            raise
+        with contextlib.ExitStack() as stack:
+            self.portal = portal = stack.enter_context(
+                anyio.start_blocking_portal(**self.async_backend)
+            )
+
+            @stack.callback
+            def reset_portal() -> None:
+                self.portal = None
+
+            self.stream_send = StapledObjectStream(
+                *anyio.create_memory_object_stream(math.inf)
+            )
+            self.stream_receive = StapledObjectStream(
+                *anyio.create_memory_object_stream(math.inf)
+            )
+            self.task = portal.start_task_soon(self.lifespan)
+            portal.call(self.wait_startup)
+
+            @stack.callback
+            def wait_shutdown() -> None:
+                portal.call(self.wait_shutdown)
+
+            self.exit_stack = stack.pop_all()
+
         return self
 
     def __exit__(self, *args: typing.Any) -> None:
-        try:
-            self.portal.call(self.wait_shutdown)
-        finally:
-            self.exit_stack.close()
+        self.exit_stack.close()
 
     async def lifespan(self) -> None:
         scope = {"type": "lifespan"}
index fd96f69a7ef7404da9fe104ad773d8e8844e095a..57ea1c3dbfd8ec2f9c3ca9d7e2485a01176440e8 100644 (file)
@@ -1,11 +1,22 @@
+import asyncio
+import itertools
+import sys
+
 import anyio
 import pytest
+import sniffio
+import trio.lowlevel
 
 from starlette.applications import Starlette
 from starlette.middleware import Middleware
 from starlette.responses import JSONResponse
 from starlette.websockets import WebSocket, WebSocketDisconnect
 
+if sys.version_info >= (3, 7):
+    from asyncio import current_task as asyncio_current_task  # pragma: no cover
+else:
+    asyncio_current_task = asyncio.Task.current_task  # pragma: no cover
+
 mock_service = Starlette()
 
 
@@ -14,16 +25,19 @@ def mock_service_endpoint(request):
     return JSONResponse({"mock": "example"})
 
 
-def create_app(test_client_factory):
-    app = Starlette()
-
-    @app.route("/")
-    def homepage(request):
-        client = test_client_factory(mock_service)
-        response = client.get("/")
-        return JSONResponse(response.json())
+def current_task():
+    # anyio's TaskInfo comparisons are invalid after their associated native
+    # task object is GC'd https://github.com/agronholm/anyio/issues/324
+    asynclib_name = sniffio.current_async_library()
+    if asynclib_name == "trio":
+        return trio.lowlevel.current_task()
 
-    return app
+    if asynclib_name == "asyncio":
+        task = asyncio_current_task()
+        if task is None:
+            raise RuntimeError("must be called from a running task")  # pragma: no cover
+        return task
+    raise RuntimeError(f"unsupported asynclib={asynclib_name}")  # pragma: no cover
 
 
 startup_error_app = Starlette()
@@ -41,14 +55,93 @@ def test_use_testclient_in_endpoint(test_client_factory):
     This is useful if we need to mock out other services,
     during tests or in development.
     """
-    client = test_client_factory(create_app(test_client_factory))
+
+    app = Starlette()
+
+    @app.route("/")
+    def homepage(request):
+        client = test_client_factory(mock_service)
+        response = client.get("/")
+        return JSONResponse(response.json())
+
+    client = test_client_factory(app)
     response = client.get("/")
     assert response.json() == {"mock": "example"}
 
 
-def test_use_testclient_as_contextmanager(test_client_factory):
-    with test_client_factory(create_app(test_client_factory)):
-        pass
+def test_use_testclient_as_contextmanager(test_client_factory, anyio_backend_name):
+    """
+    This test asserts a number of properties that are important for an
+    app level task_group
+    """
+    counter = itertools.count()
+    identity_runvar = anyio.lowlevel.RunVar[int]("identity_runvar")
+
+    def get_identity():
+        try:
+            return identity_runvar.get()
+        except LookupError:
+            token = next(counter)
+            identity_runvar.set(token)
+            return token
+
+    startup_task = object()
+    startup_loop = None
+    shutdown_task = object()
+    shutdown_loop = None
+
+    async def lifespan_context(app):
+        nonlocal startup_task, startup_loop, shutdown_task, shutdown_loop
+
+        startup_task = current_task()
+        startup_loop = get_identity()
+        async with anyio.create_task_group() as app.task_group:
+            yield
+        shutdown_task = current_task()
+        shutdown_loop = get_identity()
+
+    app = Starlette(lifespan=lifespan_context)
+
+    @app.route("/loop_id")
+    async def loop_id(request):
+        return JSONResponse(get_identity())
+
+    client = test_client_factory(app)
+
+    with client:
+        # within a TestClient context every async request runs in the same thread
+        assert client.get("/loop_id").json() == 0
+        assert client.get("/loop_id").json() == 0
+
+    # that thread is also the same as the lifespan thread
+    assert startup_loop == 0
+    assert shutdown_loop == 0
+
+    # lifespan events run in the same task, this is important because a task
+    # group must be entered and exited in the same task.
+    assert startup_task is shutdown_task
+
+    # outside the TestClient context, new requests continue to spawn in new
+    # eventloops in new threads
+    assert client.get("/loop_id").json() == 1
+    assert client.get("/loop_id").json() == 2
+
+    first_task = startup_task
+
+    with client:
+        # the TestClient context can be re-used, starting a new lifespan task
+        # in a new thread
+        assert client.get("/loop_id").json() == 3
+        assert client.get("/loop_id").json() == 3
+
+    assert startup_loop == 3
+    assert shutdown_loop == 3
+
+    # lifespan events still run in the same task, with the context but...
+    assert startup_task is shutdown_task
+
+    # ... the second TestClient context creates a new lifespan task.
+    assert first_task is not startup_task
 
 
 def test_error_on_startup(test_client_factory):