]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Make UploadFile check for future rollover (#2962)
authorMichael Honaker <37811263+HonakerM@users.noreply.github.com>
Sun, 20 Jul 2025 17:24:02 +0000 (02:24 +0900)
committerGitHub <noreply@github.com>
Sun, 20 Jul 2025 17:24:02 +0000 (19:24 +0200)
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
starlette/datastructures.py
tests/test_formparsers.py

index 70dacd024c2c0871982e34b0b4636fba8adc352f..38eabec52743cb04dde53b17fdc19fe074a12841 100644 (file)
@@ -428,6 +428,10 @@ class UploadFile:
         self.size = size
         self.headers = headers or Headers()
 
+        # Capture max size from SpooledTemporaryFile if one is provided. This slightly speeds up future checks.
+        # Note 0 means unlimited mirroring SpooledTemporaryFile's __init__
+        self._max_mem_size = getattr(self.file, "_max_size", 0)
+
     @property
     def content_type(self) -> str | None:
         return self.headers.get("content-type", None)
@@ -438,14 +442,24 @@ class UploadFile:
         rolled_to_disk = getattr(self.file, "_rolled", True)
         return not rolled_to_disk
 
+    def _will_roll(self, size_to_add: int) -> bool:
+        # If we're not in_memory then we will always roll
+        if not self._in_memory:
+            return True
+
+        # Check for SpooledTemporaryFile._max_size
+        future_size = self.file.tell() + size_to_add
+        return bool(future_size > self._max_mem_size) if self._max_mem_size else False
+
     async def write(self, data: bytes) -> None:
+        new_data_len = len(data)
         if self.size is not None:
-            self.size += len(data)
+            self.size += new_data_len
 
-        if self._in_memory:
-            self.file.write(data)
-        else:
+        if self._will_roll(new_data_len):
             await run_in_threadpool(self.file.write, data)
+        else:
+            self.file.write(data)
 
     async def read(self, size: int = -1) -> bytes:
         if self._in_memory:
index 35681e78d392f8802483a9d059790b92ac83deeb..58f6a0c738814e00a7e095c0ce9346773c1c5df6 100644 (file)
@@ -1,15 +1,20 @@
 from __future__ import annotations
 
 import os
+import threading
+from collections.abc import Generator
 from contextlib import AbstractContextManager, nullcontext as does_not_raise
+from io import BytesIO
 from pathlib import Path
-from typing import Any
+from tempfile import SpooledTemporaryFile
+from typing import Any, ClassVar
+from unittest import mock
 
 import pytest
 
 from starlette.applications import Starlette
 from starlette.datastructures import UploadFile
-from starlette.formparsers import MultiPartException, _user_safe_decode
+from starlette.formparsers import MultiPartException, MultiPartParser, _user_safe_decode
 from starlette.requests import Request
 from starlette.responses import JSONResponse
 from starlette.routing import Mount
@@ -104,6 +109,22 @@ async def app_read_body(scope: Scope, receive: Receive, send: Send) -> None:
     await response(scope, receive, send)
 
 
+async def app_monitor_thread(scope: Scope, receive: Receive, send: Send) -> None:
+    """Helper app to monitor what thread the app was called on.
+
+    This can later be used to validate thread/event loop operations.
+    """
+    request = Request(scope, receive)
+
+    # Make sure we parse the form
+    await request.form()
+    await request.close()
+
+    # Send back the current thread id
+    response = JSONResponse({"thread_ident": threading.current_thread().ident})
+    await response(scope, receive, send)
+
+
 def make_app_max_parts(max_files: int = 1000, max_fields: int = 1000, max_part_size: int = 1024 * 1024) -> ASGIApp:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         request = Request(scope, receive)
@@ -303,6 +324,47 @@ def test_multipart_request_mixed_files_and_data(tmpdir: Path, test_client_factor
     }
 
 
+class ThreadTrackingSpooledTemporaryFile(SpooledTemporaryFile[bytes]):
+    """Helper class to track which threads performed the rollover operation.
+
+    This is not threadsafe/multi-test safe.
+    """
+
+    rollover_threads: ClassVar[set[int | None]] = set()
+
+    def rollover(self) -> None:
+        ThreadTrackingSpooledTemporaryFile.rollover_threads.add(threading.current_thread().ident)
+        super().rollover()
+
+
+@pytest.fixture
+def mock_spooled_temporary_file() -> Generator[None]:
+    try:
+        with mock.patch("starlette.formparsers.SpooledTemporaryFile", ThreadTrackingSpooledTemporaryFile):
+            yield
+    finally:
+        ThreadTrackingSpooledTemporaryFile.rollover_threads.clear()
+
+
+def test_multipart_request_large_file_rollover_in_background_thread(
+    mock_spooled_temporary_file: None, test_client_factory: TestClientFactory
+) -> None:
+    """Test that Spooled file rollovers happen in background threads."""
+    data = BytesIO(b" " * (MultiPartParser.spool_max_size + 1))
+
+    client = test_client_factory(app_monitor_thread)
+    response = client.post("/", files=[("test_large", data)])
+    assert response.status_code == 200
+
+    # Parse the event thread id from the API response and ensure we have one
+    app_thread_ident = response.json().get("thread_ident")
+    assert app_thread_ident is not None
+
+    # Ensure the app thread was not the same as the rollover one and that a rollover thread exists
+    assert app_thread_ident not in ThreadTrackingSpooledTemporaryFile.rollover_threads
+    assert len(ThreadTrackingSpooledTemporaryFile.rollover_threads) == 1
+
+
 def test_multipart_request_with_charset_for_filename(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     client = test_client_factory(app)
     response = client.post(