]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Use getlist, instead of get_list
authorTom Christie <tom@tomchristie.com>
Wed, 11 Jul 2018 15:30:40 +0000 (16:30 +0100)
committerTom Christie <tom@tomchristie.com>
Wed, 11 Jul 2018 15:30:40 +0000 (16:30 +0100)
starlette/datastructures.py
tests/test_datastructures.py
tests/test_response.py

index 71c52c04d43682e902e53f30946185eaee0ebf4c..247157cab64b52daced5a4d4a0cb4fd17ce85b60 100644 (file)
@@ -75,7 +75,7 @@ class QueryParams(typing.Mapping[str, str]):
         self._dict = {k: v for k, v in reversed(items)}
         self._list = items
 
-    def get_list(self, key: str) -> typing.List[str]:
+    def getlist(self, key: str) -> typing.List[str]:
         return [item_value for item_key, item_value in self._list if item_key == key]
 
     def keys(self):
index c510d00341f69bd9a214f3f3bceb0ef2be28fa8e..0a5a8441c7f13608c0e544306194fc32b09a9d6e 100644 (file)
@@ -1,4 +1,4 @@
-from starlette.datastructures import Headers, QueryParams, URL
+from starlette.datastructures import Headers, MutableHeaders, QueryParams, URL
 
 
 def test_url():
@@ -25,7 +25,7 @@ def test_headers():
     assert h["a"] == "123"
     assert h.get("a") == "123"
     assert h.get("nope", default=None) is None
-    assert h.get_list("a") == ["123", "456"]
+    assert h.getlist("a") == ["123", "456"]
     assert h.keys() == ["a", "a", "b"]
     assert h.values() == ["123", "456", "789"]
     assert h.items() == [("a", "123"), ("a", "456"), ("b", "789")]
@@ -34,8 +34,21 @@ def test_headers():
     assert repr(h) == "Headers([('a', '123'), ('a', '456'), ('b', '789')])"
     assert h == Headers([(b"a", b"123"), (b"b", b"789"), (b"a", b"456")])
     assert h != [(b"a", b"123"), (b"A", b"456"), (b"b", b"789")]
-    h = Headers()
-    assert not h.items()
+
+
+def test_mutable_headers():
+    h = MutableHeaders()
+    assert dict(h) == {}
+    h["a"] = "1"
+    assert dict(h) == {"a": "1"}
+    h["a"] = "2"
+    assert dict(h) == {"a": "2"}
+    h.setdefault("a", "3")
+    assert dict(h) == {"a": "2"}
+    h.setdefault("b", "4")
+    assert dict(h) == {"a": "2", "b": "4"}
+    del h["a"]
+    assert dict(h) == {"b": "4"}
 
 
 def test_queryparams():
@@ -46,7 +59,7 @@ def test_queryparams():
     assert q["a"] == "123"
     assert q.get("a") == "123"
     assert q.get("nope", default=None) is None
-    assert q.get_list("a") == ["123", "456"]
+    assert q.getlist("a") == ["123", "456"]
     assert q.keys() == ["a", "a", "b"]
     assert q.values() == ["123", "456", "789"]
     assert q.items() == [("a", "123"), ("a", "456"), ("b", "789")]
index edcabb84f4348784668b893a1adcdd76252207c2..d4850fc8b3b636f82378e3ec9987db71ec90a857 100644 (file)
@@ -1,4 +1,4 @@
-from starlette import Response, StreamingResponse, TestClient
+from starlette import FileResponse, Response, StreamingResponse, TestClient
 import asyncio
 
 
@@ -67,22 +67,18 @@ def test_response_headers():
     assert response.headers["x-header-2"] == "789"
 
 
-def test_streaming_response_headers():
-    def app(scope):
-        async def asgi(receive, send):
-            async def stream(msg):
-                yield "hello, world"
+def test_file_response(tmpdir):
+    with open("xyz", "wb") as file:
+        file.write(b"<file content>")
 
-            headers = {"x-header-1": "123", "x-header-2": "456"}
-            response = StreamingResponse(
-                stream("hello, world"), media_type="text/plain", headers=headers
-            )
-            response.headers["x-header-2"] = "789"
-            await response(receive, send)
-
-        return asgi
+    def app(scope):
+        return FileResponse(path="xyz", filename="example.png")
 
     client = TestClient(app)
     response = client.get("/")
-    assert response.headers["x-header-1"] == "123"
-    assert response.headers["x-header-2"] == "789"
+    assert response.status_code == 200
+    assert response.content == b"<file content>"
+    assert response.headers["content-type"] == "image/png"
+    assert (
+        response.headers["content-disposition"] == 'attachment; filename="example.png"'
+    )