]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Preserve header casing (#1338)
authorTom Christie <tom@tomchristie.com>
Tue, 6 Oct 2020 13:57:10 +0000 (14:57 +0100)
committerGitHub <noreply@github.com>
Tue, 6 Oct 2020 13:57:10 +0000 (14:57 +0100)
httpx/_models.py
httpx/_utils.py
tests/client/test_client.py

index a1140e1f51a26d9b4d53d9fcfbc03892dce63766..7bad1c9eb93fcb4dd4af3b81e712ce2a3d39d7ad 100644 (file)
@@ -525,27 +525,28 @@ class Headers(typing.MutableMapping[str, str]):
 
     def __init__(self, headers: HeaderTypes = None, encoding: str = None) -> None:
         if headers is None:
-            self._list = []  # type: typing.List[typing.Tuple[bytes, bytes]]
+            self._list = []  # type: typing.List[typing.Tuple[bytes, bytes, bytes]]
         elif isinstance(headers, Headers):
-            self._list = list(headers.raw)
+            self._list = list(headers._list)
         elif isinstance(headers, dict):
             self._list = [
-                (normalize_header_key(k, encoding), normalize_header_value(v, encoding))
+                (
+                    normalize_header_key(k, lower=False, encoding=encoding),
+                    normalize_header_key(k, lower=True, encoding=encoding),
+                    normalize_header_value(v, encoding),
+                )
                 for k, v in headers.items()
             ]
         else:
             self._list = [
-                (normalize_header_key(k, encoding), normalize_header_value(v, encoding))
+                (
+                    normalize_header_key(k, lower=False, encoding=encoding),
+                    normalize_header_key(k, lower=True, encoding=encoding),
+                    normalize_header_value(v, encoding),
+                )
                 for k, v in headers
             ]
 
-        self._dict = {}  # type: typing.Dict[bytes, bytes]
-        for key, value in self._list:
-            if key in self._dict:
-                self._dict[key] = self._dict[key] + b", " + value
-            else:
-                self._dict[key] = value
-
         self._encoding = encoding
 
     @property
@@ -582,25 +583,36 @@ class Headers(typing.MutableMapping[str, str]):
         """
         Returns a list of the raw header items, as byte pairs.
         """
-        return list(self._list)
+        return [(raw_key, value) for raw_key, _, value in self._list]
 
     def keys(self) -> typing.KeysView[str]:
-        return {key.decode(self.encoding): None for key in self._dict.keys()}.keys()
+        return {key.decode(self.encoding): None for _, key, value in self._list}.keys()
 
     def values(self) -> typing.ValuesView[str]:
-        return {
-            key: value.decode(self.encoding) for key, value in self._dict.items()
-        }.values()
+        values_dict: typing.Dict[str, str] = {}
+        for _, key, value in self._list:
+            str_key = key.decode(self.encoding)
+            str_value = value.decode(self.encoding)
+            if str_key in values_dict:
+                values_dict[str_key] += f", {str_value}"
+            else:
+                values_dict[str_key] = str_value
+        return values_dict.values()
 
     def items(self) -> typing.ItemsView[str, str]:
         """
         Return `(key, value)` items of headers. Concatenate headers
         into a single comma seperated value when a key occurs multiple times.
         """
-        return {
-            key.decode(self.encoding): value.decode(self.encoding)
-            for key, value in self._dict.items()
-        }.items()
+        values_dict: typing.Dict[str, str] = {}
+        for _, key, value in self._list:
+            str_key = key.decode(self.encoding)
+            str_value = value.decode(self.encoding)
+            if str_key in values_dict:
+                values_dict[str_key] += f", {str_value}"
+            else:
+                values_dict[str_key] = str_value
+        return values_dict.items()
 
     def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
         """
@@ -610,7 +622,7 @@ class Headers(typing.MutableMapping[str, str]):
         """
         return [
             (key.decode(self.encoding), value.decode(self.encoding))
-            for key, value in self._list
+            for _, key, value in self._list
         ]
 
     def get(self, key: str, default: typing.Any = None) -> typing.Any:
@@ -633,8 +645,8 @@ class Headers(typing.MutableMapping[str, str]):
 
         values = [
             item_value.decode(self.encoding)
-            for item_key, item_value in self._list
-            if item_key == get_header_key
+            for _, item_key, item_value in self._list
+            if item_key.lower() == get_header_key
         ]
 
         if not split_commas:
@@ -647,11 +659,11 @@ class Headers(typing.MutableMapping[str, str]):
 
     def update(self, headers: HeaderTypes = None) -> None:  # type: ignore
         headers = Headers(headers)
-        for header in headers:
-            self[header] = headers[header]
+        for key, value in headers.raw:
+            self[key.decode(headers.encoding)] = value.decode(headers.encoding)
 
     def copy(self) -> "Headers":
-        return Headers(dict(self.items()), encoding=self.encoding)
+        return Headers(self, encoding=self.encoding)
 
     def __getitem__(self, key: str) -> str:
         """
@@ -663,7 +675,7 @@ class Headers(typing.MutableMapping[str, str]):
         normalized_key = key.lower().encode(self.encoding)
 
         items = []
-        for header_key, header_value in self._list:
+        for _, header_key, header_value in self._list:
             if header_key == normalized_key:
                 items.append(header_value.decode(self.encoding))
 
@@ -677,14 +689,13 @@ class Headers(typing.MutableMapping[str, str]):
         Set the header `key` to `value`, removing any duplicate entries.
         Retains insertion order.
         """
-        set_key = key.lower().encode(self._encoding or "utf-8")
+        set_key = key.encode(self._encoding or "utf-8")
         set_value = value.encode(self._encoding or "utf-8")
-
-        self._dict[set_key] = set_value
+        lookup_key = set_key.lower()
 
         found_indexes = []
-        for idx, (item_key, _) in enumerate(self._list):
-            if item_key == set_key:
+        for idx, (_, item_key, _) in enumerate(self._list):
+            if item_key == lookup_key:
                 found_indexes.append(idx)
 
         for idx in reversed(found_indexes[1:]):
@@ -692,9 +703,9 @@ class Headers(typing.MutableMapping[str, str]):
 
         if found_indexes:
             idx = found_indexes[0]
-            self._list[idx] = (set_key, set_value)
+            self._list[idx] = (set_key, lookup_key, set_value)
         else:
-            self._list.append((set_key, set_value))
+            self._list.append((set_key, lookup_key, set_value))
 
     def __delitem__(self, key: str) -> None:
         """
@@ -702,19 +713,20 @@ class Headers(typing.MutableMapping[str, str]):
         """
         del_key = key.lower().encode(self.encoding)
 
-        del self._dict[del_key]
-
         pop_indexes = []
-        for idx, (item_key, _) in enumerate(self._list):
-            if item_key == del_key:
+        for idx, (_, item_key, _) in enumerate(self._list):
+            if item_key.lower() == del_key:
                 pop_indexes.append(idx)
 
+        if not pop_indexes:
+            raise KeyError(key)
+
         for idx in reversed(pop_indexes):
             del self._list[idx]
 
     def __contains__(self, key: typing.Any) -> bool:
         header_key = key.lower().encode(self.encoding)
-        return header_key in self._dict
+        return header_key in [key for _, key, _ in self._list]
 
     def __iter__(self) -> typing.Iterator[typing.Any]:
         return iter(self.keys())
@@ -727,7 +739,10 @@ class Headers(typing.MutableMapping[str, str]):
             other_headers = Headers(other)
         except ValueError:
             return False
-        return sorted(self._list) == sorted(other_headers._list)
+
+        self_list = [(key, value) for _, key, value in self._list]
+        other_list = [(key, value) for _, key, value in other_headers._list]
+        return sorted(self_list) == sorted(other_list)
 
     def __repr__(self) -> str:
         class_name = self.__class__.__name__
@@ -793,15 +808,15 @@ class Request:
     def _prepare(self, default_headers: typing.Dict[str, str]) -> None:
         for key, value in default_headers.items():
             # Ignore Transfer-Encoding if the Content-Length has been set explicitly.
-            if key.lower() == "transfer-encoding" and "content-length" in self.headers:
+            if key.lower() == "transfer-encoding" and "Content-Length" in self.headers:
                 continue
             self.headers.setdefault(key, value)
 
         auto_headers: typing.List[typing.Tuple[bytes, bytes]] = []
 
-        has_host = "host" in self.headers
+        has_host = "Host" in self.headers
         has_content_length = (
-            "content-length" in self.headers or "transfer-encoding" in self.headers
+            "Content-Length" in self.headers or "Transfer-Encoding" in self.headers
         )
 
         if not has_host and self.url.host:
@@ -810,9 +825,9 @@ class Request:
                 host_header = self.url.host.encode("ascii")
             else:
                 host_header = self.url.netloc.encode("ascii")
-            auto_headers.append((b"host", host_header))
+            auto_headers.append((b"Host", host_header))
         if not has_content_length and self.method in ("POST", "PUT", "PATCH"):
-            auto_headers.append((b"content-length", b"0"))
+            auto_headers.append((b"Content-Length", b"0"))
 
         self.headers = Headers(auto_headers + self.headers.raw)
 
index 75c92fd827f1858db95719f7afbef78da4460f32..072db3f1e8a345db11e790a1326b659168e18af1 100644 (file)
@@ -30,14 +30,19 @@ _HTML5_FORM_ENCODING_RE = re.compile(
 
 
 def normalize_header_key(
-    value: typing.Union[str, bytes], encoding: str = None
+    value: typing.Union[str, bytes],
+    lower: bool,
+    encoding: str = None,
 ) -> bytes:
     """
     Coerce str/bytes into a strictly byte-wise HTTP header key.
     """
     if isinstance(value, bytes):
-        return value.lower()
-    return value.encode(encoding or "ascii").lower()
+        bytes_value = value
+    else:
+        bytes_value = value.encode(encoding or "ascii")
+
+    return bytes_value.lower() if lower else bytes_value
 
 
 def normalize_header_value(
index f1e0ec424f8348d4e4126646cc7ba8bc3ac221af..a41f4232fbb7203c67f1c39396df65efb4a4d335 100644 (file)
@@ -271,3 +271,32 @@ def test_client_closed_state_using_with_block():
     assert client.is_closed
     with pytest.raises(RuntimeError):
         client.get("http://example.com")
+
+
+def echo_raw_headers(request: httpx.Request) -> httpx.Response:
+    data = [
+        (name.decode("ascii"), value.decode("ascii"))
+        for name, value in request.headers.raw
+    ]
+    return httpx.Response(200, json=data)
+
+
+def test_raw_client_header():
+    """
+    Set a header in the Client.
+    """
+    url = "http://example.org/echo_headers"
+    headers = {"Example-Header": "example-value"}
+
+    client = httpx.Client(transport=MockTransport(echo_raw_headers), headers=headers)
+    response = client.get(url)
+
+    assert response.status_code == 200
+    assert response.json() == [
+        ["Host", "example.org"],
+        ["Accept", "*/*"],
+        ["Accept-Encoding", "gzip, deflate, br"],
+        ["Connection", "keep-alive"],
+        ["User-Agent", f"python-httpx/{httpx.__version__}"],
+        ["Example-Header", "example-value"],
+    ]