]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Support for `send` client disconnect to HTTP (#2732)
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Mon, 18 Nov 2024 20:56:13 +0000 (21:56 +0100)
committerGitHub <noreply@github.com>
Mon, 18 Nov 2024 20:56:13 +0000 (21:56 +0100)
starlette/responses.py
tests/test_responses.py

index 790aa7ebc1cf72b6e0e480c6b7c7530aa0164e60..fd156f84120cd582be7189f0e017250775f5c22e 100644 (file)
@@ -21,6 +21,7 @@ from starlette._compat import md5_hexdigest
 from starlette.background import BackgroundTask
 from starlette.concurrency import iterate_in_threadpool
 from starlette.datastructures import URL, Headers, MutableHeaders
+from starlette.requests import ClientDisconnect
 from starlette.types import Receive, Scope, Send
 
 
@@ -249,14 +250,22 @@ class StreamingResponse(Response):
         await send({"type": "http.response.body", "body": b"", "more_body": False})
 
     async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
-        async with anyio.create_task_group() as task_group:
+        spec_version = tuple(map(int, scope.get("asgi", {}).get("spec_version", "2.0").split(".")))
 
-            async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
-                await func()
-                task_group.cancel_scope.cancel()
+        if spec_version >= (2, 4):
+            try:
+                await self.stream_response(send)
+            except OSError:
+                raise ClientDisconnect()
+        else:
+            async with anyio.create_task_group() as task_group:
+
+                async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
+                    await func()
+                    task_group.cancel_scope.cancel()
 
-            task_group.start_soon(wrap, partial(self.stream_response, send))
-            await wrap(partial(self.listen_for_disconnect, receive))
+                task_group.start_soon(wrap, partial(self.stream_response, send))
+                await wrap(partial(self.listen_for_disconnect, receive))
 
         if self.background is not None:
             await self.background()
index be0701a5cfc5d4a55f8ab45b4c9058e276b4af19..298bc3b0bbd6708dc1de63864b35ba7fe7ee08cd 100644 (file)
@@ -4,7 +4,7 @@ import datetime as dt
 import time
 from http.cookies import SimpleCookie
 from pathlib import Path
-from typing import Any, AsyncIterator, Iterator
+from typing import Any, AsyncGenerator, AsyncIterator, Iterator
 
 import anyio
 import pytest
@@ -12,7 +12,7 @@ import pytest
 from starlette import status
 from starlette.background import BackgroundTask
 from starlette.datastructures import Headers
-from starlette.requests import Request
+from starlette.requests import ClientDisconnect, Request
 from starlette.responses import FileResponse, JSONResponse, RedirectResponse, Response, StreamingResponse
 from starlette.testclient import TestClient
 from starlette.types import Message, Receive, Scope, Send
@@ -542,6 +542,39 @@ async def test_streaming_response_stops_if_receiving_http_disconnect() -> None:
     assert not cancel_scope.cancel_called, "Content streaming should stop itself."
 
 
+@pytest.mark.anyio
+async def test_streaming_response_on_client_disconnects() -> None:
+    chunks = bytearray()
+    streamed = False
+
+    async def receive_disconnect() -> Message:
+        raise NotImplementedError
+
+    async def send(message: Message) -> None:
+        nonlocal streamed
+        if message["type"] == "http.response.body":
+            if not streamed:
+                chunks.extend(message.get("body", b""))
+                streamed = True
+            else:
+                raise OSError
+
+    async def stream_indefinitely() -> AsyncGenerator[bytes, None]:
+        while True:
+            await anyio.sleep(0)
+            yield b"chunk"
+
+    stream = stream_indefinitely()
+    response = StreamingResponse(content=stream)
+
+    with anyio.move_on_after(1) as cancel_scope:
+        with pytest.raises(ClientDisconnect):
+            await response({"asgi": {"spec_version": "2.4"}}, receive_disconnect, send)
+    assert not cancel_scope.cancel_called, "Content streaming should stop itself."
+    assert chunks == b"chunk"
+    await stream.aclose()
+
+
 README = """\
 # BáiZé