self.compresslevel = compresslevel
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- if scope["type"] == "http": # pragma: no branch
- headers = Headers(scope=scope)
- if "gzip" in headers.get("Accept-Encoding", ""):
- responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel)
- await responder(scope, receive, send)
- return
- await self.app(scope, receive, send)
+ if scope["type"] != "http": # pragma: no cover
+ await self.app(scope, receive, send)
+ return
+ headers = Headers(scope=scope)
+ responder: ASGIApp
+ if "gzip" in headers.get("Accept-Encoding", ""):
+ responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel)
+ else:
+ responder = IdentityResponder(self.app, self.minimum_size)
-class GZipResponder:
- def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None:
+ await responder(scope, receive, send)
+
+
+class IdentityResponder:
+ content_encoding: str
+
+ def __init__(self, app: ASGIApp, minimum_size: int) -> None:
self.app = app
self.minimum_size = minimum_size
self.send: Send = unattached_send
self.started = False
self.content_encoding_set = False
self.content_type_is_excluded = False
- self.gzip_buffer = io.BytesIO()
- self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
self.send = send
- with self.gzip_buffer, self.gzip_file:
- await self.app(scope, receive, self.send_with_gzip)
+ await self.app(scope, receive, self.send_with_compression)
- async def send_with_gzip(self, message: Message) -> None:
+ async def send_with_compression(self, message: Message) -> None:
message_type = message["type"]
if message_type == "http.response.start":
# Don't send the initial message until we've determined how to
body = message.get("body", b"")
more_body = message.get("more_body", False)
if len(body) < self.minimum_size and not more_body:
- # Don't apply GZip to small outgoing responses.
+ # Don't apply compression to small outgoing responses.
await self.send(self.initial_message)
await self.send(message)
elif not more_body:
- # Standard GZip response.
- self.gzip_file.write(body)
- self.gzip_file.close()
- body = self.gzip_buffer.getvalue()
+ # Standard response.
+ body = self.apply_compression(body, more_body=False)
headers = MutableHeaders(raw=self.initial_message["headers"])
- headers["Content-Encoding"] = "gzip"
- headers["Content-Length"] = str(len(body))
headers.add_vary_header("Accept-Encoding")
- message["body"] = body
+ if body != message["body"]:
+ headers["Content-Encoding"] = self.content_encoding
+ headers["Content-Length"] = str(len(body))
+ message["body"] = body
await self.send(self.initial_message)
await self.send(message)
else:
- # Initial body in streaming GZip response.
+ # Initial body in streaming response.
+ body = self.apply_compression(body, more_body=True)
+
headers = MutableHeaders(raw=self.initial_message["headers"])
- headers["Content-Encoding"] = "gzip"
headers.add_vary_header("Accept-Encoding")
- del headers["Content-Length"]
-
- self.gzip_file.write(body)
- message["body"] = self.gzip_buffer.getvalue()
- self.gzip_buffer.seek(0)
- self.gzip_buffer.truncate()
+ if body != message["body"]:
+ headers["Content-Encoding"] = self.content_encoding
+ del headers["Content-Length"]
+ message["body"] = body
await self.send(self.initial_message)
await self.send(message)
-
elif message_type == "http.response.body": # pragma: no branch
- # Remaining body in streaming GZip response.
+ # Remaining body in streaming response.
body = message.get("body", b"")
more_body = message.get("more_body", False)
- self.gzip_file.write(body)
- if not more_body:
- self.gzip_file.close()
-
- message["body"] = self.gzip_buffer.getvalue()
- self.gzip_buffer.seek(0)
- self.gzip_buffer.truncate()
+ message["body"] = self.apply_compression(body, more_body=more_body)
await self.send(message)
+ def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
+ """Apply compression on the response body.
+
+ If more_body is False, any compression file should be closed. If it
+ isn't, it won't be closed automatically until all background tasks
+ complete.
+ """
+ return body
+
+
+class GZipResponder(IdentityResponder):
+ content_encoding = "gzip"
+
+ def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None:
+ super().__init__(app, minimum_size)
+
+ self.gzip_buffer = io.BytesIO()
+ self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel)
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ with self.gzip_buffer, self.gzip_file:
+ await super().__call__(scope, receive, send)
+
+ def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
+ self.gzip_file.write(body)
+ if not more_body:
+ self.gzip_file.close()
+
+ body = self.gzip_buffer.getvalue()
+ self.gzip_buffer.seek(0)
+ self.gzip_buffer.truncate()
+
+ return body
+
async def unattached_send(message: Message) -> typing.NoReturn:
raise RuntimeError("send awaitable not set") # pragma: no cover
+from __future__ import annotations
+
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.gzip import GZipMiddleware
assert response.status_code == 200
assert response.text == "x" * 4000
assert response.headers["Content-Encoding"] == "gzip"
+ assert response.headers["Vary"] == "Accept-Encoding"
assert int(response.headers["Content-Length"]) < 4000
assert response.status_code == 200
assert response.text == "x" * 4000
assert "Content-Encoding" not in response.headers
+ assert response.headers["Vary"] == "Accept-Encoding"
assert int(response.headers["Content-Length"]) == 4000
assert response.status_code == 200
assert response.text == "OK"
assert "Content-Encoding" not in response.headers
+ assert "Vary" not in response.headers
assert int(response.headers["Content-Length"]) == 2
assert response.status_code == 200
assert response.text == "x" * 4000
assert response.headers["Content-Encoding"] == "gzip"
+ assert response.headers["Vary"] == "Accept-Encoding"
+ assert "Content-Length" not in response.headers
+
+
+def test_gzip_streaming_response_identity(test_client_factory: TestClientFactory) -> None:
+ def homepage(request: Request) -> StreamingResponse:
+ async def generator(bytes: bytes, count: int) -> ContentStream:
+ for index in range(count):
+ yield bytes
+
+ streaming = generator(bytes=b"x" * 400, count=10)
+ return StreamingResponse(streaming, status_code=200)
+
+ app = Starlette(
+ routes=[Route("/", endpoint=homepage)],
+ middleware=[Middleware(GZipMiddleware)],
+ )
+
+ client = test_client_factory(app)
+ response = client.get("/", headers={"accept-encoding": "identity"})
+ assert response.status_code == 200
+ assert response.text == "x" * 4000
+ assert "Content-Encoding" not in response.headers
+ assert response.headers["Vary"] == "Accept-Encoding"
assert "Content-Length" not in response.headers
assert response.status_code == 200
assert response.text == "x" * 4000
assert response.headers["Content-Encoding"] == "text"
+ assert "Vary" not in response.headers
assert "Content-Length" not in response.headers