]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Document and type annotate UploadFile as a bytes-only interface. Not bytes or text...
authorAdrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>
Mon, 10 Jan 2022 17:41:43 +0000 (09:41 -0800)
committerGitHub <noreply@github.com>
Mon, 10 Jan 2022 17:41:43 +0000 (09:41 -0800)
docs/requests.md
starlette/datastructures.py
tests/test_datastructures.py

index f4d867ab1a06506ed095830dec51eb5849bae701..87294663837f8b03987641ba6f2826594a7969a8 100644 (file)
@@ -126,8 +126,8 @@ multidict, containing both file uploads and text input. File upload items are re
 
 `UploadFile` has the following `async` methods. They all call the corresponding file methods underneath (using the internal `SpooledTemporaryFile`).
 
-* `async write(data)`: Writes `data` (`str` or `bytes`) to the file.
-* `async read(size)`: Reads `size` (`int`) bytes/characters of the file.
+* `async write(data)`: Writes `data` (`bytes`) to the file.
+* `async read(size)`: Reads `size` (`int`) bytes of the file.
 * `async seek(offset)`: Goes to the byte position `offset` (`int`) in the file.
     * E.g., `await myfile.seek(0)` would go to the start of the file.
 * `async close()`: Closes the file.
index 64f964a915912450c7a166e2c2f934e5c0e2d0da..52904d51e000af8f979521f9ab7989f70a518e8c 100644 (file)
@@ -415,12 +415,13 @@ class UploadFile:
     """
 
     spool_max_size = 1024 * 1024
+    file: typing.BinaryIO
     headers: "Headers"
 
     def __init__(
         self,
         filename: str,
-        file: typing.IO = None,
+        file: typing.Optional[typing.BinaryIO] = None,
         content_type: str = "",
         *,
         headers: "typing.Optional[Headers]" = None,
@@ -428,8 +429,9 @@ class UploadFile:
         self.filename = filename
         self.content_type = content_type
         if file is None:
-            file = tempfile.SpooledTemporaryFile(max_size=self.spool_max_size)
-        self.file = file
+            self.file = tempfile.SpooledTemporaryFile(max_size=self.spool_max_size)  # type: ignore  # noqa: E501
+        else:
+            self.file = file
         self.headers = headers or Headers()
 
     @property
@@ -437,13 +439,13 @@ class UploadFile:
         rolled_to_disk = getattr(self.file, "_rolled", True)
         return not rolled_to_disk
 
-    async def write(self, data: typing.Union[bytes, str]) -> None:
+    async def write(self, data: bytes) -> None:
         if self._in_memory:
             self.file.write(data)  # type: ignore
         else:
             await run_in_threadpool(self.file.write, data)
 
-    async def read(self, size: int = -1) -> typing.Union[bytes, str]:
+    async def read(self, size: int = -1) -> bytes:
         if self._in_memory:
             return self.file.read(size)
         return await run_in_threadpool(self.file.read, size)
index bb71ba870cc48c2abf081618846e63c885f4528e..5d44bdc13ef6a9cd328246a40451b6b1b7d999b9 100644 (file)
@@ -227,6 +227,18 @@ async def test_upload_file():
     await big_file.close()
 
 
+@pytest.mark.anyio
+async def test_upload_file_file_input():
+    """Test passing file/stream into the UploadFile constructor"""
+    stream = io.BytesIO(b"data")
+    file = UploadFile(filename="file", file=stream)
+    assert await file.read() == b"data"
+    await file.write(b" and more data!")
+    assert await file.read() == b""
+    await file.seek(0)
+    assert await file.read() == b"data and more data!"
+
+
 def test_formdata():
     upload = io.BytesIO(b"test")
     form = FormData([("a", "123"), ("a", "456"), ("b", upload)])