]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Merge commit from fork
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Tue, 15 Oct 2024 06:40:51 +0000 (08:40 +0200)
committerGitHub <noreply@github.com>
Tue, 15 Oct 2024 06:40:51 +0000 (08:40 +0200)
starlette/formparsers.py
tests/test_formparsers.py

index 48b0f2aad66a2bf4f0018f22c59821eb3005d90a..0d7e5616275b0a20274dfc7c5a17232c7466c8bf 100644 (file)
@@ -31,12 +31,12 @@ class FormMessage(Enum):
 class MultipartPart:
     content_disposition: bytes | None = None
     field_name: str = ""
-    data: bytes = b""
+    data: bytearray = field(default_factory=bytearray)
     file: UploadFile | None = None
     item_headers: list[tuple[bytes, bytes]] = field(default_factory=list)
 
 
-def _user_safe_decode(src: bytes, codec: str) -> str:
+def _user_safe_decode(src: bytes | bytearray, codec: str) -> str:
     try:
         return src.decode(codec)
     except (UnicodeDecodeError, LookupError):
@@ -117,7 +117,8 @@ class FormParser:
 
 
 class MultiPartParser:
-    max_file_size = 1024 * 1024
+    max_file_size = 1024 * 1024  # 1MB
+    max_part_size = 1024 * 1024  # 1MB
 
     def __init__(
         self,
@@ -149,7 +150,9 @@ class MultiPartParser:
     def on_part_data(self, data: bytes, start: int, end: int) -> None:
         message_bytes = data[start:end]
         if self._current_part.file is None:
-            self._current_part.data += message_bytes
+            if len(self._current_part.data) + len(message_bytes) > self.max_part_size:
+                raise MultiPartException(f"Part exceeded maximum size of {int(self.max_part_size / 1024)}KB.")
+            self._current_part.data.extend(message_bytes)
         else:
             self._file_parts_to_write.append((self._current_part, message_bytes))
 
index 61c1bede19f8af000453e43705ea11a05df2e7bf..64efdbdbcb3441d15bef19574ac78dc52f2451bd 100644 (file)
@@ -640,9 +640,7 @@ def test_max_files_is_customizable_low_raises(
         assert res.text == "Too many files. Maximum number of files is 1."
 
 
-def test_max_fields_is_customizable_high(
-    test_client_factory: TestClientFactory,
-) -> None:
+def test_max_fields_is_customizable_high(test_client_factory: TestClientFactory) -> None:
     client = test_client_factory(make_app_max_parts(max_fields=2000, max_files=2000))
     fields = []
     for i in range(2000):
@@ -664,3 +662,40 @@ def test_max_fields_is_customizable_high(
         "content": "",
         "content_type": None,
     }
+
+
+@pytest.mark.parametrize(
+    "app,expectation",
+    [
+        (app, pytest.raises(MultiPartException)),
+        (Starlette(routes=[Mount("/", app=app)]), does_not_raise()),
+    ],
+)
+def test_max_part_size_exceeds_limit(
+    app: ASGIApp,
+    expectation: typing.ContextManager[Exception],
+    test_client_factory: TestClientFactory,
+) -> None:
+    client = test_client_factory(app)
+    boundary = "------------------------4K1ON9fZkj9uCUmqLHRbbR"
+
+    multipart_data = (
+        f"--{boundary}\r\n"
+        f'Content-Disposition: form-data; name="small"\r\n\r\n'
+        "small content\r\n"
+        f"--{boundary}\r\n"
+        f'Content-Disposition: form-data; name="large"\r\n\r\n'
+        + ("x" * 1024 * 1024 + "x")  # 1MB + 1 byte of data
+        + "\r\n"
+        f"--{boundary}--\r\n"
+    ).encode("utf-8")
+
+    headers = {
+        "Content-Type": f"multipart/form-data; boundary={boundary}",
+        "Transfer-Encoding": "chunked",
+    }
+
+    with expectation:
+        response = client.post("/", data=multipart_data, headers=headers)  # type: ignore
+        assert response.status_code == 400
+        assert response.text == "Part exceeded maximum size of 1024KB."