]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Don't omit `Content-Length` header for `Content-Length: 0` cases (#1395)
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Fri, 7 Jan 2022 11:48:21 +0000 (12:48 +0100)
committerGitHub <noreply@github.com>
Fri, 7 Jan 2022 11:48:21 +0000 (12:48 +0100)
* Add content-length header by default

* Add test for #1099

* Revert changes and add tests

* Check if is StreamingResponse or FileResponse before adding content-length headers

* Change conditional logic to check if body is present

starlette/responses.py
starlette/staticfiles.py
tests/test_responses.py

index ffde4b97def6b0a9affe2d65cc99268a0692c45d..da765cfa9a78df923fa1225f5b2a663512302340 100644 (file)
@@ -70,8 +70,8 @@ class Response:
             populate_content_length = b"content-length" not in keys
             populate_content_type = b"content-type" not in keys
 
-        body = getattr(self, "body", b"")
-        if body and populate_content_length:
+        body = getattr(self, "body", None)
+        if body is not None and populate_content_length:
             content_length = str(len(body))
             raw_headers.append((b"content-length", content_length.encode("latin-1")))
 
index 76e435310be484143059a6bb83b3ee061363c27a..bd4d8bced1274f677a6276243a13ae3665d13674 100644 (file)
@@ -100,7 +100,7 @@ class StaticFiles:
     def get_path(self, scope: Scope) -> str:
         """
         Given the ASGI scope, return the `path` string to serve up,
-        with OS specific path seperators, and any '..', '.' components removed.
+        with OS specific path separators, and any '..', '.' components removed.
         """
         return os.path.normpath(os.path.join(*scope["path"].split("/")))
 
index baba549baf9a5adc1521e78471937112aad8ac5a..150fe47952101e10eff7670e7668bd3a3365c15e 100644 (file)
@@ -13,6 +13,7 @@ from starlette.responses import (
     Response,
     StreamingResponse,
 )
+from starlette.testclient import TestClient
 
 
 def test_text_response(test_client_factory):
@@ -73,6 +74,20 @@ def test_quoting_redirect_response(test_client_factory):
     assert response.url == "http://testserver/I%20%E2%99%A5%20Starlette/"
 
 
+def test_redirect_response_content_length_header(test_client_factory):
+    async def app(scope, receive, send):
+        if scope["path"] == "/":
+            response = Response("hello", media_type="text/plain")  # pragma: nocover
+        else:
+            response = RedirectResponse("/")
+        await response(scope, receive, send)
+
+    client: TestClient = test_client_factory(app)
+    response = client.request("GET", "/redirect", allow_redirects=False)
+    assert response.url == "http://testserver/redirect"
+    assert response.headers["content-length"] == "0"
+
+
 def test_streaming_response(test_client_factory):
     filled_by_bg_task = ""
 
@@ -309,3 +324,45 @@ def test_head_method(test_client_factory):
     client = test_client_factory(app)
     response = client.head("/")
     assert response.text == ""
+
+
+def test_empty_response(test_client_factory):
+    app = Response()
+    client: TestClient = test_client_factory(app)
+    response = client.get("/")
+    assert response.headers["content-length"] == "0"
+
+
+def test_non_empty_response(test_client_factory):
+    app = Response(content="hi")
+    client: TestClient = test_client_factory(app)
+    response = client.get("/")
+    assert response.headers["content-length"] == "2"
+
+
+def test_file_response_known_size(tmpdir, test_client_factory):
+    path = os.path.join(tmpdir, "xyz")
+    content = b"<file content>" * 1000
+    with open(path, "wb") as file:
+        file.write(content)
+
+    app = FileResponse(path=path, filename="example.png")
+    client: TestClient = test_client_factory(app)
+    response = client.get("/")
+    assert response.headers["content-length"] == str(len(content))
+
+
+def test_streaming_response_unknown_size(test_client_factory):
+    app = StreamingResponse(content=iter(["hello", "world"]))
+    client: TestClient = test_client_factory(app)
+    response = client.get("/")
+    assert "content-length" not in response.headers
+
+
+def test_streaming_response_known_size(test_client_factory):
+    app = StreamingResponse(
+        content=iter(["hello", "world"]), headers={"content-length": "10"}
+    )
+    client: TestClient = test_client_factory(app)
+    response = client.get("/")
+    assert response.headers["content-length"] == "10"