]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Prefer `Content-Length` over `Transfer-Encoding: chunked` for `content=<file-like...
authorTom Christie <tom@tomchristie.com>
Fri, 30 Apr 2021 09:40:42 +0000 (10:40 +0100)
committerGitHub <noreply@github.com>
Fri, 30 Apr 2021 09:40:42 +0000 (10:40 +0100)
* Add failing test case for 'content=io.BytesIO(...)'

* Refactor peek_filelike_length to return an Optional[int]

* Peek filelength on file-like objects when rendering 'content=...'

httpx/_content.py
httpx/_multipart.py
httpx/_utils.py
tests/test_content.py

index 9c7c1ff2252f0211bcf5f18ee6de65f6af4db0d1..86f3c7c254109afc5e3967ca1c92ed0bc4d885b9 100644 (file)
@@ -17,7 +17,7 @@ from ._exceptions import StreamClosed, StreamConsumed
 from ._multipart import MultipartStream
 from ._transports.base import AsyncByteStream, SyncByteStream
 from ._types import RequestContent, RequestData, RequestFiles, ResponseContent
-from ._utils import primitive_value_to_str
+from ._utils import peek_filelike_length, primitive_value_to_str
 
 
 class ByteStream(AsyncByteStream, SyncByteStream):
@@ -82,12 +82,17 @@ def encode_content(
 
     if isinstance(content, (bytes, str)):
         body = content.encode("utf-8") if isinstance(content, str) else content
-        content_length = str(len(body))
-        headers = {"Content-Length": content_length} if body else {}
+        content_length = len(body)
+        headers = {"Content-Length": str(content_length)} if body else {}
         return headers, ByteStream(body)
 
     elif isinstance(content, Iterable):
-        headers = {"Transfer-Encoding": "chunked"}
+        content_length_or_none = peek_filelike_length(content)
+
+        if content_length_or_none is None:
+            headers = {"Transfer-Encoding": "chunked"}
+        else:
+            headers = {"Content-Length": str(content_length_or_none)}
         return headers, IteratorByteStream(content)  # type: ignore
 
     elif isinstance(content, AsyncIterable):
index cb23d0cfa5604b8623bac5a81eee083258768c52..36bae664e33666d91e863633cbdfc04f138fc881 100644 (file)
@@ -93,9 +93,8 @@ class FileField:
             return len(headers) + len(to_bytes(self.file))
 
         # Let's do our best not to read `file` into memory.
-        try:
-            file_length = peek_filelike_length(self.file)
-        except OSError:
+        file_length = peek_filelike_length(self.file)
+        if file_length is None:
             # As a last resort, read file and cache contents for later.
             assert not hasattr(self, "_data")
             self._data = to_bytes(self.file.read())
index dcdc5c3aa5060e4db79c5aa8e5e750afded8d9ed..30ab2ed5a686346f78e058f3c12469147b391df3 100644 (file)
@@ -342,7 +342,7 @@ def guess_content_type(filename: typing.Optional[str]) -> typing.Optional[str]:
     return None
 
 
-def peek_filelike_length(stream: typing.IO) -> int:
+def peek_filelike_length(stream: typing.Any) -> typing.Optional[int]:
     """
     Given a file-like stream object, return its length in number of bytes
     without reading it into memory.
@@ -350,7 +350,9 @@ def peek_filelike_length(stream: typing.IO) -> int:
     try:
         # Is it an actual file?
         fd = stream.fileno()
-    except OSError:
+        # Yup, seems to be an actual file.
+        length = os.fstat(fd).st_size
+    except (AttributeError, OSError):
         # No... Maybe it's something that supports random access, like `io.BytesIO`?
         try:
             # Assuming so, go to end of stream to figure out its length,
@@ -358,14 +360,11 @@ def peek_filelike_length(stream: typing.IO) -> int:
             offset = stream.tell()
             length = stream.seek(0, os.SEEK_END)
             stream.seek(offset)
-        except OSError:
+        except (AttributeError, OSError):
             # Not even that? Sorry, we're doomed...
-            raise
-        else:
-            return length
-    else:
-        # Yup, seems to be an actual file.
-        return os.fstat(fd).st_size
+            return None
+
+    return length
 
 
 class Timer:
index b10596619863ca06e97ae6b6c8032c9fcae6dc5c..2a273a133f79ea6eb7c1b063438fd37b08cb3a36 100644 (file)
@@ -48,6 +48,18 @@ async def test_bytes_content():
     assert async_content == b"Hello, world!"
 
 
+@pytest.mark.asyncio
+async def test_bytesio_content():
+    headers, stream = encode_request(content=io.BytesIO(b"Hello, world!"))
+    assert isinstance(stream, typing.Iterable)
+    assert not isinstance(stream, typing.AsyncIterable)
+
+    content = b"".join([part for part in stream])
+
+    assert headers == {"Content-Length": "13"}
+    assert content == b"Hello, world!"
+
+
 @pytest.mark.asyncio
 async def test_iterator_content():
     def hello_world():