]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add support for ASGI `pathsend` extension (#2671)
authorGiovanni Barillari <giovanni.barillari@sentry.io>
Thu, 29 May 2025 13:17:44 +0000 (15:17 +0200)
committerGitHub <noreply@github.com>
Thu, 29 May 2025 13:17:44 +0000 (13:17 +0000)
starlette/middleware/base.py
starlette/middleware/gzip.py
starlette/responses.py
tests/middleware/test_base.py
tests/middleware/test_gzip.py
tests/test_responses.py

index 4d139c25b73322d17eabe16076668bdb1f7ebfb9..577918eb91b18183018b5e48078257e4ae695502 100644 (file)
@@ -1,17 +1,19 @@
 from __future__ import annotations
 
-from collections.abc import AsyncGenerator, Awaitable, Mapping
-from typing import Any, Callable, TypeVar
+from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Mapping, MutableMapping
+from typing import Any, Callable, TypeVar, Union
 
 import anyio
 
 from starlette._utils import collapse_excgroups
 from starlette.requests import ClientDisconnect, Request
-from starlette.responses import AsyncContentStream, Response
+from starlette.responses import Response
 from starlette.types import ASGIApp, Message, Receive, Scope, Send
 
 RequestResponseEndpoint = Callable[[Request], Awaitable[Response]]
 DispatchFunction = Callable[[Request, RequestResponseEndpoint], Awaitable[Response]]
+BodyStreamGenerator = AsyncGenerator[Union[bytes, MutableMapping[str, Any]], None]
+AsyncContentStream = AsyncIterable[Union[str, bytes, memoryview, MutableMapping[str, Any]]]
 T = TypeVar("T")
 
 
@@ -159,9 +161,12 @@ class BaseHTTPMiddleware:
 
             assert message["type"] == "http.response.start"
 
-            async def body_stream() -> AsyncGenerator[bytes, None]:
+            async def body_stream() -> BodyStreamGenerator:
                 async for message in recv_stream:
-                    assert message["type"] == "http.response.body"
+                    if message["type"] == "http.response.pathsend":
+                        yield message
+                        break
+                    assert message["type"] == "http.response.body", f"Unexpected message: {message}"
                     body = message.get("body", b"")
                     if body:
                         yield body
@@ -214,10 +219,17 @@ class _StreamingResponse(Response):
             }
         )
 
+        should_close_body = True
         async for chunk in self.body_iterator:
+            if isinstance(chunk, dict):
+                # We got an ASGI message which is not response body (eg: pathsend)
+                should_close_body = False
+                await send(chunk)
+                continue
             await send({"type": "http.response.body", "body": chunk, "more_body": True})
 
-        await send({"type": "http.response.body", "body": b"", "more_body": False})
+        if should_close_body:
+            await send({"type": "http.response.body", "body": b"", "more_body": False})
 
         if self.background:
             await self.background()
index 502f055264a582ce3b7155cb346e5168c564ec9e..abd898b2daffb4c1cc2b47b5931060ce34f2abfd 100644 (file)
@@ -93,7 +93,7 @@ class IdentityResponder:
 
                 await self.send(self.initial_message)
                 await self.send(message)
-        elif message_type == "http.response.body":  # pragma: no branch
+        elif message_type == "http.response.body":
             # Remaining body in streaming response.
             body = message.get("body", b"")
             more_body = message.get("more_body", False)
@@ -101,6 +101,10 @@ class IdentityResponder:
             message["body"] = self.apply_compression(body, more_body=more_body)
 
             await self.send(message)
+        elif message_type == "http.response.pathsend":  # pragma: no branch
+            # Don't apply GZip to pathsend responses
+            await self.send(self.initial_message)
+            await self.send(message)
 
     def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
         """Apply compression on the response body.
index b9d334eaa45a7297769e54a99216a4b9ed24a94d..031633b158c4c5ad5462cced7424dd92055c241f 100644 (file)
@@ -346,6 +346,8 @@ class FileResponse(Response):
 
     async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
         send_header_only: bool = scope["method"].upper() == "HEAD"
+        send_pathsend: bool = "http.response.pathsend" in scope.get("extensions", {})
+
         if self.stat_result is None:
             try:
                 stat_result = await anyio.to_thread.run_sync(os.stat, self.path)
@@ -364,7 +366,7 @@ class FileResponse(Response):
         http_if_range = headers.get("if-range")
 
         if http_range is None or (http_if_range is not None and not self._should_use_range(http_if_range)):
-            await self._handle_simple(send, send_header_only)
+            await self._handle_simple(send, send_header_only, send_pathsend)
         else:
             try:
                 ranges = self._parse_range_header(http_range, stat_result.st_size)
@@ -383,10 +385,12 @@ class FileResponse(Response):
         if self.background is not None:
             await self.background()
 
-    async def _handle_simple(self, send: Send, send_header_only: bool) -> None:
+    async def _handle_simple(self, send: Send, send_header_only: bool, send_pathsend: bool) -> None:
         await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers})
         if send_header_only:
             await send({"type": "http.response.body", "body": b"", "more_body": False})
+        elif send_pathsend:
+            await send({"type": "http.response.pathsend", "path": str(self.path)})
         else:
             async with await anyio.open_file(self.path, mode="rb") as file:
                 more_body = True
index 427ec44ac1f66ae5fd31dc40698febac60b9a96c..d00bd6ab5f01f8fe3df92b39fa0f946633600b5b 100644 (file)
@@ -3,6 +3,7 @@ from __future__ import annotations
 import contextvars
 from collections.abc import AsyncGenerator, AsyncIterator, Generator
 from contextlib import AsyncExitStack
+from pathlib import Path
 from typing import Any
 
 import anyio
@@ -14,7 +15,7 @@ from starlette.background import BackgroundTask
 from starlette.middleware import Middleware, _MiddlewareFactory
 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
 from starlette.requests import ClientDisconnect, Request
-from starlette.responses import PlainTextResponse, Response, StreamingResponse
+from starlette.responses import FileResponse, PlainTextResponse, Response, StreamingResponse
 from starlette.routing import Route, WebSocketRoute
 from starlette.testclient import TestClient
 from starlette.types import ASGIApp, Message, Receive, Scope, Send
@@ -1198,3 +1199,47 @@ async def test_poll_for_disconnect_repeated(send_body: bool) -> None:
         {"type": "http.response.body", "body": b"good!", "more_body": True},
         {"type": "http.response.body", "body": b"", "more_body": False},
     ]
+
+
+@pytest.mark.anyio
+async def test_asgi_pathsend_events(tmpdir: Path) -> None:
+    path = tmpdir / "example.txt"
+    with path.open("w") as file:
+        file.write("<file content>")
+
+    response_complete = anyio.Event()
+    events: list[Message] = []
+
+    async def endpoint_with_pathsend(_: Request) -> FileResponse:
+        return FileResponse(path)
+
+    async def passthrough(request: Request, call_next: RequestResponseEndpoint) -> Response:
+        return await call_next(request)
+
+    app = Starlette(
+        middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)],
+        routes=[Route("/", endpoint_with_pathsend)],
+    )
+
+    scope = {
+        "type": "http",
+        "version": "3",
+        "method": "GET",
+        "path": "/",
+        "headers": [],
+        "extensions": {"http.response.pathsend": {}},
+    }
+
+    async def receive() -> Message:
+        raise NotImplementedError("Should not be called!")  # pragma: no cover
+
+    async def send(message: Message) -> None:
+        events.append(message)
+        if message["type"] == "http.response.pathsend":
+            response_complete.set()
+
+    await app(scope, receive, send)
+
+    assert len(events) == 2
+    assert events[0]["type"] == "http.response.start"
+    assert events[1]["type"] == "http.response.pathsend"
index 48ded6ae8a126d9784af922bc483e2f07fc3cef4..3a2f2e0f3cddd017dc1e6cf9de46ef6aab48d765 100644 (file)
@@ -1,11 +1,16 @@
 from __future__ import annotations
 
+from pathlib import Path
+
+import pytest
+
 from starlette.applications import Starlette
 from starlette.middleware import Middleware
 from starlette.middleware.gzip import GZipMiddleware
 from starlette.requests import Request
-from starlette.responses import ContentStream, PlainTextResponse, StreamingResponse
+from starlette.responses import ContentStream, FileResponse, PlainTextResponse, StreamingResponse
 from starlette.routing import Route
+from starlette.types import Message
 from tests.types import TestClientFactory
 
 
@@ -156,3 +161,42 @@ def test_gzip_ignored_on_server_sent_events(test_client_factory: TestClientFacto
     assert response.text == "x" * 4000
     assert "Content-Encoding" not in response.headers
     assert "Content-Length" not in response.headers
+
+
+@pytest.mark.anyio
+async def test_gzip_ignored_for_pathsend_responses(tmpdir: Path) -> None:
+    path = tmpdir / "example.txt"
+    with path.open("w") as file:
+        file.write("<file content>")
+
+    events: list[Message] = []
+
+    async def endpoint_with_pathsend(request: Request) -> FileResponse:
+        _ = await request.body()
+        return FileResponse(path)
+
+    app = Starlette(
+        routes=[Route("/", endpoint=endpoint_with_pathsend)],
+        middleware=[Middleware(GZipMiddleware)],
+    )
+
+    scope = {
+        "type": "http",
+        "version": "3",
+        "method": "GET",
+        "path": "/",
+        "headers": [(b"accept-encoding", b"gzip, text")],
+        "extensions": {"http.response.pathsend": {}},
+    }
+
+    async def receive() -> Message:
+        return {"type": "http.request", "body": b"", "more_body": False}
+
+    async def send(message: Message) -> None:
+        events.append(message)
+
+    await app(scope, receive, send)
+
+    assert len(events) == 2
+    assert events[0]["type"] == "http.response.start"
+    assert events[1]["type"] == "http.response.pathsend"
index baf7c98f9cef1189cf3647580e66a77cf3389373..561fcad86796d6616e246f1c6893358a15c8fa85 100644 (file)
@@ -354,6 +354,38 @@ def test_file_response_with_range_header(tmp_path: Path, test_client_factory: Te
     assert response.headers["content-range"] == f"bytes 0-4/{len(content)}"
 
 
+@pytest.mark.anyio
+async def test_file_response_with_pathsend(tmpdir: Path) -> None:
+    path = tmpdir / "xyz"
+    content = b"<file content>" * 1000
+    with open(path, "wb") as file:
+        file.write(content)
+
+    app = FileResponse(path=path, filename="example.png")
+
+    async def receive() -> Message:  # type: ignore[empty-body]
+        ...  # pragma: no cover
+
+    async def send(message: Message) -> None:
+        if message["type"] == "http.response.start":
+            assert message["status"] == status.HTTP_200_OK
+            headers = Headers(raw=message["headers"])
+            assert headers["content-type"] == "image/png"
+            assert "content-length" in headers
+            assert "content-disposition" in headers
+            assert "last-modified" in headers
+            assert "etag" in headers
+        elif message["type"] == "http.response.pathsend":  # pragma: no branch
+            assert message["path"] == str(path)
+
+    # Since the TestClient doesn't support `pathsend`, we need to test this directly.
+    await app(
+        {"type": "http", "method": "get", "headers": [], "extensions": {"http.response.pathsend": {}}},
+        receive,
+        send,
+    )
+
+
 def test_set_cookie(test_client_factory: TestClientFactory, monkeypatch: pytest.MonkeyPatch) -> None:
     # Mock time used as a reference for `Expires` by stdlib `SimpleCookie`.
     mocked_now = dt.datetime(2037, 1, 22, 12, 0, 0, tzinfo=dt.timezone.utc)