]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Allow custom headers in multipart/form-data requests (#1936)
authorAdrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>
Thu, 13 Jan 2022 08:49:14 +0000 (00:49 -0800)
committerGitHub <noreply@github.com>
Thu, 13 Jan 2022 08:49:14 +0000 (08:49 +0000)
* feat: allow passing multipart headers

* Add test for including content-type in headers

* lint

* override content_type with headers

* compare tuples based on length

* incorporate suggestion

* remove .title() on headers

httpx/_multipart.py
httpx/_types.py
tests/test_multipart.py

index 51ba556a776b57a8a37a4fe20f2f4f071ae28b5b..34ee631557b197b96849ab27f9ff6971b3f96399 100644 (file)
@@ -78,23 +78,41 @@ class FileField:
 
         fileobj: FileContent
 
+        headers: typing.Dict[str, str] = {}
+        content_type: typing.Optional[str] = None
+
+        # This large tuple based API largely mirror's requests' API
+        # It would be good to think of better APIs for this that we could include in httpx 2.0
+        # since variable length tuples (especially of 4 elements) are quite unwieldly
         if isinstance(value, tuple):
-            try:
-                filename, fileobj, content_type = value  # type: ignore
-            except ValueError:
+            if len(value) == 2:
+                # neither the 3rd parameter (content_type) nor the 4th (headers) was included
                 filename, fileobj = value  # type: ignore
-                content_type = guess_content_type(filename)
+            elif len(value) == 3:
+                filename, fileobj, content_type = value  # type: ignore
+            else:
+                # all 4 parameters included
+                filename, fileobj, content_type, headers = value  # type: ignore
         else:
             filename = Path(str(getattr(value, "name", "upload"))).name
             fileobj = value
+
+        if content_type is None:
             content_type = guess_content_type(filename)
 
+        has_content_type_header = any("content-type" in key.lower() for key in headers)
+        if content_type is not None and not has_content_type_header:
+            # note that unlike requests, we ignore the content_type
+            # provided in the 3rd tuple element if it is also included in the headers
+            # requests does the opposite (it overwrites the header with the 3rd tuple element)
+            headers["Content-Type"] = content_type
+
         if isinstance(fileobj, (str, io.StringIO)):
             raise TypeError(f"Expected bytes or bytes-like object got: {type(fileobj)}")
 
         self.filename = filename
         self.file = fileobj
-        self.content_type = content_type
+        self.headers = headers
         self._consumed = False
 
     def get_length(self) -> int:
@@ -122,9 +140,9 @@ class FileField:
             if self.filename:
                 filename = format_form_param("filename", self.filename)
                 parts.extend([b"; ", filename])
-            if self.content_type is not None:
-                content_type = self.content_type.encode()
-                parts.extend([b"\r\nContent-Type: ", content_type])
+            for header_name, header_value in self.headers.items():
+                key, val = f"\r\n{header_name}: ".encode(), header_value.encode()
+                parts.extend([key, val])
             parts.append(b"\r\n\r\n")
             self._headers = b"".join(parts)
 
index 8cd85cd93307dc38292b0e3999c38028c12832ea..f7ba4486cc84ce548c0c65127af851a68a8f6010 100644 (file)
@@ -89,6 +89,8 @@ FileTypes = Union[
     Tuple[Optional[str], FileContent],
     # (filename, file (or bytes), content_type)
     Tuple[Optional[str], FileContent, Optional[str]],
+    # (filename, file (or bytes), content_type, headers)
+    Tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]],
 ]
 RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]]
 
index cd71a246b38874f2fc76877ee69c527dbffdf918..9980cb5b4eb1c3b78c989fb068e6de45a6cca698 100644 (file)
@@ -94,6 +94,58 @@ def test_multipart_file_tuple():
     assert multipart["file"] == [b"<file content>"]
 
 
+@pytest.mark.parametrize("content_type", [None, "text/plain"])
+def test_multipart_file_tuple_headers(content_type: typing.Optional[str]):
+    file_name = "test.txt"
+    expected_content_type = "text/plain"
+    headers = {"Expires": "0"}
+
+    files = {"file": (file_name, io.BytesIO(b"<file content>"), content_type, headers)}
+    with mock.patch("os.urandom", return_value=os.urandom(16)):
+        boundary = os.urandom(16).hex()
+
+        headers, stream = encode_request(data={}, files=files)
+        assert isinstance(stream, typing.Iterable)
+
+        content = (
+            f'--{boundary}\r\nContent-Disposition: form-data; name="file"; '
+            f'filename="{file_name}"\r\nExpires: 0\r\nContent-Type: '
+            f"{expected_content_type}\r\n\r\n<file content>\r\n--{boundary}--\r\n"
+            "".encode("ascii")
+        )
+        assert headers == {
+            "Content-Type": f"multipart/form-data; boundary={boundary}",
+            "Content-Length": str(len(content)),
+        }
+        assert content == b"".join(stream)
+
+
+def test_multipart_headers_include_content_type() -> None:
+    """Content-Type from 4th tuple parameter (headers) should override the 3rd parameter (content_type)"""
+    file_name = "test.txt"
+    expected_content_type = "image/png"
+    headers = {"Content-Type": "image/png"}
+
+    files = {"file": (file_name, io.BytesIO(b"<file content>"), "text_plain", headers)}
+    with mock.patch("os.urandom", return_value=os.urandom(16)):
+        boundary = os.urandom(16).hex()
+
+        headers, stream = encode_request(data={}, files=files)
+        assert isinstance(stream, typing.Iterable)
+
+        content = (
+            f'--{boundary}\r\nContent-Disposition: form-data; name="file"; '
+            f'filename="{file_name}"\r\nContent-Type: '
+            f"{expected_content_type}\r\n\r\n<file content>\r\n--{boundary}--\r\n"
+            "".encode("ascii")
+        )
+        assert headers == {
+            "Content-Type": f"multipart/form-data; boundary={boundary}",
+            "Content-Length": str(len(content)),
+        }
+        assert content == b"".join(stream)
+
+
 def test_multipart_encode(tmp_path: typing.Any) -> None:
     path = str(tmp_path / "name.txt")
     with open(path, "wb") as f: