From: Tom Christie Date: Mon, 22 Nov 2021 13:15:39 +0000 (+0000) Subject: Read upload files using `read(CHUNK_SIZE)` rather than `iter()`. (#1948) X-Git-Tag: 0.21.2~14 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6f5865f8605b2cf15872388c3d9bc3623bbf265e;p=thirdparty%2Fhttpx.git Read upload files using `read(CHUNK_SIZE)` rather than `iter()`. (#1948) * Cap upload chunk sizes * Use '.read' for file streaming, where possible * Direct iteration should not apply chunk sizes --- diff --git a/httpx/_content.py b/httpx/_content.py index d7e8aa09..9032c1c0 100644 --- a/httpx/_content.py +++ b/httpx/_content.py @@ -38,6 +38,8 @@ class ByteStream(AsyncByteStream, SyncByteStream): class IteratorByteStream(SyncByteStream): + CHUNK_SIZE = 65_536 + def __init__(self, stream: Iterable[bytes]): self._stream = stream self._is_stream_consumed = False @@ -48,11 +50,21 @@ class IteratorByteStream(SyncByteStream): raise StreamConsumed() self._is_stream_consumed = True - for part in self._stream: - yield part + if hasattr(self._stream, "read"): + # File-like interfaces should use 'read' directly. + chunk = self._stream.read(self.CHUNK_SIZE) # type: ignore + while chunk: + yield chunk + chunk = self._stream.read(self.CHUNK_SIZE) # type: ignore + else: + # Otherwise iterate. + for part in self._stream: + yield part class AsyncIteratorByteStream(AsyncByteStream): + CHUNK_SIZE = 65_536 + def __init__(self, stream: AsyncIterable[bytes]): self._stream = stream self._is_stream_consumed = False @@ -63,8 +75,16 @@ class AsyncIteratorByteStream(AsyncByteStream): raise StreamConsumed() self._is_stream_consumed = True - async for part in self._stream: - yield part + if hasattr(self._stream, "aread"): + # File-like interfaces should use 'aread' directly. + chunk = await self._stream.aread(self.CHUNK_SIZE) # type: ignore + while chunk: + yield chunk + chunk = await self._stream.aread(self.CHUNK_SIZE) # type: ignore + else: + # Otherwise iterate. + async for part in self._stream: + yield part class UnattachedStream(AsyncByteStream, SyncByteStream): diff --git a/httpx/_multipart.py b/httpx/_multipart.py index 4dfb838a..f05f02b4 100644 --- a/httpx/_multipart.py +++ b/httpx/_multipart.py @@ -71,6 +71,8 @@ class FileField: A single file field item, within a multipart form field. """ + CHUNK_SIZE = 64 * 1024 + def __init__(self, name: str, value: FileTypes) -> None: self.name = name @@ -142,8 +144,10 @@ class FileField: self.file.seek(0) self._consumed = True - for chunk in self.file: + chunk = self.file.read(self.CHUNK_SIZE) + while chunk: yield to_bytes(chunk) + chunk = self.file.read(self.CHUNK_SIZE) def render(self) -> typing.Iterator[bytes]: yield self.render_headers() diff --git a/tests/test_content.py b/tests/test_content.py index 2a273a13..afd91039 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -60,6 +60,31 @@ async def test_bytesio_content(): assert content == b"Hello, world!" +@pytest.mark.asyncio +async def test_async_bytesio_content(): + class AsyncBytesIO: + def __init__(self, content): + self._idx = 0 + self._content = content + + async def aread(self, chunk_size: int): + chunk = self._content[self._idx : self._idx + chunk_size] + self._idx = self._idx + chunk_size + return chunk + + async def __aiter__(self): + yield self._content # pragma: nocover + + headers, stream = encode_request(content=AsyncBytesIO(b"Hello, world!")) + assert not isinstance(stream, typing.Iterable) + assert isinstance(stream, typing.AsyncIterable) + + content = b"".join([part async for part in stream]) + + assert headers == {"Transfer-Encoding": "chunked"} + assert content == b"Hello, world!" + + @pytest.mark.asyncio async def test_iterator_content(): def hello_world():