]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Read upload files using `read(CHUNK_SIZE)` rather than `iter()`. (#1948)
authorTom Christie <tom@tomchristie.com>
Mon, 22 Nov 2021 13:15:39 +0000 (13:15 +0000)
committerGitHub <noreply@github.com>
Mon, 22 Nov 2021 13:15:39 +0000 (13:15 +0000)
* Cap upload chunk sizes

* Use '.read' for file streaming, where possible

* Direct iteration should not apply chunk sizes

httpx/_content.py
httpx/_multipart.py
tests/test_content.py

index d7e8aa097493af731b5bdea56e2bcfdedc0ad388..9032c1c056c89f2aa9a353b5839fb4b320484348 100644 (file)
@@ -38,6 +38,8 @@ class ByteStream(AsyncByteStream, SyncByteStream):
 
 
 class IteratorByteStream(SyncByteStream):
+    CHUNK_SIZE = 65_536
+
     def __init__(self, stream: Iterable[bytes]):
         self._stream = stream
         self._is_stream_consumed = False
@@ -48,11 +50,21 @@ class IteratorByteStream(SyncByteStream):
             raise StreamConsumed()
 
         self._is_stream_consumed = True
-        for part in self._stream:
-            yield part
+        if hasattr(self._stream, "read"):
+            # File-like interfaces should use 'read' directly.
+            chunk = self._stream.read(self.CHUNK_SIZE)  # type: ignore
+            while chunk:
+                yield chunk
+                chunk = self._stream.read(self.CHUNK_SIZE)  # type: ignore
+        else:
+            # Otherwise iterate.
+            for part in self._stream:
+                yield part
 
 
 class AsyncIteratorByteStream(AsyncByteStream):
+    CHUNK_SIZE = 65_536
+
     def __init__(self, stream: AsyncIterable[bytes]):
         self._stream = stream
         self._is_stream_consumed = False
@@ -63,8 +75,16 @@ class AsyncIteratorByteStream(AsyncByteStream):
             raise StreamConsumed()
 
         self._is_stream_consumed = True
-        async for part in self._stream:
-            yield part
+        if hasattr(self._stream, "aread"):
+            # File-like interfaces should use 'aread' directly.
+            chunk = await self._stream.aread(self.CHUNK_SIZE)  # type: ignore
+            while chunk:
+                yield chunk
+                chunk = await self._stream.aread(self.CHUNK_SIZE)  # type: ignore
+        else:
+            # Otherwise iterate.
+            async for part in self._stream:
+                yield part
 
 
 class UnattachedStream(AsyncByteStream, SyncByteStream):
index 4dfb838a68109ab296f32d85ce812b6273675f68..f05f02b4c5ab24c0643f710d583a741b0cac44d9 100644 (file)
@@ -71,6 +71,8 @@ class FileField:
     A single file field item, within a multipart form field.
     """
 
+    CHUNK_SIZE = 64 * 1024
+
     def __init__(self, name: str, value: FileTypes) -> None:
         self.name = name
 
@@ -142,8 +144,10 @@ class FileField:
             self.file.seek(0)
         self._consumed = True
 
-        for chunk in self.file:
+        chunk = self.file.read(self.CHUNK_SIZE)
+        while chunk:
             yield to_bytes(chunk)
+            chunk = self.file.read(self.CHUNK_SIZE)
 
     def render(self) -> typing.Iterator[bytes]:
         yield self.render_headers()
index 2a273a133f79ea6eb7c1b063438fd37b08cb3a36..afd910399ededd14babcf4084201ed324fd95c96 100644 (file)
@@ -60,6 +60,31 @@ async def test_bytesio_content():
     assert content == b"Hello, world!"
 
 
+@pytest.mark.asyncio
+async def test_async_bytesio_content():
+    class AsyncBytesIO:
+        def __init__(self, content):
+            self._idx = 0
+            self._content = content
+
+        async def aread(self, chunk_size: int):
+            chunk = self._content[self._idx : self._idx + chunk_size]
+            self._idx = self._idx + chunk_size
+            return chunk
+
+        async def __aiter__(self):
+            yield self._content  # pragma: nocover
+
+    headers, stream = encode_request(content=AsyncBytesIO(b"Hello, world!"))
+    assert not isinstance(stream, typing.Iterable)
+    assert isinstance(stream, typing.AsyncIterable)
+
+    content = b"".join([part async for part in stream])
+
+    assert headers == {"Transfer-Encoding": "chunked"}
+    assert content == b"Hello, world!"
+
+
 @pytest.mark.asyncio
 async def test_iterator_content():
     def hello_world():