From: florimondmanca Date: Sun, 12 Feb 2023 23:25:23 +0000 (+0100) Subject: Fix unclosed generator on trio X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5a3603726a1b77068b06f9c4e5328e6ba7fa42c1;p=thirdparty%2Fhttpx.git Fix unclosed generator on trio --- diff --git a/httpx/_client.py b/httpx/_client.py index 1f9f3beb..5702f490 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -142,9 +142,8 @@ class BoundAsyncStream(AsyncByteStream): self._response = response self._timer = timer - async def __aiter__(self) -> typing.AsyncIterator[bytes]: - async for chunk in self._stream: - yield chunk + def __aiter__(self) -> typing.AsyncIterator[bytes]: + return self._stream.__aiter__() async def aclose(self) -> None: seconds = await self._timer.async_elapsed() diff --git a/httpx/_models.py b/httpx/_models.py index e0e5278c..f3cc0061 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -3,6 +3,7 @@ import email.message import json as jsonlib import typing import urllib.request +from contextlib import aclosing from collections.abc import Mapping from http.cookiejar import Cookie, CookieJar @@ -911,7 +912,7 @@ class Response: async def aiter_bytes( self, chunk_size: typing.Optional[int] = None - ) -> typing.AsyncIterator[bytes]: + ) -> typing.AsyncGenerator[bytes, None]: """ A byte-iterator over the decoded response content. This allows us to handle gzip, deflate, and brotli encoded responses. @@ -924,19 +925,20 @@ class Response: decoder = self._get_content_decoder() chunker = ByteChunker(chunk_size=chunk_size) with request_context(request=self._request): - async for raw_bytes in self.aiter_raw(): - decoded = decoder.decode(raw_bytes) + async with aclosing(self.aiter_raw()) as stream: + async for raw_bytes in stream: + decoded = decoder.decode(raw_bytes) + for chunk in chunker.decode(decoded): + yield chunk + decoded = decoder.flush() for chunk in chunker.decode(decoded): + yield chunk # pragma: no cover + for chunk in chunker.flush(): yield chunk - decoded = decoder.flush() - for chunk in chunker.decode(decoded): - yield chunk # pragma: no cover - for chunk in chunker.flush(): - yield chunk async def aiter_text( self, chunk_size: typing.Optional[int] = None - ) -> typing.AsyncIterator[str]: + ) -> typing.AsyncGenerator[str, None]: """ A str-iterator over the decoded response content that handles both gzip, deflate, etc but also detects the content's @@ -945,28 +947,30 @@ class Response: decoder = TextDecoder(encoding=self.encoding or "utf-8") chunker = TextChunker(chunk_size=chunk_size) with request_context(request=self._request): - async for byte_content in self.aiter_bytes(): - text_content = decoder.decode(byte_content) + async with aclosing(self.aiter_bytes()) as stream: + async for byte_content in stream: + text_content = decoder.decode(byte_content) + for chunk in chunker.decode(text_content): + yield chunk + text_content = decoder.flush() for chunk in chunker.decode(text_content): yield chunk - text_content = decoder.flush() - for chunk in chunker.decode(text_content): - yield chunk - for chunk in chunker.flush(): - yield chunk + for chunk in chunker.flush(): + yield chunk - async def aiter_lines(self) -> typing.AsyncIterator[str]: + async def aiter_lines(self) -> typing.AsyncGenerator[str, None]: decoder = LineDecoder() with request_context(request=self._request): - async for text in self.aiter_text(): - for line in decoder.decode(text): + async with aclosing(self.aiter_text()) as stream: + async for text in stream: + for line in decoder.decode(text): + yield line + for line in decoder.flush(): yield line - for line in decoder.flush(): - yield line async def aiter_raw( self, chunk_size: typing.Optional[int] = None - ) -> typing.AsyncIterator[bytes]: + ) -> typing.AsyncGenerator[bytes, None]: """ A byte-iterator over the raw response content. """ diff --git a/httpx/_transports/default.py b/httpx/_transports/default.py index dfd274e7..888e2821 100644 --- a/httpx/_transports/default.py +++ b/httpx/_transports/default.py @@ -232,12 +232,14 @@ class HTTPTransport(BaseTransport): class AsyncResponseStream(AsyncByteStream): def __init__(self, httpcore_stream: typing.AsyncIterable[bytes]): - self._httpcore_stream = httpcore_stream + self._httpcore_stream = httpcore_stream.__aiter__() + + def __aiter__(self) -> typing.AsyncIterator[bytes]: + return self - async def __aiter__(self) -> typing.AsyncIterator[bytes]: + async def __anext__(self) -> bytes: with map_httpcore_exceptions(): - async for part in self._httpcore_stream: - yield part + return await self._httpcore_stream.__anext__() async def aclose(self) -> None: if hasattr(self._httpcore_stream, "aclose"): diff --git a/tests/client/test_async_client.py b/tests/client/test_async_client.py index 5be0de3b..2c681b55 100644 --- a/tests/client/test_async_client.py +++ b/tests/client/test_async_client.py @@ -1,4 +1,5 @@ import typing +from contextlib import aclosing from datetime import timedelta import pytest @@ -76,6 +77,34 @@ async def test_stream_response(server): assert response.content == b"Hello, world!" +@pytest.mark.anyio +async def test_stream_iterator(server): + body = b"" + + async with httpx.AsyncClient() as client: + async with client.stream("GET", server.url) as response: + async for chunk in response.aiter_bytes(): + body += chunk + + assert response.status_code == 200 + assert body == b"Hello, world!" + + +@pytest.mark.anyio +async def test_stream_iterator_partial(server): + body = "" + + async with httpx.AsyncClient() as client: + async with client.stream("GET", server.url) as response: + async with aclosing(response.aiter_text(5)) as stream: + async for chunk in stream: + body += chunk + break + + assert response.status_code == 200 + assert body == "Hello" + + @pytest.mark.anyio async def test_access_content_stream_response(server): async with httpx.AsyncClient() as client: diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 268cd106..c35725ec 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -107,6 +107,19 @@ def test_stream_iterator(server): assert body == b"Hello, world!" +def test_stream_iterator_partial(server): + body = "" + + with httpx.Client() as client: + with client.stream("GET", server.url) as response: + for chunk in response.iter_text(5): + body += chunk + break + + assert response.status_code == 200 + assert body == "Hello" + + def test_raw_iterator(server): body = b""