]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Implement streaming multipart uploads (#857)
authorFlorimond Manca <florimond.manca@gmail.com>
Fri, 10 Apr 2020 18:40:04 +0000 (20:40 +0200)
committerGitHub <noreply@github.com>
Fri, 10 Apr 2020 18:40:04 +0000 (20:40 +0200)
* Implement streaming multipart uploads

* Tweak seekable check

* Don't handle duplicate computation yet

* Add memory test for multipart streaming

* Lint

* 1 pct is enough

* Tweak lazy computations, fallback to non-streaming for broken filelikes

* Reduce diff size

* Drop memory test

* Cleanup

httpx/_content_streams.py
httpx/_utils.py
tests/test_multipart.py

index 362fdd58a0be7ff5610998dba2a5c55bbd4a5071..fa8616af20a97e4fc5b1bb7cb1f21ccd492745c2 100644 (file)
@@ -1,8 +1,6 @@
 import binascii
-import mimetypes
 import os
 import typing
-from io import BytesIO
 from json import dumps as json_dumps
 from pathlib import Path
 from urllib.parse import urlencode
@@ -11,7 +9,12 @@ import httpcore
 
 from ._exceptions import StreamConsumed
 from ._types import StrOrBytes
-from ._utils import format_form_param
+from ._utils import (
+    format_form_param,
+    guess_content_type,
+    peek_filelike_length,
+    to_bytes,
+)
 
 RequestData = typing.Union[
     dict, str, bytes, typing.Iterator[bytes], typing.AsyncIterator[bytes]
@@ -195,7 +198,7 @@ class URLEncodedStream(ContentStream):
 
 class MultipartStream(ContentStream):
     """
-    Request content as multipart encoded form data.
+    Request content as streaming multipart encoded form data.
     """
 
     class DataField:
@@ -212,15 +215,35 @@ class MultipartStream(ContentStream):
             self.value = value
 
         def render_headers(self) -> bytes:
-            name = format_form_param("name", self.name)
-            return b"".join([b"Content-Disposition: form-data; ", name, b"\r\n\r\n"])
+            if not hasattr(self, "_headers"):
+                name = format_form_param("name", self.name)
+                self._headers = b"".join(
+                    [b"Content-Disposition: form-data; ", name, b"\r\n\r\n"]
+                )
+
+            return self._headers
 
         def render_data(self) -> bytes:
-            return (
-                self.value
-                if isinstance(self.value, bytes)
-                else self.value.encode("utf-8")
-            )
+            if not hasattr(self, "_data"):
+                self._data = (
+                    self.value
+                    if isinstance(self.value, bytes)
+                    else self.value.encode("utf-8")
+                )
+
+            return self._data
+
+        def get_length(self) -> int:
+            headers = self.render_headers()
+            data = self.render_data()
+            return len(headers) + len(data)
+
+        def can_replay(self) -> bool:
+            return True
+
+        def render(self) -> typing.Iterator[bytes]:
+            yield self.render_headers()
+            yield self.render_data()
 
     class FileField:
         """
@@ -235,67 +258,88 @@ class MultipartStream(ContentStream):
             self.name = name
             if not isinstance(value, tuple):
                 self.filename = Path(str(getattr(value, "name", "upload"))).name
-                self.file = (
-                    value
-                )  # type: typing.Union[typing.IO[str], typing.IO[bytes]]
-                self.content_type = self.guess_content_type()
+                self.file: typing.Union[typing.IO[str], typing.IO[bytes]] = value
+                self.content_type = guess_content_type(self.filename)
             else:
                 self.filename = value[0]
                 self.file = value[1]
                 self.content_type = (
-                    value[2] if len(value) > 2 else self.guess_content_type()
+                    value[2] if len(value) > 2 else guess_content_type(self.filename)
                 )
 
-        def guess_content_type(self) -> typing.Optional[str]:
-            if self.filename:
-                return (
-                    mimetypes.guess_type(self.filename)[0] or "application/octet-stream"
-                )
-            else:
-                return None
+        def get_length(self) -> int:
+            headers = self.render_headers()
 
-        def render_headers(self) -> bytes:
-            parts = [
-                b"Content-Disposition: form-data; ",
-                format_form_param("name", self.name),
-            ]
-            if self.filename:
-                filename = format_form_param("filename", self.filename)
-                parts.extend([b"; ", filename])
-            if self.content_type is not None:
-                content_type = self.content_type.encode()
-                parts.extend([b"\r\nContent-Type: ", content_type])
-            parts.append(b"\r\n\r\n")
-            return b"".join(parts)
+            if isinstance(self.file, str):
+                return len(headers) + len(self.file)
 
-        def render_data(self) -> bytes:
-            content: typing.Union[str, bytes]
+            # Let's do our best not to read `file` into memory.
+            try:
+                file_length = peek_filelike_length(self.file)
+            except OSError:
+                # As a last resort, read file and cache contents for later.
+                assert not hasattr(self, "_data")
+                self._data = to_bytes(self.file.read())
+                file_length = len(self._data)
+
+            return len(headers) + file_length
+
+        def render_headers(self) -> bytes:
+            if not hasattr(self, "_headers"):
+                parts = [
+                    b"Content-Disposition: form-data; ",
+                    format_form_param("name", self.name),
+                ]
+                if self.filename:
+                    filename = format_form_param("filename", self.filename)
+                    parts.extend([b"; ", filename])
+                if self.content_type is not None:
+                    content_type = self.content_type.encode()
+                    parts.extend([b"\r\nContent-Type: ", content_type])
+                parts.append(b"\r\n\r\n")
+                self._headers = b"".join(parts)
+
+            return self._headers
+
+        def render_data(self) -> typing.Iterator[bytes]:
             if isinstance(self.file, str):
-                content = self.file
-            else:
-                content = self.file.read()
-            return content.encode("utf-8") if isinstance(content, str) else content
+                yield to_bytes(self.file)
+                return
 
-    def __init__(self, data: dict, files: dict, boundary: bytes = None) -> None:
-        body = BytesIO()
-        if boundary is None:
-            boundary = binascii.hexlify(os.urandom(16))
+            if hasattr(self, "_data"):
+                # Already rendered.
+                yield self._data
+                return
+
+            for chunk in self.file:
+                yield to_bytes(chunk)
+
+            # Get ready for the next replay, if possible.
+            if self.can_replay():
+                assert self.file.seekable()
+                self.file.seek(0)
 
-        for field in self.iter_fields(data, files):
-            body.write(b"--%s\r\n" % boundary)
-            body.write(field.render_headers())
-            body.write(field.render_data())
-            body.write(b"\r\n")
+        def can_replay(self) -> bool:
+            return True if isinstance(self.file, str) else self.file.seekable()
 
-        body.write(b"--%s--\r\n" % boundary)
+        def render(self) -> typing.Iterator[bytes]:
+            yield self.render_headers()
+            yield from self.render_data()
 
+    def __init__(
+        self, data: typing.Mapping, files: typing.Mapping, boundary: bytes = None
+    ) -> None:
+        if boundary is None:
+            boundary = binascii.hexlify(os.urandom(16))
+
+        self.boundary = boundary
         self.content_type = "multipart/form-data; boundary=%s" % boundary.decode(
             "ascii"
         )
-        self.body = body.getvalue()
+        self.fields = list(self._iter_fields(data, files))
 
-    def iter_fields(
-        self, data: dict, files: dict
+    def _iter_fields(
+        self, data: typing.Mapping, files: typing.Mapping
     ) -> typing.Iterator[typing.Union["FileField", "DataField"]]:
         for name, value in data.items():
             if isinstance(value, list):
@@ -307,16 +351,42 @@ class MultipartStream(ContentStream):
         for name, value in files.items():
             yield self.FileField(name=name, value=value)
 
+    def iter_chunks(self) -> typing.Iterator[bytes]:
+        for field in self.fields:
+            yield b"--%s\r\n" % self.boundary
+            yield from field.render()
+            yield b"\r\n"
+        yield b"--%s--\r\n" % self.boundary
+
+    def iter_chunks_lengths(self) -> typing.Iterator[int]:
+        boundary_length = len(self.boundary)
+        # Follow closely what `.iter_chunks()` does.
+        for field in self.fields:
+            yield 2 + boundary_length + 2
+            yield field.get_length()
+            yield 2
+        yield 2 + boundary_length + 4
+
+    def get_content_length(self) -> int:
+        return sum(self.iter_chunks_lengths())
+
+    # Content stream interface.
+
+    def can_replay(self) -> bool:
+        return all(field.can_replay() for field in self.fields)
+
     def get_headers(self) -> typing.Dict[str, str]:
-        content_length = str(len(self.body))
+        content_length = str(self.get_content_length())
         content_type = self.content_type
         return {"Content-Length": content_length, "Content-Type": content_type}
 
     def __iter__(self) -> typing.Iterator[bytes]:
-        yield self.body
+        for chunk in self.iter_chunks():
+            yield chunk
 
     async def __aiter__(self) -> typing.AsyncIterator[bytes]:
-        yield self.body
+        for chunk in self.iter_chunks():
+            yield chunk
 
 
 def encode(
index 83312fd60b4e380a309333168063b1be5b2f8306..b2be9f552815903e93237fa940552e7bc3d0d9e1 100644 (file)
@@ -2,6 +2,7 @@ import codecs
 import collections
 import contextlib
 import logging
+import mimetypes
 import netrc
 import os
 import re
@@ -310,6 +311,38 @@ def unquote(value: str) -> str:
     return value[1:-1] if value[0] == value[-1] == '"' else value
 
 
+def guess_content_type(filename: typing.Optional[str]) -> typing.Optional[str]:
+    if filename:
+        return mimetypes.guess_type(filename)[0] or "application/octet-stream"
+    return None
+
+
+def peek_filelike_length(stream: typing.IO) -> int:
+    """
+    Given a file-like stream object, return its length in number of bytes
+    without reading it into memory.
+    """
+    try:
+        # Is it an actual file?
+        fd = stream.fileno()
+    except OSError:
+        # No... Maybe it's something that supports random access, like `io.BytesIO`?
+        try:
+            # Assuming so, go to end of stream to figure out its length,
+            # then put it back in place.
+            offset = stream.tell()
+            length = stream.seek(0, os.SEEK_END)
+            stream.seek(offset)
+        except OSError:
+            # Not even that? Sorry, we're doomed...
+            raise
+        else:
+            return length
+    else:
+        # Yup, seems to be an actual file.
+        return os.fstat(fd).st_size
+
+
 def flatten_queryparams(
     queryparams: typing.Mapping[
         str, typing.Union["PrimitiveData", typing.Sequence["PrimitiveData"]]
index 3ebddc1d11e7e8dd455b8f7dfc8beccbc596a5aa..67f6a30697f25a76f93954a289537aa4c7b0386b 100644 (file)
@@ -101,20 +101,27 @@ async def test_multipart_file_tuple():
     assert multipart["file"] == [b"<file content>"]
 
 
-def test_multipart_encode():
+def test_multipart_encode(tmp_path: typing.Any) -> None:
+    path = str(tmp_path / "name.txt")
+    with open(path, "wb") as f:
+        f.write(b"<file content>")
+
     data = {
         "a": "1",
         "b": b"C",
         "c": ["11", "22", "33"],
         "d": "",
     }
-    files = {"file": ("name.txt", io.BytesIO(b"<file content>"))}
+    files = {"file": ("name.txt", open(path, "rb"))}
 
     with mock.patch("os.urandom", return_value=os.urandom(16)):
         boundary = binascii.hexlify(os.urandom(16)).decode("ascii")
+
         stream = encode(data=data, files=files)
+        assert stream.can_replay()
+
         assert stream.content_type == f"multipart/form-data; boundary={boundary}"
-        assert stream.body == (
+        content = (
             '--{0}\r\nContent-Disposition: form-data; name="a"\r\n\r\n1\r\n'
             '--{0}\r\nContent-Disposition: form-data; name="b"\r\n\r\nC\r\n'
             '--{0}\r\nContent-Disposition: form-data; name="c"\r\n\r\n11\r\n'
@@ -127,17 +134,20 @@ def test_multipart_encode():
             "--{0}--\r\n"
             "".format(boundary).encode("ascii")
         )
+        assert stream.get_headers()["Content-Length"] == str(len(content))
+        assert b"".join(stream) == content
 
 
-def test_multipart_encode_files_allows_filenames_as_none():
+def test_multipart_encode_files_allows_filenames_as_none() -> None:
     files = {"file": (None, io.BytesIO(b"<file content>"))}
     with mock.patch("os.urandom", return_value=os.urandom(16)):
         boundary = binascii.hexlify(os.urandom(16)).decode("ascii")
 
         stream = encode(data={}, files=files)
+        assert stream.can_replay()
 
         assert stream.content_type == f"multipart/form-data; boundary={boundary}"
-        assert stream.body == (
+        assert b"".join(stream) == (
             '--{0}\r\nContent-Disposition: form-data; name="file"\r\n\r\n'
             "<file content>\r\n--{0}--\r\n"
             "".format(boundary).encode("ascii")
@@ -153,16 +163,17 @@ def test_multipart_encode_files_allows_filenames_as_none():
     ],
 )
 def test_multipart_encode_files_guesses_correct_content_type(
-    file_name, expected_content_type
-):
+    file_name: str, expected_content_type: str
+) -> None:
     files = {"file": (file_name, io.BytesIO(b"<file content>"))}
     with mock.patch("os.urandom", return_value=os.urandom(16)):
         boundary = binascii.hexlify(os.urandom(16)).decode("ascii")
 
         stream = encode(data={}, files=files)
+        assert stream.can_replay()
 
         assert stream.content_type == f"multipart/form-data; boundary={boundary}"
-        assert stream.body == (
+        assert b"".join(stream) == (
             f'--{boundary}\r\nContent-Disposition: form-data; name="file"; '
             f'filename="{file_name}"\r\nContent-Type: '
             f"{expected_content_type}\r\n\r\n<file content>\r\n--{boundary}--\r\n"
@@ -170,21 +181,64 @@ def test_multipart_encode_files_guesses_correct_content_type(
         )
 
 
-def test_multipart_encode_files_allows_str_content():
+def test_multipart_encode_files_allows_str_content() -> None:
     files = {"file": ("test.txt", "<string content>", "text/plain")}
     with mock.patch("os.urandom", return_value=os.urandom(16)):
         boundary = binascii.hexlify(os.urandom(16)).decode("ascii")
 
         stream = encode(data={}, files=files)
+        assert stream.can_replay()
 
         assert stream.content_type == f"multipart/form-data; boundary={boundary}"
-        assert stream.body == (
+        content = (
             '--{0}\r\nContent-Disposition: form-data; name="file"; '
             'filename="test.txt"\r\n'
             "Content-Type: text/plain\r\n\r\n<string content>\r\n"
             "--{0}--\r\n"
             "".format(boundary).encode("ascii")
         )
+        assert stream.get_headers()["Content-Length"] == str(len(content))
+        assert b"".join(stream) == content
+
+
+def test_multipart_encode_non_seekable_filelike() -> None:
+    """
+    Test that special readable but non-seekable filelike objects are supported,
+    at the cost of reading them into memory at most once.
+    """
+
+    class IteratorIO(io.IOBase):
+        def __init__(self, iterator: typing.Iterator[bytes]) -> None:
+            self._iterator = iterator
+
+        def seekable(self) -> bool:
+            return False
+
+        def read(self, *args: typing.Any) -> bytes:
+            return b"".join(self._iterator)
+
+    def data() -> typing.Iterator[bytes]:
+        yield b"Hello"
+        yield b"World"
+
+    fileobj = IteratorIO(data())
+    files = {"file": fileobj}
+    stream = encode(files=files, boundary=b"+++")
+    assert not stream.can_replay()
+
+    content = (
+        b"--+++\r\n"
+        b'Content-Disposition: form-data; name="file"; filename="upload"\r\n'
+        b"Content-Type: application/octet-stream\r\n"
+        b"\r\n"
+        b"HelloWorld\r\n"
+        b"--+++--\r\n"
+    )
+    assert stream.get_headers() == {
+        "Content-Type": "multipart/form-data; boundary=+++",
+        "Content-Length": str(len(content)),
+    }
+    assert b"".join(stream) == content
 
 
 class TestHeaderParamHTML5Formatting: