]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
allow setting an explicit multipart boundary via headers (#2278)
authorAdrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>
Mon, 15 Aug 2022 15:20:07 +0000 (10:20 -0500)
committerGitHub <noreply@github.com>
Mon, 15 Aug 2022 15:20:07 +0000 (10:20 -0500)
httpx/_content.py
httpx/_models.py
httpx/_multipart.py
tests/test_multipart.py

index 24a967d50609658a8a496e5b088d4d820ac7ca96..eb7a7aef17cb173624d21da0b9a340626c73d05e 100644 (file)
@@ -150,7 +150,7 @@ def encode_urlencoded_data(
 
 
 def encode_multipart_data(
-    data: dict, files: RequestFiles, boundary: Optional[bytes] = None
+    data: dict, files: RequestFiles, boundary: Optional[bytes]
 ) -> Tuple[Dict[str, str], MultipartStream]:
     multipart = MultipartStream(data=data, files=files, boundary=boundary)
     headers = multipart.get_headers()
index 8879532c8164a5e8e725d9bc54f80828ecfbfedd..fd1d7fe9a14b54dd3f23bb0bb21442a685c2f320 100644 (file)
@@ -27,6 +27,7 @@ from ._exceptions import (
     StreamConsumed,
     request_context,
 )
+from ._multipart import get_multipart_boundary_from_content_type
 from ._status_codes import codes
 from ._types import (
     AsyncByteStream,
@@ -332,7 +333,18 @@ class Request:
             Cookies(cookies).set_cookie_header(self)
 
         if stream is None:
-            headers, stream = encode_request(content, data, files, json)
+            content_type: typing.Optional[str] = self.headers.get("content-type")
+            headers, stream = encode_request(
+                content=content,
+                data=data,
+                files=files,
+                json=json,
+                boundary=get_multipart_boundary_from_content_type(
+                    content_type=content_type.encode(self.headers.encoding)
+                    if content_type
+                    else None
+                ),
+            )
             self._prepare(headers)
             self.stream = stream
             # Load the request body, except for streaming content.
index d42f5cb31b2af7884eef1d6a0b16955e75656e67..8bd7a17c9b145353b20cee8cafacc15630203481 100644 (file)
@@ -20,6 +20,20 @@ from ._utils import (
 )
 
 
+def get_multipart_boundary_from_content_type(
+    content_type: typing.Optional[bytes],
+) -> typing.Optional[bytes]:
+    if not content_type or not content_type.startswith(b"multipart/form-data"):
+        return None
+    # parse boundary according to
+    # https://www.rfc-editor.org/rfc/rfc2046#section-5.1.1
+    if b";" in content_type:
+        for section in content_type.split(b";"):
+            if section.strip().lower().startswith(b"boundary="):
+                return section.strip()[len(b"boundary=") :].strip(b'"')
+    return None
+
+
 class DataField:
     """
     A single form field item, within a multipart form field.
index 46ad0e01a10a2058ab7eb1db3caa20bc32174cbf..dc93d26505f7a767eaa855ff920cb879e9f6e09b 100644 (file)
@@ -42,6 +42,58 @@ def test_multipart(value, output):
     assert multipart["file"] == [b"<file content>"]
 
 
+@pytest.mark.parametrize(
+    "header",
+    [
+        "multipart/form-data; boundary=+++; charset=utf-8",
+        "multipart/form-data; charset=utf-8; boundary=+++",
+        "multipart/form-data; boundary=+++",
+        "multipart/form-data; boundary=+++ ;",
+        'multipart/form-data; boundary="+++"; charset=utf-8',
+        'multipart/form-data; charset=utf-8; boundary="+++"',
+        'multipart/form-data; boundary="+++"',
+        'multipart/form-data; boundary="+++" ;',
+    ],
+)
+def test_multipart_explicit_boundary(header: str) -> None:
+    client = httpx.Client(transport=httpx.MockTransport(echo_request_content))
+
+    files = {"file": io.BytesIO(b"<file content>")}
+    headers = {"content-type": header}
+    response = client.post("http://127.0.0.1:8000/", files=files, headers=headers)
+    assert response.status_code == 200
+
+    # We're using the cgi module to verify the behavior here, which is a
+    # bit grungy, but sufficient just for our testing purposes.
+    assert response.request.headers["Content-Type"] == header
+    content_length = response.request.headers["Content-Length"]
+    pdict: dict = {
+        "boundary": b"+++",
+        "CONTENT-LENGTH": content_length,
+    }
+    multipart = cgi.parse_multipart(io.BytesIO(response.content), pdict)
+
+    assert multipart["file"] == [b"<file content>"]
+
+
+@pytest.mark.parametrize(
+    "header",
+    [
+        "multipart/form-data; charset=utf-8",
+        "multipart/form-data; charset=utf-8; ",
+    ],
+)
+def test_multipart_header_without_boundary(header: str) -> None:
+    client = httpx.Client(transport=httpx.MockTransport(echo_request_content))
+
+    files = {"file": io.BytesIO(b"<file content>")}
+    headers = {"content-type": header}
+    response = client.post("http://127.0.0.1:8000/", files=files, headers=headers)
+
+    assert response.status_code == 200
+    assert response.request.headers["Content-Type"] == header
+
+
 @pytest.mark.parametrize(("key"), (b"abc", 1, 2.3, None))
 def test_multipart_invalid_key(key):
     client = httpx.Client(transport=httpx.MockTransport(echo_request_content))