]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Fix bytes support in multipart uploads (#974)
authorFlorimond Manca <florimond.manca@gmail.com>
Thu, 21 May 2020 14:25:31 +0000 (16:25 +0200)
committerGitHub <noreply@github.com>
Thu, 21 May 2020 14:25:31 +0000 (16:25 +0200)
httpx/_content_streams.py
tests/test_multipart.py

index 72c5c52debbef059d335e84906a80e5382291ce2..2484aa44a6a9c16b962f1aaa43bbc0acfbc33134 100644 (file)
@@ -247,7 +247,7 @@ class MultipartStream(ContentStream):
         def get_length(self) -> int:
             headers = self.render_headers()
 
-            if isinstance(self.file, str):
+            if isinstance(self.file, (str, bytes)):
                 return len(headers) + len(self.file)
 
             # Let's do our best not to read `file` into memory.
@@ -279,7 +279,7 @@ class MultipartStream(ContentStream):
             return self._headers
 
         def render_data(self) -> typing.Iterator[bytes]:
-            if isinstance(self.file, str):
+            if isinstance(self.file, (str, bytes)):
                 yield to_bytes(self.file)
                 return
 
@@ -297,7 +297,7 @@ class MultipartStream(ContentStream):
                 self.file.seek(0)
 
         def can_replay(self) -> bool:
-            return True if isinstance(self.file, str) else self.file.seekable()
+            return True if isinstance(self.file, (str, bytes)) else self.file.seekable()
 
         def render(self) -> typing.Iterator[bytes]:
             yield self.render_headers()
index 67f6a30697f25a76f93954a289537aa4c7b0386b..aeb287a8a73388f5785ab5f98d48eefc96fb07b6 100644 (file)
@@ -181,8 +181,14 @@ def test_multipart_encode_files_guesses_correct_content_type(
         )
 
 
-def test_multipart_encode_files_allows_str_content() -> None:
-    files = {"file": ("test.txt", "<string content>", "text/plain")}
+@pytest.mark.parametrize(
+    "value, output",
+    ((b"<bytes content>", "<bytes content>"), ("<string content>", "<string content>")),
+)
+def test_multipart_encode_files_allows_bytes_or_str_content(
+    value: typing.Union[str, bytes], output: str
+) -> None:
+    files = {"file": ("test.txt", value, "text/plain")}
     with mock.patch("os.urandom", return_value=os.urandom(16)):
         boundary = binascii.hexlify(os.urandom(16)).decode("ascii")
 
@@ -193,9 +199,9 @@ def test_multipart_encode_files_allows_str_content() -> None:
         content = (
             '--{0}\r\nContent-Disposition: form-data; name="file"; '
             'filename="test.txt"\r\n'
-            "Content-Type: text/plain\r\n\r\n<string content>\r\n"
+            "Content-Type: text/plain\r\n\r\n{1}\r\n"
             "--{0}--\r\n"
-            "".format(boundary).encode("ascii")
+            "".format(boundary, output).encode("ascii")
         )
         assert stream.get_headers()["Content-Length"] == str(len(content))
         assert b"".join(stream) == content