]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add union operators to MutableHeaders (#1240)
authormanlix <manlix@yandex.ru>
Wed, 16 Feb 2022 10:58:56 +0000 (13:58 +0300)
committerGitHub <noreply@github.com>
Wed, 16 Feb 2022 10:58:56 +0000 (10:58 +0000)
* Add union operators to MutableHeaders (#1239)

* Apply suggestions from code review

* Use `TypeError`, not `NotImplemented`.
* Add `# type: ignore` to deliberate incorrect usage of types in tests.

* Apply suggestions from code review

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
Co-authored-by: Tom Christie <tom@tomchristie.com>
starlette/datastructures.py
tests/test_datastructures.py

index 1a8b965e8a2c169ae1bdec4246cdbe7de44ec167..59863282ae7afa140d22f8da0ac4c3ce77846ad1 100644 (file)
@@ -618,6 +618,19 @@ class MutableHeaders(Headers):
         for idx in reversed(pop_indexes):
             del self._list[idx]
 
+    def __ior__(self, other: typing.Mapping) -> "MutableHeaders":
+        if not isinstance(other, typing.Mapping):
+            raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
+        self.update(other)
+        return self
+
+    def __or__(self, other: typing.Mapping) -> "MutableHeaders":
+        if not isinstance(other, typing.Mapping):
+            raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
+        new = self.mutablecopy()
+        new.update(other)
+        return new
+
     @property
     def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
         return self._list
@@ -636,7 +649,7 @@ class MutableHeaders(Headers):
         self._list.append((set_key, set_value))
         return value
 
-    def update(self, other: dict) -> None:
+    def update(self, other: typing.Mapping) -> None:
         for key, val in other.items():
             self[key] = val
 
index b110aa8bd01ffa299035662347d09c717f1cfe82..22e377c99a6505311964d301182dc4d1dd80859f 100644 (file)
@@ -162,6 +162,50 @@ def test_mutable_headers():
     assert h.raw == [(b"b", b"4")]
 
 
+def test_mutable_headers_merge():
+    h = MutableHeaders()
+    h = h | MutableHeaders({"a": "1"})
+    assert isinstance(h, MutableHeaders)
+    assert dict(h) == {"a": "1"}
+    assert h.items() == [("a", "1")]
+    assert h.raw == [(b"a", b"1")]
+
+
+def test_mutable_headers_merge_dict():
+    h = MutableHeaders()
+    h = h | {"a": "1"}
+    assert isinstance(h, MutableHeaders)
+    assert dict(h) == {"a": "1"}
+    assert h.items() == [("a", "1")]
+    assert h.raw == [(b"a", b"1")]
+
+
+def test_mutable_headers_update():
+    h = MutableHeaders()
+    h |= MutableHeaders({"a": "1"})
+    assert isinstance(h, MutableHeaders)
+    assert dict(h) == {"a": "1"}
+    assert h.items() == [("a", "1")]
+    assert h.raw == [(b"a", b"1")]
+
+
+def test_mutable_headers_update_dict():
+    h = MutableHeaders()
+    h |= {"a": "1"}
+    assert isinstance(h, MutableHeaders)
+    assert dict(h) == {"a": "1"}
+    assert h.items() == [("a", "1")]
+    assert h.raw == [(b"a", b"1")]
+
+
+def test_mutable_headers_merge_not_mapping():
+    h = MutableHeaders()
+    with pytest.raises(TypeError):
+        h |= {"not_mapping"}  # type: ignore
+    with pytest.raises(TypeError):
+        h | {"not_mapping"}  # type: ignore
+
+
 def test_headers_mutablecopy():
     h = Headers(raw=[(b"a", b"123"), (b"a", b"456"), (b"b", b"789")])
     c = h.mutablecopy()