]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Support for `chunk_size` (#1277)
authorTom Christie <tom@tomchristie.com>
Wed, 25 Nov 2020 15:28:06 +0000 (15:28 +0000)
committerGitHub <noreply@github.com>
Wed, 25 Nov 2020 15:28:06 +0000 (15:28 +0000)
* Support iter_raw(chunk_size=...) and aiter_raw(chunk_size=...)

* Unit tests for ByteChunker

* Support iter_bytes(chunk_size=...)

* Add TextChunker

* Support iter_text(chunk_size=...)

* Fix merge with master

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
httpx/_decoders.py
httpx/_models.py
tests/models/test_responses.py
tests/test_decoders.py

index bac5f9c86f11ddd39eb17b1f8d00b05e4bf23583..8ef0157e6f5d14afdb0fc1a1c472bb405d41146c 100644 (file)
@@ -4,6 +4,7 @@ Handlers for Content-Encoding.
 See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding
 """
 import codecs
+import io
 import typing
 import zlib
 
@@ -155,6 +156,84 @@ class MultiDecoder(ContentDecoder):
         return data
 
 
+class ByteChunker:
+    """
+    Handles returning byte content in fixed-size chunks.
+    """
+
+    def __init__(self, chunk_size: int = None) -> None:
+        self._buffer = io.BytesIO()
+        self._chunk_size = chunk_size
+
+    def decode(self, content: bytes) -> typing.List[bytes]:
+        if self._chunk_size is None:
+            return [content]
+
+        self._buffer.write(content)
+        if self._buffer.tell() >= self._chunk_size:
+            value = self._buffer.getvalue()
+            chunks = [
+                value[i : i + self._chunk_size]
+                for i in range(0, len(value), self._chunk_size)
+            ]
+            if len(chunks[-1]) == self._chunk_size:
+                self._buffer.seek(0)
+                self._buffer.truncate()
+                return chunks
+            else:
+                self._buffer.seek(0)
+                self._buffer.write(chunks[-1])
+                self._buffer.truncate()
+                return chunks[:-1]
+        else:
+            return []
+
+    def flush(self) -> typing.List[bytes]:
+        value = self._buffer.getvalue()
+        self._buffer.seek(0)
+        self._buffer.truncate()
+        return [value] if value else []
+
+
+class TextChunker:
+    """
+    Handles returning text content in fixed-size chunks.
+    """
+
+    def __init__(self, chunk_size: int = None) -> None:
+        self._buffer = io.StringIO()
+        self._chunk_size = chunk_size
+
+    def decode(self, content: str) -> typing.List[str]:
+        if self._chunk_size is None:
+            return [content]
+
+        self._buffer.write(content)
+        if self._buffer.tell() >= self._chunk_size:
+            value = self._buffer.getvalue()
+            chunks = [
+                value[i : i + self._chunk_size]
+                for i in range(0, len(value), self._chunk_size)
+            ]
+            if len(chunks[-1]) == self._chunk_size:
+                self._buffer.seek(0)
+                self._buffer.truncate()
+                return chunks
+            else:
+                self._buffer.seek(0)
+                self._buffer.write(chunks[-1])
+                self._buffer.truncate()
+                return chunks[:-1]
+        else:
+            return []
+
+    def flush(self) -> typing.List[str]:
+        value = self._buffer.getvalue()
+        self._buffer.seek(0)
+        self._buffer.truncate()
+        return [value] if value else []
+
+
 class TextDecoder:
     """
     Handles incrementally decoding bytes into text
index 5e68a4c84283be92ddf8cebe972c7b42b0240cc2..3310ff51976df093b144008cc6a0eb7e7c94d1cc 100644 (file)
@@ -15,10 +15,12 @@ import rfc3986.exceptions
 from ._content import PlainByteStream, encode_request, encode_response
 from ._decoders import (
     SUPPORTED_DECODERS,
+    ByteChunker,
     ContentDecoder,
     IdentityDecoder,
     LineDecoder,
     MultiDecoder,
+    TextChunker,
     TextDecoder,
 )
 from ._exceptions import (
@@ -1162,31 +1164,47 @@ class Response:
             self._content = b"".join(self.iter_bytes())
         return self._content
 
-    def iter_bytes(self) -> typing.Iterator[bytes]:
+    def iter_bytes(self, chunk_size: int = None) -> typing.Iterator[bytes]:
         """
         A byte-iterator over the decoded response content.
         This allows us to handle gzip, deflate, and brotli encoded responses.
         """
         if hasattr(self, "_content"):
-            yield self._content
+            chunk_size = len(self._content) if chunk_size is None else chunk_size
+            for i in range(0, len(self._content), chunk_size):
+                yield self._content[i : i + chunk_size]
         else:
             decoder = self._get_content_decoder()
+            chunker = ByteChunker(chunk_size=chunk_size)
             with self._wrap_decoder_errors():
-                for chunk in self.iter_raw():
-                    yield decoder.decode(chunk)
-                yield decoder.flush()
-
-    def iter_text(self) -> typing.Iterator[str]:
+                for raw_bytes in self.iter_raw():
+                    decoded = decoder.decode(raw_bytes)
+                    for chunk in chunker.decode(decoded):
+                        yield chunk
+                decoded = decoder.flush()
+                for chunk in chunker.decode(decoded):
+                    yield chunk
+                for chunk in chunker.flush():
+                    yield chunk
+
+    def iter_text(self, chunk_size: int = None) -> typing.Iterator[str]:
         """
         A str-iterator over the decoded response content
         that handles both gzip, deflate, etc but also detects the content's
         string encoding.
         """
         decoder = TextDecoder(encoding=self.encoding)
+        chunker = TextChunker(chunk_size=chunk_size)
         with self._wrap_decoder_errors():
-            for chunk in self.iter_bytes():
-                yield decoder.decode(chunk)
-            yield decoder.flush()
+            for byte_content in self.iter_bytes():
+                text_content = decoder.decode(byte_content)
+                for chunk in chunker.decode(text_content):
+                    yield chunk
+            text_content = decoder.flush()
+            for chunk in chunker.decode(text_content):
+                yield chunk
+            for chunk in chunker.flush():
+                yield chunk
 
     def iter_lines(self) -> typing.Iterator[str]:
         decoder = LineDecoder()
@@ -1197,7 +1215,7 @@ class Response:
             for line in decoder.flush():
                 yield line
 
-    def iter_raw(self) -> typing.Iterator[bytes]:
+    def iter_raw(self, chunk_size: int = None) -> typing.Iterator[bytes]:
         """
         A byte-iterator over the raw response content.
         """
@@ -1210,10 +1228,17 @@ class Response:
 
         self.is_stream_consumed = True
         self._num_bytes_downloaded = 0
+        chunker = ByteChunker(chunk_size=chunk_size)
+
         with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
-            for part in self.stream:
-                self._num_bytes_downloaded += len(part)
-                yield part
+            for raw_stream_bytes in self.stream:
+                self._num_bytes_downloaded += len(raw_stream_bytes)
+                for chunk in chunker.decode(raw_stream_bytes):
+                    yield chunk
+
+        for chunk in chunker.flush():
+            yield chunk
+
         self.close()
 
     def close(self) -> None:
@@ -1234,31 +1259,47 @@ class Response:
             self._content = b"".join([part async for part in self.aiter_bytes()])
         return self._content
 
-    async def aiter_bytes(self) -> typing.AsyncIterator[bytes]:
+    async def aiter_bytes(self, chunk_size: int = None) -> typing.AsyncIterator[bytes]:
         """
         A byte-iterator over the decoded response content.
         This allows us to handle gzip, deflate, and brotli encoded responses.
         """
         if hasattr(self, "_content"):
-            yield self._content
+            chunk_size = len(self._content) if chunk_size is None else chunk_size
+            for i in range(0, len(self._content), chunk_size):
+                yield self._content[i : i + chunk_size]
         else:
             decoder = self._get_content_decoder()
+            chunker = ByteChunker(chunk_size=chunk_size)
             with self._wrap_decoder_errors():
-                async for chunk in self.aiter_raw():
-                    yield decoder.decode(chunk)
-                yield decoder.flush()
-
-    async def aiter_text(self) -> typing.AsyncIterator[str]:
+                async for raw_bytes in self.aiter_raw():
+                    decoded = decoder.decode(raw_bytes)
+                    for chunk in chunker.decode(decoded):
+                        yield chunk
+                decoded = decoder.flush()
+                for chunk in chunker.decode(decoded):
+                    yield chunk
+                for chunk in chunker.flush():
+                    yield chunk
+
+    async def aiter_text(self, chunk_size: int = None) -> typing.AsyncIterator[str]:
         """
         A str-iterator over the decoded response content
         that handles both gzip, deflate, etc but also detects the content's
         string encoding.
         """
         decoder = TextDecoder(encoding=self.encoding)
+        chunker = TextChunker(chunk_size=chunk_size)
         with self._wrap_decoder_errors():
-            async for chunk in self.aiter_bytes():
-                yield decoder.decode(chunk)
-            yield decoder.flush()
+            async for byte_content in self.aiter_bytes():
+                text_content = decoder.decode(byte_content)
+                for chunk in chunker.decode(text_content):
+                    yield chunk
+            text_content = decoder.flush()
+            for chunk in chunker.decode(text_content):
+                yield chunk
+            for chunk in chunker.flush():
+                yield chunk
 
     async def aiter_lines(self) -> typing.AsyncIterator[str]:
         decoder = LineDecoder()
@@ -1269,7 +1310,7 @@ class Response:
             for line in decoder.flush():
                 yield line
 
-    async def aiter_raw(self) -> typing.AsyncIterator[bytes]:
+    async def aiter_raw(self, chunk_size: int = None) -> typing.AsyncIterator[bytes]:
         """
         A byte-iterator over the raw response content.
         """
@@ -1282,10 +1323,17 @@ class Response:
 
         self.is_stream_consumed = True
         self._num_bytes_downloaded = 0
+        chunker = ByteChunker(chunk_size=chunk_size)
+
         with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
-            async for part in self.stream:
-                self._num_bytes_downloaded += len(part)
-                yield part
+            async for raw_stream_bytes in self.stream:
+                self._num_bytes_downloaded += len(raw_stream_bytes)
+                for chunk in chunker.decode(raw_stream_bytes):
+                    yield chunk
+
+        for chunk in chunker.flush():
+            yield chunk
+
         await self.aclose()
 
     async def aclose(self) -> None:
index ef26beda09685b4aea2e636994a031c885a0ab29..cb46719c17d43b2cb7b5c30826de723ad32872f8 100644 (file)
@@ -343,6 +343,23 @@ def test_iter_raw():
     assert raw == b"Hello, world!"
 
 
+def test_iter_raw_with_chunksize():
+    response = httpx.Response(200, content=streaming_body())
+
+    parts = [part for part in response.iter_raw(chunk_size=5)]
+    assert parts == [b"Hello", b", wor", b"ld!"]
+
+    response = httpx.Response(200, content=streaming_body())
+
+    parts = [part for part in response.iter_raw(chunk_size=13)]
+    assert parts == [b"Hello, world!"]
+
+    response = httpx.Response(200, content=streaming_body())
+
+    parts = [part for part in response.iter_raw(chunk_size=20)]
+    assert parts == [b"Hello, world!"]
+
+
 def test_iter_raw_on_iterable():
     response = httpx.Response(
         200,
@@ -384,6 +401,24 @@ async def test_aiter_raw():
     assert raw == b"Hello, world!"
 
 
+@pytest.mark.asyncio
+async def test_aiter_raw_with_chunksize():
+    response = httpx.Response(200, content=async_streaming_body())
+
+    parts = [part async for part in response.aiter_raw(chunk_size=5)]
+    assert parts == [b"Hello", b", wor", b"ld!"]
+
+    response = httpx.Response(200, content=async_streaming_body())
+
+    parts = [part async for part in response.aiter_raw(chunk_size=13)]
+    assert parts == [b"Hello, world!"]
+
+    response = httpx.Response(200, content=async_streaming_body())
+
+    parts = [part async for part in response.aiter_raw(chunk_size=20)]
+    assert parts == [b"Hello, world!"]
+
+
 @pytest.mark.asyncio
 async def test_aiter_raw_on_sync():
     response = httpx.Response(
@@ -406,10 +441,7 @@ async def test_aiter_raw_increments_updates_counter():
 
 
 def test_iter_bytes():
-    response = httpx.Response(
-        200,
-        content=b"Hello, world!",
-    )
+    response = httpx.Response(200, content=b"Hello, world!")
 
     content = b""
     for part in response.iter_bytes():
@@ -417,6 +449,20 @@ def test_iter_bytes():
     assert content == b"Hello, world!"
 
 
+def test_iter_bytes_with_chunk_size():
+    response = httpx.Response(200, content=streaming_body())
+    parts = [part for part in response.iter_bytes(chunk_size=5)]
+    assert parts == [b"Hello", b", wor", b"ld!"]
+
+    response = httpx.Response(200, content=streaming_body())
+    parts = [part for part in response.iter_bytes(chunk_size=13)]
+    assert parts == [b"Hello, world!"]
+
+    response = httpx.Response(200, content=streaming_body())
+    parts = [part for part in response.iter_bytes(chunk_size=20)]
+    assert parts == [b"Hello, world!"]
+
+
 @pytest.mark.asyncio
 async def test_aiter_bytes():
     response = httpx.Response(
@@ -430,6 +476,21 @@ async def test_aiter_bytes():
     assert content == b"Hello, world!"
 
 
+@pytest.mark.asyncio
+async def test_aiter_bytes_with_chunk_size():
+    response = httpx.Response(200, content=async_streaming_body())
+    parts = [part async for part in response.aiter_bytes(chunk_size=5)]
+    assert parts == [b"Hello", b", wor", b"ld!"]
+
+    response = httpx.Response(200, content=async_streaming_body())
+    parts = [part async for part in response.aiter_bytes(chunk_size=13)]
+    assert parts == [b"Hello, world!"]
+
+    response = httpx.Response(200, content=async_streaming_body())
+    parts = [part async for part in response.aiter_bytes(chunk_size=20)]
+    assert parts == [b"Hello, world!"]
+
+
 def test_iter_text():
     response = httpx.Response(
         200,
@@ -442,6 +503,20 @@ def test_iter_text():
     assert content == "Hello, world!"
 
 
+def test_iter_text_with_chunk_size():
+    response = httpx.Response(200, content=b"Hello, world!")
+    parts = [part for part in response.iter_text(chunk_size=5)]
+    assert parts == ["Hello", ", wor", "ld!"]
+
+    response = httpx.Response(200, content=b"Hello, world!")
+    parts = [part for part in response.iter_text(chunk_size=13)]
+    assert parts == ["Hello, world!"]
+
+    response = httpx.Response(200, content=b"Hello, world!")
+    parts = [part for part in response.iter_text(chunk_size=20)]
+    assert parts == ["Hello, world!"]
+
+
 @pytest.mark.asyncio
 async def test_aiter_text():
     response = httpx.Response(
@@ -455,6 +530,21 @@ async def test_aiter_text():
     assert content == "Hello, world!"
 
 
+@pytest.mark.asyncio
+async def test_aiter_text_with_chunk_size():
+    response = httpx.Response(200, content=b"Hello, world!")
+    parts = [part async for part in response.aiter_text(chunk_size=5)]
+    assert parts == ["Hello", ", wor", "ld!"]
+
+    response = httpx.Response(200, content=b"Hello, world!")
+    parts = [part async for part in response.aiter_text(chunk_size=13)]
+    assert parts == ["Hello, world!"]
+
+    response = httpx.Response(200, content=b"Hello, world!")
+    parts = [part async for part in response.aiter_text(chunk_size=20)]
+    assert parts == ["Hello, world!"]
+
+
 def test_iter_lines():
     response = httpx.Response(
         200,
index 351fce0520485b13c8b4c202729aff4bf631f2e8..f8c432cc8981dc5df79b398afe3a465255488a67 100644 (file)
@@ -6,10 +6,12 @@ import pytest
 import httpx
 from httpx._decoders import (
     BrotliDecoder,
+    ByteChunker,
     DeflateDecoder,
     GZipDecoder,
     IdentityDecoder,
     LineDecoder,
+    TextChunker,
     TextDecoder,
 )
 
@@ -300,6 +302,50 @@ def test_line_decoder_crnl():
     assert decoder.flush() == []
 
 
+def test_byte_chunker():
+    decoder = ByteChunker()
+    assert decoder.decode(b"1234567") == [b"1234567"]
+    assert decoder.decode(b"89") == [b"89"]
+    assert decoder.flush() == []
+
+    decoder = ByteChunker(chunk_size=3)
+    assert decoder.decode(b"1234567") == [b"123", b"456"]
+    assert decoder.decode(b"89") == [b"789"]
+    assert decoder.flush() == []
+
+    decoder = ByteChunker(chunk_size=3)
+    assert decoder.decode(b"123456") == [b"123", b"456"]
+    assert decoder.decode(b"789") == [b"789"]
+    assert decoder.flush() == []
+
+    decoder = ByteChunker(chunk_size=3)
+    assert decoder.decode(b"123456") == [b"123", b"456"]
+    assert decoder.decode(b"78") == []
+    assert decoder.flush() == [b"78"]
+
+
+def test_text_chunker():
+    decoder = TextChunker()
+    assert decoder.decode("1234567") == ["1234567"]
+    assert decoder.decode("89") == ["89"]
+    assert decoder.flush() == []
+
+    decoder = TextChunker(chunk_size=3)
+    assert decoder.decode("1234567") == ["123", "456"]
+    assert decoder.decode("89") == ["789"]
+    assert decoder.flush() == []
+
+    decoder = TextChunker(chunk_size=3)
+    assert decoder.decode("123456") == ["123", "456"]
+    assert decoder.decode("789") == ["789"]
+    assert decoder.flush() == []
+
+    decoder = TextChunker(chunk_size=3)
+    assert decoder.decode("123456") == ["123", "456"]
+    assert decoder.decode("78") == []
+    assert decoder.flush() == ["78"]
+
+
 def test_invalid_content_encoding_header():
     headers = [(b"Content-Encoding", b"invalid-header")]
     body = b"test 123"