]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Fix `BackgroundTasks` with `BaseHTTPMiddleware` (#2688)
authorAdrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>
Sat, 7 Sep 2024 12:42:54 +0000 (13:42 +0100)
committerGitHub <noreply@github.com>
Sat, 7 Sep 2024 12:42:54 +0000 (13:42 +0100)
* Streaming response early disconnect mode

* Fix BackgroundTasks with BaseHTTPMiddleware

* move comment

* initialize field

---------

Co-authored-by: Dmitry Maliuga <dmaliuga@fireworks.ai>
starlette/middleware/base.py
tests/middleware/test_base.py

index 2ac6f7f7f0afd4e5c2e7d1b5f226536262b433f6..f51b13f733a9739a0ebd7111b82096bf576a8b87 100644 (file)
@@ -206,6 +206,7 @@ class _StreamingResponse(Response):
         self.status_code = status_code
         self.media_type = media_type
         self.init_headers(headers)
+        self.background = None
 
     async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
         if self.info is not None:
@@ -222,3 +223,6 @@ class _StreamingResponse(Response):
             await send({"type": "http.response.body", "body": chunk, "more_body": True})
 
         await send({"type": "http.response.body", "body": b"", "more_body": False})
+
+        if self.background:
+            await self.background()
index 22503865088bafcb62776500f566fa1571fc0d78..15080e5c59675e14d4491db77edba64936d9d872 100644 (file)
@@ -1006,16 +1006,29 @@ def test_pr_1519_comment_1236166180_example() -> None:
 
 @pytest.mark.anyio
 async def test_multiple_middlewares_stacked_client_disconnected() -> None:
+    """
+    Tests for:
+    - https://github.com/encode/starlette/issues/2516
+    - https://github.com/encode/starlette/pull/2687
+    """
+    ordered_events: list[str] = []
+    unordered_events: list[str] = []
+
     class MyMiddleware(BaseHTTPMiddleware):
-        def __init__(self, app: ASGIApp, version: int, events: list[str]) -> None:
+        def __init__(self, app: ASGIApp, version: int) -> None:
             self.version = version
-            self.events = events
             super().__init__(app)
 
         async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
-            self.events.append(f"{self.version}:STARTED")
+            ordered_events.append(f"{self.version}:STARTED")
             res = await call_next(request)
-            self.events.append(f"{self.version}:COMPLETED")
+            ordered_events.append(f"{self.version}:COMPLETED")
+
+            def background() -> None:
+                unordered_events.append(f"{self.version}:BACKGROUND")
+
+            assert res.background is None
+            res.background = BackgroundTask(background)
             return res
 
     async def sleepy(request: Request) -> Response:
@@ -1027,11 +1040,9 @@ async def test_multiple_middlewares_stacked_client_disconnected() -> None:
             raise AssertionError("Should have raised ClientDisconnect")
         return Response(b"")
 
-    events: list[str] = []
-
     app = Starlette(
         routes=[Route("/", sleepy)],
-        middleware=[Middleware(MyMiddleware, version=_ + 1, events=events) for _ in range(10)],
+        middleware=[Middleware(MyMiddleware, version=_ + 1) for _ in range(10)],
     )
 
     scope = {
@@ -1051,7 +1062,7 @@ async def test_multiple_middlewares_stacked_client_disconnected() -> None:
 
     await app(scope, receive().__anext__, send)
 
-    assert events == [
+    assert ordered_events == [
         "1:STARTED",
         "2:STARTED",
         "3:STARTED",
@@ -1074,6 +1085,21 @@ async def test_multiple_middlewares_stacked_client_disconnected() -> None:
         "1:COMPLETED",
     ]
 
+    assert sorted(unordered_events) == sorted(
+        [
+            "1:BACKGROUND",
+            "2:BACKGROUND",
+            "3:BACKGROUND",
+            "4:BACKGROUND",
+            "5:BACKGROUND",
+            "6:BACKGROUND",
+            "7:BACKGROUND",
+            "8:BACKGROUND",
+            "9:BACKGROUND",
+            "10:BACKGROUND",
+        ]
+    )
+
     assert sent == [
         {
             "type": "http.response.start",