]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Tweak QueryParams implementation (#1598)
authorTom Christie <tom@tomchristie.com>
Mon, 26 Apr 2021 13:06:12 +0000 (14:06 +0100)
committerGitHub <noreply@github.com>
Mon, 26 Apr 2021 13:06:12 +0000 (14:06 +0100)
httpx/_models.py
tests/models/test_queryparams.py

index dc8888882130abf9bfd684ee163cb918129dad9e..dc2958c7d63559f909666c3e08f23046611b485e 100644 (file)
@@ -512,27 +512,35 @@ class QueryParams(typing.Mapping[str, str]):
         else:
             items = flatten_queryparams(value)
 
-        self._list = [(str(k), primitive_value_to_str(v)) for k, v in items]
-        self._dict = {str(k): primitive_value_to_str(v) for k, v in items}
+        self._dict: typing.Dict[str, typing.List[str]] = {}
+        for item in items:
+            k, v = item
+            if str(k) not in self._dict:
+                self._dict[str(k)] = [primitive_value_to_str(v)]
+            else:
+                self._dict[str(k)].append(primitive_value_to_str(v))
 
     def keys(self) -> typing.KeysView:
         return self._dict.keys()
 
     def values(self) -> typing.ValuesView:
-        return self._dict.values()
+        return {k: v[0] for k, v in self._dict.items()}.values()
 
     def items(self) -> typing.ItemsView:
         """
         Return all items in the query params. If a key occurs more than once
         only the first item for that key is returned.
         """
-        return self._dict.items()
+        return {k: v[0] for k, v in self._dict.items()}.items()
 
     def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
         """
         Return all items in the query params. Allow duplicate keys to occur.
         """
-        return list(self._list)
+        multi_items: typing.List[typing.Tuple[str, str]] = []
+        for k, v in self._dict.items():
+            multi_items.extend([(k, i) for i in v])
+        return multi_items
 
     def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
         """
@@ -540,47 +548,28 @@ class QueryParams(typing.Mapping[str, str]):
         more than once, then only the first value is returned.
         """
         if key in self._dict:
-            return self._dict[key]
+            return self._dict[key][0]
         return default
 
     def get_list(self, key: typing.Any) -> typing.List[str]:
         """
         Get all values from the query param for a given key.
         """
-        return [item_value for item_key, item_value in self._list if item_key == key]
+        return list(self._dict.get(key, []))
 
     def update(self, params: QueryParamTypes = None) -> None:
         if not params:
             return
 
         params = QueryParams(params)
-        for param in params:
-            item, *extras = params.get_list(param)
-            self[param] = item
-            if extras:
-                self._list.extend((param, e) for e in extras)
-                # ensure getter matches merged QueryParams getter
-                self._dict[param] = params[param]
+        for k in params.keys():
+            self._dict[k] = params.get_list(k)
 
     def __getitem__(self, key: typing.Any) -> str:
-        return self._dict[key]
+        return self._dict[key][0]
 
     def __setitem__(self, key: str, value: str) -> None:
-        self._dict[key] = value
-
-        found_indexes = []
-        for idx, (item_key, _) in enumerate(self._list):
-            if item_key == key:
-                found_indexes.append(idx)
-
-        for idx in reversed(found_indexes[1:]):
-            del self._list[idx]
-
-        if found_indexes:
-            idx = found_indexes[0]
-            self._list[idx] = (key, value)
-        else:
-            self._list.append((key, value))
+        self._dict[key] = [value]
 
     def __contains__(self, key: typing.Any) -> bool:
         return key in self._dict
@@ -594,10 +583,10 @@ class QueryParams(typing.Mapping[str, str]):
     def __eq__(self, other: typing.Any) -> bool:
         if not isinstance(other, self.__class__):
             return False
-        return sorted(self._list) == sorted(other._list)
+        return sorted(self.multi_items()) == sorted(other.multi_items())
 
     def __str__(self) -> str:
-        return urlencode(self._list)
+        return urlencode(self.multi_items())
 
     def __repr__(self) -> str:
         class_name = self.__class__.__name__
index 7031a65cb936ebdeb9b980ad6d14f0bf9fd44e4d..d7f7c9d9b080bcc405678bf06bf4992576af1f6d 100644 (file)
@@ -18,17 +18,17 @@ def test_queryparams(source):
     assert "a" in q
     assert "A" not in q
     assert "c" not in q
-    assert q["a"] == "456"
-    assert q.get("a") == "456"
+    assert q["a"] == "123"
+    assert q.get("a") == "123"
     assert q.get("nope", default=None) is None
     assert q.get_list("a") == ["123", "456"]
 
     assert list(q.keys()) == ["a", "b"]
-    assert list(q.values()) == ["456", "789"]
-    assert list(q.items()) == [("a", "456"), ("b", "789")]
+    assert list(q.values()) == ["123", "789"]
+    assert list(q.items()) == [("a", "123"), ("b", "789")]
     assert len(q) == 2
     assert list(q) == ["a", "b"]
-    assert dict(q) == {"a": "456", "b": "789"}
+    assert dict(q) == {"a": "123", "b": "789"}
     assert str(q) == "a=123&a=456&b=789"
     assert repr(q) == "QueryParams('a=123&a=456&b=789')"
     assert httpx.QueryParams({"a": "123", "b": "456"}) == httpx.QueryParams(