from starlette.datastructures import Headers, MutableHeaders
from starlette.types import ASGIApp, Message, Receive, Scope, Send
+DEFAULT_EXCLUDED_CONTENT_TYPES = ("text/event-stream",)
+
class GZipMiddleware:
def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9) -> None:
self.initial_message: Message = {}
self.started = False
self.content_encoding_set = False
+ self.content_type_is_excluded = False
self.gzip_buffer = io.BytesIO()
self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel)
self.initial_message = message
headers = Headers(raw=self.initial_message["headers"])
self.content_encoding_set = "content-encoding" in headers
- elif message_type == "http.response.body" and self.content_encoding_set:
+ self.content_type_is_excluded = headers.get("content-type", "").startswith(DEFAULT_EXCLUDED_CONTENT_TYPES)
+ elif message_type == "http.response.body" and (self.content_encoding_set or self.content_type_is_excluded):
if not self.started:
self.started = True
await self.send(self.initial_message)
assert response.text == "x" * 4000
assert response.headers["Content-Encoding"] == "text"
assert "Content-Length" not in response.headers
+
+
+def test_gzip_ignored_on_server_sent_events(test_client_factory: TestClientFactory) -> None:
+ def homepage(request: Request) -> StreamingResponse:
+ async def generator(bytes: bytes, count: int) -> ContentStream:
+ for _ in range(count):
+ yield bytes
+
+ streaming = generator(bytes=b"x" * 400, count=10)
+ return StreamingResponse(streaming, status_code=200, media_type="text/event-stream")
+
+ app = Starlette(
+ routes=[Route("/", endpoint=homepage)],
+ middleware=[Middleware(GZipMiddleware)],
+ )
+
+ client = test_client_factory(app)
+ response = client.get("/", headers={"accept-encoding": "gzip"})
+ assert response.status_code == 200
+ assert response.text == "x" * 4000
+ assert "Content-Encoding" not in response.headers
+ assert "Content-Length" not in response.headers