]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Consistent multidict methods (#1089)
authorTom Christie <tom@tomchristie.com>
Fri, 31 Jul 2020 10:46:35 +0000 (11:46 +0100)
committerGitHub <noreply@github.com>
Fri, 31 Jul 2020 10:46:35 +0000 (11:46 +0100)
* Consistent multidict methods

* Consistent multidict methods and behaviour

* Update httpx/_models.py

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
* Update httpx/_models.py

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
httpx/_models.py
tests/models/test_headers.py
tests/models/test_queryparams.py

index c8cbfbb449114b5e984c3bdd14828d9caf46ba36..4a81e5965dcc88a63cb3152da7e4c1d561d4fce3 100644 (file)
@@ -4,6 +4,7 @@ import email.message
 import json as jsonlib
 import typing
 import urllib.request
+import warnings
 from collections.abc import MutableMapping
 from http.cookiejar import Cookie, CookieJar
 from urllib.parse import parse_qsl, urlencode
@@ -240,9 +241,6 @@ class QueryParams(typing.Mapping[str, str]):
         self._list = [(str(k), str_query_param(v)) for k, v in items]
         self._dict = {str(k): str_query_param(v) for k, v in items}
 
-    def getlist(self, key: typing.Any) -> typing.List[str]:
-        return [item_value for item_key, item_value in self._list if item_key == key]
-
     def keys(self) -> typing.KeysView:
         return self._dict.keys()
 
@@ -250,16 +248,33 @@ class QueryParams(typing.Mapping[str, str]):
         return self._dict.values()
 
     def items(self) -> typing.ItemsView:
+        """
+        Return all items in the query params. If a key occurs more than once
+        only the first item for that key is returned.
+        """
         return self._dict.items()
 
     def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
+        """
+        Return all items in the query params. Allow duplicate keys to occur.
+        """
         return list(self._list)
 
     def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
+        """
+        Get a value from the query param for a given key. If the key occurs
+        more than once, then only the first value is returned.
+        """
         if key in self._dict:
             return self._dict[key]
         return default
 
+    def get_list(self, key: typing.Any) -> typing.List[str]:
+        """
+        Get all values from the query param for a given key.
+        """
+        return [item_value for item_key, item_value in self._list if item_key == key]
+
     def update(self, params: QueryParamTypes = None) -> None:
         if not params:
             return
@@ -315,6 +330,13 @@ class QueryParams(typing.Mapping[str, str]):
         query_string = str(self)
         return f"{class_name}({query_string!r})"
 
+    def getlist(self, key: typing.Any) -> typing.List[str]:
+        message = (
+            "QueryParams.getlist() is pending deprecation. Use QueryParams.get_list()"
+        )
+        warnings.warn(message, PendingDeprecationWarning)
+        return self.get_list(key)
+
 
 class Headers(typing.MutableMapping[str, str]):
     """
@@ -336,6 +358,14 @@ class Headers(typing.MutableMapping[str, str]):
                 (normalize_header_key(k, encoding), normalize_header_value(v, encoding))
                 for k, v in headers
             ]
+
+        self._dict = {}  # type: typing.Dict[bytes, bytes]
+        for key, value in self._list:
+            if key in self._dict:
+                self._dict[key] = self._dict[key] + b", " + value
+            else:
+                self._dict[key] = value
+
         self._encoding = encoding
 
     @property
@@ -376,26 +406,47 @@ class Headers(typing.MutableMapping[str, str]):
         return self._list
 
     def keys(self) -> typing.List[str]:  # type: ignore
-        return [key.decode(self.encoding) for key, value in self._list]
+        return [key.decode(self.encoding) for key in self._dict.keys()]
 
     def values(self) -> typing.List[str]:  # type: ignore
-        return [value.decode(self.encoding) for key, value in self._list]
+        return [value.decode(self.encoding) for value in self._dict.values()]
 
     def items(self) -> typing.List[typing.Tuple[str, str]]:  # type: ignore
+        """
+        Return a list of `(key, value)` pairs of headers. Concatenate headers
+        into a single comma seperated value when a key occurs multiple times.
+        """
+        return [
+            (key.decode(self.encoding), value.decode(self.encoding))
+            for key, value in self._dict.items()
+        ]
+
+    def multi_items(self) -> typing.List[typing.Tuple[str, str]]:  # type: ignore
+        """
+        Return a list of `(key, value)` pairs of headers. Allow multiple
+        occurences of the same key without concatenating into a single
+        comma seperated value.
+        """
         return [
             (key.decode(self.encoding), value.decode(self.encoding))
             for key, value in self._list
         ]
 
     def get(self, key: str, default: typing.Any = None) -> typing.Any:
+        """
+        Return a header value. If multiple occurences of the header occur
+        then concatenate them together with commas.
+        """
         try:
             return self[key]
         except KeyError:
             return default
 
-    def getlist(self, key: str, split_commas: bool = False) -> typing.List[str]:
+    def get_list(self, key: str, split_commas: bool = False) -> typing.List[str]:
         """
-        Return multiple header values.
+        Return a list of all header values for a given key.
+        If `split_commas=True` is passed, then any comma seperated header
+        values are split into multiple return strings.
         """
         get_header_key = key.lower().encode(self.encoding)
 
@@ -448,6 +499,8 @@ class Headers(typing.MutableMapping[str, str]):
         set_key = key.lower().encode(self._encoding or "utf-8")
         set_value = value.encode(self._encoding or "utf-8")
 
+        self._dict[set_key] = set_value
+
         found_indexes = []
         for idx, (item_key, _) in enumerate(self._list):
             if item_key == set_key:
@@ -468,22 +521,19 @@ class Headers(typing.MutableMapping[str, str]):
         """
         del_key = key.lower().encode(self.encoding)
 
+        del self._dict[del_key]
+
         pop_indexes = []
         for idx, (item_key, _) in enumerate(self._list):
             if item_key == del_key:
                 pop_indexes.append(idx)
-        if not pop_indexes:
-            raise KeyError(key)
 
         for idx in reversed(pop_indexes):
             del self._list[idx]
 
     def __contains__(self, key: typing.Any) -> bool:
-        get_header_key = key.lower().encode(self.encoding)
-        for header_key, _ in self._list:
-            if header_key == get_header_key:
-                return True
-        return False
+        header_key = key.lower().encode(self.encoding)
+        return header_key in self._dict
 
     def __iter__(self) -> typing.Iterator[typing.Any]:
         return iter(self.keys())
@@ -503,7 +553,7 @@ class Headers(typing.MutableMapping[str, str]):
         if self.encoding != "ascii":
             encoding_str = f", encoding={self.encoding!r}"
 
-        as_list = list(obfuscate_sensitive_headers(self.items()))
+        as_list = list(obfuscate_sensitive_headers(self.multi_items()))
         as_dict = dict(as_list)
 
         no_duplicate_keys = len(as_dict) == len(as_list)
@@ -511,6 +561,11 @@ class Headers(typing.MutableMapping[str, str]):
             return f"{class_name}({as_dict!r}{encoding_str})"
         return f"{class_name}({as_list!r}{encoding_str})"
 
+    def getlist(self, key: str, split_commas: bool = False) -> typing.List[str]:
+        message = "Headers.getlist() is pending deprecation. Use Headers.get_list()"
+        warnings.warn(message, PendingDeprecationWarning)
+        return self.get_list(key, split_commas=split_commas)
+
 
 USER_AGENT = f"python-httpx/{__version__}"
 ACCEPT_ENCODING = ", ".join(
index 088f5c8a1fe51dcbb980ea968a33d27ee239d9c6..ce08816d16299e20f588235cb1166e790ce28f64 100644 (file)
@@ -13,11 +13,12 @@ def test_headers():
     assert h["a"] == "123, 456"
     assert h.get("a") == "123, 456"
     assert h.get("nope", default=None) is None
-    assert h.getlist("a") == ["123", "456"]
-    assert h.keys() == ["a", "a", "b"]
-    assert h.values() == ["123", "456", "789"]
-    assert h.items() == [("a", "123"), ("a", "456"), ("b", "789")]
-    assert list(h) == ["a", "a", "b"]
+    assert h.get_list("a") == ["123", "456"]
+    assert h.keys() == ["a", "b"]
+    assert h.values() == ["123, 456", "789"]
+    assert h.items() == [("a", "123, 456"), ("b", "789")]
+    assert h.multi_items() == [("a", "123"), ("a", "456"), ("b", "789")]
+    assert list(h) == ["a", "b"]
     assert dict(h) == {"a": "123, 456", "b": "789"}
     assert repr(h) == "Headers([('a', '123'), ('a', '456'), ('b', '789')])"
     assert h == httpx.Headers([("a", "123"), ("b", "789"), ("a", "456")])
@@ -153,13 +154,13 @@ def test_headers_decode_explicit_encoding():
 
 def test_multiple_headers():
     """
-    Most headers should split by commas for `getlist`, except 'Set-Cookie'.
+    `Headers.get_list` should support both split_commas=False and split_commas=True.
     """
     h = httpx.Headers([("set-cookie", "a, b"), ("set-cookie", "c")])
-    h.getlist("Set-Cookie") == ["a, b", "b"]
+    assert h.get_list("Set-Cookie") == ["a, b", "c"]
 
     h = httpx.Headers([("vary", "a, b"), ("vary", "c")])
-    h.getlist("Vary") == ["a", "b", "c"]
+    assert h.get_list("Vary", split_commas=True) == ["a", "b", "c"]
 
 
 @pytest.mark.parametrize("header", ["authorization", "proxy-authorization"])
index 99f193f86d9091830b24425fd2dd45c17ad58c96..39f57e2e7f8c00a70f49e933c1a5f048100d72b9 100644 (file)
@@ -19,7 +19,7 @@ def test_queryparams(source):
     assert q["a"] == "456"
     assert q.get("a") == "456"
     assert q.get("nope", default=None) is None
-    assert q.getlist("a") == ["123", "456"]
+    assert q.get_list("a") == ["123", "456"]
     assert list(q.keys()) == ["a", "b"]
     assert list(q.values()) == ["456", "789"]
     assert list(q.items()) == [("a", "456"), ("b", "789")]