]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Drop ContentStream (#1295)
authorTom Christie <tom@tomchristie.com>
Fri, 18 Sep 2020 07:41:09 +0000 (08:41 +0100)
committerGitHub <noreply@github.com>
Fri, 18 Sep 2020 07:41:09 +0000 (08:41 +0100)
* Drop ContentStream

15 files changed:
docs/exceptions.md
httpx/__init__.py
httpx/_auth.py
httpx/_client.py
httpx/_content_streams.py
httpx/_exceptions.py
httpx/_models.py
httpx/_types.py
tests/client/test_auth.py
tests/client/test_redirects.py
tests/models/test_requests.py
tests/models/test_responses.py
tests/test_content_streams.py
tests/test_multipart.py
tests/utils.py

index 7cbddf5dbaf26e3adcd93b697005aab13811e0d3..db0d36c444bab1b5bb7fac346a510ebd8642d5d4 100644 (file)
@@ -74,7 +74,6 @@ except httpx.HTTPStatusError as exc:
             * UnsupportedProtocol
         * DecodingError
         * TooManyRedirects
-        * RequestBodyUnavailable
     * HTTPStatusError
 * InvalidURL
 * NotRedirectResponse
@@ -149,9 +148,6 @@ except httpx.HTTPStatusError as exc:
 ::: httpx.TooManyRedirects
     :docstring:
 
-::: httpx.RequestBodyUnavailable
-    :docstring:
-
 ::: httpx.HTTPStatusError
     :docstring:
 
index bfd52806efdf0269ef40b6e94fd95555199a2aa5..842532142cd2a9f3ae5f902882aa08f855424899 100644 (file)
@@ -21,7 +21,6 @@ from ._exceptions import (
     ReadError,
     ReadTimeout,
     RemoteProtocolError,
-    RequestBodyUnavailable,
     RequestError,
     RequestNotRead,
     ResponseClosed,
@@ -84,7 +83,6 @@ __all__ = [
     "RemoteProtocolError",
     "request",
     "Request",
-    "RequestBodyUnavailable",
     "RequestError",
     "RequestNotRead",
     "Response",
index 439f337fbfe5eb5141e2d59c3a91a34e58110a5a..fdbda9fa97457a39f5cfaa610822c92915791d01 100644 (file)
@@ -6,7 +6,7 @@ import typing
 from base64 import b64encode
 from urllib.request import parse_http_list
 
-from ._exceptions import ProtocolError, RequestBodyUnavailable
+from ._exceptions import ProtocolError
 from ._models import Request, Response
 from ._utils import to_bytes, to_str, unquote
 
@@ -157,13 +157,6 @@ class DigestAuth(Auth):
         self._password = to_bytes(password)
 
     def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
-        if not request.stream.can_replay():
-            raise RequestBodyUnavailable(
-                "Cannot use digest auth with streaming requests that are unable "
-                "to replay the request body if a second request is required.",
-                request=request,
-            )
-
         response = yield request
 
         if response.status_code != 401 or "www-authenticate" not in response.headers:
index 197ee498753c89d645c0faad216efb0566921f16..e5b53d4f5fdd537646829ae464b5fbfd38fa9bab 100644 (file)
@@ -1,3 +1,4 @@
+import datetime
 import functools
 import typing
 import warnings
@@ -18,13 +19,11 @@ from ._config import (
     UnsetType,
     create_ssl_context,
 )
-from ._content_streams import ContentStream
 from ._decoders import SUPPORTED_DECODERS
 from ._exceptions import (
     HTTPCORE_EXC_MAP,
     InvalidURL,
     RemoteProtocolError,
-    RequestBodyUnavailable,
     TooManyRedirects,
     map_exceptions,
 )
@@ -34,6 +33,7 @@ from ._transports.asgi import ASGITransport
 from ._transports.wsgi import WSGITransport
 from ._types import (
     AuthTypes,
+    ByteStream,
     CertTypes,
     CookieTypes,
     HeaderTypes,
@@ -480,20 +480,13 @@ class BaseClient:
 
     def _redirect_stream(
         self, request: Request, method: str
-    ) -> typing.Optional[ContentStream]:
+    ) -> typing.Optional[ByteStream]:
         """
         Return the body that should be used for the redirect request.
         """
         if method != request.method and method == "GET":
             return None
 
-        if not request.stream.can_replay():
-            raise RequestBodyUnavailable(
-                "Got a redirect response, but the request body was streaming "
-                "and is no longer available.",
-                request=request,
-            )
-
         return request.stream
 
 
@@ -864,16 +857,22 @@ class Client(BaseClient):
                 request.method.encode(),
                 request.url.raw,
                 headers=request.headers.raw,
-                stream=request.stream,
+                stream=request.stream,  # type: ignore
                 timeout=timeout.as_dict(),
             )
+
+        def on_close(response: Response) -> None:
+            response.elapsed = datetime.timedelta(timer.sync_elapsed())
+            if hasattr(stream, "close"):
+                stream.close()
+
         response = Response(
             status_code,
             http_version=http_version.decode("ascii"),
             headers=headers,
             stream=stream,  # type: ignore
             request=request,
-            elapsed_func=timer.sync_elapsed,
+            on_close=on_close,
         )
 
         self.cookies.extract_cookies(response)
@@ -1509,16 +1508,22 @@ class AsyncClient(BaseClient):
                 request.method.encode(),
                 request.url.raw,
                 headers=request.headers.raw,
-                stream=request.stream,
+                stream=request.stream,  # type: ignore
                 timeout=timeout.as_dict(),
             )
+
+        async def on_close(response: Response) -> None:
+            response.elapsed = datetime.timedelta(await timer.async_elapsed())
+            if hasattr(stream, "close"):
+                await stream.aclose()
+
         response = Response(
             status_code,
             http_version=http_version.decode("ascii"),
             headers=headers,
             stream=stream,  # type: ignore
             request=request,
-            elapsed_func=timer.async_elapsed,
+            on_close=on_close,
         )
 
         self.cookies.extract_cookies(response)
index cb08a3f598528bc44612fd3e014deab81d3d2b0f..eacb077b8de5191d893852e5ec8d54a0eb306376 100644 (file)
@@ -1,14 +1,14 @@
 import binascii
+import inspect
 import os
 import typing
 from json import dumps as json_dumps
 from pathlib import Path
 from urllib.parse import urlencode
 
-import httpcore
-
 from ._exceptions import StreamConsumed
 from ._types import (
+    ByteStream,
     FileContent,
     FileTypes,
     RequestContent,
@@ -24,36 +24,7 @@ from ._utils import (
 )
 
 
-class ContentStream(httpcore.AsyncByteStream, httpcore.SyncByteStream):
-    def get_headers(self) -> typing.Dict[str, str]:
-        """
-        Return a dictionary of headers that are implied by the encoding.
-        """
-        return {}
-
-    def can_replay(self) -> bool:
-        """
-        Return `True` if `__aiter__` can be called multiple times.
-
-        We need this in cases such determining if we can re-issue a request
-        body when we receive a redirect response.
-        """
-        return True
-
-    def __iter__(self) -> typing.Iterator[bytes]:
-        yield b""
-
-    def close(self) -> None:
-        pass
-
-    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
-        yield b""
-
-    async def aclose(self) -> None:
-        pass
-
-
-class ByteStream(ContentStream):
+class PlainByteStream:
     """
     Request content encoded as plain bytes.
     """
@@ -74,59 +45,41 @@ class ByteStream(ContentStream):
         yield self.body
 
 
-class IteratorStream(ContentStream):
+class GeneratorStream:
     """
-    Request content encoded as plain bytes, using an byte iterator.
+    Request content encoded as plain bytes, using an byte generator.
     """
 
-    def __init__(self, iterator: typing.Iterator[bytes]) -> None:
-        self.iterator = iterator
-        self.is_stream_consumed = False
-
-    def can_replay(self) -> bool:
-        return False
-
-    def get_headers(self) -> typing.Dict[str, str]:
-        return {"Transfer-Encoding": "chunked"}
+    def __init__(self, generator: typing.Iterable[bytes]) -> None:
+        self._generator = generator
+        self._is_stream_consumed = False
 
     def __iter__(self) -> typing.Iterator[bytes]:
-        if self.is_stream_consumed:
+        if self._is_stream_consumed:
             raise StreamConsumed()
-        self.is_stream_consumed = True
-        for part in self.iterator:
+        self._is_stream_consumed = True
+        for part in self._generator:
             yield part
 
-    def __aiter__(self) -> typing.AsyncIterator[bytes]:
-        raise RuntimeError("Attempted to call a async iterator on an sync stream.")
-
 
-class AsyncIteratorStream(ContentStream):
+class AsyncGeneratorStream:
     """
     Request content encoded as plain bytes, using an async byte iterator.
     """
 
-    def __init__(self, aiterator: typing.AsyncIterator[bytes]) -> None:
-        self.aiterator = aiterator
-        self.is_stream_consumed = False
-
-    def can_replay(self) -> bool:
-        return False
-
-    def get_headers(self) -> typing.Dict[str, str]:
-        return {"Transfer-Encoding": "chunked"}
-
-    def __iter__(self) -> typing.Iterator[bytes]:
-        raise RuntimeError("Attempted to call a sync iterator on an async stream.")
+    def __init__(self, agenerator: typing.AsyncIterable[bytes]) -> None:
+        self._agenerator = agenerator
+        self._is_stream_consumed = False
 
     async def __aiter__(self) -> typing.AsyncIterator[bytes]:
-        if self.is_stream_consumed:
+        if self._is_stream_consumed:
             raise StreamConsumed()
-        self.is_stream_consumed = True
-        async for part in self.aiterator:
+        self._is_stream_consumed = True
+        async for part in self._agenerator:
             yield part
 
 
-class JSONStream(ContentStream):
+class JSONStream:
     """
     Request content encoded as JSON.
     """
@@ -146,7 +99,7 @@ class JSONStream(ContentStream):
         yield self.body
 
 
-class URLEncodedStream(ContentStream):
+class URLEncodedStream:
     """
     Request content as URL encoded form data.
     """
@@ -166,7 +119,7 @@ class URLEncodedStream(ContentStream):
         yield self.body
 
 
-class MultipartStream(ContentStream):
+class MultipartStream:
     """
     Request content as streaming multipart encoded form data.
     """
@@ -208,9 +161,6 @@ class MultipartStream(ContentStream):
             data = self.render_data()
             return len(headers) + len(data)
 
-        def can_replay(self) -> bool:
-            return True
-
         def render(self) -> typing.Iterator[bytes]:
             yield self.render_headers()
             yield self.render_data()
@@ -239,6 +189,7 @@ class MultipartStream(ContentStream):
             self.filename = filename
             self.file = fileobj
             self.content_type = content_type
+            self._consumed = False
 
         def get_length(self) -> int:
             headers = self.render_headers()
@@ -284,16 +235,12 @@ class MultipartStream(ContentStream):
                 yield self._data
                 return
 
-            for chunk in self.file:
-                yield to_bytes(chunk)
-
-            # Get ready for the next replay, if possible.
-            if self.can_replay():
-                assert self.file.seekable()
+            if self._consumed:  # pragma: nocover
                 self.file.seek(0)
+            self._consumed = True
 
-        def can_replay(self) -> bool:
-            return True if isinstance(self.file, (str, bytes)) else self.file.seekable()
+            for chunk in self.file:
+                yield to_bytes(chunk)
 
         def render(self) -> typing.Iterator[bytes]:
             yield self.render_headers()
@@ -346,9 +293,6 @@ class MultipartStream(ContentStream):
 
     # Content stream interface.
 
-    def can_replay(self) -> bool:
-        return all(field.can_replay() for field in self.fields)
-
     def get_headers(self) -> typing.Dict[str, str]:
         content_length = str(self.get_content_length())
         content_type = self.content_type
@@ -369,10 +313,10 @@ def encode_request(
     files: RequestFiles = None,
     json: typing.Any = None,
     boundary: bytes = None,
-) -> ContentStream:
+) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
     """
     Handles encoding the given `content`, `data`, `files`, and `json`,
-    returning a `ContentStream` implementation.
+    returning a two-tuple of (<headers>, <stream>).
     """
     if data is not None and not isinstance(data, dict):
         # We prefer to seperate `content=<bytes|byte iterator|bytes aiterator>`
@@ -387,39 +331,65 @@ def encode_request(
 
     if content is not None:
         if isinstance(content, (str, bytes)):
-            return ByteStream(body=content)
-        elif hasattr(content, "__aiter__"):
-            content = typing.cast(typing.AsyncIterator[bytes], content)
-            return AsyncIteratorStream(aiterator=content)
-        elif hasattr(content, "__iter__"):
-            content = typing.cast(typing.Iterator[bytes], content)
-            return IteratorStream(iterator=content)
+            byte_stream = PlainByteStream(body=content)
+            headers = byte_stream.get_headers()
+            return headers, byte_stream
+        elif isinstance(content, (typing.Iterable, typing.AsyncIterable)):
+            if inspect.isgenerator(content):
+                generator_stream = GeneratorStream(content)  # type: ignore
+                return {"Transfer-Encoding": "chunked"}, generator_stream
+            if inspect.isasyncgen(content):
+                agenerator_stream = AsyncGeneratorStream(content)  # type: ignore
+                return {"Transfer-Encoding": "chunked"}, agenerator_stream
+            return {"Transfer-Encoding": "chunked"}, content  # type: ignore
         else:
             raise TypeError(f"Unexpected type for 'content', {type(content)!r}")
 
     elif data:
         if files:
-            return MultipartStream(data=data, files=files, boundary=boundary)
+            multipart_stream = MultipartStream(
+                data=data, files=files, boundary=boundary
+            )
+            headers = multipart_stream.get_headers()
+            return headers, multipart_stream
         else:
-            return URLEncodedStream(data=data)
+            urlencoded_stream = URLEncodedStream(data=data)
+            headers = urlencoded_stream.get_headers()
+            return headers, urlencoded_stream
 
     elif files:
-        return MultipartStream(data={}, files=files, boundary=boundary)
+        multipart_stream = MultipartStream(data={}, files=files, boundary=boundary)
+        headers = multipart_stream.get_headers()
+        return headers, multipart_stream
 
     elif json is not None:
-        return JSONStream(json=json)
+        json_stream = JSONStream(json=json)
+        headers = json_stream.get_headers()
+        return headers, json_stream
 
-    return ByteStream(body=b"")
+    byte_stream = PlainByteStream(body=b"")
+    headers = byte_stream.get_headers()
+    return headers, byte_stream
 
 
-def encode_response(content: ResponseContent = None) -> ContentStream:
+def encode_response(
+    content: ResponseContent = None,
+) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
     if content is None:
-        return ByteStream(b"")
+        byte_stream = PlainByteStream(b"")
+        headers = byte_stream.get_headers()
+        return headers, byte_stream
     elif isinstance(content, bytes):
-        return ByteStream(body=content)
-    elif isinstance(content, typing.AsyncIterator):
-        return AsyncIteratorStream(aiterator=content)
-    elif isinstance(content, typing.Iterator):
-        return IteratorStream(iterator=content)
+        byte_stream = PlainByteStream(body=content)
+        headers = byte_stream.get_headers()
+        return headers, byte_stream
+    elif isinstance(content, (typing.Iterable, typing.AsyncIterable)):
+        if inspect.isgenerator(content):
+            generator_stream = GeneratorStream(content)  # type: ignore
+            return {"Transfer-Encoding": "chunked"}, generator_stream
+        elif inspect.isasyncgen(content):
+            agenerator_stream = AsyncGeneratorStream(content)  # type: ignore
+            return {"Transfer-Encoding": "chunked"}, agenerator_stream
+        return {"Transfer-Encoding": "chunked"}, content  # type: ignore
 
     raise TypeError(f"Unexpected type for 'content', {type(content)!r}")
index 260d14ee5faba903b7353e826999f5028d5d87b8..3aabb00c9b6367cc3d4b97ac7261ed109c7efa72 100644 (file)
@@ -207,13 +207,6 @@ class TooManyRedirects(RequestError):
     """
 
 
-class RequestBodyUnavailable(RequestError):
-    """
-    Had to send the request again, but the request body was streaming, and is
-    no longer available.
-    """
-
-
 # Client errors
 
 
@@ -283,14 +276,18 @@ class StreamError(Exception):
 
 class StreamConsumed(StreamError):
     """
-    Attempted to read or stream response content, but the content has already
+    Attempted to read or stream content, but the content has already
     been streamed.
     """
 
     def __init__(self) -> None:
         message = (
-            "Attempted to read or stream response content, but the content has "
-            "already been streamed."
+            "Attempted to read or stream some content, but the content has "
+            "already been streamed. For requests, this could be due to passing "
+            "a generator as request content, and then receiving a redirect "
+            "response or a secondary request as part of an authentication flow."
+            "For responses, this could be due to attempting to stream the response "
+            "content more than once."
         )
         super().__init__(message)
 
index 7141e0a87d5aba7df9122cb393b9afde1ee1818b..e750c10f0fd7fb7974a89cb50f73021daeeb7d40 100644 (file)
@@ -13,7 +13,7 @@ from urllib.parse import parse_qsl, quote, unquote, urlencode
 import rfc3986
 import rfc3986.exceptions
 
-from ._content_streams import ByteStream, ContentStream, encode_request, encode_response
+from ._content_streams import PlainByteStream, encode_request, encode_response
 from ._decoders import (
     SUPPORTED_DECODERS,
     ContentDecoder,
@@ -37,6 +37,7 @@ from ._exceptions import (
 )
 from ._status_codes import codes
 from ._types import (
+    ByteStream,
     CookieTypes,
     HeaderTypes,
     PrimitiveData,
@@ -606,7 +607,7 @@ class Request:
         data: RequestData = None,
         files: RequestFiles = None,
         json: typing.Any = None,
-        stream: ContentStream = None,
+        stream: ByteStream = None,
     ):
         if isinstance(method, bytes):
             self.method = method.decode("ascii").upper()
@@ -618,14 +619,28 @@ class Request:
             Cookies(cookies).set_cookie_header(self)
 
         if stream is not None:
+            # There's an important distinction between `Request(content=...)`,
+            # and `Request(stream=...)`.
+            #
+            # Using `content=...` implies automatically populated content headers,
+            # of either `Content-Length: ...` or `Transfer-Encoding: chunked`.
+            #
+            # Using `stream=...` will not automatically include any content headers.
+            #
+            # As an end-user you don't really need `stream=...`. It's only
+            # useful when:
+            #
+            # * Preserving the request stream when copying requests, eg for redirects.
+            # * Creating request instances on the *server-side* of the transport API.
             self.stream = stream
+            self._prepare({})
         else:
-            self.stream = encode_request(content, data, files, json)
-
-        self._prepare()
+            headers, stream = encode_request(content, data, files, json)
+            self._prepare(headers)
+            self.stream = stream
 
-    def _prepare(self) -> None:
-        for key, value in self.stream.get_headers().items():
+    def _prepare(self, default_headers: typing.Dict[str, str]) -> None:
+        for key, value in default_headers.items():
             # Ignore Transfer-Encoding if the Content-Length has been set explicitly.
             if key.lower() == "transfer-encoding" and "content-length" in self.headers:
                 continue
@@ -657,11 +672,12 @@ class Request:
         Read and return the request content.
         """
         if not hasattr(self, "_content"):
+            assert isinstance(self.stream, typing.Iterable)
             self._content = b"".join(self.stream)
             # If a streaming request has been read entirely into memory, then
             # we can replace the stream with a raw bytes implementation,
             # to ensure that any non-replayable streams can still be used.
-            self.stream = ByteStream(self._content)
+            self.stream = PlainByteStream(self._content)
         return self._content
 
     async def aread(self) -> bytes:
@@ -669,11 +685,12 @@ class Request:
         Read and return the request content.
         """
         if not hasattr(self, "_content"):
+            assert isinstance(self.stream, typing.AsyncIterable)
             self._content = b"".join([part async for part in self.stream])
             # If a streaming request has been read entirely into memory, then
             # we can replace the stream with a raw bytes implementation,
             # to ensure that any non-replayable streams can still be used.
-            self.stream = ByteStream(self._content)
+            self.stream = PlainByteStream(self._content)
         return self._content
 
     def __repr__(self) -> str:
@@ -690,10 +707,10 @@ class Response:
         request: Request = None,
         http_version: str = None,
         headers: HeaderTypes = None,
-        stream: ContentStream = None,
         content: ResponseContent = None,
+        stream: ByteStream = None,
         history: typing.List["Response"] = None,
-        elapsed_func: typing.Callable = None,
+        on_close: typing.Callable = None,
     ):
         self.status_code = status_code
         self.http_version = http_version
@@ -704,20 +721,41 @@ class Response:
         self.call_next: typing.Optional[typing.Callable] = None
 
         self.history = [] if history is None else list(history)
-        self._elapsed_func = elapsed_func
+        self._on_close = on_close
 
         self.is_closed = False
         self.is_stream_consumed = False
+
         if stream is not None:
-            self._raw_stream = stream
+            # There's an important distinction between `Response(content=...)`,
+            # and `Response(stream=...)`.
+            #
+            # Using `content=...` implies automatically populated content headers,
+            # of either `Content-Length: ...` or `Transfer-Encoding: chunked`.
+            #
+            # Using `stream=...` will not automatically include any content headers.
+            #
+            # As an end-user you don't really need `stream=...`. It's only
+            # useful when creating response instances having received a stream
+            # from the transport API.
+            self.stream = stream
         else:
-            self._raw_stream = encode_response(content)
+            headers, stream = encode_response(content)
+            self._prepare(headers)
+            self.stream = stream
             if content is None or isinstance(content, bytes):
                 # Load the response body, except for streaming content.
                 self.read()
 
         self._num_bytes_downloaded = 0
 
+    def _prepare(self, default_headers: typing.Dict[str, str]) -> None:
+        for key, value in default_headers.items():
+            # Ignore Transfer-Encoding if the Content-Length has been set explicitly.
+            if key.lower() == "transfer-encoding" and "content-length" in self.headers:
+                continue
+            self.headers.setdefault(key, value)
+
     @property
     def elapsed(self) -> datetime.timedelta:
         """
@@ -729,7 +767,11 @@ class Response:
                 "'.elapsed' may only be accessed after the response "
                 "has been read or closed."
             )
-        return datetime.timedelta(seconds=self._elapsed)
+        return self._elapsed
+
+    @elapsed.setter
+    def elapsed(self, elapsed: datetime.timedelta) -> None:
+        self._elapsed = elapsed
 
     @property
     def request(self) -> Request:
@@ -963,11 +1005,13 @@ class Response:
             raise StreamConsumed()
         if self.is_closed:
             raise ResponseClosed()
+        if not isinstance(self.stream, typing.Iterable):
+            raise RuntimeError("Attempted to call a sync iterator on an async stream.")
 
         self.is_stream_consumed = True
         self._num_bytes_downloaded = 0
         with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
-            for part in self._raw_stream:
+            for part in self.stream:
                 self._num_bytes_downloaded += len(part)
                 yield part
         self.close()
@@ -992,9 +1036,8 @@ class Response:
         """
         if not self.is_closed:
             self.is_closed = True
-            if self._elapsed_func is not None:
-                self._elapsed = self._elapsed_func()
-            self._raw_stream.close()
+            if self._on_close is not None:
+                self._on_close(self)
 
     async def aread(self) -> bytes:
         """
@@ -1047,11 +1090,13 @@ class Response:
             raise StreamConsumed()
         if self.is_closed:
             raise ResponseClosed()
+        if not isinstance(self.stream, typing.AsyncIterable):
+            raise RuntimeError("Attempted to call a async iterator on a sync stream.")
 
         self.is_stream_consumed = True
         self._num_bytes_downloaded = 0
         with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
-            async for part in self._raw_stream:
+            async for part in self.stream:
                 self._num_bytes_downloaded += len(part)
                 yield part
         await self.aclose()
@@ -1075,9 +1120,8 @@ class Response:
         """
         if not self.is_closed:
             self.is_closed = True
-            if self._elapsed_func is not None:
-                self._elapsed = await self._elapsed_func()
-            await self._raw_stream.aclose()
+            if self._on_close is not None:
+                await self._on_close(self)
 
 
 class Cookies(MutableMapping):
index bc334ec3bcaed50a4423affb1f9e34abb07a3dba..1b1d3b7817a991994ca5b6654878b7d7e9f9318a 100644 (file)
@@ -7,10 +7,10 @@ from http.cookiejar import CookieJar
 from typing import (
     IO,
     TYPE_CHECKING,
-    AsyncIterator,
+    AsyncIterable,
     Callable,
     Dict,
-    Iterator,
+    Iterable,
     List,
     Mapping,
     Optional,
@@ -66,8 +66,9 @@ AuthTypes = Union[
     None,
 ]
 
-RequestContent = Union[str, bytes, Iterator[bytes], AsyncIterator[bytes]]
-ResponseContent = Union[bytes, Iterator[bytes], AsyncIterator[bytes]]
+RequestContent = Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]]
+ResponseContent = Union[bytes, Iterable[bytes], AsyncIterable[bytes]]
+ByteStream = Union[Iterable[bytes], AsyncIterable[bytes]]
 
 RequestData = dict
 
index cc6fd69c000ab5deabcd4928cc05e0033f8084be..e71fe906b0b15a6ddba0867c0b385d757f925257 100644 (file)
@@ -13,16 +13,7 @@ import typing
 import pytest
 
 import httpx
-from httpx import (
-    URL,
-    Auth,
-    BasicAuth,
-    DigestAuth,
-    ProtocolError,
-    Request,
-    RequestBodyUnavailable,
-    Response,
-)
+from httpx import URL, Auth, BasicAuth, DigestAuth, ProtocolError, Request, Response
 from tests.utils import AsyncMockTransport, MockTransport
 
 from ..common import FIXTURES_DIR
@@ -617,13 +608,13 @@ def test_sync_auth_history() -> None:
 async def test_digest_auth_unavailable_streaming_body():
     url = "https://example.org/"
     auth = DigestAuth(username="tomchristie", password="password123")
-    app = App()
+    app = DigestApp()
 
     async def streaming_body():
         yield b"Example request body"  # pragma: nocover
 
     async with httpx.AsyncClient(transport=AsyncMockTransport(app)) as client:
-        with pytest.raises(RequestBodyUnavailable):
+        with pytest.raises(httpx.StreamConsumed):
             await client.post(url, data=streaming_body(), auth=auth)
 
 
index 898fefd7e9368e4815086adc4533b1d032a5b4e7..0d51717a0507a82a2f7929666a933ddd98131051 100644 (file)
@@ -334,7 +334,7 @@ def test_cannot_redirect_streaming_body():
     def streaming_body():
         yield b"Example request body"  # pragma: nocover
 
-    with pytest.raises(httpx.RequestBodyUnavailable):
+    with pytest.raises(httpx.StreamConsumed):
         client.post(url, content=streaming_body())
 
 
index f265a4788e800dda6705337b133f0763503b86f3..66ba887626c9000aad4f3ce103d96c5c0e8eba05 100644 (file)
@@ -1,3 +1,5 @@
+import typing
+
 import pytest
 
 import httpx
@@ -18,6 +20,40 @@ def test_content_length_header():
     assert request.headers["Content-Length"] == "8"
 
 
+def test_iterable_content():
+    class Content:
+        def __iter__(self):
+            yield b"test 123"  # pragma: nocover
+
+    request = httpx.Request("POST", "http://example.org", content=Content())
+    assert request.headers == httpx.Headers(
+        {"Host": "example.org", "Transfer-Encoding": "chunked"}
+    )
+
+
+def test_generator_with_transfer_encoding_header():
+    def content():
+        yield b"test 123"  # pragma: nocover
+
+    request = httpx.Request("POST", "http://example.org", content=content())
+    assert request.headers == httpx.Headers(
+        {"Host": "example.org", "Transfer-Encoding": "chunked"}
+    )
+
+
+def test_generator_with_content_length_header():
+    def content():
+        yield b"test 123"  # pragma: nocover
+
+    headers = {"Content-Length": "8"}
+    request = httpx.Request(
+        "POST", "http://example.org", content=content(), headers=headers
+    )
+    assert request.headers == httpx.Headers(
+        {"Host": "example.org", "Content-Length": "8"}
+    )
+
+
 def test_url_encoded_data():
     request = httpx.Request("POST", "http://example.org", data={"test": "123"})
     request.read()
@@ -51,6 +87,8 @@ def test_read_and_stream_data():
     # Needed for cases such as authentication classes that read the request body.
     request = httpx.Request("POST", "http://example.org", json={"test": 123})
     request.read()
+    assert request.stream is not None
+    assert isinstance(request.stream, typing.Iterable)
     content = b"".join([part for part in request.stream])
     assert content == request.content
 
@@ -61,6 +99,8 @@ async def test_aread_and_stream_data():
     # Needed for cases such as authentication classes that read the request body.
     request = httpx.Request("POST", "http://example.org", json={"test": 123})
     await request.aread()
+    assert request.stream is not None
+    assert isinstance(request.stream, typing.AsyncIterable)
     content = b"".join([part async for part in request.stream])
     assert content == request.content
 
@@ -68,7 +108,7 @@ async def test_aread_and_stream_data():
 @pytest.mark.asyncio
 async def test_cannot_access_content_without_read():
     # Ensure a request may still be streamed if it has been read.
-    # Â Needed for cases such as authentication classes that read the request body.
+    # Needed for cases such as authentication classes that read the request body.
     request = httpx.Request("POST", "http://example.org", json={"test": 123})
     with pytest.raises(httpx.RequestNotRead):
         request.content
index 1e033deba1758bcfe8575bdfde0b71b0dcd27a41..e1568b324dd8f74105962bb3b8dd6ab0034441ad 100644 (file)
@@ -7,6 +7,12 @@ import pytest
 import httpx
 
 
+class StreamingBody:
+    def __iter__(self):
+        yield b"Hello, "
+        yield b"world!"
+
+
 def streaming_body():
     yield b"Hello, "
     yield b"world!"
@@ -230,6 +236,21 @@ def test_read():
     assert response.is_closed
 
 
+def test_empty_read():
+    response = httpx.Response(200)
+
+    assert response.status_code == 200
+    assert response.text == ""
+    assert response.encoding is None
+    assert response.is_closed
+
+    content = response.read()
+
+    assert content == b""
+    assert response.content == b""
+    assert response.is_closed
+
+
 @pytest.mark.asyncio
 async def test_aread():
     response = httpx.Response(
@@ -249,6 +270,22 @@ async def test_aread():
     assert response.is_closed
 
 
+@pytest.mark.asyncio
+async def test_empty_aread():
+    response = httpx.Response(200)
+
+    assert response.status_code == 200
+    assert response.text == ""
+    assert response.encoding is None
+    assert response.is_closed
+
+    content = await response.aread()
+
+    assert content == b""
+    assert response.content == b""
+    assert response.is_closed
+
+
 def test_iter_raw():
     response = httpx.Response(
         200,
@@ -261,6 +298,28 @@ def test_iter_raw():
     assert raw == b"Hello, world!"
 
 
+def test_iter_raw_on_iterable():
+    response = httpx.Response(
+        200,
+        content=StreamingBody(),
+    )
+
+    raw = b""
+    for part in response.iter_raw():
+        raw += part
+    assert raw == b"Hello, world!"
+
+
+def test_iter_raw_on_async():
+    response = httpx.Response(
+        200,
+        content=async_streaming_body(),
+    )
+
+    with pytest.raises(RuntimeError):
+        [part for part in response.iter_raw()]
+
+
 def test_iter_raw_increments_updates_counter():
     response = httpx.Response(200, content=streaming_body())
 
@@ -280,6 +339,17 @@ async def test_aiter_raw():
     assert raw == b"Hello, world!"
 
 
+@pytest.mark.asyncio
+async def test_aiter_raw_on_sync():
+    response = httpx.Response(
+        200,
+        content=streaming_body(),
+    )
+
+    with pytest.raises(RuntimeError):
+        [part async for part in response.aiter_raw()]
+
+
 @pytest.mark.asyncio
 async def test_aiter_raw_increments_updates_counter():
     response = httpx.Response(200, content=async_streaming_body())
@@ -610,3 +680,20 @@ def test_cannot_access_unset_request():
 
     with pytest.raises(RuntimeError):
         response.request
+
+
+def test_generator_with_transfer_encoding_header():
+    def content():
+        yield b"test 123"  # pragma: nocover
+
+    response = httpx.Response(200, content=content())
+    assert response.headers == httpx.Headers({"Transfer-Encoding": "chunked"})
+
+
+def test_generator_with_content_length_header():
+    def content():
+        yield b"test 123"  # pragma: nocover
+
+    headers = {"Content-Length": "8"}
+    response = httpx.Response(200, content=content(), headers=headers)
+    assert response.headers == httpx.Headers({"Content-Length": "8"})
index 6801c714f430083769862fd25c40fae3e640eaff..6bc40f405cb481a2933ac142185dba283825663e 100644 (file)
@@ -1,53 +1,48 @@
 import io
+import typing
 
 import pytest
 
 from httpx import StreamConsumed
-from httpx._content_streams import ContentStream, encode_request, encode_response
-
-
-@pytest.mark.asyncio
-async def test_base_content():
-    stream = ContentStream()
-    sync_content = b"".join([part for part in stream])
-    async_content = b"".join([part async for part in stream])
-
-    assert stream.can_replay()
-    assert stream.get_headers() == {}
-    assert sync_content == b""
-    assert async_content == b""
+from httpx._content_streams import encode_request, encode_response
 
 
 @pytest.mark.asyncio
 async def test_empty_content():
-    stream = encode_request()
+    headers, stream = encode_request()
+    assert isinstance(stream, typing.Iterable)
+    assert isinstance(stream, typing.AsyncIterable)
+
     sync_content = b"".join([part for part in stream])
     async_content = b"".join([part async for part in stream])
 
-    assert stream.can_replay()
-    assert stream.get_headers() == {}
+    assert headers == {}
     assert sync_content == b""
     assert async_content == b""
 
 
 @pytest.mark.asyncio
 async def test_bytes_content():
-    stream = encode_request(content=b"Hello, world!")
+    headers, stream = encode_request(content=b"Hello, world!")
+    assert isinstance(stream, typing.Iterable)
+    assert isinstance(stream, typing.AsyncIterable)
+
     sync_content = b"".join([part for part in stream])
     async_content = b"".join([part async for part in stream])
 
-    assert stream.can_replay()
-    assert stream.get_headers() == {"Content-Length": "13"}
+    assert headers == {"Content-Length": "13"}
     assert sync_content == b"Hello, world!"
     assert async_content == b"Hello, world!"
 
     # Support 'data' for compat with requests.
-    stream = encode_request(data=b"Hello, world!")  # type: ignore
+    headers, stream = encode_request(data=b"Hello, world!")  # type: ignore
+    assert isinstance(stream, typing.Iterable)
+    assert isinstance(stream, typing.AsyncIterable)
+
     sync_content = b"".join([part for part in stream])
     async_content = b"".join([part async for part in stream])
 
-    assert stream.can_replay()
-    assert stream.get_headers() == {"Content-Length": "13"}
+    assert headers == {"Content-Length": "13"}
     assert sync_content == b"Hello, world!"
     assert async_content == b"Hello, world!"
 
@@ -58,25 +53,26 @@ async def test_iterator_content():
         yield b"Hello, "
         yield b"world!"
 
-    stream = encode_request(content=hello_world())
+    headers, stream = encode_request(content=hello_world())
+    assert isinstance(stream, typing.Iterable)
+    assert not isinstance(stream, typing.AsyncIterable)
+
     content = b"".join([part for part in stream])
 
-    assert not stream.can_replay()
-    assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
+    assert headers == {"Transfer-Encoding": "chunked"}
     assert content == b"Hello, world!"
 
-    with pytest.raises(RuntimeError):
-        [part async for part in stream]
-
     with pytest.raises(StreamConsumed):
         [part for part in stream]
 
     # Support 'data' for compat with requests.
-    stream = encode_request(data=hello_world())  # type: ignore
+    headers, stream = encode_request(data=hello_world())  # type: ignore
+    assert isinstance(stream, typing.Iterable)
+    assert not isinstance(stream, typing.AsyncIterable)
+
     content = b"".join([part for part in stream])
 
-    assert not stream.can_replay()
-    assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
+    assert headers == {"Transfer-Encoding": "chunked"}
     assert content == b"Hello, world!"
 
 
@@ -86,36 +82,39 @@ async def test_aiterator_content():
         yield b"Hello, "
         yield b"world!"
 
-    stream = encode_request(content=hello_world())
+    headers, stream = encode_request(content=hello_world())
+    assert not isinstance(stream, typing.Iterable)
+    assert isinstance(stream, typing.AsyncIterable)
+
     content = b"".join([part async for part in stream])
 
-    assert not stream.can_replay()
-    assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
+    assert headers == {"Transfer-Encoding": "chunked"}
     assert content == b"Hello, world!"
 
-    with pytest.raises(RuntimeError):
-        [part for part in stream]
-
     with pytest.raises(StreamConsumed):
         [part async for part in stream]
 
     # Support 'data' for compat with requests.
-    stream = encode_request(data=hello_world())  # type: ignore
+    headers, stream = encode_request(data=hello_world())  # type: ignore
+    assert not isinstance(stream, typing.Iterable)
+    assert isinstance(stream, typing.AsyncIterable)
+
     content = b"".join([part async for part in stream])
 
-    assert not stream.can_replay()
-    assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
+    assert headers == {"Transfer-Encoding": "chunked"}
     assert content == b"Hello, world!"
 
 
 @pytest.mark.asyncio
 async def test_json_content():
-    stream = encode_request(json={"Hello": "world!"})
+    headers, stream = encode_request(json={"Hello": "world!"})
+    assert isinstance(stream, typing.Iterable)
+    assert isinstance(stream, typing.AsyncIterable)
+
     sync_content = b"".join([part for part in stream])
     async_content = b"".join([part async for part in stream])
 
-    assert stream.can_replay()
-    assert stream.get_headers() == {
+    assert headers == {
         "Content-Length": "19",
         "Content-Type": "application/json",
     }
@@ -125,12 +124,14 @@ async def test_json_content():
 
 @pytest.mark.asyncio
 async def test_urlencoded_content():
-    stream = encode_request(data={"Hello": "world!"})
+    headers, stream = encode_request(data={"Hello": "world!"})
+    assert isinstance(stream, typing.Iterable)
+    assert isinstance(stream, typing.AsyncIterable)
+
     sync_content = b"".join([part for part in stream])
     async_content = b"".join([part async for part in stream])
 
-    assert stream.can_replay()
-    assert stream.get_headers() == {
+    assert headers == {
         "Content-Length": "14",
         "Content-Type": "application/x-www-form-urlencoded",
     }
@@ -141,12 +142,14 @@ async def test_urlencoded_content():
 @pytest.mark.asyncio
 async def test_multipart_files_content():
     files = {"file": io.BytesIO(b"<file content>")}
-    stream = encode_request(files=files, boundary=b"+++")
+    headers, stream = encode_request(files=files, boundary=b"+++")
+    assert isinstance(stream, typing.Iterable)
+    assert isinstance(stream, typing.AsyncIterable)
+
     sync_content = b"".join([part for part in stream])
     async_content = b"".join([part async for part in stream])
 
-    assert stream.can_replay()
-    assert stream.get_headers() == {
+    assert headers == {
         "Content-Length": "138",
         "Content-Type": "multipart/form-data; boundary=+++",
     }
@@ -176,12 +179,14 @@ async def test_multipart_files_content():
 async def test_multipart_data_and_files_content():
     data = {"message": "Hello, world!"}
     files = {"file": io.BytesIO(b"<file content>")}
-    stream = encode_request(data=data, files=files, boundary=b"+++")
+    headers, stream = encode_request(data=data, files=files, boundary=b"+++")
+    assert isinstance(stream, typing.Iterable)
+    assert isinstance(stream, typing.AsyncIterable)
+
     sync_content = b"".join([part for part in stream])
     async_content = b"".join([part async for part in stream])
 
-    assert stream.can_replay()
-    assert stream.get_headers() == {
+    assert headers == {
         "Content-Length": "210",
         "Content-Type": "multipart/form-data; boundary=+++",
     }
@@ -217,12 +222,14 @@ async def test_multipart_data_and_files_content():
 
 @pytest.mark.asyncio
 async def test_empty_request():
-    stream = encode_request(data={}, files={})
+    headers, stream = encode_request(data={}, files={})
+    assert isinstance(stream, typing.Iterable)
+    assert isinstance(stream, typing.AsyncIterable)
+
     sync_content = b"".join([part for part in stream])
     async_content = b"".join([part async for part in stream])
 
-    assert stream.can_replay()
-    assert stream.get_headers() == {}
+    assert headers == {}
     assert sync_content == b""
     assert async_content == b""
 
@@ -238,12 +245,14 @@ async def test_multipart_multiple_files_single_input_content():
         ("file", io.BytesIO(b"<file content 1>")),
         ("file", io.BytesIO(b"<file content 2>")),
     ]
-    stream = encode_request(files=files, boundary=b"+++")
+    headers, stream = encode_request(files=files, boundary=b"+++")
+    assert isinstance(stream, typing.Iterable)
+    assert isinstance(stream, typing.AsyncIterable)
+
     sync_content = b"".join([part for part in stream])
     async_content = b"".join([part async for part in stream])
 
-    assert stream.can_replay()
-    assert stream.get_headers() == {
+    assert headers == {
         "Content-Length": "271",
         "Content-Type": "multipart/form-data; boundary=+++",
     }
@@ -281,24 +290,28 @@ async def test_multipart_multiple_files_single_input_content():
 
 @pytest.mark.asyncio
 async def test_response_empty_content():
-    stream = encode_response()
+    headers, stream = encode_response()
+    assert isinstance(stream, typing.Iterable)
+    assert isinstance(stream, typing.AsyncIterable)
+
     sync_content = b"".join([part for part in stream])
     async_content = b"".join([part async for part in stream])
 
-    assert stream.can_replay()
-    assert stream.get_headers() == {}
+    assert headers == {}
     assert sync_content == b""
     assert async_content == b""
 
 
 @pytest.mark.asyncio
 async def test_response_bytes_content():
-    stream = encode_response(content=b"Hello, world!")
+    headers, stream = encode_response(content=b"Hello, world!")
+    assert isinstance(stream, typing.Iterable)
+    assert isinstance(stream, typing.AsyncIterable)
+
     sync_content = b"".join([part for part in stream])
     async_content = b"".join([part async for part in stream])
 
-    assert stream.can_replay()
-    assert stream.get_headers() == {"Content-Length": "13"}
+    assert headers == {"Content-Length": "13"}
     assert sync_content == b"Hello, world!"
     assert async_content == b"Hello, world!"
 
@@ -309,16 +322,15 @@ async def test_response_iterator_content():
         yield b"Hello, "
         yield b"world!"
 
-    stream = encode_response(content=hello_world())
+    headers, stream = encode_response(content=hello_world())
+    assert isinstance(stream, typing.Iterable)
+    assert not isinstance(stream, typing.AsyncIterable)
+
     content = b"".join([part for part in stream])
 
-    assert not stream.can_replay()
-    assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
+    assert headers == {"Transfer-Encoding": "chunked"}
     assert content == b"Hello, world!"
 
-    with pytest.raises(RuntimeError):
-        [part async for part in stream]
-
     with pytest.raises(StreamConsumed):
         [part for part in stream]
 
@@ -329,16 +341,15 @@ async def test_response_aiterator_content():
         yield b"Hello, "
         yield b"world!"
 
-    stream = encode_response(content=hello_world())
+    headers, stream = encode_response(content=hello_world())
+    assert not isinstance(stream, typing.Iterable)
+    assert isinstance(stream, typing.AsyncIterable)
+
     content = b"".join([part async for part in stream])
 
-    assert not stream.can_replay()
-    assert stream.get_headers() == {"Transfer-Encoding": "chunked"}
+    assert headers == {"Transfer-Encoding": "chunked"}
     assert content == b"Hello, world!"
 
-    with pytest.raises(RuntimeError):
-        [part for part in stream]
-
     with pytest.raises(StreamConsumed):
         [part async for part in stream]
 
index 4403946520902c386719e4053e3eebe07ffe2900..9f63faa189804e2ecb0301db1456104248c7cf5b 100644 (file)
@@ -110,9 +110,8 @@ def test_multipart_encode(tmp_path: typing.Any) -> None:
     with mock.patch("os.urandom", return_value=os.urandom(16)):
         boundary = os.urandom(16).hex()
 
-        stream = encode_request(data=data, files=files)
+        headers, stream = encode_request(data=data, files=files)
         assert isinstance(stream, MultipartStream)
-        assert stream.can_replay()
 
         assert stream.content_type == f"multipart/form-data; boundary={boundary}"
         content = (
@@ -128,7 +127,7 @@ def test_multipart_encode(tmp_path: typing.Any) -> None:
             "--{0}--\r\n"
             "".format(boundary).encode("ascii")
         )
-        assert stream.get_headers()["Content-Length"] == str(len(content))
+        assert headers["Content-Length"] == str(len(content))
         assert b"".join(stream) == content
 
 
@@ -137,9 +136,8 @@ def test_multipart_encode_files_allows_filenames_as_none() -> None:
     with mock.patch("os.urandom", return_value=os.urandom(16)):
         boundary = os.urandom(16).hex()
 
-        stream = encode_request(data={}, files=files)
+        headers, stream = encode_request(data={}, files=files)
         assert isinstance(stream, MultipartStream)
-        assert stream.can_replay()
 
         assert stream.content_type == f"multipart/form-data; boundary={boundary}"
         assert b"".join(stream) == (
@@ -164,9 +162,8 @@ def test_multipart_encode_files_guesses_correct_content_type(
     with mock.patch("os.urandom", return_value=os.urandom(16)):
         boundary = os.urandom(16).hex()
 
-        stream = encode_request(data={}, files=files)
+        headers, stream = encode_request(data={}, files=files)
         assert isinstance(stream, MultipartStream)
-        assert stream.can_replay()
 
         assert stream.content_type == f"multipart/form-data; boundary={boundary}"
         assert b"".join(stream) == (
@@ -188,9 +185,8 @@ def test_multipart_encode_files_allows_bytes_or_str_content(
     with mock.patch("os.urandom", return_value=os.urandom(16)):
         boundary = os.urandom(16).hex()
 
-        stream = encode_request(data={}, files=files)
+        headers, stream = encode_request(data={}, files=files)
         assert isinstance(stream, MultipartStream)
-        assert stream.can_replay()
 
         assert stream.content_type == f"multipart/form-data; boundary={boundary}"
         content = (
@@ -200,7 +196,7 @@ def test_multipart_encode_files_allows_bytes_or_str_content(
             "--{0}--\r\n"
             "".format(boundary, output).encode("ascii")
         )
-        assert stream.get_headers()["Content-Length"] == str(len(content))
+        assert headers["Content-Length"] == str(len(content))
         assert b"".join(stream) == content
 
 
@@ -214,9 +210,6 @@ def test_multipart_encode_non_seekable_filelike() -> None:
         def __init__(self, iterator: typing.Iterator[bytes]) -> None:
             self._iterator = iterator
 
-        def seekable(self) -> bool:
-            return False
-
         def read(self, *args: typing.Any) -> bytes:
             return b"".join(self._iterator)
 
@@ -226,8 +219,8 @@ def test_multipart_encode_non_seekable_filelike() -> None:
 
     fileobj: typing.Any = IteratorIO(data())
     files = {"file": fileobj}
-    stream = encode_request(files=files, boundary=b"+++")
-    assert not stream.can_replay()
+    headers, stream = encode_request(files=files, boundary=b"+++")
+    assert isinstance(stream, typing.Iterable)
 
     content = (
         b"--+++\r\n"
@@ -237,7 +230,7 @@ def test_multipart_encode_non_seekable_filelike() -> None:
         b"HelloWorld\r\n"
         b"--+++--\r\n"
     )
-    assert stream.get_headers() == {
+    assert headers == {
         "Content-Type": "multipart/form-data; boundary=+++",
         "Content-Length": str(len(content)),
     }
index 87a1a85515de54d238761d8d767256c15cb8426e..ba8a188e78b9d662e55bd167509425af477bb834 100644 (file)
@@ -36,22 +36,11 @@ class MockTransport(httpcore.SyncHTTPTransport):
         stream: httpcore.SyncByteStream = None,
         timeout: Mapping[str, Optional[float]] = None,
     ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.SyncByteStream]:
-        request_headers = httpx.Headers(headers)
-        content = (
-            (item for item in stream)
-            if stream
-            and (
-                "Content-Length" in request_headers
-                or "Transfer-Encoding" in request_headers
-            )
-            else None
-        )
-
         request = httpx.Request(
             method=method,
             url=url,
-            headers=request_headers,
-            content=content,
+            headers=headers,
+            stream=stream,
         )
         request.read()
         response = self.handler(request)
@@ -60,13 +49,13 @@ class MockTransport(httpcore.SyncHTTPTransport):
             response.status_code,
             response.reason_phrase.encode("ascii"),
             response.headers.raw,
-            response._raw_stream,
+            response.stream,
         )
 
 
 class AsyncMockTransport(httpcore.AsyncHTTPTransport):
     def __init__(self, handler: Callable) -> None:
-        self.impl = MockTransport(handler)
+        self.handler = handler
 
     async def request(
         self,
@@ -76,28 +65,18 @@ class AsyncMockTransport(httpcore.AsyncHTTPTransport):
         stream: httpcore.AsyncByteStream = None,
         timeout: Mapping[str, Optional[float]] = None,
     ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.AsyncByteStream]:
-        content = (
-            httpcore.PlainByteStream(b"".join([part async for part in stream]))
-            if stream
-            else httpcore.PlainByteStream(b"")
-        )
-
-        (
-            http_version,
-            status_code,
-            reason_phrase,
-            headers,
-            response_stream,
-        ) = self.impl.request(
-            method, url, headers=headers, stream=content, timeout=timeout
+        request = httpx.Request(
+            method=method,
+            url=url,
+            headers=headers,
+            stream=stream,
         )
-
-        content = httpcore.PlainByteStream(b"".join([part for part in response_stream]))
-
+        await request.aread()
+        response = self.handler(request)
         return (
-            http_version,
-            status_code,
-            reason_phrase,
-            headers,
-            content,
+            (response.http_version or "HTTP/1.1").encode("ascii"),
+            response.status_code,
+            response.reason_phrase.encode("ascii"),
+            response.headers.raw,
+            response.stream,
         )