From: Tom Christie Date: Thu, 2 Jan 2020 12:56:11 +0000 (+0000) Subject: Sync streaming interface on responses (#695) X-Git-Tag: 0.11.0~22 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=11e7604d1ad0a59907245fa898b74da36cbaf77c;p=thirdparty%2Fhttpx.git Sync streaming interface on responses (#695) * Sync streaming interface on responses * Fix test case * Test coverage for sync response APIs * Address review comments --- diff --git a/httpx/content_streams.py b/httpx/content_streams.py index 73ef59ef..e3e343cf 100644 --- a/httpx/content_streams.py +++ b/httpx/content_streams.py @@ -10,7 +10,9 @@ from urllib.parse import urlencode from .exceptions import StreamConsumed from .utils import format_form_param -RequestData = typing.Union[dict, str, bytes, typing.AsyncIterator[bytes]] +RequestData = typing.Union[ + dict, str, bytes, typing.Iterator[bytes], typing.AsyncIterator[bytes] +] RequestFiles = typing.Dict[ str, @@ -47,6 +49,12 @@ class ContentStream: """ return True + def __iter__(self) -> typing.Iterator[bytes]: + yield b"" + + def close(self) -> None: + pass + async def __aiter__(self) -> typing.AsyncIterator[bytes]: yield b"" @@ -68,10 +76,46 @@ class ByteStream(ContentStream): content_length = str(len(self.body)) return {"Content-Length": content_length} + def __iter__(self) -> typing.Iterator[bytes]: + yield self.body + async def __aiter__(self) -> typing.AsyncIterator[bytes]: yield self.body +class IteratorStream(ContentStream): + """ + Request content encoded as plain bytes, using an byte iterator. + """ + + def __init__( + self, iterator: typing.Iterator[bytes], close_func: typing.Callable = None + ) -> None: + self.iterator = iterator + self.close_func = close_func + self.is_stream_consumed = False + + def can_replay(self) -> bool: + return False + + def get_headers(self) -> typing.Dict[str, str]: + return {"Transfer-Encoding": "chunked"} + + def __iter__(self) -> typing.Iterator[bytes]: + if self.is_stream_consumed: + raise StreamConsumed() + self.is_stream_consumed = True + for part in self.iterator: + yield part + + def __aiter__(self) -> typing.AsyncIterator[bytes]: + raise RuntimeError("Attempted to call a async iterator on an sync stream.") + + def close(self) -> None: + if self.close_func is not None: + self.close_func() + + class AsyncIteratorStream(ContentStream): """ Request content encoded as plain bytes, using an async byte iterator. @@ -90,6 +134,9 @@ class AsyncIteratorStream(ContentStream): def get_headers(self) -> typing.Dict[str, str]: return {"Transfer-Encoding": "chunked"} + def __iter__(self) -> typing.Iterator[bytes]: + raise RuntimeError("Attempted to call a sync iterator on an async stream.") + async def __aiter__(self) -> typing.AsyncIterator[bytes]: if self.is_stream_consumed: raise StreamConsumed() @@ -115,6 +162,9 @@ class JSONStream(ContentStream): content_type = "application/json" return {"Content-Length": content_length, "Content-Type": content_type} + def __iter__(self) -> typing.Iterator[bytes]: + yield self.body + async def __aiter__(self) -> typing.AsyncIterator[bytes]: yield self.body @@ -132,6 +182,9 @@ class URLEncodedStream(ContentStream): content_type = "application/x-www-form-urlencoded" return {"Content-Length": content_length, "Content-Type": content_type} + def __iter__(self) -> typing.Iterator[bytes]: + yield self.body + async def __aiter__(self) -> typing.AsyncIterator[bytes]: yield self.body @@ -252,6 +305,9 @@ class MultipartStream(ContentStream): content_type = self.content_type return {"Content-Length": content_length, "Content-Type": content_type} + def __iter__(self) -> typing.Iterator[bytes]: + yield self.body + async def __aiter__(self) -> typing.AsyncIterator[bytes]: yield self.body @@ -280,5 +336,11 @@ def encode( return URLEncodedStream(data=data) elif isinstance(data, (str, bytes)): return ByteStream(body=data) - else: + elif hasattr(data, "__aiter__"): + data = typing.cast(typing.AsyncIterator[bytes], data) return AsyncIteratorStream(aiterator=data) + elif hasattr(data, "__iter__"): + data = typing.cast(typing.Iterator[bytes], data) + return IteratorStream(iterator=data) + + raise TypeError(f"Unexpected type for 'data', {type(data)!r}") diff --git a/httpx/models.py b/httpx/models.py index 91cf5a5e..dc98b302 100644 --- a/httpx/models.py +++ b/httpx/models.py @@ -13,7 +13,13 @@ import chardet import rfc3986 from .config import USER_AGENT -from .content_streams import ContentStream, RequestData, RequestFiles, encode +from .content_streams import ( + ByteStream, + ContentStream, + RequestData, + RequestFiles, + encode, +) from .decoders import ( ACCEPT_ENCODING, SUPPORTED_DECODERS, @@ -665,15 +671,13 @@ class Response: self.history = [] if history is None else list(history) - if stream is None: - self.is_closed = True - self.is_stream_consumed = True - self._raw_content = content or b"" - self._elapsed = request.timer.elapsed - else: - self.is_closed = False - self.is_stream_consumed = False + self.is_closed = False + self.is_stream_consumed = False + if stream is not None: self._raw_stream = stream + else: + self._raw_stream = ByteStream(body=content or b"") + self.read() @property def elapsed(self) -> datetime.timedelta: @@ -702,13 +706,7 @@ class Response: @property def content(self) -> bytes: if not hasattr(self, "_content"): - if hasattr(self, "_raw_content"): - raw_content = self._raw_content # type: ignore - content = self.decoder.decode(raw_content) - content += self.decoder.flush() - self._content = content - else: - raise ResponseNotRead() + raise ResponseNotRead() return self._content @property @@ -850,14 +848,6 @@ class Response: def __repr__(self) -> str: return f"" - async def aread(self) -> bytes: - """ - Read and return the response content. - """ - if not hasattr(self, "_content"): - self._content = b"".join([part async for part in self.aiter_bytes()]) - return self._content - @property def stream(self): # type: ignore warnings.warn( # pragma: nocover @@ -874,6 +864,78 @@ class Response: ) return self.aiter_raw # pragma: nocover + def read(self) -> bytes: + """ + Read and return the response content. + """ + if not hasattr(self, "_content"): + self._content = b"".join([part for part in self.iter_bytes()]) + return self._content + + def iter_bytes(self) -> typing.Iterator[bytes]: + """ + A byte-iterator over the decoded response content. + This allows us to handle gzip, deflate, and brotli encoded responses. + """ + if hasattr(self, "_content"): + yield self._content + else: + for chunk in self.iter_raw(): + yield self.decoder.decode(chunk) + yield self.decoder.flush() + + def iter_text(self) -> typing.Iterator[str]: + """ + A str-iterator over the decoded response content + that handles both gzip, deflate, etc but also detects the content's + string encoding. + """ + decoder = TextDecoder(encoding=self.charset_encoding) + for chunk in self.iter_bytes(): + yield decoder.decode(chunk) + yield decoder.flush() + + def iter_lines(self) -> typing.Iterator[str]: + decoder = LineDecoder() + for text in self.iter_text(): + for line in decoder.decode(text): + yield line + for line in decoder.flush(): + yield line + + def iter_raw(self) -> typing.Iterator[bytes]: + """ + A byte-iterator over the raw response content. + """ + if self.is_stream_consumed: + raise StreamConsumed() + if self.is_closed: + raise ResponseClosed() + + self.is_stream_consumed = True + for part in self._raw_stream: + yield part + self.close() + + def close(self) -> None: + """ + Close the response and release the connection. + Automatically called if the response body is read to completion. + """ + if not self.is_closed: + self.is_closed = True + self._elapsed = self.request.timer.elapsed + if hasattr(self, "_raw_stream"): + self._raw_stream.close() + + async def aread(self) -> bytes: + """ + Read and return the response content. + """ + if not hasattr(self, "_content"): + self._content = b"".join([part async for part in self.aiter_bytes()]) + return self._content + async def aiter_bytes(self) -> typing.AsyncIterator[bytes]: """ A byte-iterator over the decoded response content. @@ -909,18 +971,15 @@ class Response: """ A byte-iterator over the raw response content. """ - if hasattr(self, "_raw_content"): - yield self._raw_content - else: - if self.is_stream_consumed: - raise StreamConsumed() - if self.is_closed: - raise ResponseClosed() - - self.is_stream_consumed = True - async for part in self._raw_stream: - yield part - await self.aclose() + if self.is_stream_consumed: + raise StreamConsumed() + if self.is_closed: + raise ResponseClosed() + + self.is_stream_consumed = True + async for part in self._raw_stream: + yield part + await self.aclose() async def anext(self) -> "Response": """ diff --git a/tests/models/test_responses.py b/tests/models/test_responses.py index d7d519cc..80cbfb08 100644 --- a/tests/models/test_responses.py +++ b/tests/models/test_responses.py @@ -5,7 +5,7 @@ from unittest import mock import pytest import httpx -from httpx.content_streams import AsyncIteratorStream +from httpx.content_streams import AsyncIteratorStream, IteratorStream REQUEST = httpx.Request("GET", "https://example.org") @@ -124,8 +124,23 @@ def test_response_force_encoding(): assert response.encoding == "iso-8859-1" +def test_read(): + response = httpx.Response(200, content=b"Hello, world!", request=REQUEST) + + assert response.status_code == 200 + assert response.text == "Hello, world!" + assert response.encoding == "ascii" + assert response.is_closed + + content = response.read() + + assert content == b"Hello, world!" + assert response.content == b"Hello, world!" + assert response.is_closed + + @pytest.mark.asyncio -async def test_read_response(): +async def test_aread(): response = httpx.Response(200, content=b"Hello, world!", request=REQUEST) assert response.status_code == 200 @@ -140,9 +155,20 @@ async def test_read_response(): assert response.is_closed +def test_iter_raw(): + stream = IteratorStream(iterator=streaming_body()) + response = httpx.Response(200, stream=stream, request=REQUEST) + + raw = b"" + for part in response.iter_raw(): + raw += part + assert raw == b"Hello, world!" + + @pytest.mark.asyncio -async def test_raw_interface(): - response = httpx.Response(200, content=b"Hello, world!", request=REQUEST) +async def test_aiter_raw(): + stream = AsyncIteratorStream(aiterator=async_streaming_body()) + response = httpx.Response(200, stream=stream, request=REQUEST) raw = b"" async for part in response.aiter_raw(): @@ -150,8 +176,17 @@ async def test_raw_interface(): assert raw == b"Hello, world!" +def test_iter_bytes(): + response = httpx.Response(200, content=b"Hello, world!", request=REQUEST) + + content = b"" + for part in response.iter_bytes(): + content += part + assert content == b"Hello, world!" + + @pytest.mark.asyncio -async def test_bytes_interface(): +async def test_aiter_bytes(): response = httpx.Response(200, content=b"Hello, world!", request=REQUEST) content = b"" @@ -160,11 +195,18 @@ async def test_bytes_interface(): assert content == b"Hello, world!" -@pytest.mark.asyncio -async def test_text_interface(): +def test_iter_text(): response = httpx.Response(200, content=b"Hello, world!", request=REQUEST) - await response.aread() + content = "" + for part in response.iter_text(): + content += part + assert content == "Hello, world!" + + +@pytest.mark.asyncio +async def test_aiter_text(): + response = httpx.Response(200, content=b"Hello, world!", request=REQUEST) content = "" async for part in response.aiter_text(): @@ -172,11 +214,18 @@ async def test_text_interface(): assert content == "Hello, world!" -@pytest.mark.asyncio -async def test_lines_interface(): +def test_iter_lines(): response = httpx.Response(200, content=b"Hello,\nworld!", request=REQUEST) - await response.aread() + content = [] + for line in response.iter_lines(): + content.append(line) + assert content == ["Hello,\n", "world!"] + + +@pytest.mark.asyncio +async def test_aiter_lines(): + response = httpx.Response(200, content=b"Hello,\nworld!", request=REQUEST) content = [] async for line in response.aiter_lines(): @@ -184,20 +233,22 @@ async def test_lines_interface(): assert content == ["Hello,\n", "world!"] -@pytest.mark.asyncio -async def test_stream_interface_after_read(): - response = httpx.Response(200, content=b"Hello, world!", request=REQUEST) +def test_sync_streaming_response(): + stream = IteratorStream(iterator=streaming_body()) + response = httpx.Response(200, stream=stream, request=REQUEST) - await response.aread() + assert response.status_code == 200 + assert not response.is_closed + + content = response.read() - content = b"" - async for part in response.aiter_bytes(): - content += part assert content == b"Hello, world!" + assert response.content == b"Hello, world!" + assert response.is_closed @pytest.mark.asyncio -async def test_streaming_response(): +async def test_async_streaming_response(): stream = AsyncIteratorStream(aiterator=async_streaming_body()) response = httpx.Response(200, stream=stream, request=REQUEST) @@ -211,8 +262,20 @@ async def test_streaming_response(): assert response.is_closed +def test_cannot_read_after_stream_consumed(): + stream = IteratorStream(iterator=streaming_body()) + response = httpx.Response(200, stream=stream, request=REQUEST) + + content = b"" + for part in response.iter_bytes(): + content += part + + with pytest.raises(httpx.StreamConsumed): + response.read() + + @pytest.mark.asyncio -async def test_cannot_read_after_stream_consumed(): +async def test_cannot_aread_after_stream_consumed(): stream = AsyncIteratorStream(aiterator=async_streaming_body()) response = httpx.Response(200, stream=stream, request=REQUEST) @@ -224,12 +287,38 @@ async def test_cannot_read_after_stream_consumed(): await response.aread() +def test_cannot_read_after_response_closed(): + is_closed = False + + def close_func(): + nonlocal is_closed + is_closed = True + + stream = IteratorStream(iterator=streaming_body(), close_func=close_func) + response = httpx.Response(200, stream=stream, request=REQUEST) + + response.close() + assert is_closed + + with pytest.raises(httpx.ResponseClosed): + response.read() + + @pytest.mark.asyncio -async def test_cannot_read_after_response_closed(): - stream = AsyncIteratorStream(aiterator=async_streaming_body()) +async def test_cannot_aread_after_response_closed(): + is_closed = False + + async def close_func(): + nonlocal is_closed + is_closed = True + + stream = AsyncIteratorStream( + aiterator=async_streaming_body(), close_func=close_func + ) response = httpx.Response(200, stream=stream, request=REQUEST) await response.aclose() + assert is_closed with pytest.raises(httpx.ResponseClosed): await response.aread() diff --git a/tests/test_content_streams.py b/tests/test_content_streams.py index 764e00ed..93baf8e7 100644 --- a/tests/test_content_streams.py +++ b/tests/test_content_streams.py @@ -2,97 +2,139 @@ import io import pytest -from httpx.content_streams import encode +from httpx.content_streams import ContentStream, encode from httpx.exceptions import StreamConsumed +@pytest.mark.asyncio +async def test_base_content(): + stream = ContentStream() + sync_content = b"".join([part for part in stream]) + async_content = b"".join([part async for part in stream]) + + assert stream.can_replay() + assert stream.get_headers() == {} + assert sync_content == b"" + assert async_content == b"" + + @pytest.mark.asyncio async def test_empty_content(): stream = encode() - content = b"".join([part async for part in stream]) + sync_content = b"".join([part for part in stream]) + async_content = b"".join([part async for part in stream]) assert stream.can_replay() assert stream.get_headers() == {} - assert content == b"" + assert sync_content == b"" + assert async_content == b"" @pytest.mark.asyncio async def test_bytes_content(): stream = encode(data=b"Hello, world!") - content = b"".join([part async for part in stream]) + sync_content = b"".join([part for part in stream]) + async_content = b"".join([part async for part in stream]) assert stream.can_replay() assert stream.get_headers() == {"Content-Length": "13"} - assert content == b"Hello, world!" + assert sync_content == b"Hello, world!" + assert async_content == b"Hello, world!" @pytest.mark.asyncio -async def test_aiterator_content(): - async def hello_world(): +async def test_iterator_content(): + def hello_world(): yield b"Hello, " yield b"world!" stream = encode(data=hello_world()) - content = b"".join([part async for part in stream]) + content = b"".join([part for part in stream]) assert not stream.can_replay() assert stream.get_headers() == {"Transfer-Encoding": "chunked"} assert content == b"Hello, world!" + with pytest.raises(RuntimeError): + [part async for part in stream] + + with pytest.raises(StreamConsumed): + [part for part in stream] + @pytest.mark.asyncio -async def test_aiterator_is_stream_consumed(): +async def test_aiterator_content(): async def hello_world(): yield b"Hello, " yield b"world!" stream = encode(data=hello_world()) - b"".join([part async for part in stream]) + content = b"".join([part async for part in stream]) + + assert not stream.can_replay() + assert stream.get_headers() == {"Transfer-Encoding": "chunked"} + assert content == b"Hello, world!" - assert stream.is_stream_consumed + with pytest.raises(RuntimeError): + [part for part in stream] - with pytest.raises(StreamConsumed) as _: - b"".join([part async for part in stream]) + with pytest.raises(StreamConsumed): + [part async for part in stream] @pytest.mark.asyncio async def test_json_content(): stream = encode(json={"Hello": "world!"}) - content = b"".join([part async for part in stream]) + sync_content = b"".join([part for part in stream]) + async_content = b"".join([part async for part in stream]) assert stream.can_replay() assert stream.get_headers() == { "Content-Length": "19", "Content-Type": "application/json", } - assert content == b'{"Hello": "world!"}' + assert sync_content == b'{"Hello": "world!"}' + assert async_content == b'{"Hello": "world!"}' @pytest.mark.asyncio async def test_urlencoded_content(): stream = encode(data={"Hello": "world!"}) - content = b"".join([part async for part in stream]) + sync_content = b"".join([part for part in stream]) + async_content = b"".join([part async for part in stream]) assert stream.can_replay() assert stream.get_headers() == { "Content-Length": "14", "Content-Type": "application/x-www-form-urlencoded", } - assert content == b"Hello=world%21" + assert sync_content == b"Hello=world%21" + assert async_content == b"Hello=world%21" @pytest.mark.asyncio async def test_multipart_files_content(): files = {"file": io.BytesIO(b"")} stream = encode(files=files, boundary=b"+++") - content = b"".join([part async for part in stream]) + sync_content = b"".join([part for part in stream]) + async_content = b"".join([part async for part in stream]) assert stream.can_replay() assert stream.get_headers() == { "Content-Length": "138", "Content-Type": "multipart/form-data; boundary=+++", } - assert content == b"".join( + assert sync_content == b"".join( + [ + b"--+++\r\n", + b'Content-Disposition: form-data; name="file"; filename="upload"\r\n', + b"Content-Type: application/octet-stream\r\n", + b"\r\n", + b"\r\n", + b"--+++--\r\n", + ] + ) + assert async_content == b"".join( [ b"--+++\r\n", b'Content-Disposition: form-data; name="file"; filename="upload"\r\n', @@ -109,14 +151,29 @@ async def test_multipart_data_and_files_content(): data = {"message": "Hello, world!"} files = {"file": io.BytesIO(b"")} stream = encode(data=data, files=files, boundary=b"+++") - content = b"".join([part async for part in stream]) + sync_content = b"".join([part for part in stream]) + async_content = b"".join([part async for part in stream]) assert stream.can_replay() assert stream.get_headers() == { "Content-Length": "210", "Content-Type": "multipart/form-data; boundary=+++", } - assert content == b"".join( + assert sync_content == b"".join( + [ + b"--+++\r\n", + b'Content-Disposition: form-data; name="message"\r\n', + b"\r\n", + b"Hello, world!\r\n", + b"--+++\r\n", + b'Content-Disposition: form-data; name="file"; filename="upload"\r\n', + b"Content-Type: application/octet-stream\r\n", + b"\r\n", + b"\r\n", + b"--+++--\r\n", + ] + ) + assert async_content == b"".join( [ b"--+++\r\n", b'Content-Disposition: form-data; name="message"\r\n', @@ -130,3 +187,8 @@ async def test_multipart_data_and_files_content(): b"--+++--\r\n", ] ) + + +def test_invalid_argument(): + with pytest.raises(TypeError): + encode(123)