From: Marcelo Trylesinski Date: Mon, 18 Nov 2024 20:56:13 +0000 (+0100) Subject: Support for `send` client disconnect to HTTP (#2732) X-Git-Tag: 0.42.0~16 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=7d586f7e39f3e83e8b86636cd41d75e9a879a791;p=thirdparty%2Fstarlette.git Support for `send` client disconnect to HTTP (#2732) --- diff --git a/starlette/responses.py b/starlette/responses.py index 790aa7eb..fd156f84 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -21,6 +21,7 @@ from starlette._compat import md5_hexdigest from starlette.background import BackgroundTask from starlette.concurrency import iterate_in_threadpool from starlette.datastructures import URL, Headers, MutableHeaders +from starlette.requests import ClientDisconnect from starlette.types import Receive, Scope, Send @@ -249,14 +250,22 @@ class StreamingResponse(Response): await send({"type": "http.response.body", "body": b"", "more_body": False}) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - async with anyio.create_task_group() as task_group: + spec_version = tuple(map(int, scope.get("asgi", {}).get("spec_version", "2.0").split("."))) - async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None: - await func() - task_group.cancel_scope.cancel() + if spec_version >= (2, 4): + try: + await self.stream_response(send) + except OSError: + raise ClientDisconnect() + else: + async with anyio.create_task_group() as task_group: + + async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None: + await func() + task_group.cancel_scope.cancel() - task_group.start_soon(wrap, partial(self.stream_response, send)) - await wrap(partial(self.listen_for_disconnect, receive)) + task_group.start_soon(wrap, partial(self.stream_response, send)) + await wrap(partial(self.listen_for_disconnect, receive)) if self.background is not None: await self.background() diff --git a/tests/test_responses.py b/tests/test_responses.py index be0701a5..298bc3b0 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -4,7 +4,7 @@ import datetime as dt import time from http.cookies import SimpleCookie from pathlib import Path -from typing import Any, AsyncIterator, Iterator +from typing import Any, AsyncGenerator, AsyncIterator, Iterator import anyio import pytest @@ -12,7 +12,7 @@ import pytest from starlette import status from starlette.background import BackgroundTask from starlette.datastructures import Headers -from starlette.requests import Request +from starlette.requests import ClientDisconnect, Request from starlette.responses import FileResponse, JSONResponse, RedirectResponse, Response, StreamingResponse from starlette.testclient import TestClient from starlette.types import Message, Receive, Scope, Send @@ -542,6 +542,39 @@ async def test_streaming_response_stops_if_receiving_http_disconnect() -> None: assert not cancel_scope.cancel_called, "Content streaming should stop itself." +@pytest.mark.anyio +async def test_streaming_response_on_client_disconnects() -> None: + chunks = bytearray() + streamed = False + + async def receive_disconnect() -> Message: + raise NotImplementedError + + async def send(message: Message) -> None: + nonlocal streamed + if message["type"] == "http.response.body": + if not streamed: + chunks.extend(message.get("body", b"")) + streamed = True + else: + raise OSError + + async def stream_indefinitely() -> AsyncGenerator[bytes, None]: + while True: + await anyio.sleep(0) + yield b"chunk" + + stream = stream_indefinitely() + response = StreamingResponse(content=stream) + + with anyio.move_on_after(1) as cancel_scope: + with pytest.raises(ClientDisconnect): + await response({"asgi": {"spec_version": "2.4"}}, receive_disconnect, send) + assert not cancel_scope.cancel_called, "Content streaming should stop itself." + assert chunks == b"chunk" + await stream.aclose() + + README = """\ # BáiZé