From: Michael Honaker <37811263+HonakerM@users.noreply.github.com> Date: Sun, 20 Jul 2025 17:24:02 +0000 (+0900) Subject: Make UploadFile check for future rollover (#2962) X-Git-Tag: 0.47.2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9f7ec2eb512fcc3fe90b43cb9dd9e1d08696bec1;p=thirdparty%2Fstarlette.git Make UploadFile check for future rollover (#2962) Co-authored-by: Marcelo Trylesinski --- diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 70dacd02..38eabec5 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -428,6 +428,10 @@ class UploadFile: self.size = size self.headers = headers or Headers() + # Capture max size from SpooledTemporaryFile if one is provided. This slightly speeds up future checks. + # Note 0 means unlimited mirroring SpooledTemporaryFile's __init__ + self._max_mem_size = getattr(self.file, "_max_size", 0) + @property def content_type(self) -> str | None: return self.headers.get("content-type", None) @@ -438,14 +442,24 @@ class UploadFile: rolled_to_disk = getattr(self.file, "_rolled", True) return not rolled_to_disk + def _will_roll(self, size_to_add: int) -> bool: + # If we're not in_memory then we will always roll + if not self._in_memory: + return True + + # Check for SpooledTemporaryFile._max_size + future_size = self.file.tell() + size_to_add + return bool(future_size > self._max_mem_size) if self._max_mem_size else False + async def write(self, data: bytes) -> None: + new_data_len = len(data) if self.size is not None: - self.size += len(data) + self.size += new_data_len - if self._in_memory: - self.file.write(data) - else: + if self._will_roll(new_data_len): await run_in_threadpool(self.file.write, data) + else: + self.file.write(data) async def read(self, size: int = -1) -> bytes: if self._in_memory: diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index 35681e78..58f6a0c7 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -1,15 +1,20 @@ from __future__ import annotations import os +import threading +from collections.abc import Generator from contextlib import AbstractContextManager, nullcontext as does_not_raise +from io import BytesIO from pathlib import Path -from typing import Any +from tempfile import SpooledTemporaryFile +from typing import Any, ClassVar +from unittest import mock import pytest from starlette.applications import Starlette from starlette.datastructures import UploadFile -from starlette.formparsers import MultiPartException, _user_safe_decode +from starlette.formparsers import MultiPartException, MultiPartParser, _user_safe_decode from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Mount @@ -104,6 +109,22 @@ async def app_read_body(scope: Scope, receive: Receive, send: Send) -> None: await response(scope, receive, send) +async def app_monitor_thread(scope: Scope, receive: Receive, send: Send) -> None: + """Helper app to monitor what thread the app was called on. + + This can later be used to validate thread/event loop operations. + """ + request = Request(scope, receive) + + # Make sure we parse the form + await request.form() + await request.close() + + # Send back the current thread id + response = JSONResponse({"thread_ident": threading.current_thread().ident}) + await response(scope, receive, send) + + def make_app_max_parts(max_files: int = 1000, max_fields: int = 1000, max_part_size: int = 1024 * 1024) -> ASGIApp: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) @@ -303,6 +324,47 @@ def test_multipart_request_mixed_files_and_data(tmpdir: Path, test_client_factor } +class ThreadTrackingSpooledTemporaryFile(SpooledTemporaryFile[bytes]): + """Helper class to track which threads performed the rollover operation. + + This is not threadsafe/multi-test safe. + """ + + rollover_threads: ClassVar[set[int | None]] = set() + + def rollover(self) -> None: + ThreadTrackingSpooledTemporaryFile.rollover_threads.add(threading.current_thread().ident) + super().rollover() + + +@pytest.fixture +def mock_spooled_temporary_file() -> Generator[None]: + try: + with mock.patch("starlette.formparsers.SpooledTemporaryFile", ThreadTrackingSpooledTemporaryFile): + yield + finally: + ThreadTrackingSpooledTemporaryFile.rollover_threads.clear() + + +def test_multipart_request_large_file_rollover_in_background_thread( + mock_spooled_temporary_file: None, test_client_factory: TestClientFactory +) -> None: + """Test that Spooled file rollovers happen in background threads.""" + data = BytesIO(b" " * (MultiPartParser.spool_max_size + 1)) + + client = test_client_factory(app_monitor_thread) + response = client.post("/", files=[("test_large", data)]) + assert response.status_code == 200 + + # Parse the event thread id from the API response and ensure we have one + app_thread_ident = response.json().get("thread_ident") + assert app_thread_ident is not None + + # Ensure the app thread was not the same as the rollover one and that a rollover thread exists + assert app_thread_ident not in ThreadTrackingSpooledTemporaryFile.rollover_threads + assert len(ThreadTrackingSpooledTemporaryFile.rollover_threads) == 1 + + def test_multipart_request_with_charset_for_filename(tmpdir: Path, test_client_factory: TestClientFactory) -> None: client = test_client_factory(app) response = client.post(