]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Adds check to enforce single consumption of AsyncIteratorStream. (#697)
authorGabriel Strauss <gstrauss5@gmail.com>
Tue, 31 Dec 2019 12:01:43 +0000 (07:01 -0500)
committerTom Christie <tom@tomchristie.com>
Tue, 31 Dec 2019 12:01:43 +0000 (12:01 +0000)
httpx/content_streams.py
tests/test_content_streams.py

index 62d150b4c127b0d787084c468608340b4c8f3de7..73ef59ef0b0f7c8e5932f23f28ce05f53b001269 100644 (file)
@@ -7,6 +7,7 @@ from json import dumps as json_dumps
 from pathlib import Path
 from urllib.parse import urlencode
 
+from .exceptions import StreamConsumed
 from .utils import format_form_param
 
 RequestData = typing.Union[dict, str, bytes, typing.AsyncIterator[bytes]]
@@ -81,6 +82,7 @@ class AsyncIteratorStream(ContentStream):
     ) -> None:
         self.aiterator = aiterator
         self.close_func = close_func
+        self.is_stream_consumed = False
 
     def can_replay(self) -> bool:
         return False
@@ -89,6 +91,9 @@ class AsyncIteratorStream(ContentStream):
         return {"Transfer-Encoding": "chunked"}
 
     async def __aiter__(self) -> typing.AsyncIterator[bytes]:
+        if self.is_stream_consumed:
+            raise StreamConsumed()
+        self.is_stream_consumed = True
         async for part in self.aiterator:
             yield part
 
index 64146ed941f535a8b84f4b3177feafe73bf4e33e..764e00ed7a8a598efa483f9d3bf33d256de5d9b6 100644 (file)
@@ -3,6 +3,7 @@ import io
 import pytest
 
 from httpx.content_streams import encode
+from httpx.exceptions import StreamConsumed
 
 
 @pytest.mark.asyncio
@@ -39,6 +40,21 @@ async def test_aiterator_content():
     assert content == b"Hello, world!"
 
 
+@pytest.mark.asyncio
+async def test_aiterator_is_stream_consumed():
+    async def hello_world():
+        yield b"Hello, "
+        yield b"world!"
+
+    stream = encode(data=hello_world())
+    b"".join([part async for part in stream])
+
+    assert stream.is_stream_consumed
+
+    with pytest.raises(StreamConsumed) as _:
+        b"".join([part async for part in stream])
+
+
 @pytest.mark.asyncio
 async def test_json_content():
     stream = encode(json={"Hello": "world!"})