]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Fix unclosed generator on trio
authorflorimondmanca <florimond.manca@protonmail.com>
Sun, 12 Feb 2023 23:25:23 +0000 (00:25 +0100)
committerflorimondmanca <florimond.manca@protonmail.com>
Sun, 12 Feb 2023 23:25:23 +0000 (00:25 +0100)
httpx/_client.py
httpx/_models.py
httpx/_transports/default.py
tests/client/test_async_client.py
tests/client/test_client.py

index 1f9f3beb56aae687a9c2fd6d70f450fc6483bb99..5702f490b1b500a574d70db9c010be8f4ec10ab5 100644 (file)
@@ -142,9 +142,8 @@ class BoundAsyncStream(AsyncByteStream):
         self._response = response
         self._timer = timer
 
-    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
-        async for chunk in self._stream:
-            yield chunk
+    def __aiter__(self) -> typing.AsyncIterator[bytes]:
+        return self._stream.__aiter__()
 
     async def aclose(self) -> None:
         seconds = await self._timer.async_elapsed()
index e0e5278cc052e2f9a6d0af0a1cb2107b03de98f4..f3cc00616a8d9715fc63b2d44222bfed2b6dc614 100644 (file)
@@ -3,6 +3,7 @@ import email.message
 import json as jsonlib
 import typing
 import urllib.request
+from contextlib import aclosing
 from collections.abc import Mapping
 from http.cookiejar import Cookie, CookieJar
 
@@ -911,7 +912,7 @@ class Response:
 
     async def aiter_bytes(
         self, chunk_size: typing.Optional[int] = None
-    ) -> typing.AsyncIterator[bytes]:
+    ) -> typing.AsyncGenerator[bytes, None]:
         """
         A byte-iterator over the decoded response content.
         This allows us to handle gzip, deflate, and brotli encoded responses.
@@ -924,19 +925,20 @@ class Response:
             decoder = self._get_content_decoder()
             chunker = ByteChunker(chunk_size=chunk_size)
             with request_context(request=self._request):
-                async for raw_bytes in self.aiter_raw():
-                    decoded = decoder.decode(raw_bytes)
+                async with aclosing(self.aiter_raw()) as stream:
+                    async for raw_bytes in stream:
+                        decoded = decoder.decode(raw_bytes)
+                        for chunk in chunker.decode(decoded):
+                            yield chunk
+                    decoded = decoder.flush()
                     for chunk in chunker.decode(decoded):
+                        yield chunk  # pragma: no cover
+                    for chunk in chunker.flush():
                         yield chunk
-                decoded = decoder.flush()
-                for chunk in chunker.decode(decoded):
-                    yield chunk  # pragma: no cover
-                for chunk in chunker.flush():
-                    yield chunk
 
     async def aiter_text(
         self, chunk_size: typing.Optional[int] = None
-    ) -> typing.AsyncIterator[str]:
+    ) -> typing.AsyncGenerator[str, None]:
         """
         A str-iterator over the decoded response content
         that handles both gzip, deflate, etc but also detects the content's
@@ -945,28 +947,30 @@ class Response:
         decoder = TextDecoder(encoding=self.encoding or "utf-8")
         chunker = TextChunker(chunk_size=chunk_size)
         with request_context(request=self._request):
-            async for byte_content in self.aiter_bytes():
-                text_content = decoder.decode(byte_content)
+            async with aclosing(self.aiter_bytes()) as stream:
+                async for byte_content in stream:
+                    text_content = decoder.decode(byte_content)
+                    for chunk in chunker.decode(text_content):
+                        yield chunk
+                text_content = decoder.flush()
                 for chunk in chunker.decode(text_content):
                     yield chunk
-            text_content = decoder.flush()
-            for chunk in chunker.decode(text_content):
-                yield chunk
-            for chunk in chunker.flush():
-                yield chunk
+                for chunk in chunker.flush():
+                    yield chunk
 
-    async def aiter_lines(self) -> typing.AsyncIterator[str]:
+    async def aiter_lines(self) -> typing.AsyncGenerator[str, None]:
         decoder = LineDecoder()
         with request_context(request=self._request):
-            async for text in self.aiter_text():
-                for line in decoder.decode(text):
+            async with aclosing(self.aiter_text()) as stream:
+                async for text in stream:
+                    for line in decoder.decode(text):
+                        yield line
+                for line in decoder.flush():
                     yield line
-            for line in decoder.flush():
-                yield line
 
     async def aiter_raw(
         self, chunk_size: typing.Optional[int] = None
-    ) -> typing.AsyncIterator[bytes]:
+    ) -> typing.AsyncGenerator[bytes, None]:
         """
         A byte-iterator over the raw response content.
         """
index dfd274e7bf557f8f88839d5a468a4279fecb9b63..888e28219e4d8305117112b6d872d8860d170e4c 100644 (file)
@@ -232,12 +232,14 @@ class HTTPTransport(BaseTransport):
 
 class AsyncResponseStream(AsyncByteStream):
     def __init__(self, httpcore_stream: typing.AsyncIterable[bytes]):
-        self._httpcore_stream = httpcore_stream
+        self._httpcore_stream = httpcore_stream.__aiter__()
+
+    def __aiter__(self) -> typing.AsyncIterator[bytes]:
+        return self
 
-    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
+    async def __anext__(self) -> bytes:
         with map_httpcore_exceptions():
-            async for part in self._httpcore_stream:
-                yield part
+            return await self._httpcore_stream.__anext__()
 
     async def aclose(self) -> None:
         if hasattr(self._httpcore_stream, "aclose"):
index 5be0de3b12d83ffdf3515165fa789d57134d8565..2c681b553bd14971dd63057f029304af6f7e5f71 100644 (file)
@@ -1,4 +1,5 @@
 import typing
+from contextlib import aclosing
 from datetime import timedelta
 
 import pytest
@@ -76,6 +77,34 @@ async def test_stream_response(server):
     assert response.content == b"Hello, world!"
 
 
+@pytest.mark.anyio
+async def test_stream_iterator(server):
+    body = b""
+
+    async with httpx.AsyncClient() as client:
+        async with client.stream("GET", server.url) as response:
+            async for chunk in response.aiter_bytes():
+                body += chunk
+
+    assert response.status_code == 200
+    assert body == b"Hello, world!"
+
+
+@pytest.mark.anyio
+async def test_stream_iterator_partial(server):
+    body = ""
+
+    async with httpx.AsyncClient() as client:
+        async with client.stream("GET", server.url) as response:
+            async with aclosing(response.aiter_text(5)) as stream:
+                async for chunk in stream:
+                    body += chunk
+                    break
+
+    assert response.status_code == 200
+    assert body == "Hello"
+
+
 @pytest.mark.anyio
 async def test_access_content_stream_response(server):
     async with httpx.AsyncClient() as client:
index 268cd106899674883814d3650cc452931a2146e5..c35725ecfe35abd877c3d8a20d374cb81594a2d3 100644 (file)
@@ -107,6 +107,19 @@ def test_stream_iterator(server):
     assert body == b"Hello, world!"
 
 
+def test_stream_iterator_partial(server):
+    body = ""
+
+    with httpx.Client() as client:
+        with client.stream("GET", server.url) as response:
+            for chunk in response.iter_text(5):
+                body += chunk
+                break
+
+    assert response.status_code == 200
+    assert body == "Hello"
+
+
 def test_raw_iterator(server):
     body = b""