import binascii
-import mimetypes
import os
import typing
-from io import BytesIO
from json import dumps as json_dumps
from pathlib import Path
from urllib.parse import urlencode
from ._exceptions import StreamConsumed
from ._types import StrOrBytes
-from ._utils import format_form_param
+from ._utils import (
+ format_form_param,
+ guess_content_type,
+ peek_filelike_length,
+ to_bytes,
+)
RequestData = typing.Union[
dict, str, bytes, typing.Iterator[bytes], typing.AsyncIterator[bytes]
class MultipartStream(ContentStream):
"""
- Request content as multipart encoded form data.
+ Request content as streaming multipart encoded form data.
"""
class DataField:
self.value = value
def render_headers(self) -> bytes:
- name = format_form_param("name", self.name)
- return b"".join([b"Content-Disposition: form-data; ", name, b"\r\n\r\n"])
+ if not hasattr(self, "_headers"):
+ name = format_form_param("name", self.name)
+ self._headers = b"".join(
+ [b"Content-Disposition: form-data; ", name, b"\r\n\r\n"]
+ )
+
+ return self._headers
def render_data(self) -> bytes:
- return (
- self.value
- if isinstance(self.value, bytes)
- else self.value.encode("utf-8")
- )
+ if not hasattr(self, "_data"):
+ self._data = (
+ self.value
+ if isinstance(self.value, bytes)
+ else self.value.encode("utf-8")
+ )
+
+ return self._data
+
+ def get_length(self) -> int:
+ headers = self.render_headers()
+ data = self.render_data()
+ return len(headers) + len(data)
+
+ def can_replay(self) -> bool:
+ return True
+
+ def render(self) -> typing.Iterator[bytes]:
+ yield self.render_headers()
+ yield self.render_data()
class FileField:
"""
self.name = name
if not isinstance(value, tuple):
self.filename = Path(str(getattr(value, "name", "upload"))).name
- self.file = (
- value
- ) # type: typing.Union[typing.IO[str], typing.IO[bytes]]
- self.content_type = self.guess_content_type()
+ self.file: typing.Union[typing.IO[str], typing.IO[bytes]] = value
+ self.content_type = guess_content_type(self.filename)
else:
self.filename = value[0]
self.file = value[1]
self.content_type = (
- value[2] if len(value) > 2 else self.guess_content_type()
+ value[2] if len(value) > 2 else guess_content_type(self.filename)
)
- def guess_content_type(self) -> typing.Optional[str]:
- if self.filename:
- return (
- mimetypes.guess_type(self.filename)[0] or "application/octet-stream"
- )
- else:
- return None
+ def get_length(self) -> int:
+ headers = self.render_headers()
- def render_headers(self) -> bytes:
- parts = [
- b"Content-Disposition: form-data; ",
- format_form_param("name", self.name),
- ]
- if self.filename:
- filename = format_form_param("filename", self.filename)
- parts.extend([b"; ", filename])
- if self.content_type is not None:
- content_type = self.content_type.encode()
- parts.extend([b"\r\nContent-Type: ", content_type])
- parts.append(b"\r\n\r\n")
- return b"".join(parts)
+ if isinstance(self.file, str):
+ return len(headers) + len(self.file)
- def render_data(self) -> bytes:
- content: typing.Union[str, bytes]
+ # Let's do our best not to read `file` into memory.
+ try:
+ file_length = peek_filelike_length(self.file)
+ except OSError:
+ # As a last resort, read file and cache contents for later.
+ assert not hasattr(self, "_data")
+ self._data = to_bytes(self.file.read())
+ file_length = len(self._data)
+
+ return len(headers) + file_length
+
+ def render_headers(self) -> bytes:
+ if not hasattr(self, "_headers"):
+ parts = [
+ b"Content-Disposition: form-data; ",
+ format_form_param("name", self.name),
+ ]
+ if self.filename:
+ filename = format_form_param("filename", self.filename)
+ parts.extend([b"; ", filename])
+ if self.content_type is not None:
+ content_type = self.content_type.encode()
+ parts.extend([b"\r\nContent-Type: ", content_type])
+ parts.append(b"\r\n\r\n")
+ self._headers = b"".join(parts)
+
+ return self._headers
+
+ def render_data(self) -> typing.Iterator[bytes]:
if isinstance(self.file, str):
- content = self.file
- else:
- content = self.file.read()
- return content.encode("utf-8") if isinstance(content, str) else content
+ yield to_bytes(self.file)
+ return
- def __init__(self, data: dict, files: dict, boundary: bytes = None) -> None:
- body = BytesIO()
- if boundary is None:
- boundary = binascii.hexlify(os.urandom(16))
+ if hasattr(self, "_data"):
+ # Already rendered.
+ yield self._data
+ return
+
+ for chunk in self.file:
+ yield to_bytes(chunk)
+
+ # Get ready for the next replay, if possible.
+ if self.can_replay():
+ assert self.file.seekable()
+ self.file.seek(0)
- for field in self.iter_fields(data, files):
- body.write(b"--%s\r\n" % boundary)
- body.write(field.render_headers())
- body.write(field.render_data())
- body.write(b"\r\n")
+ def can_replay(self) -> bool:
+ return True if isinstance(self.file, str) else self.file.seekable()
- body.write(b"--%s--\r\n" % boundary)
+ def render(self) -> typing.Iterator[bytes]:
+ yield self.render_headers()
+ yield from self.render_data()
+ def __init__(
+ self, data: typing.Mapping, files: typing.Mapping, boundary: bytes = None
+ ) -> None:
+ if boundary is None:
+ boundary = binascii.hexlify(os.urandom(16))
+
+ self.boundary = boundary
self.content_type = "multipart/form-data; boundary=%s" % boundary.decode(
"ascii"
)
- self.body = body.getvalue()
+ self.fields = list(self._iter_fields(data, files))
- def iter_fields(
- self, data: dict, files: dict
+ def _iter_fields(
+ self, data: typing.Mapping, files: typing.Mapping
) -> typing.Iterator[typing.Union["FileField", "DataField"]]:
for name, value in data.items():
if isinstance(value, list):
for name, value in files.items():
yield self.FileField(name=name, value=value)
+ def iter_chunks(self) -> typing.Iterator[bytes]:
+ for field in self.fields:
+ yield b"--%s\r\n" % self.boundary
+ yield from field.render()
+ yield b"\r\n"
+ yield b"--%s--\r\n" % self.boundary
+
+ def iter_chunks_lengths(self) -> typing.Iterator[int]:
+ boundary_length = len(self.boundary)
+ # Follow closely what `.iter_chunks()` does.
+ for field in self.fields:
+ yield 2 + boundary_length + 2
+ yield field.get_length()
+ yield 2
+ yield 2 + boundary_length + 4
+
+ def get_content_length(self) -> int:
+ return sum(self.iter_chunks_lengths())
+
+ # Content stream interface.
+
+ def can_replay(self) -> bool:
+ return all(field.can_replay() for field in self.fields)
+
def get_headers(self) -> typing.Dict[str, str]:
- content_length = str(len(self.body))
+ content_length = str(self.get_content_length())
content_type = self.content_type
return {"Content-Length": content_length, "Content-Type": content_type}
def __iter__(self) -> typing.Iterator[bytes]:
- yield self.body
+ for chunk in self.iter_chunks():
+ yield chunk
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
- yield self.body
+ for chunk in self.iter_chunks():
+ yield chunk
def encode(
assert multipart["file"] == [b"<file content>"]
-def test_multipart_encode():
+def test_multipart_encode(tmp_path: typing.Any) -> None:
+ path = str(tmp_path / "name.txt")
+ with open(path, "wb") as f:
+ f.write(b"<file content>")
+
data = {
"a": "1",
"b": b"C",
"c": ["11", "22", "33"],
"d": "",
}
- files = {"file": ("name.txt", io.BytesIO(b"<file content>"))}
+ files = {"file": ("name.txt", open(path, "rb"))}
with mock.patch("os.urandom", return_value=os.urandom(16)):
boundary = binascii.hexlify(os.urandom(16)).decode("ascii")
+
stream = encode(data=data, files=files)
+ assert stream.can_replay()
+
assert stream.content_type == f"multipart/form-data; boundary={boundary}"
- assert stream.body == (
+ content = (
'--{0}\r\nContent-Disposition: form-data; name="a"\r\n\r\n1\r\n'
'--{0}\r\nContent-Disposition: form-data; name="b"\r\n\r\nC\r\n'
'--{0}\r\nContent-Disposition: form-data; name="c"\r\n\r\n11\r\n'
"--{0}--\r\n"
"".format(boundary).encode("ascii")
)
+ assert stream.get_headers()["Content-Length"] == str(len(content))
+ assert b"".join(stream) == content
-def test_multipart_encode_files_allows_filenames_as_none():
+def test_multipart_encode_files_allows_filenames_as_none() -> None:
files = {"file": (None, io.BytesIO(b"<file content>"))}
with mock.patch("os.urandom", return_value=os.urandom(16)):
boundary = binascii.hexlify(os.urandom(16)).decode("ascii")
stream = encode(data={}, files=files)
+ assert stream.can_replay()
assert stream.content_type == f"multipart/form-data; boundary={boundary}"
- assert stream.body == (
+ assert b"".join(stream) == (
'--{0}\r\nContent-Disposition: form-data; name="file"\r\n\r\n'
"<file content>\r\n--{0}--\r\n"
"".format(boundary).encode("ascii")
],
)
def test_multipart_encode_files_guesses_correct_content_type(
- file_name, expected_content_type
-):
+ file_name: str, expected_content_type: str
+) -> None:
files = {"file": (file_name, io.BytesIO(b"<file content>"))}
with mock.patch("os.urandom", return_value=os.urandom(16)):
boundary = binascii.hexlify(os.urandom(16)).decode("ascii")
stream = encode(data={}, files=files)
+ assert stream.can_replay()
assert stream.content_type == f"multipart/form-data; boundary={boundary}"
- assert stream.body == (
+ assert b"".join(stream) == (
f'--{boundary}\r\nContent-Disposition: form-data; name="file"; '
f'filename="{file_name}"\r\nContent-Type: '
f"{expected_content_type}\r\n\r\n<file content>\r\n--{boundary}--\r\n"
)
-def test_multipart_encode_files_allows_str_content():
+def test_multipart_encode_files_allows_str_content() -> None:
files = {"file": ("test.txt", "<string content>", "text/plain")}
with mock.patch("os.urandom", return_value=os.urandom(16)):
boundary = binascii.hexlify(os.urandom(16)).decode("ascii")
stream = encode(data={}, files=files)
+ assert stream.can_replay()
assert stream.content_type == f"multipart/form-data; boundary={boundary}"
- assert stream.body == (
+ content = (
'--{0}\r\nContent-Disposition: form-data; name="file"; '
'filename="test.txt"\r\n'
"Content-Type: text/plain\r\n\r\n<string content>\r\n"
"--{0}--\r\n"
"".format(boundary).encode("ascii")
)
+ assert stream.get_headers()["Content-Length"] == str(len(content))
+ assert b"".join(stream) == content
+
+
+def test_multipart_encode_non_seekable_filelike() -> None:
+ """
+ Test that special readable but non-seekable filelike objects are supported,
+ at the cost of reading them into memory at most once.
+ """
+
+ class IteratorIO(io.IOBase):
+ def __init__(self, iterator: typing.Iterator[bytes]) -> None:
+ self._iterator = iterator
+
+ def seekable(self) -> bool:
+ return False
+
+ def read(self, *args: typing.Any) -> bytes:
+ return b"".join(self._iterator)
+
+ def data() -> typing.Iterator[bytes]:
+ yield b"Hello"
+ yield b"World"
+
+ fileobj = IteratorIO(data())
+ files = {"file": fileobj}
+ stream = encode(files=files, boundary=b"+++")
+ assert not stream.can_replay()
+
+ content = (
+ b"--+++\r\n"
+ b'Content-Disposition: form-data; name="file"; filename="upload"\r\n'
+ b"Content-Type: application/octet-stream\r\n"
+ b"\r\n"
+ b"HelloWorld\r\n"
+ b"--+++--\r\n"
+ )
+ assert stream.get_headers() == {
+ "Content-Type": "multipart/form-data; boundary=+++",
+ "Content-Length": str(len(content)),
+ }
+ assert b"".join(stream) == content
class TestHeaderParamHTML5Formatting: