]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Allow lists in query params (#386)
authorFlorimond Manca <florimond.manca@gmail.com>
Tue, 8 Oct 2019 20:12:04 +0000 (22:12 +0200)
committerSeth Michael Larson <sethmichaellarson@gmail.com>
Tue, 8 Oct 2019 20:12:04 +0000 (15:12 -0500)
httpx/models.py
httpx/utils.py
tests/models/test_queryparams.py

index f70fdf440a5e9d6267be85dd22467b253cd43fb7..136aa41c26338511eb0895146e3fa8259f4415b3 100644 (file)
@@ -32,6 +32,7 @@ from .exceptions import (
 from .multipart import multipart_encode
 from .status_codes import StatusCode
 from .utils import (
+    flatten_queryparams,
     guess_json_utf,
     is_known_encoding,
     normalize_header_key,
@@ -51,7 +52,7 @@ URLTypes = typing.Union["URL", str]
 
 QueryParamTypes = typing.Union[
     "QueryParams",
-    typing.Mapping[str, PrimitiveData],
+    typing.Mapping[str, typing.Union[PrimitiveData, typing.Sequence[PrimitiveData]]],
     typing.List[typing.Tuple[str, PrimitiveData]],
     str,
 ]
@@ -311,14 +312,15 @@ class QueryParams(typing.Mapping[str, str]):
 
         value = args[0] if args else kwargs
 
+        items: typing.Sequence[typing.Tuple[str, PrimitiveData]]
         if isinstance(value, str):
             items = parse_qsl(value)
         elif isinstance(value, QueryParams):
             items = value.multi_items()
         elif isinstance(value, list):
-            items = value  # type: ignore
+            items = value
         else:
-            items = value.items()  # type: ignore
+            items = flatten_queryparams(value)
 
         self._list = [(str(k), str_query_param(v)) for k, v in items]
         self._dict = {str(k): str_query_param(v) for k, v in items}
index c8fcb1e66893333c161254568405147e51544023..8aea5e7119338f44de6c70748d3b97941989cda0 100644 (file)
@@ -1,4 +1,5 @@
 import codecs
+import collections
 import logging
 import netrc
 import os
@@ -11,6 +12,9 @@ from time import perf_counter
 from types import TracebackType
 from urllib.request import getproxies
 
+if typing.TYPE_CHECKING:  # pragma: no cover
+    from .models import PrimitiveData
+
 
 def normalize_header_key(value: typing.AnyStr, encoding: str = None) -> bytes:
     """
@@ -30,7 +34,7 @@ def normalize_header_value(value: typing.AnyStr, encoding: str = None) -> bytes:
     return value.encode(encoding or "ascii")
 
 
-def str_query_param(value: typing.Optional[typing.Union[str, int, float, bool]]) -> str:
+def str_query_param(value: "PrimitiveData") -> str:
     """
     Coerce a primitive data type into a string value for query params.
 
@@ -256,6 +260,31 @@ def unquote(value: str) -> str:
     return value[1:-1] if value[0] == value[-1] == '"' else value
 
 
+def flatten_queryparams(
+    queryparams: typing.Mapping[
+        str, typing.Union["PrimitiveData", typing.Sequence["PrimitiveData"]]
+    ]
+) -> typing.List[typing.Tuple[str, "PrimitiveData"]]:
+    """
+    Convert a mapping of query params into a flat list of two-tuples
+    representing each item.
+
+    Example:
+    >>> flatten_queryparams_values({"q": "httpx", "tag": ["python", "dev"]})
+    [("q", "httpx), ("tag", "python"), ("tag", "dev")]
+    """
+    items = []
+
+    for k, v in queryparams.items():
+        if isinstance(v, collections.abc.Sequence) and not isinstance(v, (str, bytes)):
+            for u in v:
+                items.append((k, u))
+        else:
+            items.append((k, typing.cast("PrimitiveData", v)))
+
+    return items
+
+
 class ElapsedTimer:
     def __init__(self) -> None:
         self.start: float = perf_counter()
index 8c4df4907ffe31120a719a0a93ba0490e769e7a5..1303170d00dabfbf07e3339b32d9e19dffcb949b 100644 (file)
@@ -1,8 +1,18 @@
+import pytest
+
 from httpx import QueryParams
 
 
-def test_queryparams():
-    q = QueryParams("a=123&a=456&b=789")
+@pytest.mark.parametrize(
+    "source",
+    [
+        "a=123&a=456&b=789",
+        {"a": ["123", "456"], "b": 789},
+        {"a": ("123", "456"), "b": 789},
+    ],
+)
+def test_queryparams(source):
+    q = QueryParams(source)
     assert "a" in q
     assert "A" not in q
     assert "c" not in q