From: Gabriel Strauss Date: Tue, 31 Dec 2019 12:01:43 +0000 (-0500) Subject: Adds check to enforce single consumption of AsyncIteratorStream. (#697) X-Git-Tag: 0.11.0~32 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=de8b95533dd09a5014c3c024667737ef1a8bb2e1;p=thirdparty%2Fhttpx.git Adds check to enforce single consumption of AsyncIteratorStream. (#697) --- diff --git a/httpx/content_streams.py b/httpx/content_streams.py index 62d150b4..73ef59ef 100644 --- a/httpx/content_streams.py +++ b/httpx/content_streams.py @@ -7,6 +7,7 @@ from json import dumps as json_dumps from pathlib import Path from urllib.parse import urlencode +from .exceptions import StreamConsumed from .utils import format_form_param RequestData = typing.Union[dict, str, bytes, typing.AsyncIterator[bytes]] @@ -81,6 +82,7 @@ class AsyncIteratorStream(ContentStream): ) -> None: self.aiterator = aiterator self.close_func = close_func + self.is_stream_consumed = False def can_replay(self) -> bool: return False @@ -89,6 +91,9 @@ class AsyncIteratorStream(ContentStream): return {"Transfer-Encoding": "chunked"} async def __aiter__(self) -> typing.AsyncIterator[bytes]: + if self.is_stream_consumed: + raise StreamConsumed() + self.is_stream_consumed = True async for part in self.aiterator: yield part diff --git a/tests/test_content_streams.py b/tests/test_content_streams.py index 64146ed9..764e00ed 100644 --- a/tests/test_content_streams.py +++ b/tests/test_content_streams.py @@ -3,6 +3,7 @@ import io import pytest from httpx.content_streams import encode +from httpx.exceptions import StreamConsumed @pytest.mark.asyncio @@ -39,6 +40,21 @@ async def test_aiterator_content(): assert content == b"Hello, world!" +@pytest.mark.asyncio +async def test_aiterator_is_stream_consumed(): + async def hello_world(): + yield b"Hello, " + yield b"world!" + + stream = encode(data=hello_world()) + b"".join([part async for part in stream]) + + assert stream.is_stream_consumed + + with pytest.raises(StreamConsumed) as _: + b"".join([part async for part in stream]) + + @pytest.mark.asyncio async def test_json_content(): stream = encode(json={"Hello": "world!"})