]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Stream interface (#1550)
authorTom Christie <tom@tomchristie.com>
Tue, 13 Apr 2021 12:14:04 +0000 (13:14 +0100)
committerGitHub <noreply@github.com>
Tue, 13 Apr 2021 12:14:04 +0000 (13:14 +0100)
* Add SyncByteStream, AsyncByteStream to interface

* request.stream and response.stream as httpx.SyncByteStream/httpx.AsyncByteStream

* Update httpx/_transports/base.py

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
* Update httpx/_transports/default.py

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
* Move response classes in transports to module level

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
13 files changed:
docs/advanced.md
httpx/__init__.py
httpx/_client.py
httpx/_content.py
httpx/_models.py
httpx/_multipart.py
httpx/_transports/asgi.py
httpx/_transports/base.py
httpx/_transports/default.py
httpx/_transports/mock.py
httpx/_transports/wsgi.py
httpx/_types.py
tests/test_content.py

index 1902b0eeb5f39493a61d504c81284108dad6f610..4438cb2d6f280af3001f50dc29401626f3b5988a 100644 (file)
@@ -1070,7 +1070,7 @@ class HelloWorldTransport(httpx.BaseTransport):
     def handle_request(self, method, url, headers, stream, extensions):
         message = {"text": "Hello, world!"}
         content = json.dumps(message).encode("utf-8")
-        stream = [content]
+        stream = httpx.ByteStream(content)
         headers = [(b"content-type", b"application/json")]
         extensions = {}
         return 200, headers, stream, extensions
@@ -1131,7 +1131,7 @@ class HTTPSRedirectTransport(httpx.BaseTransport):
             location = b"https://%s%s" % (host, path)
         else:
             location = b"https://%s:%d%s" % (host, port, path)
-        stream = [b""]
+        stream = httpx.ByteStream(b"")
         headers = [(b"location", location)]
         extensions = {}
         return 303, headers, stream, extensions
index a441669bf6a75fd9038609571b9225f92656cdea..af38f8a91285d7a1b826fadf0fd33f5e807d45da 100644 (file)
@@ -3,6 +3,7 @@ from ._api import delete, get, head, options, patch, post, put, request, stream
 from ._auth import Auth, BasicAuth, DigestAuth
 from ._client import AsyncClient, Client
 from ._config import Limits, Proxy, Timeout, create_ssl_context
+from ._content import ByteStream
 from ._exceptions import (
     CloseError,
     ConnectError,
@@ -36,7 +37,12 @@ from ._exceptions import (
 from ._models import URL, Cookies, Headers, QueryParams, Request, Response
 from ._status_codes import StatusCode, codes
 from ._transports.asgi import ASGITransport
-from ._transports.base import AsyncBaseTransport, BaseTransport
+from ._transports.base import (
+    AsyncBaseTransport,
+    AsyncByteStream,
+    BaseTransport,
+    SyncByteStream,
+)
 from ._transports.default import AsyncHTTPTransport, HTTPTransport
 from ._transports.mock import MockTransport
 from ._transports.wsgi import WSGITransport
@@ -47,11 +53,13 @@ __all__ = [
     "__version__",
     "ASGITransport",
     "AsyncBaseTransport",
+    "AsyncByteStream",
     "AsyncClient",
     "AsyncHTTPTransport",
     "Auth",
     "BaseTransport",
     "BasicAuth",
+    "ByteStream",
     "Client",
     "CloseError",
     "codes",
@@ -97,6 +105,7 @@ __all__ = [
     "stream",
     "StreamConsumed",
     "StreamError",
+    "SyncByteStream",
     "Timeout",
     "TimeoutException",
     "TooManyRedirects",
index 7f8ce53101f3492e1a75f225b6d0b8e42b0221f0..ce466aa3a48d41b4031fe55ce511c46489bab479 100644 (file)
@@ -26,12 +26,16 @@ from ._exceptions import (
 from ._models import URL, Cookies, Headers, QueryParams, Request, Response
 from ._status_codes import codes
 from ._transports.asgi import ASGITransport
-from ._transports.base import AsyncBaseTransport, BaseTransport
+from ._transports.base import (
+    AsyncBaseTransport,
+    AsyncByteStream,
+    BaseTransport,
+    SyncByteStream,
+)
 from ._transports.default import AsyncHTTPTransport, HTTPTransport
 from ._transports.wsgi import WSGITransport
 from ._types import (
     AuthTypes,
-    ByteStream,
     CertTypes,
     CookieTypes,
     HeaderTypes,
@@ -509,7 +513,7 @@ class BaseClient:
 
     def _redirect_stream(
         self, request: Request, method: str
-    ) -> typing.Optional[ByteStream]:
+    ) -> typing.Optional[typing.Union[SyncByteStream, AsyncByteStream]]:
         """
         Return the body that should be used for the redirect request.
         """
@@ -880,8 +884,7 @@ class Client(BaseClient):
 
         def on_close(response: Response) -> None:
             response.elapsed = datetime.timedelta(seconds=timer.sync_elapsed())
-            if "close" in extensions:
-                extensions["close"]()
+            stream.close()
 
         response = Response(
             status_code,
@@ -1524,8 +1527,7 @@ class AsyncClient(BaseClient):
 
         async def on_close(response: Response) -> None:
             response.elapsed = datetime.timedelta(seconds=await timer.async_elapsed())
-            if "aclose" in extensions:
-                await extensions["aclose"]()
+            await stream.aclose()
 
         response = Response(
             status_code,
index 0b9672be3fb94327cefe5bc88f74ede3d26b9d8b..4b16d1e970af2e6a21e33f23348090550ffb44fa 100644 (file)
@@ -14,92 +14,69 @@ from urllib.parse import urlencode
 
 from ._exceptions import StreamConsumed
 from ._multipart import MultipartStream
-from ._types import (
-    ByteStream,
-    RequestContent,
-    RequestData,
-    RequestFiles,
-    ResponseContent,
-)
+from ._transports.base import AsyncByteStream, SyncByteStream
+from ._types import RequestContent, RequestData, RequestFiles, ResponseContent
 from ._utils import primitive_value_to_str
 
 
-class PlainByteStream:
-    """
-    Request content encoded as plain bytes.
-    """
-
-    def __init__(self, body: bytes) -> None:
-        self._body = body
+class ByteStream(AsyncByteStream, SyncByteStream):
+    def __init__(self, stream: bytes) -> None:
+        self._stream = stream
 
     def __iter__(self) -> Iterator[bytes]:
-        yield self._body
+        yield self._stream
 
     async def __aiter__(self) -> AsyncIterator[bytes]:
-        yield self._body
+        yield self._stream
 
 
-class GeneratorStream:
-    """
-    Request content encoded as plain bytes, using an byte generator.
-    """
-
-    def __init__(self, generator: Iterable[bytes]) -> None:
-        self._generator = generator
+class IteratorByteStream(SyncByteStream):
+    def __init__(self, stream: Iterable[bytes]):
+        self._stream = stream
         self._is_stream_consumed = False
+        self._is_generator = inspect.isgenerator(stream)
 
     def __iter__(self) -> Iterator[bytes]:
-        if self._is_stream_consumed:
+        if self._is_stream_consumed and self._is_generator:
             raise StreamConsumed()
 
         self._is_stream_consumed = True
-        for part in self._generator:
+        for part in self._stream:
             yield part
 
 
-class AsyncGeneratorStream:
-    """
-    Request content encoded as plain bytes, using an async byte iterator.
-    """
-
-    def __init__(self, agenerator: AsyncIterable[bytes]) -> None:
-        self._agenerator = agenerator
+class AsyncIteratorByteStream(AsyncByteStream):
+    def __init__(self, stream: AsyncIterable[bytes]):
+        self._stream = stream
         self._is_stream_consumed = False
+        self._is_generator = inspect.isasyncgen(stream)
 
     async def __aiter__(self) -> AsyncIterator[bytes]:
-        if self._is_stream_consumed:
+        if self._is_stream_consumed and self._is_generator:
             raise StreamConsumed()
 
         self._is_stream_consumed = True
-        async for part in self._agenerator:
+        async for part in self._stream:
             yield part
 
 
 def encode_content(
-    content: Union[str, bytes, ByteStream]
-) -> Tuple[Dict[str, str], ByteStream]:
-    if isinstance(content, (str, bytes)):
+    content: Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]]
+) -> Tuple[Dict[str, str], Union[SyncByteStream, AsyncByteStream]]:
+
+    if isinstance(content, (bytes, str)):
         body = content.encode("utf-8") if isinstance(content, str) else content
         content_length = str(len(body))
         headers = {"Content-Length": content_length} if body else {}
-        stream = PlainByteStream(body)
-        return headers, stream
+        return headers, ByteStream(body)
 
-    elif isinstance(content, (Iterable, AsyncIterable)):
+    elif isinstance(content, Iterable):
         headers = {"Transfer-Encoding": "chunked"}
+        return headers, IteratorByteStream(content)  # type: ignore
 
-        # Generators should be wrapped in GeneratorStream/AsyncGeneratorStream
-        # which will raise `StreamConsumed` if the stream is accessed more
-        # than once. (Eg. Following HTTP 307 or HTTP 308 redirects.)
-        if inspect.isgenerator(content):
-            generator_stream = GeneratorStream(content)  # type: ignore
-            return headers, generator_stream
-        if inspect.isasyncgen(content):
-            agenerator_stream = AsyncGeneratorStream(content)  # type: ignore
-            return headers, agenerator_stream
-
-        # Other iterables may be passed through as-is.
-        return headers, content  # type: ignore
+    elif isinstance(content, AsyncIterable):
+        headers = {"Transfer-Encoding": "chunked"}
+        return headers, AsyncIteratorByteStream(content)
 
     raise TypeError(f"Unexpected type for 'content', {type(content)!r}")
 
@@ -117,15 +94,15 @@ def encode_urlencoded_data(
     content_length = str(len(body))
     content_type = "application/x-www-form-urlencoded"
     headers = {"Content-Length": content_length, "Content-Type": content_type}
-    return headers, PlainByteStream(body)
+    return headers, ByteStream(body)
 
 
 def encode_multipart_data(
     data: dict, files: RequestFiles, boundary: bytes = None
-) -> Tuple[Dict[str, str], ByteStream]:
-    stream = MultipartStream(data=data, files=files, boundary=boundary)
-    headers = stream.get_headers()
-    return headers, stream
+) -> Tuple[Dict[str, str], MultipartStream]:
+    multipart = MultipartStream(data=data, files=files, boundary=boundary)
+    headers = multipart.get_headers()
+    return headers, multipart
 
 
 def encode_text(text: str) -> Tuple[Dict[str, str], ByteStream]:
@@ -133,7 +110,7 @@ def encode_text(text: str) -> Tuple[Dict[str, str], ByteStream]:
     content_length = str(len(body))
     content_type = "text/plain; charset=utf-8"
     headers = {"Content-Length": content_length, "Content-Type": content_type}
-    return headers, PlainByteStream(body)
+    return headers, ByteStream(body)
 
 
 def encode_html(html: str) -> Tuple[Dict[str, str], ByteStream]:
@@ -141,7 +118,7 @@ def encode_html(html: str) -> Tuple[Dict[str, str], ByteStream]:
     content_length = str(len(body))
     content_type = "text/html; charset=utf-8"
     headers = {"Content-Length": content_length, "Content-Type": content_type}
-    return headers, PlainByteStream(body)
+    return headers, ByteStream(body)
 
 
 def encode_json(json: Any) -> Tuple[Dict[str, str], ByteStream]:
@@ -149,7 +126,7 @@ def encode_json(json: Any) -> Tuple[Dict[str, str], ByteStream]:
     content_length = str(len(body))
     content_type = "application/json"
     headers = {"Content-Length": content_length, "Content-Type": content_type}
-    return headers, PlainByteStream(body)
+    return headers, ByteStream(body)
 
 
 def encode_request(
@@ -158,7 +135,7 @@ def encode_request(
     files: RequestFiles = None,
     json: Any = None,
     boundary: bytes = None,
-) -> Tuple[Dict[str, str], ByteStream]:
+) -> Tuple[Dict[str, str], Union[SyncByteStream, AsyncByteStream]]:
     """
     Handles encoding the given `content`, `data`, `files`, and `json`,
     returning a two-tuple of (<headers>, <stream>).
@@ -182,7 +159,7 @@ def encode_request(
     elif json is not None:
         return encode_json(json)
 
-    return {}, PlainByteStream(b"")
+    return {}, ByteStream(b"")
 
 
 def encode_response(
@@ -190,7 +167,7 @@ def encode_response(
     text: str = None,
     html: str = None,
     json: Any = None,
-) -> Tuple[Dict[str, str], ByteStream]:
+) -> Tuple[Dict[str, str], Union[SyncByteStream, AsyncByteStream]]:
     """
     Handles encoding the given `content`, returning a two-tuple of
     (<headers>, <stream>).
@@ -204,4 +181,4 @@ def encode_response(
     elif json is not None:
         return encode_json(json)
 
-    return {}, PlainByteStream(b"")
+    return {}, ByteStream(b"")
index ade5a3192598ba12f03a02f6e04c33dd9a84d384..a3b6ff1f01119c4d39190d2cc8226d3406f865dd 100644 (file)
@@ -11,7 +11,7 @@ from urllib.parse import parse_qsl, quote, unquote, urlencode
 import rfc3986
 import rfc3986.exceptions
 
-from ._content import PlainByteStream, encode_request, encode_response
+from ._content import ByteStream, encode_request, encode_response
 from ._decoders import (
     SUPPORTED_DECODERS,
     ByteChunker,
@@ -33,8 +33,8 @@ from ._exceptions import (
     request_context,
 )
 from ._status_codes import codes
+from ._transports.base import AsyncByteStream, SyncByteStream
 from ._types import (
-    ByteStream,
     CookieTypes,
     HeaderTypes,
     PrimitiveData,
@@ -798,7 +798,7 @@ class Request:
         data: RequestData = None,
         files: RequestFiles = None,
         json: typing.Any = None,
-        stream: ByteStream = None,
+        stream: typing.Union[SyncByteStream, AsyncByteStream] = None,
     ):
         if isinstance(method, bytes):
             self.method = method.decode("ascii").upper()
@@ -872,7 +872,7 @@ class Request:
             # If a streaming request has been read entirely into memory, then
             # we can replace the stream with a raw bytes implementation,
             # to ensure that any non-replayable streams can still be used.
-            self.stream = PlainByteStream(self._content)
+            self.stream = ByteStream(self._content)
         return self._content
 
     async def aread(self) -> bytes:
@@ -885,7 +885,7 @@ class Request:
             # If a streaming request has been read entirely into memory, then
             # we can replace the stream with a raw bytes implementation,
             # to ensure that any non-replayable streams can still be used.
-            self.stream = PlainByteStream(self._content)
+            self.stream = ByteStream(self._content)
         return self._content
 
     def __repr__(self) -> str:
@@ -904,7 +904,7 @@ class Response:
         text: str = None,
         html: str = None,
         json: typing.Any = None,
-        stream: ByteStream = None,
+        stream: typing.Union[SyncByteStream, AsyncByteStream] = None,
         request: Request = None,
         extensions: dict = None,
         history: typing.List["Response"] = None,
@@ -1222,7 +1222,7 @@ class Response:
             raise StreamConsumed()
         if self.is_closed:
             raise ResponseClosed()
-        if not isinstance(self.stream, typing.Iterable):
+        if not isinstance(self.stream, SyncByteStream):
             raise RuntimeError("Attempted to call a sync iterator on an async stream.")
 
         self.is_stream_consumed = True
@@ -1318,8 +1318,8 @@ class Response:
             raise StreamConsumed()
         if self.is_closed:
             raise ResponseClosed()
-        if not isinstance(self.stream, typing.AsyncIterable):
-            raise RuntimeError("Attempted to call a async iterator on a sync stream.")
+        if not isinstance(self.stream, AsyncByteStream):
+            raise RuntimeError("Attempted to call an async iterator on an sync stream.")
 
         self.is_stream_consumed = True
         self._num_bytes_downloaded = 0
index b5f8fb48f83b171cedc89b2861159e1af9bf30a9..cb23d0cfa5604b8623bac5a81eee083258768c52 100644 (file)
@@ -3,6 +3,7 @@ import os
 import typing
 from pathlib import Path
 
+from ._transports.base import AsyncByteStream, SyncByteStream
 from ._types import FileContent, FileTypes, RequestFiles
 from ._utils import (
     format_form_param,
@@ -141,7 +142,7 @@ class FileField:
         yield from self.render_data()
 
 
-class MultipartStream:
+class MultipartStream(SyncByteStream, AsyncByteStream):
     """
     Request content as streaming multipart encoded form data.
     """
index ef0a3ef29ab43843d9847f9b35accd6df8ef6219..24c5452dc921ea3ea5a37619c6e591c24942872e 100644 (file)
@@ -3,7 +3,7 @@ from urllib.parse import unquote
 
 import sniffio
 
-from .base import AsyncBaseTransport
+from .base import AsyncBaseTransport, AsyncByteStream
 
 if typing.TYPE_CHECKING:  # pragma: no cover
     import asyncio
@@ -24,6 +24,14 @@ def create_event() -> "Event":
         return asyncio.Event()
 
 
+class ASGIResponseStream(AsyncByteStream):
+    def __init__(self, body: typing.List[bytes]) -> None:
+        self._body = body
+
+    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
+        yield b"".join(self._body)
+
+
 class ASGITransport(AsyncBaseTransport):
     """
     A custom AsyncTransport that handles sending requests directly to an ASGI app.
@@ -74,10 +82,10 @@ class ASGITransport(AsyncBaseTransport):
         method: bytes,
         url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
         headers: typing.List[typing.Tuple[bytes, bytes]],
-        stream: typing.AsyncIterable[bytes],
+        stream: AsyncByteStream,
         extensions: dict,
     ) -> typing.Tuple[
-        int, typing.List[typing.Tuple[bytes, bytes]], typing.AsyncIterable[bytes], dict
+        int, typing.List[typing.Tuple[bytes, bytes]], AsyncByteStream, dict
     ]:
         # ASGI scope.
         scheme, host, port, full_path = url
@@ -155,9 +163,7 @@ class ASGITransport(AsyncBaseTransport):
         assert status_code is not None
         assert response_headers is not None
 
-        async def response_stream() -> typing.AsyncIterator[bytes]:
-            yield b"".join(body_parts)
-
+        stream = ASGIResponseStream(body_parts)
         extensions = {}
 
-        return (status_code, response_headers, response_stream(), extensions)
+        return (status_code, response_headers, stream, extensions)
index e26938f94b0bbe5681088a22db4bc2d29035f38d..eb519269704882d424f6e7c4c2d64bef5d6b668b 100644 (file)
@@ -5,6 +5,63 @@ T = typing.TypeVar("T", bound="BaseTransport")
 A = typing.TypeVar("A", bound="AsyncBaseTransport")
 
 
+class SyncByteStream:
+    def __iter__(self) -> typing.Iterator[bytes]:
+        raise NotImplementedError(
+            "The '__iter__' method must be implemented."
+        )  # pragma: nocover
+        yield b""  # pragma: nocover
+
+    def close(self) -> None:
+        """
+        Subclasses can override this method to release any network resources
+        after a request/response cycle is complete.
+
+        Streaming cases should use a `try...finally` block to ensure that
+        the stream `close()` method is always called.
+
+        Example:
+
+            status_code, headers, stream, extensions = transport.handle_request(...)
+            try:
+                ...
+            finally:
+                stream.close()
+        """
+
+    def read(self) -> bytes:
+        """
+        Simple cases can use `.read()` as a convience method for consuming
+        the entire stream and then closing it.
+
+        Example:
+
+            status_code, headers, stream, extensions = transport.handle_request(...)
+            body = stream.read()
+        """
+        try:
+            return b"".join([part for part in self])
+        finally:
+            self.close()
+
+
+class AsyncByteStream:
+    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
+        raise NotImplementedError(
+            "The '__aiter__' method must be implemented."
+        )  # pragma: nocover
+        yield b""  # pragma: nocover
+
+    async def aclose(self) -> None:
+        pass
+
+    async def aread(self) -> bytes:
+        try:
+            return b"".join([part async for part in self])
+        finally:
+            await self.aclose()
+
+
 class BaseTransport:
     def __enter__(self: T) -> T:
         return self
@@ -22,10 +79,10 @@ class BaseTransport:
         method: bytes,
         url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
         headers: typing.List[typing.Tuple[bytes, bytes]],
-        stream: typing.Iterable[bytes],
+        stream: SyncByteStream,
         extensions: dict,
     ) -> typing.Tuple[
-        int, typing.List[typing.Tuple[bytes, bytes]], typing.Iterable[bytes], dict
+        int, typing.List[typing.Tuple[bytes, bytes]], SyncByteStream, dict
     ]:
         """
         Send a single HTTP request and return a response.
@@ -39,6 +96,11 @@ class BaseTransport:
         since the Client class provides all the higher level user-facing API
         niceties.
 
+        In order to properly release any network resources, the response stream
+        should *either* be consumed immediately, with a call to `stream.read()`,
+        or else the `handle_request` call should be followed with a try/finally
+        block to ensuring the stream is always closed.
+
         Example usage:
 
             with httpx.HTTPTransport() as transport:
@@ -49,11 +111,7 @@ class BaseTransport:
                     stream=[],
                     extensions={}
                 )
-                try:
-                    body = b''.join([part for part in stream])
-                finally:
-                    if 'close' in extensions:
-                        extensions['close']()
+                body = stream.read()
                 print(status_code, headers, body)
 
         Arguments:
@@ -86,10 +144,6 @@ class BaseTransport:
                     eg. the leading response bytes were b"HTTP/1.1 200 <CRLF>".
             http_version: The HTTP version, as bytes. Eg. b"HTTP/1.1".
                     When no http_version key is included, HTTP/1.1 may be assumed.
-            close:  A callback which should be invoked to release any network
-                    resources.
-            aclose: An async callback which should be invoked to release any
-                    network resources.
         """
         raise NotImplementedError(
             "The 'handle_request' method must be implemented."
@@ -116,10 +170,10 @@ class AsyncBaseTransport:
         method: bytes,
         url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
         headers: typing.List[typing.Tuple[bytes, bytes]],
-        stream: typing.AsyncIterable[bytes],
+        stream: AsyncByteStream,
         extensions: dict,
     ) -> typing.Tuple[
-        int, typing.List[typing.Tuple[bytes, bytes]], typing.AsyncIterable[bytes], dict
+        int, typing.List[typing.Tuple[bytes, bytes]], AsyncByteStream, dict
     ]:
         raise NotImplementedError(
             "The 'handle_async_request' method must be implemented."
index 5691538d86b32c3353e8bab0c6087ff0407970ca..29d5299a161bc40efd0703245f632dfd860a33ba 100644 (file)
@@ -49,7 +49,7 @@ from .._exceptions import (
     WriteTimeout,
 )
 from .._types import CertTypes, VerifyTypes
-from .base import AsyncBaseTransport, BaseTransport
+from .base import AsyncBaseTransport, AsyncByteStream, BaseTransport, SyncByteStream
 
 T = typing.TypeVar("T", bound="HTTPTransport")
 A = typing.TypeVar("A", bound="AsyncHTTPTransport")
@@ -110,6 +110,20 @@ HTTPCORE_EXC_MAP = {
 }
 
 
+class ResponseStream(SyncByteStream):
+    def __init__(self, httpcore_stream: httpcore.SyncByteStream):
+        self._httpcore_stream = httpcore_stream
+
+    def __iter__(self) -> typing.Iterator[bytes]:
+        with map_httpcore_exceptions():
+            for part in self._httpcore_stream:
+                yield part
+
+    def close(self) -> None:
+        with map_httpcore_exceptions():
+            self._httpcore_stream.close()
+
+
 class HTTPTransport(BaseTransport):
     def __init__(
         self,
@@ -168,10 +182,10 @@ class HTTPTransport(BaseTransport):
         method: bytes,
         url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
         headers: typing.List[typing.Tuple[bytes, bytes]],
-        stream: typing.Iterable[bytes],
+        stream: SyncByteStream,
         extensions: dict,
     ) -> typing.Tuple[
-        int, typing.List[typing.Tuple[bytes, bytes]], typing.Iterable[bytes], dict
+        int, typing.List[typing.Tuple[bytes, bytes]], SyncByteStream, dict
     ]:
         with map_httpcore_exceptions():
             status_code, headers, byte_stream, extensions = self._pool.request(
@@ -182,24 +196,29 @@ class HTTPTransport(BaseTransport):
                 ext=extensions,
             )
 
-        def response_stream() -> typing.Iterator[bytes]:
-            with map_httpcore_exceptions():
-                for part in byte_stream:
-                    yield part
-
-        def close() -> None:
-            with map_httpcore_exceptions():
-                byte_stream.close()
-
         ensure_http_version_reason_phrase_as_bytes(extensions)
-        extensions["close"] = close
+        stream = ResponseStream(byte_stream)
 
-        return status_code, headers, response_stream(), extensions
+        return status_code, headers, stream, extensions
 
     def close(self) -> None:
         self._pool.close()
 
 
+class AsyncResponseStream(AsyncByteStream):
+    def __init__(self, httpcore_stream: httpcore.AsyncByteStream):
+        self._httpcore_stream = httpcore_stream
+
+    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
+        with map_httpcore_exceptions():
+            async for part in self._httpcore_stream:
+                yield part
+
+    async def aclose(self) -> None:
+        with map_httpcore_exceptions():
+            await self._httpcore_stream.aclose()
+
+
 class AsyncHTTPTransport(AsyncBaseTransport):
     def __init__(
         self,
@@ -258,10 +277,10 @@ class AsyncHTTPTransport(AsyncBaseTransport):
         method: bytes,
         url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
         headers: typing.List[typing.Tuple[bytes, bytes]],
-        stream: typing.AsyncIterable[bytes],
+        stream: AsyncByteStream,
         extensions: dict,
     ) -> typing.Tuple[
-        int, typing.List[typing.Tuple[bytes, bytes]], typing.AsyncIterable[bytes], dict
+        int, typing.List[typing.Tuple[bytes, bytes]], AsyncByteStream, dict
     ]:
         with map_httpcore_exceptions():
             status_code, headers, byte_stream, extensions = await self._pool.arequest(
@@ -272,19 +291,10 @@ class AsyncHTTPTransport(AsyncBaseTransport):
                 ext=extensions,
             )
 
-        async def response_stream() -> typing.AsyncIterator[bytes]:
-            with map_httpcore_exceptions():
-                async for part in byte_stream:
-                    yield part
-
-        async def aclose() -> None:
-            with map_httpcore_exceptions():
-                await byte_stream.aclose()
-
         ensure_http_version_reason_phrase_as_bytes(extensions)
-        extensions["aclose"] = aclose
+        stream = AsyncResponseStream(byte_stream)
 
-        return status_code, headers, response_stream(), extensions
+        return status_code, headers, stream, extensions
 
     async def aclose(self) -> None:
         await self._pool.aclose()
index b6ca353a315800214b1a46b34edd1f4f9a0849a6..8d59b73820d87a85ee7ebfef2097e1e1e71a7cdc 100644 (file)
@@ -2,7 +2,7 @@ import asyncio
 import typing
 
 from .._models import Request
-from .base import AsyncBaseTransport, BaseTransport
+from .base import AsyncBaseTransport, AsyncByteStream, BaseTransport, SyncByteStream
 
 
 class MockTransport(AsyncBaseTransport, BaseTransport):
@@ -14,10 +14,10 @@ class MockTransport(AsyncBaseTransport, BaseTransport):
         method: bytes,
         url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
         headers: typing.List[typing.Tuple[bytes, bytes]],
-        stream: typing.Iterable[bytes],
+        stream: SyncByteStream,
         extensions: dict,
     ) -> typing.Tuple[
-        int, typing.List[typing.Tuple[bytes, bytes]], typing.Iterable[bytes], dict
+        int, typing.List[typing.Tuple[bytes, bytes]], SyncByteStream, dict
     ]:
         request = Request(
             method=method,
@@ -39,10 +39,10 @@ class MockTransport(AsyncBaseTransport, BaseTransport):
         method: bytes,
         url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
         headers: typing.List[typing.Tuple[bytes, bytes]],
-        stream: typing.AsyncIterable[bytes],
+        stream: AsyncByteStream,
         extensions: dict,
     ) -> typing.Tuple[
-        int, typing.List[typing.Tuple[bytes, bytes]], typing.AsyncIterable[bytes], dict
+        int, typing.List[typing.Tuple[bytes, bytes]], AsyncByteStream, dict
     ]:
         request = Request(
             method=method,
index 3b7651fba71a000f53c50984fd4b7ea1be1e6632..c8266c73925183d5249fb5c00a4d6e81dc144144 100644 (file)
@@ -3,7 +3,7 @@ import itertools
 import typing
 from urllib.parse import unquote
 
-from .base import BaseTransport
+from .base import BaseTransport, SyncByteStream
 
 
 def _skip_leading_empty_chunks(body: typing.Iterable) -> typing.Iterable:
@@ -14,6 +14,15 @@ def _skip_leading_empty_chunks(body: typing.Iterable) -> typing.Iterable:
     return []
 
 
+class WSGIByteStream(SyncByteStream):
+    def __init__(self, result: typing.Iterable[bytes]) -> None:
+        self._result = _skip_leading_empty_chunks(result)
+
+    def __iter__(self) -> typing.Iterator[bytes]:
+        for part in self._result:
+            yield part
+
+
 class WSGITransport(BaseTransport):
     """
     A custom transport that handles sending requests directly to an WSGI app.
@@ -64,10 +73,10 @@ class WSGITransport(BaseTransport):
         method: bytes,
         url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
         headers: typing.List[typing.Tuple[bytes, bytes]],
-        stream: typing.Iterable[bytes],
+        stream: SyncByteStream,
         extensions: dict,
     ) -> typing.Tuple[
-        int, typing.List[typing.Tuple[bytes, bytes]], typing.Iterable[bytes], dict
+        int, typing.List[typing.Tuple[bytes, bytes]], SyncByteStream, dict
     ]:
         wsgi_input = io.BytesIO(b"".join(stream))
 
@@ -111,9 +120,8 @@ class WSGITransport(BaseTransport):
             seen_exc_info = exc_info
 
         result = self.app(environ, start_response)
-        # This is needed because the status returned by start_response
-        # shouldn't be used until the first non-empty chunk has been served.
-        result = _skip_leading_empty_chunks(result)
+
+        stream = WSGIByteStream(result)
 
         assert seen_status is not None
         assert seen_response_headers is not None
@@ -127,4 +135,4 @@ class WSGITransport(BaseTransport):
         ]
         extensions = {}
 
-        return (status_code, headers, result, extensions)
+        return (status_code, headers, stream, extensions)
index 385f89ddb9024b71d5d847f58ed5fdeb966e8371..75bb9006c8705229b7d908a9881f6b4a36f187d1 100644 (file)
@@ -74,9 +74,8 @@ AuthTypes = Union[
     None,
 ]
 
-ByteStream = Union[Iterable[bytes], AsyncIterable[bytes]]
-RequestContent = Union[str, bytes, ByteStream]
-ResponseContent = Union[str, bytes, ByteStream]
+RequestContent = Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]]
+ResponseContent = Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]]
 
 RequestData = dict
 
index 1dda02863231d959e010e44250a7b65c05779e4f..d692a3036161163b502396955d97d0f3902dff2c 100644 (file)
@@ -3,18 +3,18 @@ import typing
 
 import pytest
 
-from httpx import StreamConsumed
+import httpx
 from httpx._content import encode_request, encode_response
 
 
 @pytest.mark.asyncio
 async def test_empty_content():
     headers, stream = encode_request()
-    assert isinstance(stream, typing.Iterable)
-    assert isinstance(stream, typing.AsyncIterable)
+    assert isinstance(stream, httpx.SyncByteStream)
+    assert isinstance(stream, httpx.AsyncByteStream)
 
-    sync_content = b"".join([part for part in stream])
-    async_content = b"".join([part async for part in stream])
+    sync_content = stream.read()
+    async_content = await stream.aread()
 
     assert headers == {}
     assert sync_content == b""
@@ -62,7 +62,7 @@ async def test_iterator_content():
     assert headers == {"Transfer-Encoding": "chunked"}
     assert content == b"Hello, world!"
 
-    with pytest.raises(StreamConsumed):
+    with pytest.raises(httpx.StreamConsumed):
         [part for part in stream]
 
     # Support 'data' for compat with requests.
@@ -91,7 +91,7 @@ async def test_aiterator_content():
     assert headers == {"Transfer-Encoding": "chunked"}
     assert content == b"Hello, world!"
 
-    with pytest.raises(StreamConsumed):
+    with pytest.raises(httpx.StreamConsumed):
         [part async for part in stream]
 
     # Support 'data' for compat with requests.
@@ -382,7 +382,7 @@ async def test_response_iterator_content():
     assert headers == {"Transfer-Encoding": "chunked"}
     assert content == b"Hello, world!"
 
-    with pytest.raises(StreamConsumed):
+    with pytest.raises(httpx.StreamConsumed):
         [part for part in stream]
 
 
@@ -401,7 +401,7 @@ async def test_response_aiterator_content():
     assert headers == {"Transfer-Encoding": "chunked"}
     assert content == b"Hello, world!"
 
-    with pytest.raises(StreamConsumed):
+    with pytest.raises(httpx.StreamConsumed):
         [part async for part in stream]