]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Don't poll for disconnects in BaseHTTPMiddleware via StreamingResponse (#2620)
authorAdrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>
Sun, 1 Sep 2024 07:04:50 +0000 (02:04 -0500)
committerGitHub <noreply@github.com>
Sun, 1 Sep 2024 07:04:50 +0000 (09:04 +0200)
* Don't poll for disconnects in BaseHTTPMiddleware via StreamingResponse

Fixes #2516

* add test

* fmt

* Update tests/middleware/test_base.py

Co-authored-by: Mikkel Duif <mikkel@duifs.dk>
* add test for line now missing coverage

* more coverage, fix test

* add comment

* fmt

* tweak test

* fix

* fix coverage

* relint

---------

Co-authored-by: Mikkel Duif <mikkel@duifs.dk>
starlette/middleware/base.py
tests/middleware/test_base.py

index 4e5054d7a294223754b0c1f054d7f56d9248380e..87c0f51f8ba7cf5704006361c1c9e79b2fba6083 100644 (file)
@@ -6,9 +6,8 @@ import anyio
 from anyio.abc import ObjectReceiveStream, ObjectSendStream
 
 from starlette._utils import collapse_excgroups
-from starlette.background import BackgroundTask
 from starlette.requests import ClientDisconnect, Request
-from starlette.responses import ContentStream, Response, StreamingResponse
+from starlette.responses import AsyncContentStream, Response
 from starlette.types import ASGIApp, Message, Receive, Scope, Send
 
 RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
@@ -56,6 +55,7 @@ class _CachedRequest(Request):
                 # at this point a disconnect is all that we should be receiving
                 # if we get something else, things went wrong somewhere
                 raise RuntimeError(f"Unexpected message received: {msg['type']}")
+            self._wrapped_rcv_disconnected = True
             return msg
 
         # wrapped_rcv state 3: not yet consumed
@@ -198,20 +198,33 @@ class BaseHTTPMiddleware:
         raise NotImplementedError()  # pragma: no cover
 
 
-class _StreamingResponse(StreamingResponse):
+class _StreamingResponse(Response):
     def __init__(
         self,
-        content: ContentStream,
+        content: AsyncContentStream,
         status_code: int = 200,
         headers: typing.Mapping[str, str] | None = None,
         media_type: str | None = None,
-        background: BackgroundTask | None = None,
         info: typing.Mapping[str, typing.Any] | None = None,
     ) -> None:
-        self._info = info
-        super().__init__(content, status_code, headers, media_type, background)
+        self.info = info
+        self.body_iterator = content
+        self.status_code = status_code
+        self.media_type = media_type
+        self.init_headers(headers)
 
-    async def stream_response(self, send: Send) -> None:
-        if self._info:
-            await send({"type": "http.response.debug", "info": self._info})
-        return await super().stream_response(send)
+    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+        if self.info is not None:
+            await send({"type": "http.response.debug", "info": self.info})
+        await send(
+            {
+                "type": "http.response.start",
+                "status": self.status_code,
+                "headers": self.raw_headers,
+            }
+        )
+
+        async for chunk in self.body_iterator:
+            await send({"type": "http.response.body", "body": chunk, "more_body": True})
+
+        await send({"type": "http.response.body", "body": b"", "more_body": False})
index 3ad1751a20a425abd07e2593a9dc12d4d6ebc65f..8e410cb1515da7e21c292a259b6f3bf8966058cf 100644 (file)
@@ -5,6 +5,7 @@ from contextlib import AsyncExitStack
 from typing import (
     Any,
     AsyncGenerator,
+    AsyncIterator,
     Generator,
 )
 
@@ -16,7 +17,7 @@ from starlette.applications import Starlette
 from starlette.background import BackgroundTask
 from starlette.middleware import Middleware, _MiddlewareClass
 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
-from starlette.requests import Request
+from starlette.requests import ClientDisconnect, Request
 from starlette.responses import PlainTextResponse, Response, StreamingResponse
 from starlette.routing import Route, WebSocketRoute
 from starlette.testclient import TestClient
@@ -260,7 +261,6 @@ def test_contextvars(
 @pytest.mark.anyio
 async def test_run_background_tasks_even_if_client_disconnects() -> None:
     # test for https://github.com/encode/starlette/issues/1438
-    request_body_sent = False
     response_complete = anyio.Event()
     background_task_run = anyio.Event()
 
@@ -293,13 +293,7 @@ async def test_run_background_tasks_even_if_client_disconnects() -> None:
     }
 
     async def receive() -> Message:
-        nonlocal request_body_sent
-        if not request_body_sent:
-            request_body_sent = True
-            return {"type": "http.request", "body": b"", "more_body": False}
-        # We simulate a client that disconnects immediately after receiving the response
-        await response_complete.wait()
-        return {"type": "http.disconnect"}
+        raise NotImplementedError("Should not be called!")  # pragma: no cover
 
     async def send(message: Message) -> None:
         if message["type"] == "http.response.body":
@@ -313,7 +307,6 @@ async def test_run_background_tasks_even_if_client_disconnects() -> None:
 
 @pytest.mark.anyio
 async def test_do_not_block_on_background_tasks() -> None:
-    request_body_sent = False
     response_complete = anyio.Event()
     events: list[str | Message] = []
 
@@ -345,12 +338,7 @@ async def test_do_not_block_on_background_tasks() -> None:
     }
 
     async def receive() -> Message:
-        nonlocal request_body_sent
-        if not request_body_sent:
-            request_body_sent = True
-            return {"type": "http.request", "body": b"", "more_body": False}
-        await response_complete.wait()
-        return {"type": "http.disconnect"}
+        raise NotImplementedError("Should not be called!")  # pragma: no cover
 
     async def send(message: Message) -> None:
         if message["type"] == "http.response.body":
@@ -379,7 +367,6 @@ async def test_do_not_block_on_background_tasks() -> None:
 @pytest.mark.anyio
 async def test_run_context_manager_exit_even_if_client_disconnects() -> None:
     # test for https://github.com/encode/starlette/issues/1678#issuecomment-1172916042
-    request_body_sent = False
     response_complete = anyio.Event()
     context_manager_exited = anyio.Event()
 
@@ -424,13 +411,7 @@ async def test_run_context_manager_exit_even_if_client_disconnects() -> None:
     }
 
     async def receive() -> Message:
-        nonlocal request_body_sent
-        if not request_body_sent:
-            request_body_sent = True
-            return {"type": "http.request", "body": b"", "more_body": False}
-        # We simulate a client that disconnects immediately after receiving the response
-        await response_complete.wait()
-        return {"type": "http.disconnect"}
+        raise NotImplementedError("Should not be called!")  # pragma: no cover
 
     async def send(message: Message) -> None:
         if message["type"] == "http.response.body":
@@ -778,7 +759,9 @@ async def test_read_request_stream_in_dispatch_wrapping_app_calls_body() -> None
         yield {"type": "http.request", "body": b"1", "more_body": True}
         yield {"type": "http.request", "body": b"2", "more_body": True}
         yield {"type": "http.request", "body": b"3"}
-        await anyio.sleep(float("inf"))
+        raise AssertionError(  # pragma: no cover
+            "Should not be called, no need to poll for disconnect"
+        )
 
     sent: list[Message] = []
 
@@ -1033,3 +1016,139 @@ def test_pr_1519_comment_1236166180_example() -> None:
     resp.raise_for_status()
 
     assert bodies == [b"Hello, World!-foo"]
+
+
+@pytest.mark.anyio
+async def test_multiple_middlewares_stacked_client_disconnected() -> None:
+    class MyMiddleware(BaseHTTPMiddleware):
+        def __init__(self, app: ASGIApp, version: int, events: list[str]) -> None:
+            self.version = version
+            self.events = events
+            super().__init__(app)
+
+        async def dispatch(
+            self, request: Request, call_next: RequestResponseEndpoint
+        ) -> Response:
+            self.events.append(f"{self.version}:STARTED")
+            res = await call_next(request)
+            self.events.append(f"{self.version}:COMPLETED")
+            return res
+
+    async def sleepy(request: Request) -> Response:
+        try:
+            await request.body()
+        except ClientDisconnect:
+            pass
+        else:  # pragma: no cover
+            raise AssertionError("Should have raised ClientDisconnect")
+        return Response(b"")
+
+    events: list[str] = []
+
+    app = Starlette(
+        routes=[Route("/", sleepy)],
+        middleware=[
+            Middleware(MyMiddleware, version=_ + 1, events=events) for _ in range(10)
+        ],
+    )
+
+    scope = {
+        "type": "http",
+        "version": "3",
+        "method": "GET",
+        "path": "/",
+    }
+
+    async def receive() -> AsyncIterator[Message]:
+        yield {"type": "http.disconnect"}
+
+    sent: list[Message] = []
+
+    async def send(message: Message) -> None:
+        sent.append(message)
+
+    await app(scope, receive().__anext__, send)
+
+    assert events == [
+        "1:STARTED",
+        "2:STARTED",
+        "3:STARTED",
+        "4:STARTED",
+        "5:STARTED",
+        "6:STARTED",
+        "7:STARTED",
+        "8:STARTED",
+        "9:STARTED",
+        "10:STARTED",
+        "10:COMPLETED",
+        "9:COMPLETED",
+        "8:COMPLETED",
+        "7:COMPLETED",
+        "6:COMPLETED",
+        "5:COMPLETED",
+        "4:COMPLETED",
+        "3:COMPLETED",
+        "2:COMPLETED",
+        "1:COMPLETED",
+    ]
+
+    assert sent == [
+        {
+            "type": "http.response.start",
+            "status": 200,
+            "headers": [(b"content-length", b"0")],
+        },
+        {"type": "http.response.body", "body": b"", "more_body": False},
+    ]
+
+
+@pytest.mark.anyio
+@pytest.mark.parametrize("send_body", [True, False])
+async def test_poll_for_disconnect_repeated(send_body: bool) -> None:
+    async def app_poll_disconnect(scope: Scope, receive: Receive, send: Send) -> None:
+        for _ in range(2):
+            msg = await receive()
+            while msg["type"] == "http.request":
+                msg = await receive()
+            assert msg["type"] == "http.disconnect"
+        await Response(b"good!")(scope, receive, send)
+
+    class MyMiddleware(BaseHTTPMiddleware):
+        async def dispatch(
+            self, request: Request, call_next: RequestResponseEndpoint
+        ) -> Response:
+            return await call_next(request)
+
+    app = MyMiddleware(app_poll_disconnect)
+
+    scope = {
+        "type": "http",
+        "version": "3",
+        "method": "GET",
+        "path": "/",
+    }
+
+    async def receive() -> AsyncIterator[Message]:
+        # the key here is that we only ever send 1 htt.disconnect message
+        if send_body:
+            yield {"type": "http.request", "body": b"hello", "more_body": True}
+            yield {"type": "http.request", "body": b"", "more_body": False}
+        yield {"type": "http.disconnect"}
+        raise AssertionError("Should not be called, would hang")  # pragma: no cover
+
+    sent: list[Message] = []
+
+    async def send(message: Message) -> None:
+        sent.append(message)
+
+    await app(scope, receive().__anext__, send)
+
+    assert sent == [
+        {
+            "type": "http.response.start",
+            "status": 200,
+            "headers": [(b"content-length", b"5")],
+        },
+        {"type": "http.response.body", "body": b"good!", "more_body": True},
+        {"type": "http.response.body", "body": b"", "more_body": False},
+    ]