]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Refactor content streams (#1296)
authorTom Christie <tom@tomchristie.com>
Fri, 18 Sep 2020 09:50:15 +0000 (10:50 +0100)
committerGitHub <noreply@github.com>
Fri, 18 Sep 2020 09:50:15 +0000 (10:50 +0100)
* Refactor content_streams internally

* Tidy up multipart

* Use ByteStream annotation internally

httpx/_content.py [new file with mode: 0644]
httpx/_content_streams.py [deleted file]
httpx/_models.py
httpx/_multipart.py [new file with mode: 0644]
httpx/_types.py
tests/test_content.py [moved from tests/test_content_streams.py with 99% similarity]
tests/test_multipart.py

diff --git a/httpx/_content.py b/httpx/_content.py
new file mode 100644 (file)
index 0000000..8e5d5e9
--- /dev/null
@@ -0,0 +1,166 @@
+import inspect
+import typing
+from json import dumps as json_dumps
+from urllib.parse import urlencode
+
+from ._exceptions import StreamConsumed
+from ._multipart import MultipartStream
+from ._types import (
+    ByteStream,
+    RequestContent,
+    RequestData,
+    RequestFiles,
+    ResponseContent,
+)
+
+
+class PlainByteStream:
+    """
+    Request content encoded as plain bytes.
+    """
+
+    def __init__(self, body: bytes) -> None:
+        self._body = body
+
+    def __iter__(self) -> typing.Iterator[bytes]:
+        yield self._body
+
+    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
+        yield self._body
+
+
+class GeneratorStream:
+    """
+    Request content encoded as plain bytes, using an byte generator.
+    """
+
+    def __init__(self, generator: typing.Iterable[bytes]) -> None:
+        self._generator = generator
+        self._is_stream_consumed = False
+
+    def __iter__(self) -> typing.Iterator[bytes]:
+        if self._is_stream_consumed:
+            raise StreamConsumed()
+
+        self._is_stream_consumed = True
+        for part in self._generator:
+            yield part
+
+
+class AsyncGeneratorStream:
+    """
+    Request content encoded as plain bytes, using an async byte iterator.
+    """
+
+    def __init__(self, agenerator: typing.AsyncIterable[bytes]) -> None:
+        self._agenerator = agenerator
+        self._is_stream_consumed = False
+
+    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
+        if self._is_stream_consumed:
+            raise StreamConsumed()
+
+        self._is_stream_consumed = True
+        async for part in self._agenerator:
+            yield part
+
+
+def encode_content(
+    content: typing.Union[str, bytes, ByteStream]
+) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
+    if isinstance(content, (str, bytes)):
+        body = content.encode("utf-8") if isinstance(content, str) else content
+        content_length = str(len(body))
+        headers = {"Content-Length": content_length} if body else {}
+        stream = PlainByteStream(body)
+        return headers, stream
+
+    elif isinstance(content, (typing.Iterable, typing.AsyncIterable)):
+        headers = {"Transfer-Encoding": "chunked"}
+
+        # Generators should be wrapped in GeneratorStream/AsyncGeneratorStream
+        # which will raise `StreamConsumed` if the stream is accessed more
+        # than once. (Eg. Following HTTP 307 or HTTP 308 redirects.)
+        if inspect.isgenerator(content):
+            generator_stream = GeneratorStream(content)  # type: ignore
+            return headers, generator_stream
+        if inspect.isasyncgen(content):
+            agenerator_stream = AsyncGeneratorStream(content)  # type: ignore
+            return headers, agenerator_stream
+
+        # Other iterables may be passed through as-is.
+        return headers, content  # type: ignore
+
+    raise TypeError(f"Unexpected type for 'content', {type(content)!r}")
+
+
+def encode_urlencoded_data(
+    data: dict,
+) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
+    body = urlencode(data, doseq=True).encode("utf-8")
+    content_length = str(len(body))
+    content_type = "application/x-www-form-urlencoded"
+    headers = {"Content-Length": content_length, "Content-Type": content_type}
+    return headers, PlainByteStream(body)
+
+
+def encode_multipart_data(
+    data: dict, files: RequestFiles, boundary: bytes = None
+) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
+    stream = MultipartStream(data=data, files=files, boundary=boundary)
+    headers = stream.get_headers()
+    return headers, stream
+
+
+def encode_json(json: typing.Any) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
+    body = json_dumps(json).encode("utf-8")
+    content_length = str(len(body))
+    content_type = "application/json"
+    headers = {"Content-Length": content_length, "Content-Type": content_type}
+    return headers, PlainByteStream(body)
+
+
+def encode_request(
+    content: RequestContent = None,
+    data: RequestData = None,
+    files: RequestFiles = None,
+    json: typing.Any = None,
+    boundary: bytes = None,
+) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
+    """
+    Handles encoding the given `content`, `data`, `files`, and `json`,
+    returning a two-tuple of (<headers>, <stream>).
+    """
+    if data is not None and not isinstance(data, dict):
+        # We prefer to seperate `content=<bytes|str|byte iterator|bytes aiterator>`
+        # for raw request content, and `data=<form data>` for url encoded or
+        # multipart form content.
+        #
+        # However for compat with requests, we *do* still support
+        # `data=<bytes...>` usages. We deal with that case here, treating it
+        # as if `content=<...>` had been supplied instead.
+        return encode_content(data)
+
+    if content is not None:
+        return encode_content(content)
+    elif files:
+        return encode_multipart_data(data or {}, files, boundary)
+    elif data:
+        return encode_urlencoded_data(data)
+    elif json is not None:
+        return encode_json(json)
+
+    return {}, PlainByteStream(b"")
+
+
+def encode_response(
+    content: ResponseContent = None,
+) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
+    """
+    Handles encoding the given `content`, returning a two-tuple of
+    (<headers>, <stream>).
+    """
+    if content is not None:
+        return encode_content(content)
+
+    return {}, PlainByteStream(b"")
diff --git a/httpx/_content_streams.py b/httpx/_content_streams.py
deleted file mode 100644 (file)
index eacb077..0000000
+++ /dev/null
@@ -1,395 +0,0 @@
-import binascii
-import inspect
-import os
-import typing
-from json import dumps as json_dumps
-from pathlib import Path
-from urllib.parse import urlencode
-
-from ._exceptions import StreamConsumed
-from ._types import (
-    ByteStream,
-    FileContent,
-    FileTypes,
-    RequestContent,
-    RequestData,
-    RequestFiles,
-    ResponseContent,
-)
-from ._utils import (
-    format_form_param,
-    guess_content_type,
-    peek_filelike_length,
-    to_bytes,
-)
-
-
-class PlainByteStream:
-    """
-    Request content encoded as plain bytes.
-    """
-
-    def __init__(self, body: typing.Union[str, bytes]) -> None:
-        self.body = body.encode("utf-8") if isinstance(body, str) else body
-
-    def get_headers(self) -> typing.Dict[str, str]:
-        if not self.body:
-            return {}
-        content_length = str(len(self.body))
-        return {"Content-Length": content_length}
-
-    def __iter__(self) -> typing.Iterator[bytes]:
-        yield self.body
-
-    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
-        yield self.body
-
-
-class GeneratorStream:
-    """
-    Request content encoded as plain bytes, using an byte generator.
-    """
-
-    def __init__(self, generator: typing.Iterable[bytes]) -> None:
-        self._generator = generator
-        self._is_stream_consumed = False
-
-    def __iter__(self) -> typing.Iterator[bytes]:
-        if self._is_stream_consumed:
-            raise StreamConsumed()
-        self._is_stream_consumed = True
-        for part in self._generator:
-            yield part
-
-
-class AsyncGeneratorStream:
-    """
-    Request content encoded as plain bytes, using an async byte iterator.
-    """
-
-    def __init__(self, agenerator: typing.AsyncIterable[bytes]) -> None:
-        self._agenerator = agenerator
-        self._is_stream_consumed = False
-
-    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
-        if self._is_stream_consumed:
-            raise StreamConsumed()
-        self._is_stream_consumed = True
-        async for part in self._agenerator:
-            yield part
-
-
-class JSONStream:
-    """
-    Request content encoded as JSON.
-    """
-
-    def __init__(self, json: typing.Any) -> None:
-        self.body = json_dumps(json).encode("utf-8")
-
-    def get_headers(self) -> typing.Dict[str, str]:
-        content_length = str(len(self.body))
-        content_type = "application/json"
-        return {"Content-Length": content_length, "Content-Type": content_type}
-
-    def __iter__(self) -> typing.Iterator[bytes]:
-        yield self.body
-
-    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
-        yield self.body
-
-
-class URLEncodedStream:
-    """
-    Request content as URL encoded form data.
-    """
-
-    def __init__(self, data: dict) -> None:
-        self.body = urlencode(data, doseq=True).encode("utf-8")
-
-    def get_headers(self) -> typing.Dict[str, str]:
-        content_length = str(len(self.body))
-        content_type = "application/x-www-form-urlencoded"
-        return {"Content-Length": content_length, "Content-Type": content_type}
-
-    def __iter__(self) -> typing.Iterator[bytes]:
-        yield self.body
-
-    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
-        yield self.body
-
-
-class MultipartStream:
-    """
-    Request content as streaming multipart encoded form data.
-    """
-
-    class DataField:
-        """
-        A single form field item, within a multipart form field.
-        """
-
-        def __init__(self, name: str, value: typing.Union[str, bytes]) -> None:
-            if not isinstance(name, str):
-                raise TypeError("Invalid type for name. Expected str.")
-            if not isinstance(value, (str, bytes)):
-                raise TypeError("Invalid type for value. Expected str or bytes.")
-            self.name = name
-            self.value = value
-
-        def render_headers(self) -> bytes:
-            if not hasattr(self, "_headers"):
-                name = format_form_param("name", self.name)
-                self._headers = b"".join(
-                    [b"Content-Disposition: form-data; ", name, b"\r\n\r\n"]
-                )
-
-            return self._headers
-
-        def render_data(self) -> bytes:
-            if not hasattr(self, "_data"):
-                self._data = (
-                    self.value
-                    if isinstance(self.value, bytes)
-                    else self.value.encode("utf-8")
-                )
-
-            return self._data
-
-        def get_length(self) -> int:
-            headers = self.render_headers()
-            data = self.render_data()
-            return len(headers) + len(data)
-
-        def render(self) -> typing.Iterator[bytes]:
-            yield self.render_headers()
-            yield self.render_data()
-
-    class FileField:
-        """
-        A single file field item, within a multipart form field.
-        """
-
-        def __init__(self, name: str, value: FileTypes) -> None:
-            self.name = name
-
-            fileobj: FileContent
-
-            if isinstance(value, tuple):
-                try:
-                    filename, fileobj, content_type = value  # type: ignore
-                except ValueError:
-                    filename, fileobj = value  # type: ignore
-                    content_type = guess_content_type(filename)
-            else:
-                filename = Path(str(getattr(value, "name", "upload"))).name
-                fileobj = value
-                content_type = guess_content_type(filename)
-
-            self.filename = filename
-            self.file = fileobj
-            self.content_type = content_type
-            self._consumed = False
-
-        def get_length(self) -> int:
-            headers = self.render_headers()
-
-            if isinstance(self.file, (str, bytes)):
-                return len(headers) + len(self.file)
-
-            # Let's do our best not to read `file` into memory.
-            try:
-                file_length = peek_filelike_length(self.file)
-            except OSError:
-                # As a last resort, read file and cache contents for later.
-                assert not hasattr(self, "_data")
-                self._data = to_bytes(self.file.read())
-                file_length = len(self._data)
-
-            return len(headers) + file_length
-
-        def render_headers(self) -> bytes:
-            if not hasattr(self, "_headers"):
-                parts = [
-                    b"Content-Disposition: form-data; ",
-                    format_form_param("name", self.name),
-                ]
-                if self.filename:
-                    filename = format_form_param("filename", self.filename)
-                    parts.extend([b"; ", filename])
-                if self.content_type is not None:
-                    content_type = self.content_type.encode()
-                    parts.extend([b"\r\nContent-Type: ", content_type])
-                parts.append(b"\r\n\r\n")
-                self._headers = b"".join(parts)
-
-            return self._headers
-
-        def render_data(self) -> typing.Iterator[bytes]:
-            if isinstance(self.file, (str, bytes)):
-                yield to_bytes(self.file)
-                return
-
-            if hasattr(self, "_data"):
-                # Already rendered.
-                yield self._data
-                return
-
-            if self._consumed:  # pragma: nocover
-                self.file.seek(0)
-            self._consumed = True
-
-            for chunk in self.file:
-                yield to_bytes(chunk)
-
-        def render(self) -> typing.Iterator[bytes]:
-            yield self.render_headers()
-            yield from self.render_data()
-
-    def __init__(
-        self, data: typing.Mapping, files: RequestFiles, boundary: bytes = None
-    ) -> None:
-        if boundary is None:
-            boundary = binascii.hexlify(os.urandom(16))
-
-        self.boundary = boundary
-        self.content_type = "multipart/form-data; boundary=%s" % boundary.decode(
-            "ascii"
-        )
-        self.fields = list(self._iter_fields(data, files))
-
-    def _iter_fields(
-        self, data: typing.Mapping, files: RequestFiles
-    ) -> typing.Iterator[typing.Union["FileField", "DataField"]]:
-        for name, value in data.items():
-            if isinstance(value, list):
-                for item in value:
-                    yield self.DataField(name=name, value=item)
-            else:
-                yield self.DataField(name=name, value=value)
-
-        file_items = files.items() if isinstance(files, typing.Mapping) else files
-        for name, value in file_items:
-            yield self.FileField(name=name, value=value)
-
-    def iter_chunks(self) -> typing.Iterator[bytes]:
-        for field in self.fields:
-            yield b"--%s\r\n" % self.boundary
-            yield from field.render()
-            yield b"\r\n"
-        yield b"--%s--\r\n" % self.boundary
-
-    def iter_chunks_lengths(self) -> typing.Iterator[int]:
-        boundary_length = len(self.boundary)
-        # Follow closely what `.iter_chunks()` does.
-        for field in self.fields:
-            yield 2 + boundary_length + 2
-            yield field.get_length()
-            yield 2
-        yield 2 + boundary_length + 4
-
-    def get_content_length(self) -> int:
-        return sum(self.iter_chunks_lengths())
-
-    # Content stream interface.
-
-    def get_headers(self) -> typing.Dict[str, str]:
-        content_length = str(self.get_content_length())
-        content_type = self.content_type
-        return {"Content-Length": content_length, "Content-Type": content_type}
-
-    def __iter__(self) -> typing.Iterator[bytes]:
-        for chunk in self.iter_chunks():
-            yield chunk
-
-    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
-        for chunk in self.iter_chunks():
-            yield chunk
-
-
-def encode_request(
-    content: RequestContent = None,
-    data: RequestData = None,
-    files: RequestFiles = None,
-    json: typing.Any = None,
-    boundary: bytes = None,
-) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
-    """
-    Handles encoding the given `content`, `data`, `files`, and `json`,
-    returning a two-tuple of (<headers>, <stream>).
-    """
-    if data is not None and not isinstance(data, dict):
-        # We prefer to seperate `content=<bytes|byte iterator|bytes aiterator>`
-        # for raw request content, and `data=<form data>` for url encoded or
-        # multipart form content.
-        #
-        # However for compat with requests, we *do* still support
-        # `data=<bytes...>` usages. We deal with that case here, treating it
-        # as if `content=<...>` had been supplied instead.
-        content = data
-        data = None
-
-    if content is not None:
-        if isinstance(content, (str, bytes)):
-            byte_stream = PlainByteStream(body=content)
-            headers = byte_stream.get_headers()
-            return headers, byte_stream
-        elif isinstance(content, (typing.Iterable, typing.AsyncIterable)):
-            if inspect.isgenerator(content):
-                generator_stream = GeneratorStream(content)  # type: ignore
-                return {"Transfer-Encoding": "chunked"}, generator_stream
-            if inspect.isasyncgen(content):
-                agenerator_stream = AsyncGeneratorStream(content)  # type: ignore
-                return {"Transfer-Encoding": "chunked"}, agenerator_stream
-            return {"Transfer-Encoding": "chunked"}, content  # type: ignore
-        else:
-            raise TypeError(f"Unexpected type for 'content', {type(content)!r}")
-
-    elif data:
-        if files:
-            multipart_stream = MultipartStream(
-                data=data, files=files, boundary=boundary
-            )
-            headers = multipart_stream.get_headers()
-            return headers, multipart_stream
-        else:
-            urlencoded_stream = URLEncodedStream(data=data)
-            headers = urlencoded_stream.get_headers()
-            return headers, urlencoded_stream
-
-    elif files:
-        multipart_stream = MultipartStream(data={}, files=files, boundary=boundary)
-        headers = multipart_stream.get_headers()
-        return headers, multipart_stream
-
-    elif json is not None:
-        json_stream = JSONStream(json=json)
-        headers = json_stream.get_headers()
-        return headers, json_stream
-
-    byte_stream = PlainByteStream(body=b"")
-    headers = byte_stream.get_headers()
-    return headers, byte_stream
-
-
-def encode_response(
-    content: ResponseContent = None,
-) -> typing.Tuple[typing.Dict[str, str], ByteStream]:
-    if content is None:
-        byte_stream = PlainByteStream(b"")
-        headers = byte_stream.get_headers()
-        return headers, byte_stream
-    elif isinstance(content, bytes):
-        byte_stream = PlainByteStream(body=content)
-        headers = byte_stream.get_headers()
-        return headers, byte_stream
-    elif isinstance(content, (typing.Iterable, typing.AsyncIterable)):
-        if inspect.isgenerator(content):
-            generator_stream = GeneratorStream(content)  # type: ignore
-            return {"Transfer-Encoding": "chunked"}, generator_stream
-        elif inspect.isasyncgen(content):
-            agenerator_stream = AsyncGeneratorStream(content)  # type: ignore
-            return {"Transfer-Encoding": "chunked"}, agenerator_stream
-        return {"Transfer-Encoding": "chunked"}, content  # type: ignore
-
-    raise TypeError(f"Unexpected type for 'content', {type(content)!r}")
index e750c10f0fd7fb7974a89cb50f73021daeeb7d40..9dd3e0b0d4912e455ad28d8763bb12969eaa8096 100644 (file)
@@ -13,7 +13,7 @@ from urllib.parse import parse_qsl, quote, unquote, urlencode
 import rfc3986
 import rfc3986.exceptions
 
-from ._content_streams import PlainByteStream, encode_request, encode_response
+from ._content import PlainByteStream, encode_request, encode_response
 from ._decoders import (
     SUPPORTED_DECODERS,
     ContentDecoder,
diff --git a/httpx/_multipart.py b/httpx/_multipart.py
new file mode 100644 (file)
index 0000000..fc81e6f
--- /dev/null
@@ -0,0 +1,200 @@
+import binascii
+import os
+import typing
+from pathlib import Path
+
+from ._types import FileContent, FileTypes, RequestFiles
+from ._utils import (
+    format_form_param,
+    guess_content_type,
+    peek_filelike_length,
+    to_bytes,
+)
+
+
+class DataField:
+    """
+    A single form field item, within a multipart form field.
+    """
+
+    def __init__(self, name: str, value: typing.Union[str, bytes]) -> None:
+        if not isinstance(name, str):
+            raise TypeError("Invalid type for name. Expected str.")
+        if not isinstance(value, (str, bytes)):
+            raise TypeError("Invalid type for value. Expected str or bytes.")
+        self.name = name
+        self.value = value
+
+    def render_headers(self) -> bytes:
+        if not hasattr(self, "_headers"):
+            name = format_form_param("name", self.name)
+            self._headers = b"".join(
+                [b"Content-Disposition: form-data; ", name, b"\r\n\r\n"]
+            )
+
+        return self._headers
+
+    def render_data(self) -> bytes:
+        if not hasattr(self, "_data"):
+            self._data = (
+                self.value
+                if isinstance(self.value, bytes)
+                else self.value.encode("utf-8")
+            )
+
+        return self._data
+
+    def get_length(self) -> int:
+        headers = self.render_headers()
+        data = self.render_data()
+        return len(headers) + len(data)
+
+    def render(self) -> typing.Iterator[bytes]:
+        yield self.render_headers()
+        yield self.render_data()
+
+
+class FileField:
+    """
+    A single file field item, within a multipart form field.
+    """
+
+    def __init__(self, name: str, value: FileTypes) -> None:
+        self.name = name
+
+        fileobj: FileContent
+
+        if isinstance(value, tuple):
+            try:
+                filename, fileobj, content_type = value  # type: ignore
+            except ValueError:
+                filename, fileobj = value  # type: ignore
+                content_type = guess_content_type(filename)
+        else:
+            filename = Path(str(getattr(value, "name", "upload"))).name
+            fileobj = value
+            content_type = guess_content_type(filename)
+
+        self.filename = filename
+        self.file = fileobj
+        self.content_type = content_type
+        self._consumed = False
+
+    def get_length(self) -> int:
+        headers = self.render_headers()
+
+        if isinstance(self.file, (str, bytes)):
+            return len(headers) + len(self.file)
+
+        # Let's do our best not to read `file` into memory.
+        try:
+            file_length = peek_filelike_length(self.file)
+        except OSError:
+            # As a last resort, read file and cache contents for later.
+            assert not hasattr(self, "_data")
+            self._data = to_bytes(self.file.read())
+            file_length = len(self._data)
+
+        return len(headers) + file_length
+
+    def render_headers(self) -> bytes:
+        if not hasattr(self, "_headers"):
+            parts = [
+                b"Content-Disposition: form-data; ",
+                format_form_param("name", self.name),
+            ]
+            if self.filename:
+                filename = format_form_param("filename", self.filename)
+                parts.extend([b"; ", filename])
+            if self.content_type is not None:
+                content_type = self.content_type.encode()
+                parts.extend([b"\r\nContent-Type: ", content_type])
+            parts.append(b"\r\n\r\n")
+            self._headers = b"".join(parts)
+
+        return self._headers
+
+    def render_data(self) -> typing.Iterator[bytes]:
+        if isinstance(self.file, (str, bytes)):
+            yield to_bytes(self.file)
+            return
+
+        if hasattr(self, "_data"):
+            # Already rendered.
+            yield self._data
+            return
+
+        if self._consumed:  # pragma: nocover
+            self.file.seek(0)
+        self._consumed = True
+
+        for chunk in self.file:
+            yield to_bytes(chunk)
+
+    def render(self) -> typing.Iterator[bytes]:
+        yield self.render_headers()
+        yield from self.render_data()
+
+
+class MultipartStream:
+    """
+    Request content as streaming multipart encoded form data.
+    """
+
+    def __init__(self, data: dict, files: RequestFiles, boundary: bytes = None) -> None:
+        if boundary is None:
+            boundary = binascii.hexlify(os.urandom(16))
+
+        self.boundary = boundary
+        self.content_type = "multipart/form-data; boundary=%s" % boundary.decode(
+            "ascii"
+        )
+        self.fields = list(self._iter_fields(data, files))
+
+    def _iter_fields(
+        self, data: dict, files: RequestFiles
+    ) -> typing.Iterator[typing.Union[FileField, DataField]]:
+        for name, value in data.items():
+            if isinstance(value, list):
+                for item in value:
+                    yield DataField(name=name, value=item)
+            else:
+                yield DataField(name=name, value=value)
+
+        file_items = files.items() if isinstance(files, typing.Mapping) else files
+        for name, value in file_items:
+            yield FileField(name=name, value=value)
+
+    def iter_chunks(self) -> typing.Iterator[bytes]:
+        for field in self.fields:
+            yield b"--%s\r\n" % self.boundary
+            yield from field.render()
+            yield b"\r\n"
+        yield b"--%s--\r\n" % self.boundary
+
+    def iter_chunks_lengths(self) -> typing.Iterator[int]:
+        boundary_length = len(self.boundary)
+        # Follow closely what `.iter_chunks()` does.
+        for field in self.fields:
+            yield 2 + boundary_length + 2
+            yield field.get_length()
+            yield 2
+        yield 2 + boundary_length + 4
+
+    def get_content_length(self) -> int:
+        return sum(self.iter_chunks_lengths())
+
+    # Content stream interface.
+
+    def get_headers(self) -> typing.Dict[str, str]:
+        content_length = str(self.get_content_length())
+        content_type = self.content_type
+        return {"Content-Length": content_length, "Content-Type": content_type}
+
+    def __iter__(self) -> typing.Iterator[bytes]:
+        for chunk in self.iter_chunks():
+            yield chunk
+
+    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
+        for chunk in self.iter_chunks():
+            yield chunk
index 1b1d3b7817a991994ca5b6654878b7d7e9f9318a..776df1d8dc9b46313ff87ba204a680425f235466 100644 (file)
@@ -66,9 +66,9 @@ AuthTypes = Union[
     None,
 ]
 
-RequestContent = Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]]
-ResponseContent = Union[bytes, Iterable[bytes], AsyncIterable[bytes]]
 ByteStream = Union[Iterable[bytes], AsyncIterable[bytes]]
+RequestContent = Union[str, bytes, ByteStream]
+ResponseContent = Union[str, bytes, ByteStream]
 
 RequestData = dict
 
similarity index 99%
rename from tests/test_content_streams.py
rename to tests/test_content.py
index 6bc40f405cb481a2933ac142185dba283825663e..384f9f228754fafad02bf9c398c236b4c3aa0806 100644 (file)
@@ -4,7 +4,7 @@ import typing
 import pytest
 
 from httpx import StreamConsumed
-from httpx._content_streams import encode_request, encode_response
+from httpx._content import encode_request, encode_response
 
 
 @pytest.mark.asyncio
index 9f63faa189804e2ecb0301db1456104248c7cf5b..d2932ee0986220dc0d2c2ca6a3f3e6065ffe239d 100644 (file)
@@ -7,7 +7,7 @@ from unittest import mock
 import pytest
 
 import httpx
-from httpx._content_streams import MultipartStream, encode_request
+from httpx._content import encode_request
 from httpx._utils import format_form_param
 from tests.utils import MockTransport
 
@@ -111,9 +111,8 @@ def test_multipart_encode(tmp_path: typing.Any) -> None:
         boundary = os.urandom(16).hex()
 
         headers, stream = encode_request(data=data, files=files)
-        assert isinstance(stream, MultipartStream)
+        assert isinstance(stream, typing.Iterable)
 
-        assert stream.content_type == f"multipart/form-data; boundary={boundary}"
         content = (
             '--{0}\r\nContent-Disposition: form-data; name="a"\r\n\r\n1\r\n'
             '--{0}\r\nContent-Disposition: form-data; name="b"\r\n\r\nC\r\n'
@@ -127,8 +126,11 @@ def test_multipart_encode(tmp_path: typing.Any) -> None:
             "--{0}--\r\n"
             "".format(boundary).encode("ascii")
         )
-        assert headers["Content-Length"] == str(len(content))
-        assert b"".join(stream) == content
+        assert headers == {
+            "Content-Type": f"multipart/form-data; boundary={boundary}",
+            "Content-Length": str(len(content)),
+        }
+        assert content == b"".join(stream)
 
 
 def test_multipart_encode_files_allows_filenames_as_none() -> None:
@@ -137,14 +139,18 @@ def test_multipart_encode_files_allows_filenames_as_none() -> None:
         boundary = os.urandom(16).hex()
 
         headers, stream = encode_request(data={}, files=files)
-        assert isinstance(stream, MultipartStream)
+        assert isinstance(stream, typing.Iterable)
 
-        assert stream.content_type == f"multipart/form-data; boundary={boundary}"
-        assert b"".join(stream) == (
+        content = (
             '--{0}\r\nContent-Disposition: form-data; name="file"\r\n\r\n'
             "<file content>\r\n--{0}--\r\n"
             "".format(boundary).encode("ascii")
         )
+        assert headers == {
+            "Content-Type": f"multipart/form-data; boundary={boundary}",
+            "Content-Length": str(len(content)),
+        }
+        assert content == b"".join(stream)
 
 
 @pytest.mark.parametrize(
@@ -163,15 +169,19 @@ def test_multipart_encode_files_guesses_correct_content_type(
         boundary = os.urandom(16).hex()
 
         headers, stream = encode_request(data={}, files=files)
-        assert isinstance(stream, MultipartStream)
+        assert isinstance(stream, typing.Iterable)
 
-        assert stream.content_type == f"multipart/form-data; boundary={boundary}"
-        assert b"".join(stream) == (
+        content = (
             f'--{boundary}\r\nContent-Disposition: form-data; name="file"; '
             f'filename="{file_name}"\r\nContent-Type: '
             f"{expected_content_type}\r\n\r\n<file content>\r\n--{boundary}--\r\n"
             "".encode("ascii")
         )
+        assert headers == {
+            "Content-Type": f"multipart/form-data; boundary={boundary}",
+            "Content-Length": str(len(content)),
+        }
+        assert content == b"".join(stream)
 
 
 @pytest.mark.parametrize(
@@ -186,9 +196,8 @@ def test_multipart_encode_files_allows_bytes_or_str_content(
         boundary = os.urandom(16).hex()
 
         headers, stream = encode_request(data={}, files=files)
-        assert isinstance(stream, MultipartStream)
+        assert isinstance(stream, typing.Iterable)
 
-        assert stream.content_type == f"multipart/form-data; boundary={boundary}"
         content = (
             '--{0}\r\nContent-Disposition: form-data; name="file"; '
             'filename="test.txt"\r\n'
@@ -196,8 +205,11 @@ def test_multipart_encode_files_allows_bytes_or_str_content(
             "--{0}--\r\n"
             "".format(boundary, output).encode("ascii")
         )
-        assert headers["Content-Length"] == str(len(content))
-        assert b"".join(stream) == content
+        assert headers == {
+            "Content-Type": f"multipart/form-data; boundary={boundary}",
+            "Content-Length": str(len(content)),
+        }
+        assert content == b"".join(stream)
 
 
 def test_multipart_encode_non_seekable_filelike() -> None:
@@ -234,7 +246,7 @@ def test_multipart_encode_non_seekable_filelike() -> None:
         "Content-Type": "multipart/form-data; boundary=+++",
         "Content-Length": str(len(content)),
     }
-    assert b"".join(stream) == content
+    assert content == b"".join(stream)
 
 
 class TestHeaderParamHTML5Formatting: