From: Florimond Manca Date: Thu, 21 May 2020 14:25:31 +0000 (+0200) Subject: Fix bytes support in multipart uploads (#974) X-Git-Tag: 0.13.0~4 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ab9ace274931b4d47f9c52a8b8625027c348661c;p=thirdparty%2Fhttpx.git Fix bytes support in multipart uploads (#974) --- diff --git a/httpx/_content_streams.py b/httpx/_content_streams.py index 72c5c52d..2484aa44 100644 --- a/httpx/_content_streams.py +++ b/httpx/_content_streams.py @@ -247,7 +247,7 @@ class MultipartStream(ContentStream): def get_length(self) -> int: headers = self.render_headers() - if isinstance(self.file, str): + if isinstance(self.file, (str, bytes)): return len(headers) + len(self.file) # Let's do our best not to read `file` into memory. @@ -279,7 +279,7 @@ class MultipartStream(ContentStream): return self._headers def render_data(self) -> typing.Iterator[bytes]: - if isinstance(self.file, str): + if isinstance(self.file, (str, bytes)): yield to_bytes(self.file) return @@ -297,7 +297,7 @@ class MultipartStream(ContentStream): self.file.seek(0) def can_replay(self) -> bool: - return True if isinstance(self.file, str) else self.file.seekable() + return True if isinstance(self.file, (str, bytes)) else self.file.seekable() def render(self) -> typing.Iterator[bytes]: yield self.render_headers() diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 67f6a306..aeb287a8 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -181,8 +181,14 @@ def test_multipart_encode_files_guesses_correct_content_type( ) -def test_multipart_encode_files_allows_str_content() -> None: - files = {"file": ("test.txt", "", "text/plain")} +@pytest.mark.parametrize( + "value, output", + ((b"", ""), ("", "")), +) +def test_multipart_encode_files_allows_bytes_or_str_content( + value: typing.Union[str, bytes], output: str +) -> None: + files = {"file": ("test.txt", value, "text/plain")} with mock.patch("os.urandom", return_value=os.urandom(16)): boundary = binascii.hexlify(os.urandom(16)).decode("ascii") @@ -193,9 +199,9 @@ def test_multipart_encode_files_allows_str_content() -> None: content = ( '--{0}\r\nContent-Disposition: form-data; name="file"; ' 'filename="test.txt"\r\n' - "Content-Type: text/plain\r\n\r\n\r\n" + "Content-Type: text/plain\r\n\r\n{1}\r\n" "--{0}--\r\n" - "".format(boundary).encode("ascii") + "".format(boundary, output).encode("ascii") ) assert stream.get_headers()["Content-Length"] == str(len(content)) assert b"".join(stream) == content