]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Handle `data={"key": [None|int|float|bool]}` cases. (#1539)
authorTom Christie <tom@tomchristie.com>
Fri, 26 Mar 2021 12:54:04 +0000 (12:54 +0000)
committerGitHub <noreply@github.com>
Fri, 26 Mar 2021 12:54:04 +0000 (12:54 +0000)
* Fix Content-Length for unicode file contents with multipart

* Handle bool and None cases for URLEncoded data

* Handle int, float, bool, and None for multipart or urlencoded data

* Update httpx/_utils.py

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
httpx/_content.py
httpx/_models.py
httpx/_multipart.py
httpx/_utils.py
tests/test_content.py
tests/test_multipart.py

index bf402c9e299d7df07305efb55f3c91b03056e632..0b9672be3fb94327cefe5bc88f74ede3d26b9d8b 100644 (file)
@@ -21,6 +21,7 @@ from ._types import (
     RequestFiles,
     ResponseContent,
 )
+from ._utils import primitive_value_to_str
 
 
 class PlainByteStream:
@@ -106,7 +107,13 @@ def encode_content(
 def encode_urlencoded_data(
     data: dict,
 ) -> Tuple[Dict[str, str], ByteStream]:
-    body = urlencode(data, doseq=True).encode("utf-8")
+    plain_data = []
+    for key, value in data.items():
+        if isinstance(value, (list, tuple)):
+            plain_data.extend([(key, primitive_value_to_str(item)) for item in value])
+        else:
+            plain_data.append((key, primitive_value_to_str(value)))
+    body = urlencode(plain_data, doseq=True).encode("utf-8")
     content_length = str(len(body))
     content_type = "application/x-www-form-urlencoded"
     headers = {"Content-Length": content_length, "Content-Type": content_type}
index 34fb2d388c448eb760231e4109bb44b6e1d8a761..ade5a3192598ba12f03a02f6e04c33dd9a84d384 100644 (file)
@@ -54,7 +54,7 @@ from ._utils import (
     normalize_header_value,
     obfuscate_sensitive_headers,
     parse_header_links,
-    str_query_param,
+    primitive_value_to_str,
 )
 
 
@@ -450,8 +450,8 @@ class QueryParams(typing.Mapping[str, str]):
         else:
             items = flatten_queryparams(value)
 
-        self._list = [(str(k), str_query_param(v)) for k, v in items]
-        self._dict = {str(k): str_query_param(v) for k, v in items}
+        self._list = [(str(k), primitive_value_to_str(v)) for k, v in items]
+        self._dict = {str(k): primitive_value_to_str(v) for k, v in items}
 
     def keys(self) -> typing.KeysView:
         return self._dict.keys()
index bf75a5663b75c0472c74764fe2ab27d1a101c803..b5f8fb48f83b171cedc89b2861159e1af9bf30a9 100644 (file)
@@ -8,6 +8,7 @@ from ._utils import (
     format_form_param,
     guess_content_type,
     peek_filelike_length,
+    primitive_value_to_str,
     to_bytes,
 )
 
@@ -17,17 +18,21 @@ class DataField:
     A single form field item, within a multipart form field.
     """
 
-    def __init__(self, name: str, value: typing.Union[str, bytes]) -> None:
+    def __init__(
+        self, name: str, value: typing.Union[str, bytes, int, float, None]
+    ) -> None:
         if not isinstance(name, str):
             raise TypeError(
                 f"Invalid type for name. Expected str, got {type(name)}: {name!r}"
             )
-        if not isinstance(value, (str, bytes)):
+        if value is not None and not isinstance(value, (str, bytes, int, float)):
             raise TypeError(
-                f"Invalid type for value. Expected str or bytes, got {type(value)}: {value!r}"
+                f"Invalid type for value. Expected primitive type, got {type(value)}: {value!r}"
             )
         self.name = name
-        self.value = value
+        self.value: typing.Union[str, bytes] = (
+            value if isinstance(value, bytes) else primitive_value_to_str(value)
+        )
 
     def render_headers(self) -> bytes:
         if not hasattr(self, "_headers"):
index 072db3f1e8a345db11e790a1326b659168e18af1..cf136a3ba8e0612c6f180b84bfaf4c87470e7e8b 100644 (file)
@@ -56,9 +56,9 @@ def normalize_header_value(
     return value.encode(encoding or "ascii")
 
 
-def str_query_param(value: "PrimitiveData") -> str:
+def primitive_value_to_str(value: "PrimitiveData") -> str:
     """
-    Coerce a primitive data type into a string value for query params.
+    Coerce a primitive data type into a string value.
 
     Note that we prefer JSON-style 'true'/'false' for boolean values here.
     """
index 384f9f228754fafad02bf9c398c236b4c3aa0806..1dda02863231d959e010e44250a7b65c05779e4f 100644 (file)
@@ -139,6 +139,57 @@ async def test_urlencoded_content():
     assert async_content == b"Hello=world%21"
 
 
+@pytest.mark.asyncio
+async def test_urlencoded_boolean():
+    headers, stream = encode_request(data={"example": True})
+    assert isinstance(stream, typing.Iterable)
+    assert isinstance(stream, typing.AsyncIterable)
+
+    sync_content = b"".join([part for part in stream])
+    async_content = b"".join([part async for part in stream])
+
+    assert headers == {
+        "Content-Length": "12",
+        "Content-Type": "application/x-www-form-urlencoded",
+    }
+    assert sync_content == b"example=true"
+    assert async_content == b"example=true"
+
+
+@pytest.mark.asyncio
+async def test_urlencoded_none():
+    headers, stream = encode_request(data={"example": None})
+    assert isinstance(stream, typing.Iterable)
+    assert isinstance(stream, typing.AsyncIterable)
+
+    sync_content = b"".join([part for part in stream])
+    async_content = b"".join([part async for part in stream])
+
+    assert headers == {
+        "Content-Length": "8",
+        "Content-Type": "application/x-www-form-urlencoded",
+    }
+    assert sync_content == b"example="
+    assert async_content == b"example="
+
+
+@pytest.mark.asyncio
+async def test_urlencoded_list():
+    headers, stream = encode_request(data={"example": ["a", 1, True]})
+    assert isinstance(stream, typing.Iterable)
+    assert isinstance(stream, typing.AsyncIterable)
+
+    sync_content = b"".join([part for part in stream])
+    async_content = b"".join([part async for part in stream])
+
+    assert headers == {
+        "Content-Length": "32",
+        "Content-Type": "application/x-www-form-urlencoded",
+    }
+    assert sync_content == b"example=a&example=1&example=true"
+    assert async_content == b"example=a&example=1&example=true"
+
+
 @pytest.mark.asyncio
 async def test_multipart_files_content():
     files = {"file": io.BytesIO(b"<file content>")}
index 199af4b0a5098d737d06cf71d263b939f462ec49..9eb62f785beb49ffe486b845a27e4a0dd62daad1 100644 (file)
@@ -57,7 +57,7 @@ def test_multipart_invalid_key(key):
     assert repr(key) in str(e.value)
 
 
-@pytest.mark.parametrize(("value"), (1, 2.3, None, [None, "abc"], {None: "abc"}))
+@pytest.mark.parametrize(("value"), (object(), {"key": "value"}))
 def test_multipart_invalid_value(value):
     client = httpx.Client(transport=httpx.MockTransport(echo_request_content))
 
@@ -104,6 +104,8 @@ def test_multipart_encode(tmp_path: typing.Any) -> None:
         "b": b"C",
         "c": ["11", "22", "33"],
         "d": "",
+        "e": True,
+        "f": "",
     }
     files = {"file": ("name.txt", open(path, "rb"))}
 
@@ -120,6 +122,8 @@ def test_multipart_encode(tmp_path: typing.Any) -> None:
             '--{0}\r\nContent-Disposition: form-data; name="c"\r\n\r\n22\r\n'
             '--{0}\r\nContent-Disposition: form-data; name="c"\r\n\r\n33\r\n'
             '--{0}\r\nContent-Disposition: form-data; name="d"\r\n\r\n\r\n'
+            '--{0}\r\nContent-Disposition: form-data; name="e"\r\n\r\ntrue\r\n'
+            '--{0}\r\nContent-Disposition: form-data; name="f"\r\n\r\n\r\n'
             '--{0}\r\nContent-Disposition: form-data; name="file";'
             ' filename="name.txt"\r\n'
             "Content-Type: text/plain\r\n\r\n<file content>\r\n"