]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Raise exception from background task on BaseHTTPMiddleware (#2812)
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Sat, 22 Feb 2025 13:41:46 +0000 (10:41 -0300)
committerGitHub <noreply@github.com>
Sat, 22 Feb 2025 13:41:46 +0000 (10:41 -0300)
Co-authored-by: Thomas Grainger <tagrain@gmail.com>
starlette/middleware/base.py
starlette/responses.py
tests/middleware/test_base.py

index f146984b3428c5217a9e62aa001ff87f3ae92d2e..2a59337e524c7eee2b4e92e2499fb3d2d7979bef 100644 (file)
@@ -103,10 +103,9 @@ class BaseHTTPMiddleware:
         request = _CachedRequest(scope, receive)
         wrapped_receive = request.wrapped_receive
         response_sent = anyio.Event()
+        app_exc: Exception | None = None
 
         async def call_next(request: Request) -> Response:
-            app_exc: Exception | None = None
-
             async def receive_or_disconnect() -> Message:
                 if response_sent.is_set():
                     return {"type": "http.disconnect"}
@@ -165,9 +164,6 @@ class BaseHTTPMiddleware:
                     if not message.get("more_body", False):
                         break
 
-                if app_exc is not None:
-                    raise app_exc
-
             response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info)
             response.raw_headers = message["headers"]
             return response
@@ -181,6 +177,9 @@ class BaseHTTPMiddleware:
                 response_sent.set()
                 recv_stream.close()
 
+        if app_exc is not None:
+            raise app_exc
+
     async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
         raise NotImplementedError()  # pragma: no cover
 
index 5964df6780ebcf041106b7863f2416c453ace732..81e89faefed9050e59177af5d91872355410e97a 100644 (file)
@@ -18,6 +18,7 @@ from urllib.parse import quote
 import anyio
 import anyio.to_thread
 
+from starlette._utils import collapse_excgroups
 from starlette.background import BackgroundTask
 from starlette.concurrency import iterate_in_threadpool
 from starlette.datastructures import URL, Headers, MutableHeaders
@@ -258,14 +259,15 @@ class StreamingResponse(Response):
             except OSError:
                 raise ClientDisconnect()
         else:
-            async with anyio.create_task_group() as task_group:
+            with collapse_excgroups():
+                async with anyio.create_task_group() as task_group:
 
-                async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
-                    await func()
-                    task_group.cancel_scope.cancel()
+                    async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
+                        await func()
+                        task_group.cancel_scope.cancel()
 
-                task_group.start_soon(wrap, partial(self.stream_response, send))
-                await wrap(partial(self.listen_for_disconnect, receive))
+                    task_group.start_soon(wrap, partial(self.stream_response, send))
+                    await wrap(partial(self.listen_for_disconnect, receive))
 
         if self.background is not None:
             await self.background()
index 7232cfd18de02a33056988ffe835a8afe774415d..e4e82077f19659f73efb5c104ac4b0d345c7c148 100644 (file)
@@ -297,6 +297,29 @@ async def test_run_background_tasks_even_if_client_disconnects() -> None:
     assert background_task_run.is_set()
 
 
+def test_run_background_tasks_raise_exceptions(test_client_factory: TestClientFactory) -> None:
+    # test for https://github.com/encode/starlette/issues/2625
+
+    async def sleep_and_set() -> None:
+        await anyio.sleep(0.1)
+        raise ValueError("TEST")
+
+    async def endpoint_with_background_task(_: Request) -> PlainTextResponse:
+        return PlainTextResponse(background=BackgroundTask(sleep_and_set))
+
+    async def passthrough(request: Request, call_next: RequestResponseEndpoint) -> Response:
+        return await call_next(request)
+
+    app = Starlette(
+        middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)],
+        routes=[Route("/", endpoint_with_background_task)],
+    )
+
+    client = test_client_factory(app)
+    with pytest.raises(ValueError, match="TEST"):
+        client.get("/")
+
+
 @pytest.mark.anyio
 async def test_do_not_block_on_background_tasks() -> None:
     response_complete = anyio.Event()