]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Sync streaming interface on responses (#695)
authorTom Christie <tom@tomchristie.com>
Thu, 2 Jan 2020 12:56:11 +0000 (12:56 +0000)
committerGitHub <noreply@github.com>
Thu, 2 Jan 2020 12:56:11 +0000 (12:56 +0000)
* Sync streaming interface on responses

* Fix test case

* Test coverage for sync response APIs

* Address review comments

httpx/content_streams.py
httpx/models.py
tests/models/test_responses.py
tests/test_content_streams.py

index 73ef59ef0b0f7c8e5932f23f28ce05f53b001269..e3e343cf0826afdaa7fb907eb1292f9876130937 100644 (file)
@@ -10,7 +10,9 @@ from urllib.parse import urlencode
 from .exceptions import StreamConsumed
 from .utils import format_form_param
 
-RequestData = typing.Union[dict, str, bytes, typing.AsyncIterator[bytes]]
+RequestData = typing.Union[
+    dict, str, bytes, typing.Iterator[bytes], typing.AsyncIterator[bytes]
+]
 
 RequestFiles = typing.Dict[
     str,
@@ -47,6 +49,12 @@ class ContentStream:
         """
         return True
 
+    def __iter__(self) -> typing.Iterator[bytes]:
+        yield b""
+
+    def close(self) -> None:
+        pass
+
     async def __aiter__(self) -> typing.AsyncIterator[bytes]:
         yield b""
 
@@ -68,10 +76,46 @@ class ByteStream(ContentStream):
         content_length = str(len(self.body))
         return {"Content-Length": content_length}
 
+    def __iter__(self) -> typing.Iterator[bytes]:
+        yield self.body
+
     async def __aiter__(self) -> typing.AsyncIterator[bytes]:
         yield self.body
 
 
+class IteratorStream(ContentStream):
+    """
+    Request content encoded as plain bytes, using an byte iterator.
+    """
+
+    def __init__(
+        self, iterator: typing.Iterator[bytes], close_func: typing.Callable = None
+    ) -> None:
+        self.iterator = iterator
+        self.close_func = close_func
+        self.is_stream_consumed = False
+
+    def can_replay(self) -> bool:
+        return False
+
+    def get_headers(self) -> typing.Dict[str, str]:
+        return {"Transfer-Encoding": "chunked"}
+
+    def __iter__(self) -> typing.Iterator[bytes]:
+        if self.is_stream_consumed:
+            raise StreamConsumed()
+        self.is_stream_consumed = True
+        for part in self.iterator:
+            yield part
+
+    def __aiter__(self) -> typing.AsyncIterator[bytes]:
+        raise RuntimeError("Attempted to call a async iterator on an sync stream.")
+
+    def close(self) -> None:
+        if self.close_func is not None:
+            self.close_func()
+
+
 class AsyncIteratorStream(ContentStream):
     """
     Request content encoded as plain bytes, using an async byte iterator.
@@ -90,6 +134,9 @@ class AsyncIteratorStream(ContentStream):
     def get_headers(self) -> typing.Dict[str, str]:
         return {"Transfer-Encoding": "chunked"}
 
+    def __iter__(self) -> typing.Iterator[bytes]:
+        raise RuntimeError("Attempted to call a sync iterator on an async stream.")
+
     async def __aiter__(self) -> typing.AsyncIterator[bytes]:
         if self.is_stream_consumed:
             raise StreamConsumed()
@@ -115,6 +162,9 @@ class JSONStream(ContentStream):
         content_type = "application/json"
         return {"Content-Length": content_length, "Content-Type": content_type}
 
+    def __iter__(self) -> typing.Iterator[bytes]:
+        yield self.body
+
     async def __aiter__(self) -> typing.AsyncIterator[bytes]:
         yield self.body
 
@@ -132,6 +182,9 @@ class URLEncodedStream(ContentStream):
         content_type = "application/x-www-form-urlencoded"
         return {"Content-Length": content_length, "Content-Type": content_type}
 
+    def __iter__(self) -> typing.Iterator[bytes]:
+        yield self.body
+
     async def __aiter__(self) -> typing.AsyncIterator[bytes]:
         yield self.body
 
@@ -252,6 +305,9 @@ class MultipartStream(ContentStream):
         content_type = self.content_type
         return {"Content-Length": content_length, "Content-Type": content_type}
 
+    def __iter__(self) -> typing.Iterator[bytes]:
+        yield self.body
+
     async def __aiter__(self) -> typing.AsyncIterator[bytes]:
         yield self.body
 
@@ -280,5 +336,11 @@ def encode(
             return URLEncodedStream(data=data)
     elif isinstance(data, (str, bytes)):
         return ByteStream(body=data)
-    else:
+    elif hasattr(data, "__aiter__"):
+        data = typing.cast(typing.AsyncIterator[bytes], data)
         return AsyncIteratorStream(aiterator=data)
+    elif hasattr(data, "__iter__"):
+        data = typing.cast(typing.Iterator[bytes], data)
+        return IteratorStream(iterator=data)
+
+    raise TypeError(f"Unexpected type for 'data', {type(data)!r}")
index 91cf5a5e28d5f0992a2b2954925696ec07810ec2..dc98b302e1a8e3359cad1a4fa7f3aaf26f764f52 100644 (file)
@@ -13,7 +13,13 @@ import chardet
 import rfc3986
 
 from .config import USER_AGENT
-from .content_streams import ContentStream, RequestData, RequestFiles, encode
+from .content_streams import (
+    ByteStream,
+    ContentStream,
+    RequestData,
+    RequestFiles,
+    encode,
+)
 from .decoders import (
     ACCEPT_ENCODING,
     SUPPORTED_DECODERS,
@@ -665,15 +671,13 @@ class Response:
 
         self.history = [] if history is None else list(history)
 
-        if stream is None:
-            self.is_closed = True
-            self.is_stream_consumed = True
-            self._raw_content = content or b""
-            self._elapsed = request.timer.elapsed
-        else:
-            self.is_closed = False
-            self.is_stream_consumed = False
+        self.is_closed = False
+        self.is_stream_consumed = False
+        if stream is not None:
             self._raw_stream = stream
+        else:
+            self._raw_stream = ByteStream(body=content or b"")
+            self.read()
 
     @property
     def elapsed(self) -> datetime.timedelta:
@@ -702,13 +706,7 @@ class Response:
     @property
     def content(self) -> bytes:
         if not hasattr(self, "_content"):
-            if hasattr(self, "_raw_content"):
-                raw_content = self._raw_content  # type: ignore
-                content = self.decoder.decode(raw_content)
-                content += self.decoder.flush()
-                self._content = content
-            else:
-                raise ResponseNotRead()
+            raise ResponseNotRead()
         return self._content
 
     @property
@@ -850,14 +848,6 @@ class Response:
     def __repr__(self) -> str:
         return f"<Response [{self.status_code} {self.reason_phrase}]>"
 
-    async def aread(self) -> bytes:
-        """
-        Read and return the response content.
-        """
-        if not hasattr(self, "_content"):
-            self._content = b"".join([part async for part in self.aiter_bytes()])
-        return self._content
-
     @property
     def stream(self):  # type: ignore
         warnings.warn(  # pragma: nocover
@@ -874,6 +864,78 @@ class Response:
         )
         return self.aiter_raw  # pragma: nocover
 
+    def read(self) -> bytes:
+        """
+        Read and return the response content.
+        """
+        if not hasattr(self, "_content"):
+            self._content = b"".join([part for part in self.iter_bytes()])
+        return self._content
+
+    def iter_bytes(self) -> typing.Iterator[bytes]:
+        """
+        A byte-iterator over the decoded response content.
+        This allows us to handle gzip, deflate, and brotli encoded responses.
+        """
+        if hasattr(self, "_content"):
+            yield self._content
+        else:
+            for chunk in self.iter_raw():
+                yield self.decoder.decode(chunk)
+            yield self.decoder.flush()
+
+    def iter_text(self) -> typing.Iterator[str]:
+        """
+        A str-iterator over the decoded response content
+        that handles both gzip, deflate, etc but also detects the content's
+        string encoding.
+        """
+        decoder = TextDecoder(encoding=self.charset_encoding)
+        for chunk in self.iter_bytes():
+            yield decoder.decode(chunk)
+        yield decoder.flush()
+
+    def iter_lines(self) -> typing.Iterator[str]:
+        decoder = LineDecoder()
+        for text in self.iter_text():
+            for line in decoder.decode(text):
+                yield line
+        for line in decoder.flush():
+            yield line
+
+    def iter_raw(self) -> typing.Iterator[bytes]:
+        """
+        A byte-iterator over the raw response content.
+        """
+        if self.is_stream_consumed:
+            raise StreamConsumed()
+        if self.is_closed:
+            raise ResponseClosed()
+
+        self.is_stream_consumed = True
+        for part in self._raw_stream:
+            yield part
+        self.close()
+
+    def close(self) -> None:
+        """
+        Close the response and release the connection.
+        Automatically called if the response body is read to completion.
+        """
+        if not self.is_closed:
+            self.is_closed = True
+            self._elapsed = self.request.timer.elapsed
+            if hasattr(self, "_raw_stream"):
+                self._raw_stream.close()
+
+    async def aread(self) -> bytes:
+        """
+        Read and return the response content.
+        """
+        if not hasattr(self, "_content"):
+            self._content = b"".join([part async for part in self.aiter_bytes()])
+        return self._content
+
     async def aiter_bytes(self) -> typing.AsyncIterator[bytes]:
         """
         A byte-iterator over the decoded response content.
@@ -909,18 +971,15 @@ class Response:
         """
         A byte-iterator over the raw response content.
         """
-        if hasattr(self, "_raw_content"):
-            yield self._raw_content
-        else:
-            if self.is_stream_consumed:
-                raise StreamConsumed()
-            if self.is_closed:
-                raise ResponseClosed()
-
-            self.is_stream_consumed = True
-            async for part in self._raw_stream:
-                yield part
-            await self.aclose()
+        if self.is_stream_consumed:
+            raise StreamConsumed()
+        if self.is_closed:
+            raise ResponseClosed()
+
+        self.is_stream_consumed = True
+        async for part in self._raw_stream:
+            yield part
+        await self.aclose()
 
     async def anext(self) -> "Response":
         """
index d7d519cc7f6985ed24543154547f2816870a10b8..80cbfb08e1d6d8ec7777caca789a7d663b24bf0c 100644 (file)
@@ -5,7 +5,7 @@ from unittest import mock
 import pytest
 
 import httpx
-from httpx.content_streams import AsyncIteratorStream
+from httpx.content_streams import AsyncIteratorStream, IteratorStream
 
 REQUEST = httpx.Request("GET", "https://example.org")
 
@@ -124,8 +124,23 @@ def test_response_force_encoding():
     assert response.encoding == "iso-8859-1"
 
 
+def test_read():
+    response = httpx.Response(200, content=b"Hello, world!", request=REQUEST)
+
+    assert response.status_code == 200
+    assert response.text == "Hello, world!"
+    assert response.encoding == "ascii"
+    assert response.is_closed
+
+    content = response.read()
+
+    assert content == b"Hello, world!"
+    assert response.content == b"Hello, world!"
+    assert response.is_closed
+
+
 @pytest.mark.asyncio
-async def test_read_response():
+async def test_aread():
     response = httpx.Response(200, content=b"Hello, world!", request=REQUEST)
 
     assert response.status_code == 200
@@ -140,9 +155,20 @@ async def test_read_response():
     assert response.is_closed
 
 
+def test_iter_raw():
+    stream = IteratorStream(iterator=streaming_body())
+    response = httpx.Response(200, stream=stream, request=REQUEST)
+
+    raw = b""
+    for part in response.iter_raw():
+        raw += part
+    assert raw == b"Hello, world!"
+
+
 @pytest.mark.asyncio
-async def test_raw_interface():
-    response = httpx.Response(200, content=b"Hello, world!", request=REQUEST)
+async def test_aiter_raw():
+    stream = AsyncIteratorStream(aiterator=async_streaming_body())
+    response = httpx.Response(200, stream=stream, request=REQUEST)
 
     raw = b""
     async for part in response.aiter_raw():
@@ -150,8 +176,17 @@ async def test_raw_interface():
     assert raw == b"Hello, world!"
 
 
+def test_iter_bytes():
+    response = httpx.Response(200, content=b"Hello, world!", request=REQUEST)
+
+    content = b""
+    for part in response.iter_bytes():
+        content += part
+    assert content == b"Hello, world!"
+
+
 @pytest.mark.asyncio
-async def test_bytes_interface():
+async def test_aiter_bytes():
     response = httpx.Response(200, content=b"Hello, world!", request=REQUEST)
 
     content = b""
@@ -160,11 +195,18 @@ async def test_bytes_interface():
     assert content == b"Hello, world!"
 
 
-@pytest.mark.asyncio
-async def test_text_interface():
+def test_iter_text():
     response = httpx.Response(200, content=b"Hello, world!", request=REQUEST)
 
-    await response.aread()
+    content = ""
+    for part in response.iter_text():
+        content += part
+    assert content == "Hello, world!"
+
+
+@pytest.mark.asyncio
+async def test_aiter_text():
+    response = httpx.Response(200, content=b"Hello, world!", request=REQUEST)
 
     content = ""
     async for part in response.aiter_text():
@@ -172,11 +214,18 @@ async def test_text_interface():
     assert content == "Hello, world!"
 
 
-@pytest.mark.asyncio
-async def test_lines_interface():
+def test_iter_lines():
     response = httpx.Response(200, content=b"Hello,\nworld!", request=REQUEST)
 
-    await response.aread()
+    content = []
+    for line in response.iter_lines():
+        content.append(line)
+    assert content == ["Hello,\n", "world!"]
+
+
+@pytest.mark.asyncio
+async def test_aiter_lines():
+    response = httpx.Response(200, content=b"Hello,\nworld!", request=REQUEST)
 
     content = []
     async for line in response.aiter_lines():
@@ -184,20 +233,22 @@ async def test_lines_interface():
     assert content == ["Hello,\n", "world!"]
 
 
-@pytest.mark.asyncio
-async def test_stream_interface_after_read():
-    response = httpx.Response(200, content=b"Hello, world!", request=REQUEST)
+def test_sync_streaming_response():
+    stream = IteratorStream(iterator=streaming_body())
+    response = httpx.Response(200, stream=stream, request=REQUEST)
 
-    await response.aread()
+    assert response.status_code == 200
+    assert not response.is_closed
+
+    content = response.read()
 
-    content = b""
-    async for part in response.aiter_bytes():
-        content += part
     assert content == b"Hello, world!"
+    assert response.content == b"Hello, world!"
+    assert response.is_closed
 
 
 @pytest.mark.asyncio
-async def test_streaming_response():
+async def test_async_streaming_response():
     stream = AsyncIteratorStream(aiterator=async_streaming_body())
     response = httpx.Response(200, stream=stream, request=REQUEST)
 
@@ -211,8 +262,20 @@ async def test_streaming_response():
     assert response.is_closed
 
 
+def test_cannot_read_after_stream_consumed():
+    stream = IteratorStream(iterator=streaming_body())
+    response = httpx.Response(200, stream=stream, request=REQUEST)
+
+    content = b""
+    for part in response.iter_bytes():
+        content += part
+
+    with pytest.raises(httpx.StreamConsumed):
+        response.read()
+
+
 @pytest.mark.asyncio
-async def test_cannot_read_after_stream_consumed():
+async def test_cannot_aread_after_stream_consumed():
     stream = AsyncIteratorStream(aiterator=async_streaming_body())
     response = httpx.Response(200, stream=stream, request=REQUEST)
 
@@ -224,12 +287,38 @@ async def test_cannot_read_after_stream_consumed():
         await response.aread()
 
 
+def test_cannot_read_after_response_closed():
+    is_closed = False
+
+    def close_func():
+        nonlocal is_closed
+        is_closed = True
+
+    stream = IteratorStream(iterator=streaming_body(), close_func=close_func)
+    response = httpx.Response(200, stream=stream, request=REQUEST)
+
+    response.close()
+    assert is_closed
+
+    with pytest.raises(httpx.ResponseClosed):
+        response.read()
+
+
 @pytest.mark.asyncio
-async def test_cannot_read_after_response_closed():
-    stream = AsyncIteratorStream(aiterator=async_streaming_body())
+async def test_cannot_aread_after_response_closed():
+    is_closed = False
+
+    async def close_func():
+        nonlocal is_closed
+        is_closed = True
+
+    stream = AsyncIteratorStream(
+        aiterator=async_streaming_body(), close_func=close_func
+    )
     response = httpx.Response(200, stream=stream, request=REQUEST)
 
     await response.aclose()
+    assert is_closed
 
     with pytest.raises(httpx.ResponseClosed):
         await response.aread()
index 764e00ed7a8a598efa483f9d3bf33d256de5d9b6..93baf8e7013254b95348df82755ac6dc540831a1 100644 (file)
@@ -2,97 +2,139 @@ import io
 
 import pytest
 
-from httpx.content_streams import encode
+from httpx.content_streams import ContentStream, encode
 from httpx.exceptions import StreamConsumed
 
 
+@pytest.mark.asyncio
+async def test_base_content():
+    stream = ContentStream()
+    sync_content = b"".join([part for part in stream])
+    async_content = b"".join([part async for part in stream])
+
+    assert stream.can_replay()
+    assert stream.get_headers() == {}
+    assert sync_content == b""
+    assert async_content == b""
+
+
 @pytest.mark.asyncio
 async def test_empty_content():
     stream = encode()
-    content = b"".join([part async for part in stream])
+    sync_content = b"".join([part for part in stream])
+    async_content = b"".join([part async for part in stream])
 
     assert stream.can_replay()
     assert stream.get_headers() == {}
-    assert content == b""
+    assert sync_content == b""
+    assert async_content == b""
 
 
 @pytest.mark.asyncio
 async def test_bytes_content():
     stream = encode(data=b"Hello, world!")
-    content = b"".join([part async for part in stream])
+    sync_content = b"".join([part for part in stream])
+    async_content = b"".join([part async for part in stream])
 
     assert stream.can_replay()
     assert stream.get_headers() == {"Content-Length": "13"}
-    assert content == b"Hello, world!"
+    assert sync_content == b"Hello, world!"
+    assert async_content == b"Hello, world!"
 
 
 @pytest.mark.asyncio
-async def test_aiterator_content():
-    async def hello_world():
+async def test_iterator_content():
+    def hello_world():
         yield b"Hello, "
         yield b"world!"
 
     stream = encode(data=hello_world())
-    content = b"".join([part async for part in stream])
+    content = b"".join([part for part in stream])
 
     assert not stream.can_replay()
     assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
     assert content == b"Hello, world!"
 
+    with pytest.raises(RuntimeError):
+        [part async for part in stream]
+
+    with pytest.raises(StreamConsumed):
+        [part for part in stream]
+
 
 @pytest.mark.asyncio
-async def test_aiterator_is_stream_consumed():
+async def test_aiterator_content():
     async def hello_world():
         yield b"Hello, "
         yield b"world!"
 
     stream = encode(data=hello_world())
-    b"".join([part async for part in stream])
+    content = b"".join([part async for part in stream])
+
+    assert not stream.can_replay()
+    assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
+    assert content == b"Hello, world!"
 
-    assert stream.is_stream_consumed
+    with pytest.raises(RuntimeError):
+        [part for part in stream]
 
-    with pytest.raises(StreamConsumed) as _:
-        b"".join([part async for part in stream])
+    with pytest.raises(StreamConsumed):
+        [part async for part in stream]
 
 
 @pytest.mark.asyncio
 async def test_json_content():
     stream = encode(json={"Hello": "world!"})
-    content = b"".join([part async for part in stream])
+    sync_content = b"".join([part for part in stream])
+    async_content = b"".join([part async for part in stream])
 
     assert stream.can_replay()
     assert stream.get_headers() == {
         "Content-Length": "19",
         "Content-Type": "application/json",
     }
-    assert content == b'{"Hello": "world!"}'
+    assert sync_content == b'{"Hello": "world!"}'
+    assert async_content == b'{"Hello": "world!"}'
 
 
 @pytest.mark.asyncio
 async def test_urlencoded_content():
     stream = encode(data={"Hello": "world!"})
-    content = b"".join([part async for part in stream])
+    sync_content = b"".join([part for part in stream])
+    async_content = b"".join([part async for part in stream])
 
     assert stream.can_replay()
     assert stream.get_headers() == {
         "Content-Length": "14",
         "Content-Type": "application/x-www-form-urlencoded",
     }
-    assert content == b"Hello=world%21"
+    assert sync_content == b"Hello=world%21"
+    assert async_content == b"Hello=world%21"
 
 
 @pytest.mark.asyncio
 async def test_multipart_files_content():
     files = {"file": io.BytesIO(b"<file content>")}
     stream = encode(files=files, boundary=b"+++")
-    content = b"".join([part async for part in stream])
+    sync_content = b"".join([part for part in stream])
+    async_content = b"".join([part async for part in stream])
 
     assert stream.can_replay()
     assert stream.get_headers() == {
         "Content-Length": "138",
         "Content-Type": "multipart/form-data; boundary=+++",
     }
-    assert content == b"".join(
+    assert sync_content == b"".join(
+        [
+            b"--+++\r\n",
+            b'Content-Disposition: form-data; name="file"; filename="upload"\r\n',
+            b"Content-Type: application/octet-stream\r\n",
+            b"\r\n",
+            b"<file content>\r\n",
+            b"--+++--\r\n",
+        ]
+    )
+    assert async_content == b"".join(
         [
             b"--+++\r\n",
             b'Content-Disposition: form-data; name="file"; filename="upload"\r\n',
@@ -109,14 +151,29 @@ async def test_multipart_data_and_files_content():
     data = {"message": "Hello, world!"}
     files = {"file": io.BytesIO(b"<file content>")}
     stream = encode(data=data, files=files, boundary=b"+++")
-    content = b"".join([part async for part in stream])
+    sync_content = b"".join([part for part in stream])
+    async_content = b"".join([part async for part in stream])
 
     assert stream.can_replay()
     assert stream.get_headers() == {
         "Content-Length": "210",
         "Content-Type": "multipart/form-data; boundary=+++",
     }
-    assert content == b"".join(
+    assert sync_content == b"".join(
+        [
+            b"--+++\r\n",
+            b'Content-Disposition: form-data; name="message"\r\n',
+            b"\r\n",
+            b"Hello, world!\r\n",
+            b"--+++\r\n",
+            b'Content-Disposition: form-data; name="file"; filename="upload"\r\n',
+            b"Content-Type: application/octet-stream\r\n",
+            b"\r\n",
+            b"<file content>\r\n",
+            b"--+++--\r\n",
+        ]
+    )
+    assert async_content == b"".join(
         [
             b"--+++\r\n",
             b'Content-Disposition: form-data; name="message"\r\n',
@@ -130,3 +187,8 @@ async def test_multipart_data_and_files_content():
             b"--+++--\r\n",
         ]
     )
+
+
+def test_invalid_argument():
+    with pytest.raises(TypeError):
+        encode(123)