]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Always rewind files on multipart uploads. (#2065)
authorTom Christie <tom@tomchristie.com>
Fri, 4 Feb 2022 14:48:57 +0000 (14:48 +0000)
committerGitHub <noreply@github.com>
Fri, 4 Feb 2022 14:48:57 +0000 (14:48 +0000)
* Test for multipart POST same file twice.

* Always rewind files on multipart uploads

* Linting

httpx/_multipart.py
tests/test_multipart.py

index 34ee631557b197b96849ab27f9ff6971b3f96399..3f981c85d0203440d136eea2f0b0caa53dace2f3 100644 (file)
@@ -113,7 +113,6 @@ class FileField:
         self.filename = filename
         self.file = fileobj
         self.headers = headers
-        self._consumed = False
 
     def get_length(self) -> int:
         headers = self.render_headers()
@@ -158,9 +157,8 @@ class FileField:
             yield self._data
             return
 
-        if self._consumed:  # pragma: nocover
+        if hasattr(self.file, "seek"):
             self.file.seek(0)
-        self._consumed = True
 
         chunk = self.file.read(self.CHUNK_SIZE)
         while chunk:
index 9980cb5b4eb1c3b78c989fb068e6de45a6cca698..46ad0e01a10a2058ab7eb1db3caa20bc32174cbf 100644 (file)
@@ -1,6 +1,7 @@
 import cgi
 import io
 import os
+import tempfile
 import typing
 from unittest import mock
 
@@ -339,6 +340,25 @@ def test_multipart_encode_non_seekable_filelike() -> None:
     assert content == b"".join(stream)
 
 
+def test_multipart_rewinds_files():
+    with tempfile.TemporaryFile() as upload:
+        upload.write(b"Hello, world!")
+
+        transport = httpx.MockTransport(echo_request_content)
+        client = httpx.Client(transport=transport)
+
+        files = {"file": upload}
+        response = client.post("http://127.0.0.1:8000/", files=files)
+        assert response.status_code == 200
+        assert b"\r\nHello, world!\r\n" in response.content
+
+        # POSTing the same file instance a second time should have the same content.
+        files = {"file": upload}
+        response = client.post("http://127.0.0.1:8000/", files=files)
+        assert response.status_code == 200
+        assert b"\r\nHello, world!\r\n" in response.content
+
+
 class TestHeaderParamHTML5Formatting:
     def test_unicode(self):
         param = format_form_param("filename", "n\u00e4me")