def encode_multipart_data(
- data: dict, files: RequestFiles, boundary: Optional[bytes] = None
+ data: dict, files: RequestFiles, boundary: Optional[bytes]
) -> Tuple[Dict[str, str], MultipartStream]:
multipart = MultipartStream(data=data, files=files, boundary=boundary)
headers = multipart.get_headers()
StreamConsumed,
request_context,
)
+from ._multipart import get_multipart_boundary_from_content_type
from ._status_codes import codes
from ._types import (
AsyncByteStream,
Cookies(cookies).set_cookie_header(self)
if stream is None:
- headers, stream = encode_request(content, data, files, json)
+ content_type: typing.Optional[str] = self.headers.get("content-type")
+ headers, stream = encode_request(
+ content=content,
+ data=data,
+ files=files,
+ json=json,
+ boundary=get_multipart_boundary_from_content_type(
+ content_type=content_type.encode(self.headers.encoding)
+ if content_type
+ else None
+ ),
+ )
self._prepare(headers)
self.stream = stream
# Load the request body, except for streaming content.
)
+def get_multipart_boundary_from_content_type(
+ content_type: typing.Optional[bytes],
+) -> typing.Optional[bytes]:
+ if not content_type or not content_type.startswith(b"multipart/form-data"):
+ return None
+ # parse boundary according to
+ # https://www.rfc-editor.org/rfc/rfc2046#section-5.1.1
+ if b";" in content_type:
+ for section in content_type.split(b";"):
+ if section.strip().lower().startswith(b"boundary="):
+ return section.strip()[len(b"boundary=") :].strip(b'"')
+ return None
+
+
class DataField:
"""
A single form field item, within a multipart form field.
assert multipart["file"] == [b"<file content>"]
+@pytest.mark.parametrize(
+ "header",
+ [
+ "multipart/form-data; boundary=+++; charset=utf-8",
+ "multipart/form-data; charset=utf-8; boundary=+++",
+ "multipart/form-data; boundary=+++",
+ "multipart/form-data; boundary=+++ ;",
+ 'multipart/form-data; boundary="+++"; charset=utf-8',
+ 'multipart/form-data; charset=utf-8; boundary="+++"',
+ 'multipart/form-data; boundary="+++"',
+ 'multipart/form-data; boundary="+++" ;',
+ ],
+)
+def test_multipart_explicit_boundary(header: str) -> None:
+ client = httpx.Client(transport=httpx.MockTransport(echo_request_content))
+
+ files = {"file": io.BytesIO(b"<file content>")}
+ headers = {"content-type": header}
+ response = client.post("http://127.0.0.1:8000/", files=files, headers=headers)
+ assert response.status_code == 200
+
+ # We're using the cgi module to verify the behavior here, which is a
+ # bit grungy, but sufficient just for our testing purposes.
+ assert response.request.headers["Content-Type"] == header
+ content_length = response.request.headers["Content-Length"]
+ pdict: dict = {
+ "boundary": b"+++",
+ "CONTENT-LENGTH": content_length,
+ }
+ multipart = cgi.parse_multipart(io.BytesIO(response.content), pdict)
+
+ assert multipart["file"] == [b"<file content>"]
+
+
+@pytest.mark.parametrize(
+ "header",
+ [
+ "multipart/form-data; charset=utf-8",
+ "multipart/form-data; charset=utf-8; ",
+ ],
+)
+def test_multipart_header_without_boundary(header: str) -> None:
+ client = httpx.Client(transport=httpx.MockTransport(echo_request_content))
+
+ files = {"file": io.BytesIO(b"<file content>")}
+ headers = {"content-type": header}
+ response = client.post("http://127.0.0.1:8000/", files=files, headers=headers)
+
+ assert response.status_code == 200
+ assert response.request.headers["Content-Type"] == header
+
+
@pytest.mark.parametrize(("key"), (b"abc", 1, 2.3, None))
def test_multipart_invalid_key(key):
client = httpx.Client(transport=httpx.MockTransport(echo_request_content))