]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Immutable QueryParams (#1600)
authorTom Christie <tom@tomchristie.com>
Mon, 26 Apr 2021 13:57:02 +0000 (14:57 +0100)
committerGitHub <noreply@github.com>
Mon, 26 Apr 2021 13:57:02 +0000 (14:57 +0100)
* Tweak QueryParams implementation

* Immutable QueryParams

httpx/_client.py
httpx/_models.py
httpx/_utils.py
tests/models/test_queryparams.py

index 371dbe77f47e6680ea749b69a3125ff5348b314c..ae42e9eac62f42c6ab76e53d60aa7601ae66dd02 100644 (file)
@@ -385,7 +385,7 @@ class BaseClient:
         """
         if params or self.params:
             merged_queryparams = QueryParams(self.params)
-            merged_queryparams.update(params)
+            merged_queryparams = merged_queryparams.merge(params)
             return merged_queryparams
         return params
 
index dc2958c7d63559f909666c3e08f23046611b485e..7b749ceeb4a9a7d4e510fc10fe95666674a325d7 100644 (file)
@@ -6,7 +6,7 @@ import typing
 import urllib.request
 from collections.abc import MutableMapping
 from http.cookiejar import Cookie, CookieJar
-from urllib.parse import parse_qsl, quote, unquote, urlencode
+from urllib.parse import parse_qs, quote, unquote, urlencode
 
 import idna
 import rfc3986
@@ -48,7 +48,6 @@ from ._types import (
     URLTypes,
 )
 from ._utils import (
-    flatten_queryparams,
     guess_json_utf,
     is_known_encoding,
     normalize_header_key,
@@ -148,8 +147,7 @@ class URL:
         # Add any query parameters, merging with any in the URL if needed.
         if params:
             if self._uri_reference.query:
-                url_params = QueryParams(self._uri_reference.query)
-                url_params.update(params)
+                url_params = QueryParams(self._uri_reference.query).merge(params)
                 query_string = str(url_params)
             else:
                 query_string = str(QueryParams(params))
@@ -450,7 +448,7 @@ class URL:
 
         url = httpx.URL("https://www.example.com/test")
         url = url.join("/new/path")
-        assert url == "https://www.example.com/test/new/path"
+        assert url == "https://www.example.com/new/path"
         """
         if self.is_relative_url:
             # Workaround to handle relative URLs, which otherwise raise
@@ -504,38 +502,79 @@ class QueryParams(typing.Mapping[str, str]):
         items: typing.Sequence[typing.Tuple[str, PrimitiveData]]
         if value is None or isinstance(value, (str, bytes)):
             value = value.decode("ascii") if isinstance(value, bytes) else value
-            items = parse_qsl(value)
+            self._dict = parse_qs(value)
         elif isinstance(value, QueryParams):
-            items = value.multi_items()
-        elif isinstance(value, (list, tuple)):
-            items = value
+            self._dict = {k: list(v) for k, v in value._dict.items()}
         else:
-            items = flatten_queryparams(value)
-
-        self._dict: typing.Dict[str, typing.List[str]] = {}
-        for item in items:
-            k, v = item
-            if str(k) not in self._dict:
-                self._dict[str(k)] = [primitive_value_to_str(v)]
+            dict_value: typing.Dict[typing.Any, typing.List[typing.Any]] = {}
+            if isinstance(value, (list, tuple)):
+                # Convert list inputs like:
+                #     [("a", "123"), ("a", "456"), ("b", "789")]
+                # To a dict representation, like:
+                #     {"a": ["123", "456"], "b": ["789"]}
+                for item in value:
+                    dict_value.setdefault(item[0], []).append(item[1])
             else:
-                self._dict[str(k)].append(primitive_value_to_str(v))
+                # Convert dict inputs like:
+                #    {"a": "123", "b": ["456", "789"]}
+                # To dict inputs where values are always lists, like:
+                #    {"a": ["123"], "b": ["456", "789"]}
+                dict_value = {
+                    k: list(v) if isinstance(v, (list, tuple)) else [v]
+                    for k, v in value.items()
+                }
+
+            # Ensure that keys and values are neatly coerced to strings.
+            # We coerce values `True` and `False` to JSON-like "true" and "false"
+            # representations, and coerce `None` values to the empty string.
+            self._dict = {
+                str(k): [primitive_value_to_str(item) for item in v]
+                for k, v in dict_value.items()
+            }
 
     def keys(self) -> typing.KeysView:
+        """
+        Return all the keys in the query params.
+
+        Usage:
+
+        q = httpx.QueryParams("a=123&a=456&b=789")
+        assert list(q.keys()) == ["a", "b"]
+        """
         return self._dict.keys()
 
     def values(self) -> typing.ValuesView:
+        """
+        Return all the values in the query params. If a key occurs more than once
+        only the first item for that key is returned.
+
+        Usage:
+
+        q = httpx.QueryParams("a=123&a=456&b=789")
+        assert list(q.values()) == ["123", "789"]
+        """
         return {k: v[0] for k, v in self._dict.items()}.values()
 
     def items(self) -> typing.ItemsView:
         """
         Return all items in the query params. If a key occurs more than once
         only the first item for that key is returned.
+
+        Usage:
+
+        q = httpx.QueryParams("a=123&a=456&b=789")
+        assert list(q.items()) == [("a", "123"), ("b", "789")]
         """
         return {k: v[0] for k, v in self._dict.items()}.items()
 
     def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
         """
         Return all items in the query params. Allow duplicate keys to occur.
+
+        Usage:
+
+        q = httpx.QueryParams("a=123&a=456&b=789")
+        assert list(q.multi_items()) == [("a", "123"), ("a", "456"), ("b", "789")]
         """
         multi_items: typing.List[typing.Tuple[str, str]] = []
         for k, v in self._dict.items():
@@ -546,31 +585,93 @@ class QueryParams(typing.Mapping[str, str]):
         """
         Get a value from the query param for a given key. If the key occurs
         more than once, then only the first value is returned.
+
+        Usage:
+
+        q = httpx.QueryParams("a=123&a=456&b=789")
+        assert q.get("a") == "123"
         """
         if key in self._dict:
-            return self._dict[key][0]
+            return self._dict[str(key)][0]
         return default
 
     def get_list(self, key: typing.Any) -> typing.List[str]:
         """
         Get all values from the query param for a given key.
+
+        Usage:
+
+        q = httpx.QueryParams("a=123&a=456&b=789")
+        assert q.get_list("a") == ["123", "456"]
         """
-        return list(self._dict.get(key, []))
+        return list(self._dict.get(str(key), []))
 
-    def update(self, params: QueryParamTypes = None) -> None:
-        if not params:
-            return
+    def set(self, key: typing.Any, value: typing.Any = None) -> "QueryParams":
+        """
+        Return a new QueryParams instance, setting the value of a key.
+
+        Usage:
+
+        q = httpx.QueryParams("a=123")
+        q = q.set("a", "456")
+        assert q == httpx.QueryParams("a=456")
+        """
+        q = QueryParams()
+        q._dict = dict(self._dict)
+        q._dict[str(key)] = [primitive_value_to_str(value)]
+        return q
+
+    def add(self, key: typing.Any, value: typing.Any = None) -> "QueryParams":
+        """
+        Return a new QueryParams instance, setting or appending the value of a key.
 
-        params = QueryParams(params)
-        for k in params.keys():
-            self._dict[k] = params.get_list(k)
+        Usage:
+
+        q = httpx.QueryParams("a=123")
+        q = q.add("a", "456")
+        assert q == httpx.QueryParams("a=123&a=456")
+        """
+        q = QueryParams()
+        q._dict = dict(self._dict)
+        q._dict[str(key)] = q.get_list(key) + [primitive_value_to_str(value)]
+        return q
+
+    def remove(self, key: typing.Any) -> "QueryParams":
+        """
+        Return a new QueryParams instance, removing the value of a key.
+
+        Usage:
+
+        q = httpx.QueryParams("a=123")
+        q = q.remove("a")
+        assert q == httpx.QueryParams("")
+        """
+        q = QueryParams()
+        q._dict = dict(self._dict)
+        q._dict.pop(str(key), None)
+        return q
+
+    def merge(self, params: QueryParamTypes = None) -> "QueryParams":
+        """
+        Return a new QueryParams instance, updated with.
+
+        Usage:
+
+        q = httpx.QueryParams("a=123")
+        q = q.merge({"b": "456"})
+        assert q == httpx.QueryParams("a=123&b=456")
+
+        q = httpx.QueryParams("a=123")
+        q = q.merge({"a": "456", "b": "789"})
+        assert q == httpx.QueryParams("a=456&b=789")
+        """
+        q = QueryParams(params)
+        q._dict = {**self._dict, **q._dict}
+        return q
 
     def __getitem__(self, key: typing.Any) -> str:
         return self._dict[key][0]
 
-    def __setitem__(self, key: str, value: str) -> None:
-        self._dict[key] = [value]
-
     def __contains__(self, key: typing.Any) -> bool:
         return key in self._dict
 
@@ -580,6 +681,9 @@ class QueryParams(typing.Mapping[str, str]):
     def __len__(self) -> int:
         return len(self._dict)
 
+    def __hash__(self) -> int:
+        return hash(str(self))
+
     def __eq__(self, other: typing.Any) -> bool:
         if not isinstance(other, self.__class__):
             return False
@@ -593,6 +697,18 @@ class QueryParams(typing.Mapping[str, str]):
         query_string = str(self)
         return f"{class_name}({query_string!r})"
 
+    def update(self, params: QueryParamTypes = None) -> None:
+        raise RuntimeError(
+            "QueryParams are immutable since 0.18.0. "
+            "Use `q = q.merge(...)` to create an updated copy."
+        )
+
+    def __setitem__(self, key: str, value: str) -> None:
+        raise RuntimeError(
+            "QueryParams are immutable since 0.18.0. "
+            "Use `q = q.set(key, value)` to create an updated copy."
+        )
+
 
 class Headers(typing.MutableMapping[str, str]):
     """
index 06995ad508f246aa3dadcf07766382d93cc2e3f7..dcdc5c3aa5060e4db79c5aa8e5e750afded8d9ed 100644 (file)
@@ -1,5 +1,4 @@
 import codecs
-import collections
 import logging
 import mimetypes
 import netrc
@@ -369,31 +368,6 @@ def peek_filelike_length(stream: typing.IO) -> int:
         return os.fstat(fd).st_size
 
 
-def flatten_queryparams(
-    queryparams: typing.Mapping[
-        str, typing.Union["PrimitiveData", typing.Sequence["PrimitiveData"]]
-    ]
-) -> typing.List[typing.Tuple[str, "PrimitiveData"]]:
-    """
-    Convert a mapping of query params into a flat list of two-tuples
-    representing each item.
-
-    Example:
-    >>> flatten_queryparams_values({"q": "httpx", "tag": ["python", "dev"]})
-    [("q", "httpx), ("tag", "python"), ("tag", "dev")]
-    """
-    items = []
-
-    for k, v in queryparams.items():
-        if isinstance(v, collections.abc.Sequence) and not isinstance(v, (str, bytes)):
-            for u in v:
-                items.append((k, u))
-        else:
-            items.append((k, typing.cast("PrimitiveData", v)))
-
-    return items
-
-
 class Timer:
     async def _get_time(self) -> float:
         library = sniffio.current_async_library()
index d7f7c9d9b080bcc405678bf06bf4992576af1f6d..ba200f146d8cd8bda5d8b60786ad785a18468759 100644 (file)
@@ -76,19 +76,50 @@ def test_queryparam_types():
     assert str(q) == "a=1&a=2"
 
 
-def test_queryparam_setters():
-    q = httpx.QueryParams({"a": 1})
-    q.update([])
+def test_queryparam_update_is_hard_deprecated():
+    q = httpx.QueryParams("a=123")
+    with pytest.raises(RuntimeError):
+        q.update({"a": "456"})
 
-    assert str(q) == "a=1"
 
-    q = httpx.QueryParams([("a", 1), ("a", 2)])
-    q["a"] = "3"
-    assert str(q) == "a=3"
+def test_queryparam_setter_is_hard_deprecated():
+    q = httpx.QueryParams("a=123")
+    with pytest.raises(RuntimeError):
+        q["a"] = "456"
 
-    q = httpx.QueryParams([("a", 1), ("b", 1)])
-    u = httpx.QueryParams([("b", 2), ("b", 3)])
-    q.update(u)
 
-    assert str(q) == "a=1&b=2&b=3"
-    assert q["b"] == u["b"]
+def test_queryparam_set():
+    q = httpx.QueryParams("a=123")
+    q = q.set("a", "456")
+    assert q == httpx.QueryParams("a=456")
+
+
+def test_queryparam_add():
+    q = httpx.QueryParams("a=123")
+    q = q.add("a", "456")
+    assert q == httpx.QueryParams("a=123&a=456")
+
+
+def test_queryparam_remove():
+    q = httpx.QueryParams("a=123")
+    q = q.remove("a")
+    assert q == httpx.QueryParams("")
+
+
+def test_queryparam_merge():
+    q = httpx.QueryParams("a=123")
+    q = q.merge({"b": "456"})
+    assert q == httpx.QueryParams("a=123&b=456")
+    q = q.merge({"a": "000", "c": "789"})
+    assert q == httpx.QueryParams("a=000&b=456&c=789")
+
+
+def test_queryparams_are_hashable():
+    params = (
+        httpx.QueryParams("a=123"),
+        httpx.QueryParams({"a": 123}),
+        httpx.QueryParams("b=456"),
+        httpx.QueryParams({"b": 456}),
+    )
+
+    assert len(set(params)) == 2