]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
fix(gzip): Make sure Vary header is always added if a response can be compressed...
authorMatthew Messinger <mattmess1221@gmail.com>
Sat, 22 Feb 2025 13:16:21 +0000 (08:16 -0500)
committerGitHub <noreply@github.com>
Sat, 22 Feb 2025 13:16:21 +0000 (10:16 -0300)
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
starlette/middleware/gzip.py
tests/middleware/test_gzip.py

index fc63e91b660abb55e459dd7634f7d84d0fc15ad4..c7fd5b77e6dcd00aa1dee733647fa9223c878ad1 100644 (file)
@@ -15,17 +15,24 @@ class GZipMiddleware:
         self.compresslevel = compresslevel
 
     async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
-        if scope["type"] == "http":  # pragma: no branch
-            headers = Headers(scope=scope)
-            if "gzip" in headers.get("Accept-Encoding", ""):
-                responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel)
-                await responder(scope, receive, send)
-                return
-        await self.app(scope, receive, send)
+        if scope["type"] != "http":  # pragma: no cover
+            await self.app(scope, receive, send)
+            return
 
+        headers = Headers(scope=scope)
+        responder: ASGIApp
+        if "gzip" in headers.get("Accept-Encoding", ""):
+            responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel)
+        else:
+            responder = IdentityResponder(self.app, self.minimum_size)
 
-class GZipResponder:
-    def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None:
+        await responder(scope, receive, send)
+
+
+class IdentityResponder:
+    content_encoding: str
+
+    def __init__(self, app: ASGIApp, minimum_size: int) -> None:
         self.app = app
         self.minimum_size = minimum_size
         self.send: Send = unattached_send
@@ -33,15 +40,12 @@ class GZipResponder:
         self.started = False
         self.content_encoding_set = False
         self.content_type_is_excluded = False
-        self.gzip_buffer = io.BytesIO()
-        self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel)
 
     async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
         self.send = send
-        with self.gzip_buffer, self.gzip_file:
-            await self.app(scope, receive, self.send_with_gzip)
+        await self.app(scope, receive, self.send_with_compression)
 
-    async def send_with_gzip(self, message: Message) -> None:
+    async def send_with_compression(self, message: Message) -> None:
         message_type = message["type"]
         if message_type == "http.response.start":
             # Don't send the initial message until we've determined how to
@@ -60,53 +64,78 @@ class GZipResponder:
             body = message.get("body", b"")
             more_body = message.get("more_body", False)
             if len(body) < self.minimum_size and not more_body:
-                # Don't apply GZip to small outgoing responses.
+                # Don't apply compression to small outgoing responses.
                 await self.send(self.initial_message)
                 await self.send(message)
             elif not more_body:
-                # Standard GZip response.
-                self.gzip_file.write(body)
-                self.gzip_file.close()
-                body = self.gzip_buffer.getvalue()
+                # Standard response.
+                body = self.apply_compression(body, more_body=False)
 
                 headers = MutableHeaders(raw=self.initial_message["headers"])
-                headers["Content-Encoding"] = "gzip"
-                headers["Content-Length"] = str(len(body))
                 headers.add_vary_header("Accept-Encoding")
-                message["body"] = body
+                if body != message["body"]:
+                    headers["Content-Encoding"] = self.content_encoding
+                    headers["Content-Length"] = str(len(body))
+                    message["body"] = body
 
                 await self.send(self.initial_message)
                 await self.send(message)
             else:
-                # Initial body in streaming GZip response.
+                # Initial body in streaming response.
+                body = self.apply_compression(body, more_body=True)
+
                 headers = MutableHeaders(raw=self.initial_message["headers"])
-                headers["Content-Encoding"] = "gzip"
                 headers.add_vary_header("Accept-Encoding")
-                del headers["Content-Length"]
-
-                self.gzip_file.write(body)
-                message["body"] = self.gzip_buffer.getvalue()
-                self.gzip_buffer.seek(0)
-                self.gzip_buffer.truncate()
+                if body != message["body"]:
+                    headers["Content-Encoding"] = self.content_encoding
+                    del headers["Content-Length"]
+                    message["body"] = body
 
                 await self.send(self.initial_message)
                 await self.send(message)
-
         elif message_type == "http.response.body":  # pragma: no branch
-            # Remaining body in streaming GZip response.
+            # Remaining body in streaming response.
             body = message.get("body", b"")
             more_body = message.get("more_body", False)
 
-            self.gzip_file.write(body)
-            if not more_body:
-                self.gzip_file.close()
-
-            message["body"] = self.gzip_buffer.getvalue()
-            self.gzip_buffer.seek(0)
-            self.gzip_buffer.truncate()
+            message["body"] = self.apply_compression(body, more_body=more_body)
 
             await self.send(message)
 
+    def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
+        """Apply compression on the response body.
+
+        If more_body is False, any compression file should be closed. If it
+        isn't, it won't be closed automatically until all background tasks
+        complete.
+        """
+        return body
+
+
+class GZipResponder(IdentityResponder):
+    content_encoding = "gzip"
+
+    def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None:
+        super().__init__(app, minimum_size)
+
+        self.gzip_buffer = io.BytesIO()
+        self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel)
+
+    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+        with self.gzip_buffer, self.gzip_file:
+            await super().__call__(scope, receive, send)
+
+    def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
+        self.gzip_file.write(body)
+        if not more_body:
+            self.gzip_file.close()
+
+        body = self.gzip_buffer.getvalue()
+        self.gzip_buffer.seek(0)
+        self.gzip_buffer.truncate()
+
+        return body
+
 
 async def unattached_send(message: Message) -> typing.NoReturn:
     raise RuntimeError("send awaitable not set")  # pragma: no cover
index 38a4e1e35b7fcb982ae6523b165daf67cd2252aa..48ded6ae8a126d9784af922bc483e2f07fc3cef4 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 from starlette.applications import Starlette
 from starlette.middleware import Middleware
 from starlette.middleware.gzip import GZipMiddleware
@@ -21,6 +23,7 @@ def test_gzip_responses(test_client_factory: TestClientFactory) -> None:
     assert response.status_code == 200
     assert response.text == "x" * 4000
     assert response.headers["Content-Encoding"] == "gzip"
+    assert response.headers["Vary"] == "Accept-Encoding"
     assert int(response.headers["Content-Length"]) < 4000
 
 
@@ -38,6 +41,7 @@ def test_gzip_not_in_accept_encoding(test_client_factory: TestClientFactory) ->
     assert response.status_code == 200
     assert response.text == "x" * 4000
     assert "Content-Encoding" not in response.headers
+    assert response.headers["Vary"] == "Accept-Encoding"
     assert int(response.headers["Content-Length"]) == 4000
 
 
@@ -57,6 +61,7 @@ def test_gzip_ignored_for_small_responses(
     assert response.status_code == 200
     assert response.text == "OK"
     assert "Content-Encoding" not in response.headers
+    assert "Vary" not in response.headers
     assert int(response.headers["Content-Length"]) == 2
 
 
@@ -79,6 +84,30 @@ def test_gzip_streaming_response(test_client_factory: TestClientFactory) -> None
     assert response.status_code == 200
     assert response.text == "x" * 4000
     assert response.headers["Content-Encoding"] == "gzip"
+    assert response.headers["Vary"] == "Accept-Encoding"
+    assert "Content-Length" not in response.headers
+
+
+def test_gzip_streaming_response_identity(test_client_factory: TestClientFactory) -> None:
+    def homepage(request: Request) -> StreamingResponse:
+        async def generator(bytes: bytes, count: int) -> ContentStream:
+            for index in range(count):
+                yield bytes
+
+        streaming = generator(bytes=b"x" * 400, count=10)
+        return StreamingResponse(streaming, status_code=200)
+
+    app = Starlette(
+        routes=[Route("/", endpoint=homepage)],
+        middleware=[Middleware(GZipMiddleware)],
+    )
+
+    client = test_client_factory(app)
+    response = client.get("/", headers={"accept-encoding": "identity"})
+    assert response.status_code == 200
+    assert response.text == "x" * 4000
+    assert "Content-Encoding" not in response.headers
+    assert response.headers["Vary"] == "Accept-Encoding"
     assert "Content-Length" not in response.headers
 
 
@@ -103,6 +132,7 @@ def test_gzip_ignored_for_responses_with_encoding_set(
     assert response.status_code == 200
     assert response.text == "x" * 4000
     assert response.headers["Content-Encoding"] == "text"
+    assert "Vary" not in response.headers
     assert "Content-Length" not in response.headers