]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Streaming multipart support (#2382)
authorTom Christie <tom@tomchristie.com>
Mon, 12 Dec 2022 16:14:56 +0000 (16:14 +0000)
committerGitHub <noreply@github.com>
Mon, 12 Dec 2022 16:14:56 +0000 (16:14 +0000)
* Streaming multipart support

* Update tests for streaming multipary

httpx/_multipart.py
tests/test_multipart.py

index 2c08776f4b956c1ad245bbca3d5ef8f23ef1c49f..1d46d96a989b1b7269253046039350d10efe1172 100644 (file)
@@ -135,19 +135,18 @@ class FileField:
         self.file = fileobj
         self.headers = headers
 
-    def get_length(self) -> int:
+    def get_length(self) -> typing.Optional[int]:
         headers = self.render_headers()
 
         if isinstance(self.file, (str, bytes)):
             return len(headers) + len(to_bytes(self.file))
 
-        # Let's do our best not to read `file` into memory.
         file_length = peek_filelike_length(self.file)
+
+        # If we can't determine the filesize without reading it into memory,
+        # then return `None` here, to indicate an unknown file length.
         if file_length is None:
-            # As a last resort, read file and cache contents for later.
-            assert not hasattr(self, "_data")
-            self._data = to_bytes(self.file.read())
-            file_length = len(self._data)
+            return None
 
         return len(headers) + file_length
 
@@ -173,13 +172,11 @@ class FileField:
             yield to_bytes(self.file)
             return
 
-        if hasattr(self, "_data"):
-            # Already rendered.
-            yield self._data
-            return
-
         if hasattr(self.file, "seek"):
-            self.file.seek(0)
+            try:
+                self.file.seek(0)
+            except io.UnsupportedOperation:
+                pass
 
         chunk = self.file.read(self.CHUNK_SIZE)
         while chunk:
@@ -232,24 +229,34 @@ class MultipartStream(SyncByteStream, AsyncByteStream):
             yield b"\r\n"
         yield b"--%s--\r\n" % self.boundary
 
-    def iter_chunks_lengths(self) -> typing.Iterator[int]:
+    def get_content_length(self) -> typing.Optional[int]:
+        """
+        Return the length of the multipart encoded content, or `None` if
+        any of the files have a length that cannot be determined upfront.
+        """
         boundary_length = len(self.boundary)
-        # Follow closely what `.iter_chunks()` does.
+        length = 0
+
         for field in self.fields:
-            yield 2 + boundary_length + 2
-            yield field.get_length()
-            yield 2
-        yield 2 + boundary_length + 4
+            field_length = field.get_length()
+            if field_length is None:
+                return None
+
+            length += 2 + boundary_length + 2  # b"--{boundary}\r\n"
+            length += field_length
+            length += 2  # b"\r\n"
 
-    def get_content_length(self) -> int:
-        return sum(self.iter_chunks_lengths())
+        length += 2 + boundary_length + 4  # b"--{boundary}--\r\n"
+        return length
 
     # Content stream interface.
 
     def get_headers(self) -> typing.Dict[str, str]:
-        content_length = str(self.get_content_length())
+        content_length = self.get_content_length()
         content_type = self.content_type
-        return {"Content-Length": content_length, "Content-Type": content_type}
+        if content_length is None:
+            return {"Transfer-Encoding": "chunked", "Content-Type": content_type}
+        return {"Content-Length": str(content_length), "Content-Type": content_type}
 
     def __iter__(self) -> typing.Iterator[bytes]:
         for chunk in self.iter_chunks():
index e9ce928a16855e7b9dcce55011abe4df067946cc..6d281ed7d09955fbaa90c1548236b9579fa79d7b 100644 (file)
@@ -380,8 +380,9 @@ def test_multipart_encode_files_raises_exception_with_text_mode_file() -> None:
 
 def test_multipart_encode_non_seekable_filelike() -> None:
     """
-    Test that special readable but non-seekable filelike objects are supported,
-    at the cost of reading them into memory at most once.
+    Test that special readable but non-seekable filelike objects are supported.
+    In this case uploads with use 'Transfer-Encoding: chunked', instead of
+    a 'Content-Length' header.
     """
 
     class IteratorIO(io.IOBase):
@@ -410,7 +411,7 @@ def test_multipart_encode_non_seekable_filelike() -> None:
     )
     assert headers == {
         "Content-Type": "multipart/form-data; boundary=+++",
-        "Content-Length": str(len(content)),
+        "Transfer-Encoding": "chunked",
     }
     assert content == b"".join(stream)