]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Consider `FileResponse.chunk_size` when handling multiple ranges (#2703)
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Wed, 25 Sep 2024 07:16:08 +0000 (09:16 +0200)
committerGitHub <noreply@github.com>
Wed, 25 Sep 2024 07:16:08 +0000 (09:16 +0200)
* Take in consideration the `FileResponse.chunk_size` on multiple ranges

* Update starlette/responses.py

* Update starlette/responses.py

* Update starlette/responses.py

Co-authored-by: Frost Ming <mianghong@gmail.com>
---------

Co-authored-by: Frost Ming <mianghong@gmail.com>
pyproject.toml
starlette/responses.py
tests/middleware/test_base.py
tests/test_responses.py

index 52156528e8cd86b0854a3dcb3eafd688e06f5321..b993c2a2d62094cd5b336c0d0b828eaf5cb83bae 100644 (file)
@@ -103,4 +103,5 @@ exclude_lines = [
     "pragma: nocover",
     "if typing.TYPE_CHECKING:",
     "@typing.overload",
+    "raise NotImplementedError",
 ]
index fc92cbab136230184d0a096c318c639741edad79..b951e51254d144bc2d457a3baad8cb3a9427f093 100644 (file)
@@ -374,13 +374,7 @@ class FileResponse(Response):
                 while more_body:
                     chunk = await file.read(self.chunk_size)
                     more_body = len(chunk) == self.chunk_size
-                    await send(
-                        {
-                            "type": "http.response.body",
-                            "body": chunk,
-                            "more_body": more_body,
-                        }
-                    )
+                    await send({"type": "http.response.body", "body": chunk, "more_body": more_body})
 
     async def _handle_single_range(
         self, send: Send, start: int, end: int, file_size: int, send_header_only: bool
@@ -419,10 +413,12 @@ class FileResponse(Response):
         else:
             async with await anyio.open_file(self.path, mode="rb") as file:
                 for start, end in ranges:
-                    await file.seek(start)
-                    chunk = await file.read(min(self.chunk_size, end - start))
                     await send({"type": "http.response.body", "body": header_generator(start, end), "more_body": True})
-                    await send({"type": "http.response.body", "body": chunk, "more_body": True})
+                    await file.seek(start)
+                    while start < end:
+                        chunk = await file.read(min(self.chunk_size, end - start))
+                        start += len(chunk)
+                        await send({"type": "http.response.body", "body": chunk, "more_body": True})
                     await send({"type": "http.response.body", "body": b"\n", "more_body": True})
                 await send(
                     {
index 15080e5c59675e14d4491db77edba64936d9d872..e2375a7b9f866de5f0b2427c32d7e775be76df46 100644 (file)
@@ -289,7 +289,7 @@ async def test_run_background_tasks_even_if_client_disconnects() -> None:
     }
 
     async def receive() -> Message:
-        raise NotImplementedError("Should not be called!")  # pragma: no cover
+        raise NotImplementedError("Should not be called!")
 
     async def send(message: Message) -> None:
         if message["type"] == "http.response.body":
@@ -330,7 +330,7 @@ async def test_do_not_block_on_background_tasks() -> None:
     }
 
     async def receive() -> Message:
-        raise NotImplementedError("Should not be called!")  # pragma: no cover
+        raise NotImplementedError("Should not be called!")
 
     async def send(message: Message) -> None:
         if message["type"] == "http.response.body":
@@ -403,7 +403,7 @@ async def test_run_context_manager_exit_even_if_client_disconnects() -> None:
     }
 
     async def receive() -> Message:
-        raise NotImplementedError("Should not be called!")  # pragma: no cover
+        raise NotImplementedError("Should not be called!")
 
     async def send(message: Message) -> None:
         if message["type"] == "http.response.body":
index 359516c75d4c19530b595e981c56a55f6b48780e..645d26a682eaa26a47a77393398be2123c04cf20 100644 (file)
@@ -4,7 +4,7 @@ import datetime as dt
 import time
 from http.cookies import SimpleCookie
 from pathlib import Path
-from typing import AsyncIterator, Iterator
+from typing import Any, AsyncIterator, Iterator
 
 import anyio
 import pytest
@@ -682,3 +682,58 @@ def test_file_response_insert_ranges(file_response_client: TestClient) -> None:
         "",
         f"--{boundary}--",
     ]
+
+
+@pytest.mark.anyio
+async def test_file_response_multi_small_chunk_size(readme_file: Path) -> None:
+    class SmallChunkSizeFileResponse(FileResponse):
+        chunk_size = 10
+
+    app = SmallChunkSizeFileResponse(path=str(readme_file))
+
+    received_chunks: list[bytes] = []
+    start_message: dict[str, Any] = {}
+
+    async def receive() -> Message:
+        raise NotImplementedError("Should not be called!")
+
+    async def send(message: Message) -> None:
+        if message["type"] == "http.response.start":
+            start_message.update(message)
+        elif message["type"] == "http.response.body":
+            received_chunks.append(message["body"])
+
+    await app({"type": "http", "method": "get", "headers": [(b"range", b"bytes=0-15,20-35,35-50")]}, receive, send)
+    assert start_message["status"] == 206
+
+    headers = Headers(raw=start_message["headers"])
+    assert headers.get("content-type") == "text/plain; charset=utf-8"
+    assert headers.get("accept-ranges") == "bytes"
+    assert "content-length" in headers
+    assert "last-modified" in headers
+    assert "etag" in headers
+    assert headers["content-range"].startswith("multipart/byteranges; boundary=")
+    boundary = headers["content-range"].split("boundary=")[1]
+
+    assert received_chunks == [
+        # Send the part headers.
+        f"--{boundary}\nContent-Type: text/plain; charset=utf-8\nContent-Range: bytes 0-15/526\n\n".encode(),
+        # Send the first chunk (10 bytes).
+        b"# B\xc3\xa1iZ\xc3\xa9\n",
+        # Send the second chunk (6 bytes).
+        b"\nPower",
+        # Send the new line to separate the parts.
+        b"\n",
+        # Send the part headers. We merge the ranges 20-35 and 35-50 into a single part.
+        f"--{boundary}\nContent-Type: text/plain; charset=utf-8\nContent-Range: bytes 20-50/526\n\n".encode(),
+        # Send the first chunk (10 bytes).
+        b"and exquis",
+        # Send the second chunk (10 bytes).
+        b"ite WSGI/A",
+        # Send the third chunk (10 bytes).
+        b"SGI framew",
+        # Send the last chunk (1 byte).
+        b"o",
+        b"\n",
+        f"\n--{boundary}--\n".encode(),
+    ]