]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Implement Response.stream_text() (#183)
authorSeth Michael Larson <sethmichaellarson@gmail.com>
Thu, 15 Aug 2019 02:56:17 +0000 (21:56 -0500)
committerGitHub <noreply@github.com>
Thu, 15 Aug 2019 02:56:17 +0000 (21:56 -0500)
httpx/decoders.py
httpx/models.py
tests/models/test_responses.py
tests/test_decoders.py

index a2d43fa9acfdb0e5ae11443cec1c182bc3c86c2d..bfa965404c60c5d775a918cedeab9002967b1e16 100644 (file)
@@ -3,9 +3,12 @@ Handlers for Content-Encoding.
 
 See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding
 """
+import codecs
 import typing
 import zlib
 
+import chardet
+
 from .exceptions import DecodingError
 
 try:
@@ -138,6 +141,70 @@ class MultiDecoder(Decoder):
         return data
 
 
+class TextDecoder:
+    """
+    Handles incrementally decoding bytes into text
+    """
+
+    def __init__(self, encoding: typing.Optional[str] = None):
+        self.decoder: typing.Optional[codecs.IncrementalDecoder] = (
+            None if encoding is None else codecs.getincrementaldecoder(encoding)()
+        )
+        self.detector = chardet.universaldetector.UniversalDetector()
+
+        # This buffer is only needed if 'decoder' is 'None'
+        # we want to trigger errors if data is getting added to
+        # our internal buffer for some silly reason while
+        # a decoder is discovered.
+        self.buffer: typing.Optional[bytearray] = None if self.decoder else bytearray()
+
+    def decode(self, data: bytes) -> str:
+        try:
+            if self.decoder is not None:
+                text = self.decoder.decode(data)
+            else:
+                assert self.buffer is not None
+                text = ""
+                self.detector.feed(data)
+                self.buffer += data
+
+                # Should be more than enough data to process, we don't
+                # want to buffer too long as chardet will wait until
+                # detector.close() is used to give back common
+                # encodings like 'utf-8'.
+                if len(self.buffer) >= 4096:
+                    self.decoder = codecs.getincrementaldecoder(
+                        self._detector_result()
+                    )()
+                    text = self.decoder.decode(bytes(self.buffer), False)
+                    self.buffer = None
+
+            return text
+        except UnicodeDecodeError:  # pragma: nocover
+            raise DecodingError() from None
+
+    def flush(self) -> str:
+        try:
+            if self.decoder is None:
+                # Empty string case as chardet is guaranteed to not have a guess.
+                assert self.buffer is not None
+                if len(self.buffer) == 0:
+                    return ""
+                return bytes(self.buffer).decode(self._detector_result())
+
+            return self.decoder.decode(b"", True)
+        except UnicodeDecodeError:  # pragma: nocover
+            raise DecodingError() from None
+
+    def _detector_result(self) -> str:
+        self.detector.close()
+        result = self.detector.result["encoding"]
+        if not result:  # pragma: nocover
+            raise DecodingError("Unable to determine encoding of content")
+
+        return result
+
+
 SUPPORTED_DECODERS = {
     "identity": IdentityDecoder,
     "gzip": GZipDecoder,
index 210c8e21059963f35b5db5e3cf2dc01c4d9d16ea..2ffc290e94ef6c23b372ae64a72fe06db9580247 100644 (file)
@@ -17,6 +17,7 @@ from .decoders import (
     Decoder,
     IdentityDecoder,
     MultiDecoder,
+    TextDecoder,
 )
 from .exceptions import (
     CookieConflict,
@@ -890,6 +891,17 @@ class AsyncResponse(BaseResponse):
                 yield self.decoder.decode(chunk)
             yield self.decoder.flush()
 
+    async def stream_text(self) -> typing.AsyncIterator[str]:
+        """
+        A str-iterator over the decoded response content
+        that handles both gzip, deflate, etc but also detects the content's
+        string encoding.
+        """
+        decoder = TextDecoder(encoding=self.charset_encoding)
+        async for chunk in self.stream():
+            yield decoder.decode(chunk)
+        yield decoder.flush()
+
     async def raw(self) -> typing.AsyncIterator[bytes]:
         """
         A byte-iterator over the raw response content.
@@ -969,6 +981,17 @@ class Response(BaseResponse):
                 yield self.decoder.decode(chunk)
             yield self.decoder.flush()
 
+    def stream_text(self) -> typing.Iterator[str]:
+        """
+        A str-iterator over the decoded response content
+        that handles both gzip, deflate, etc but also detects the content's
+        string encoding.
+        """
+        decoder = TextDecoder(encoding=self.charset_encoding)
+        for chunk in self.stream():
+            yield decoder.decode(chunk)
+        yield decoder.flush()
+
     def raw(self) -> typing.Iterator[bytes]:
         """
         A byte-iterator over the raw response content.
index ddec72e6d133194c20d764aee10aec47c9902bff..a66daf5a28ad11a0a56b8bda28e7b789b9be7cb3 100644 (file)
@@ -288,3 +288,18 @@ def test_json_without_specified_encoding_decode_error():
         response = httpx.Response(200, content=content, headers=headers)
         with pytest.raises(json.JSONDecodeError):
             response.json()
+
+
+@pytest.mark.asyncio
+async def test_stream_text():
+    async def iterator():
+        yield b"Hello, world!"
+
+    response = httpx.AsyncResponse(200, content=iterator().__aiter__())
+
+    await response.read()
+
+    content = ""
+    async for part in response.stream_text():
+        content += part
+    assert content == "Hello, world!"
index 83ad1ec126e473869249f642974232e86b43b412..036a4168703591a885b56b6d3ccb9eedbec44b57 100644 (file)
@@ -4,6 +4,7 @@ import brotli
 import pytest
 
 import httpx
+from httpx.decoders import TextDecoder
 
 
 def test_deflate():
@@ -88,6 +89,57 @@ def test_decoding_errors(header_value):
         response.content
 
 
+@pytest.mark.parametrize(
+    ["data", "encoding"],
+    [
+        ((b"Hello,", b" world!"), "ascii"),
+        ((b"\xe3\x83", b"\x88\xe3\x83\xa9", b"\xe3", b"\x83\x99\xe3\x83\xab"), "utf-8"),
+        ((b"\x83g\x83\x89\x83x\x83\x8b",) * 64, "shift-jis"),
+        ((b"\x83g\x83\x89\x83x\x83\x8b",) * 600, "shift-jis"),
+        (
+            (b"\xcb\xee\xf0\xe5\xec \xe8\xef\xf1\xf3\xec \xe4\xee\xeb\xee\xf0",) * 64,
+            "MacCyrillic",
+        ),
+        (
+            (b"\xa5\xa6\xa5\xa7\xa5\xd6\xa4\xce\xb9\xf1\xba\xdd\xb2\xbd",) * 512,
+            "euc-jp",
+        ),
+    ],
+)
+def test_text_decoder(data, encoding):
+    def iterator():
+        nonlocal data
+        for chunk in data:
+            yield chunk
+
+    response = httpx.Response(200, content=iterator())
+    assert "".join(response.stream_text()) == (b"".join(data)).decode(encoding)
+
+
+def test_text_decoder_known_encoding():
+    def iterator():
+        yield b"\x83g"
+        yield b"\x83"
+        yield b"\x89\x83x\x83\x8b"
+
+    response = httpx.Response(
+        200,
+        headers=[(b"Content-Type", b"text/html; charset=shift-jis")],
+        content=iterator(),
+    )
+
+    assert "".join(response.stream_text()) == "トラベル"
+
+
+def test_text_decoder_empty_cases():
+    decoder = TextDecoder()
+    assert decoder.flush() == ""
+
+    decoder = TextDecoder()
+    assert decoder.decode(b"") == ""
+    assert decoder.flush() == ""
+
+
 def test_invalid_content_encoding_header():
     headers = [(b"Content-Encoding", b"invalid-header")]
     body = b"test 123"