From: Tom Christie Date: Fri, 4 Feb 2022 14:48:57 +0000 (+0000) Subject: Always rewind files on multipart uploads. (#2065) X-Git-Tag: 0.23.0~47 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=0088253b32c7e6541ce85f40429f133c01ca35dc;p=thirdparty%2Fhttpx.git Always rewind files on multipart uploads. (#2065) * Test for multipart POST same file twice. * Always rewind files on multipart uploads * Linting --- diff --git a/httpx/_multipart.py b/httpx/_multipart.py index 34ee6315..3f981c85 100644 --- a/httpx/_multipart.py +++ b/httpx/_multipart.py @@ -113,7 +113,6 @@ class FileField: self.filename = filename self.file = fileobj self.headers = headers - self._consumed = False def get_length(self) -> int: headers = self.render_headers() @@ -158,9 +157,8 @@ class FileField: yield self._data return - if self._consumed: # pragma: nocover + if hasattr(self.file, "seek"): self.file.seek(0) - self._consumed = True chunk = self.file.read(self.CHUNK_SIZE) while chunk: diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 9980cb5b..46ad0e01 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -1,6 +1,7 @@ import cgi import io import os +import tempfile import typing from unittest import mock @@ -339,6 +340,25 @@ def test_multipart_encode_non_seekable_filelike() -> None: assert content == b"".join(stream) +def test_multipart_rewinds_files(): + with tempfile.TemporaryFile() as upload: + upload.write(b"Hello, world!") + + transport = httpx.MockTransport(echo_request_content) + client = httpx.Client(transport=transport) + + files = {"file": upload} + response = client.post("http://127.0.0.1:8000/", files=files) + assert response.status_code == 200 + assert b"\r\nHello, world!\r\n" in response.content + + # POSTing the same file instance a second time should have the same content. + files = {"file": upload} + response = client.post("http://127.0.0.1:8000/", files=files) + assert response.status_code == 200 + assert b"\r\nHello, world!\r\n" in response.content + + class TestHeaderParamHTML5Formatting: def test_unicode(self): param = format_form_param("filename", "n\u00e4me")