From: Tom Christie Date: Mon, 26 Apr 2021 13:06:12 +0000 (+0100) Subject: Tweak QueryParams implementation (#1598) X-Git-Tag: 0.18.0~5 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=8fe32c52debc0add303fed81b9220689ddc502e4;p=thirdparty%2Fhttpx.git Tweak QueryParams implementation (#1598) --- diff --git a/httpx/_models.py b/httpx/_models.py index dc888888..dc2958c7 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -512,27 +512,35 @@ class QueryParams(typing.Mapping[str, str]): else: items = flatten_queryparams(value) - self._list = [(str(k), primitive_value_to_str(v)) for k, v in items] - self._dict = {str(k): primitive_value_to_str(v) for k, v in items} + self._dict: typing.Dict[str, typing.List[str]] = {} + for item in items: + k, v = item + if str(k) not in self._dict: + self._dict[str(k)] = [primitive_value_to_str(v)] + else: + self._dict[str(k)].append(primitive_value_to_str(v)) def keys(self) -> typing.KeysView: return self._dict.keys() def values(self) -> typing.ValuesView: - return self._dict.values() + return {k: v[0] for k, v in self._dict.items()}.values() def items(self) -> typing.ItemsView: """ Return all items in the query params. If a key occurs more than once only the first item for that key is returned. """ - return self._dict.items() + return {k: v[0] for k, v in self._dict.items()}.items() def multi_items(self) -> typing.List[typing.Tuple[str, str]]: """ Return all items in the query params. Allow duplicate keys to occur. """ - return list(self._list) + multi_items: typing.List[typing.Tuple[str, str]] = [] + for k, v in self._dict.items(): + multi_items.extend([(k, i) for i in v]) + return multi_items def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any: """ @@ -540,47 +548,28 @@ class QueryParams(typing.Mapping[str, str]): more than once, then only the first value is returned. """ if key in self._dict: - return self._dict[key] + return self._dict[key][0] return default def get_list(self, key: typing.Any) -> typing.List[str]: """ Get all values from the query param for a given key. """ - return [item_value for item_key, item_value in self._list if item_key == key] + return list(self._dict.get(key, [])) def update(self, params: QueryParamTypes = None) -> None: if not params: return params = QueryParams(params) - for param in params: - item, *extras = params.get_list(param) - self[param] = item - if extras: - self._list.extend((param, e) for e in extras) - # ensure getter matches merged QueryParams getter - self._dict[param] = params[param] + for k in params.keys(): + self._dict[k] = params.get_list(k) def __getitem__(self, key: typing.Any) -> str: - return self._dict[key] + return self._dict[key][0] def __setitem__(self, key: str, value: str) -> None: - self._dict[key] = value - - found_indexes = [] - for idx, (item_key, _) in enumerate(self._list): - if item_key == key: - found_indexes.append(idx) - - for idx in reversed(found_indexes[1:]): - del self._list[idx] - - if found_indexes: - idx = found_indexes[0] - self._list[idx] = (key, value) - else: - self._list.append((key, value)) + self._dict[key] = [value] def __contains__(self, key: typing.Any) -> bool: return key in self._dict @@ -594,10 +583,10 @@ class QueryParams(typing.Mapping[str, str]): def __eq__(self, other: typing.Any) -> bool: if not isinstance(other, self.__class__): return False - return sorted(self._list) == sorted(other._list) + return sorted(self.multi_items()) == sorted(other.multi_items()) def __str__(self) -> str: - return urlencode(self._list) + return urlencode(self.multi_items()) def __repr__(self) -> str: class_name = self.__class__.__name__ diff --git a/tests/models/test_queryparams.py b/tests/models/test_queryparams.py index 7031a65c..d7f7c9d9 100644 --- a/tests/models/test_queryparams.py +++ b/tests/models/test_queryparams.py @@ -18,17 +18,17 @@ def test_queryparams(source): assert "a" in q assert "A" not in q assert "c" not in q - assert q["a"] == "456" - assert q.get("a") == "456" + assert q["a"] == "123" + assert q.get("a") == "123" assert q.get("nope", default=None) is None assert q.get_list("a") == ["123", "456"] assert list(q.keys()) == ["a", "b"] - assert list(q.values()) == ["456", "789"] - assert list(q.items()) == [("a", "456"), ("b", "789")] + assert list(q.values()) == ["123", "789"] + assert list(q.items()) == [("a", "123"), ("b", "789")] assert len(q) == 2 assert list(q) == ["a", "b"] - assert dict(q) == {"a": "456", "b": "789"} + assert dict(q) == {"a": "123", "b": "789"} assert str(q) == "a=123&a=456&b=789" assert repr(q) == "QueryParams('a=123&a=456&b=789')" assert httpx.QueryParams({"a": "123", "b": "456"}) == httpx.QueryParams(