]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Support `Response(content=<bytes iterator>)` (#1265)
authorTom Christie <tom@tomchristie.com>
Fri, 11 Sep 2020 09:28:18 +0000 (10:28 +0100)
committerGitHub <noreply@github.com>
Fri, 11 Sep 2020 09:28:18 +0000 (10:28 +0100)
* Support Response(content=<bytes iterator>)

* Update test for merged master

httpx/_content_streams.py
httpx/_models.py
httpx/_types.py
tests/models/test_responses.py
tests/test_content_streams.py
tests/test_decoders.py

index 402fa959c8f07282d585573140dc15e76c2c2c0e..3cd2196ab482609ce5b3020efbdcf82014dfdd59 100644 (file)
@@ -8,7 +8,7 @@ from urllib.parse import urlencode
 import httpcore
 
 from ._exceptions import StreamConsumed
-from ._types import FileContent, FileTypes, RequestData, RequestFiles
+from ._types import FileContent, FileTypes, RequestData, RequestFiles, ResponseContent
 from ._utils import (
     format_form_param,
     guess_content_type,
@@ -72,11 +72,8 @@ class IteratorStream(ContentStream):
     Request content encoded as plain bytes, using an byte iterator.
     """
 
-    def __init__(
-        self, iterator: typing.Iterator[bytes], close_func: typing.Callable = None
-    ) -> None:
+    def __init__(self, iterator: typing.Iterator[bytes]) -> None:
         self.iterator = iterator
-        self.close_func = close_func
         self.is_stream_consumed = False
 
     def can_replay(self) -> bool:
@@ -95,21 +92,14 @@ class IteratorStream(ContentStream):
     def __aiter__(self) -> typing.AsyncIterator[bytes]:
         raise RuntimeError("Attempted to call a async iterator on an sync stream.")
 
-    def close(self) -> None:
-        if self.close_func is not None:
-            self.close_func()
-
 
 class AsyncIteratorStream(ContentStream):
     """
     Request content encoded as plain bytes, using an async byte iterator.
     """
 
-    def __init__(
-        self, aiterator: typing.AsyncIterator[bytes], close_func: typing.Callable = None
-    ) -> None:
+    def __init__(self, aiterator: typing.AsyncIterator[bytes]) -> None:
         self.aiterator = aiterator
-        self.close_func = close_func
         self.is_stream_consumed = False
 
     def can_replay(self) -> bool:
@@ -128,10 +118,6 @@ class AsyncIteratorStream(ContentStream):
         async for part in self.aiterator:
             yield part
 
-    async def aclose(self) -> None:
-        if self.close_func is not None:
-            await self.close_func()
-
 
 class JSONStream(ContentStream):
     """
@@ -402,3 +388,18 @@ def encode(
         return IteratorStream(iterator=data)
 
     raise TypeError(f"Unexpected type for 'data', {type(data)!r}")
+
+
+def encode_response(content: ResponseContent = None) -> ContentStream:
+    if content is None:
+        return ByteStream(b"")
+    elif isinstance(content, bytes):
+        return ByteStream(body=content)
+    elif hasattr(content, "__aiter__"):
+        content = typing.cast(typing.AsyncIterator[bytes], content)
+        return AsyncIteratorStream(aiterator=content)
+    elif hasattr(content, "__iter__"):
+        content = typing.cast(typing.Iterator[bytes], content)
+        return IteratorStream(iterator=content)
+
+    raise TypeError(f"Unexpected type for 'content', {type(content)!r}")
index 6d90f90b4efb044bdf5556885378cb4da7e4158c..526ee2cebf1e12745c29d44c69c2538008cfbde7 100644 (file)
@@ -14,7 +14,7 @@ import chardet
 import rfc3986
 import rfc3986.exceptions
 
-from ._content_streams import ByteStream, ContentStream, encode
+from ._content_streams import ByteStream, ContentStream, encode, encode_response
 from ._decoders import (
     SUPPORTED_DECODERS,
     ContentDecoder,
@@ -44,6 +44,7 @@ from ._types import (
     QueryParamTypes,
     RequestData,
     RequestFiles,
+    ResponseContent,
     URLTypes,
 )
 from ._utils import (
@@ -674,7 +675,7 @@ class Response:
         http_version: str = None,
         headers: HeaderTypes = None,
         stream: ContentStream = None,
-        content: bytes = None,
+        content: ResponseContent = None,
         history: typing.List["Response"] = None,
         elapsed_func: typing.Callable = None,
     ):
@@ -694,8 +695,10 @@ class Response:
         if stream is not None:
             self._raw_stream = stream
         else:
-            self._raw_stream = ByteStream(body=content or b"")
-            self.read()
+            self._raw_stream = encode_response(content)
+            if content is None or isinstance(content, bytes):
+                # Load the response body, except for streaming content.
+                self.read()
 
         self._num_bytes_downloaded = 0
 
index 3a90ee42e7a3a7d5bd962a1ad7e583310d51700f..8989b2826c7172228b6d6ea01aa7df79e223d137 100644 (file)
@@ -63,6 +63,8 @@ AuthTypes = Union[
     None,
 ]
 
+ResponseContent = Union[bytes, Iterator[bytes], AsyncIterator[bytes]]
+
 RequestData = Union[dict, str, bytes, Iterator[bytes], AsyncIterator[bytes]]
 
 FileContent = Union[IO[str], IO[bytes], str, bytes]
index 30d600086ac86419ba61d5f02647d2bd4e99e05b..b52e4846f3675814f46455000d1d78f8e3d5901f 100644 (file)
@@ -5,7 +5,6 @@ import brotli
 import pytest
 
 import httpx
-from httpx._content_streams import AsyncIteratorStream, IteratorStream
 
 
 def streaming_body():
@@ -215,10 +214,9 @@ async def test_aread():
 
 
 def test_iter_raw():
-    stream = IteratorStream(iterator=streaming_body())
     response = httpx.Response(
         200,
-        stream=stream,
+        content=streaming_body(),
     )
 
     raw = b""
@@ -228,12 +226,7 @@ def test_iter_raw():
 
 
 def test_iter_raw_increments_updates_counter():
-    stream = IteratorStream(iterator=streaming_body())
-
-    response = httpx.Response(
-        200,
-        stream=stream,
-    )
+    response = httpx.Response(200, content=streaming_body())
 
     num_downloaded = response.num_bytes_downloaded
     for part in response.iter_raw():
@@ -243,11 +236,7 @@ def test_iter_raw_increments_updates_counter():
 
 @pytest.mark.asyncio
 async def test_aiter_raw():
-    stream = AsyncIteratorStream(aiterator=async_streaming_body())
-    response = httpx.Response(
-        200,
-        stream=stream,
-    )
+    response = httpx.Response(200, content=async_streaming_body())
 
     raw = b""
     async for part in response.aiter_raw():
@@ -257,12 +246,7 @@ async def test_aiter_raw():
 
 @pytest.mark.asyncio
 async def test_aiter_raw_increments_updates_counter():
-    stream = AsyncIteratorStream(aiterator=async_streaming_body())
-
-    response = httpx.Response(
-        200,
-        stream=stream,
-    )
+    response = httpx.Response(200, content=async_streaming_body())
 
     num_downloaded = response.num_bytes_downloaded
     async for part in response.aiter_raw():
@@ -346,10 +330,9 @@ async def test_aiter_lines():
 
 
 def test_sync_streaming_response():
-    stream = IteratorStream(iterator=streaming_body())
     response = httpx.Response(
         200,
-        stream=stream,
+        content=streaming_body(),
     )
 
     assert response.status_code == 200
@@ -364,10 +347,9 @@ def test_sync_streaming_response():
 
 @pytest.mark.asyncio
 async def test_async_streaming_response():
-    stream = AsyncIteratorStream(aiterator=async_streaming_body())
     response = httpx.Response(
         200,
-        stream=stream,
+        content=async_streaming_body(),
     )
 
     assert response.status_code == 200
@@ -381,10 +363,9 @@ async def test_async_streaming_response():
 
 
 def test_cannot_read_after_stream_consumed():
-    stream = IteratorStream(iterator=streaming_body())
     response = httpx.Response(
         200,
-        stream=stream,
+        content=streaming_body(),
     )
 
     content = b""
@@ -397,10 +378,9 @@ def test_cannot_read_after_stream_consumed():
 
 @pytest.mark.asyncio
 async def test_cannot_aread_after_stream_consumed():
-    stream = AsyncIteratorStream(aiterator=async_streaming_body())
     response = httpx.Response(
         200,
-        stream=stream,
+        content=async_streaming_body(),
     )
 
     content = b""
@@ -412,54 +392,33 @@ async def test_cannot_aread_after_stream_consumed():
 
 
 def test_cannot_read_after_response_closed():
-    is_closed = False
-
-    def close_func():
-        nonlocal is_closed
-        is_closed = True
-
-    stream = IteratorStream(iterator=streaming_body(), close_func=close_func)
     response = httpx.Response(
         200,
-        stream=stream,
+        content=streaming_body(),
     )
 
     response.close()
-    assert is_closed
-
     with pytest.raises(httpx.ResponseClosed):
         response.read()
 
 
 @pytest.mark.asyncio
 async def test_cannot_aread_after_response_closed():
-    is_closed = False
-
-    async def close_func():
-        nonlocal is_closed
-        is_closed = True
-
-    stream = AsyncIteratorStream(
-        aiterator=async_streaming_body(), close_func=close_func
-    )
     response = httpx.Response(
         200,
-        stream=stream,
+        content=async_streaming_body(),
     )
 
     await response.aclose()
-    assert is_closed
-
     with pytest.raises(httpx.ResponseClosed):
         await response.aread()
 
 
 @pytest.mark.asyncio
 async def test_elapsed_not_available_until_closed():
-    stream = AsyncIteratorStream(aiterator=async_streaming_body())
     response = httpx.Response(
         200,
-        stream=stream,
+        content=async_streaming_body(),
     )
 
     with pytest.raises(RuntimeError):
index 140aa8d2af4f4a36b786f03c0b7a12bf777f8705..2d1de1f1c05fe9a0d4fd608de231bd085fd42a02 100644 (file)
@@ -3,7 +3,7 @@ import io
 import pytest
 
 from httpx import StreamConsumed
-from httpx._content_streams import ContentStream, encode
+from httpx._content_streams import ContentStream, encode, encode_response
 
 
 @pytest.mark.asyncio
@@ -251,3 +251,72 @@ async def test_multipart_multiple_files_single_input_content():
             b"--+++--\r\n",
         ]
     )
+
+
+@pytest.mark.asyncio
+async def test_response_empty_content():
+    stream = encode_response()
+    sync_content = b"".join([part for part in stream])
+    async_content = b"".join([part async for part in stream])
+
+    assert stream.can_replay()
+    assert stream.get_headers() == {}
+    assert sync_content == b""
+    assert async_content == b""
+
+
+@pytest.mark.asyncio
+async def test_response_bytes_content():
+    stream = encode_response(content=b"Hello, world!")
+    sync_content = b"".join([part for part in stream])
+    async_content = b"".join([part async for part in stream])
+
+    assert stream.can_replay()
+    assert stream.get_headers() == {"Content-Length": "13"}
+    assert sync_content == b"Hello, world!"
+    assert async_content == b"Hello, world!"
+
+
+@pytest.mark.asyncio
+async def test_response_iterator_content():
+    def hello_world():
+        yield b"Hello, "
+        yield b"world!"
+
+    stream = encode_response(content=hello_world())
+    content = b"".join([part for part in stream])
+
+    assert not stream.can_replay()
+    assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
+    assert content == b"Hello, world!"
+
+    with pytest.raises(RuntimeError):
+        [part async for part in stream]
+
+    with pytest.raises(StreamConsumed):
+        [part for part in stream]
+
+
+@pytest.mark.asyncio
+async def test_response_aiterator_content():
+    async def hello_world():
+        yield b"Hello, "
+        yield b"world!"
+
+    stream = encode_response(content=hello_world())
+    content = b"".join([part async for part in stream])
+
+    assert not stream.can_replay()
+    assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
+    assert content == b"Hello, world!"
+
+    with pytest.raises(RuntimeError):
+        [part for part in stream]
+
+    with pytest.raises(StreamConsumed):
+        [part async for part in stream]
+
+
+def test_response_invalid_argument():
+    with pytest.raises(TypeError):
+        encode_response(123)  # type: ignore
index dbbaac5450877797acde1ef3ee597445dd34e58d..7dfca9ef50e5151b40b0be0a2e15c61b138db625 100644 (file)
@@ -4,7 +4,6 @@ import brotli
 import pytest
 
 import httpx
-from httpx._content_streams import AsyncIteratorStream
 from httpx._decoders import (
     BrotliDecoder,
     DeflateDecoder,
@@ -130,11 +129,10 @@ async def test_streaming():
         yield compressor.flush()
 
     headers = [(b"Content-Encoding", b"gzip")]
-    stream = AsyncIteratorStream(aiterator=compress(body))
     response = httpx.Response(
         200,
         headers=headers,
-        stream=stream,
+        content=compress(body),
     )
     assert not hasattr(response, "body")
     assert await response.aread() == body
@@ -199,19 +197,17 @@ async def test_text_decoder(data, encoding):
             yield chunk
 
     # Accessing `.text` on a read response.
-    stream = AsyncIteratorStream(aiterator=iterator())
     response = httpx.Response(
         200,
-        stream=stream,
+        content=iterator(),
     )
     await response.aread()
     assert response.text == (b"".join(data)).decode(encoding)
 
     # Streaming `.aiter_text` iteratively.
-    stream = AsyncIteratorStream(aiterator=iterator())
     response = httpx.Response(
         200,
-        stream=stream,
+        content=iterator(),
     )
     text = "".join([part async for part in response.aiter_text()])
     assert text == (b"".join(data)).decode(encoding)
@@ -224,11 +220,10 @@ async def test_text_decoder_known_encoding():
         yield b"\x83"
         yield b"\x89\x83x\x83\x8b"
 
-    stream = AsyncIteratorStream(aiterator=iterator())
     response = httpx.Response(
         200,
         headers=[(b"Content-Type", b"text/html; charset=shift-jis")],
-        stream=stream,
+        content=iterator(),
     )
 
     await response.aread()