]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Stop `body_stream` in case `more_body=False` (#2194)
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Thu, 13 Jul 2023 07:48:10 +0000 (09:48 +0200)
committerGitHub <noreply@github.com>
Thu, 13 Jul 2023 07:48:10 +0000 (01:48 -0600)
starlette/middleware/base.py
tests/middleware/test_base.py

index 2ff0e047b0c179c8a6d6e73a92ff7ce8953b1d9a..170a805a758452f5ccc5b11cb6b96987f9cdbdcd 100644 (file)
@@ -170,6 +170,8 @@ class BaseHTTPMiddleware:
                         body = message.get("body", b"")
                         if body:
                             yield body
+                        if not message.get("more_body", False):
+                            break
 
                 if app_exc is not None:
                     raise app_exc
index f7dcf521c908f4ec0fac96ea9f2979a9e458cfe6..cf4780cce725dfef250f2f5d883205ef848e86f3 100644 (file)
@@ -265,6 +265,71 @@ async def test_run_background_tasks_even_if_client_disconnects():
     assert background_task_run.is_set()
 
 
+@pytest.mark.anyio
+async def test_do_not_block_on_background_tasks():
+    request_body_sent = False
+    response_complete = anyio.Event()
+    events: List[Union[str, Message]] = []
+
+    async def sleep_and_set():
+        events.append("Background task started")
+        await anyio.sleep(0.1)
+        events.append("Background task finished")
+
+    async def endpoint_with_background_task(_):
+        return PlainTextResponse(
+            content="Hello", background=BackgroundTask(sleep_and_set)
+        )
+
+    async def passthrough(
+        request: Request, call_next: Callable[[Request], Awaitable[Response]]
+    ) -> Response:
+        return await call_next(request)
+
+    app = Starlette(
+        middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)],
+        routes=[Route("/", endpoint_with_background_task)],
+    )
+
+    scope = {
+        "type": "http",
+        "version": "3",
+        "method": "GET",
+        "path": "/",
+    }
+
+    async def receive() -> Message:
+        nonlocal request_body_sent
+        if not request_body_sent:
+            request_body_sent = True
+            return {"type": "http.request", "body": b"", "more_body": False}
+        await response_complete.wait()
+        return {"type": "http.disconnect"}
+
+    async def send(message: Message):
+        if message["type"] == "http.response.body":
+            events.append(message)
+            if not message.get("more_body", False):
+                response_complete.set()
+
+    async with anyio.create_task_group() as tg:
+        tg.start_soon(app, scope, receive, send)
+        tg.start_soon(app, scope, receive, send)
+
+    # Without the fix, the background tasks would start and finish before the
+    # last http.response.body is sent.
+    assert events == [
+        {"body": b"Hello", "more_body": True, "type": "http.response.body"},
+        {"body": b"", "more_body": False, "type": "http.response.body"},
+        {"body": b"Hello", "more_body": True, "type": "http.response.body"},
+        {"body": b"", "more_body": False, "type": "http.response.body"},
+        "Background task started",
+        "Background task started",
+        "Background task finished",
+        "Background task finished",
+    ]
+
+
 @pytest.mark.anyio
 async def test_run_context_manager_exit_even_if_client_disconnects():
     # test for https://github.com/encode/starlette/issues/1678#issuecomment-1172916042