From 5b1348a7cafc3f7ab259c4b0490924e57f43e40a Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Fri, 27 Mar 2020 18:14:05 +0100 Subject: [PATCH] Fix safelisted CORS headers implementation --- starlette/middleware/cors.py | 6 +++--- tests/middleware/test_cors.py | 8 ++++++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/starlette/middleware/cors.py b/starlette/middleware/cors.py index 69249421..338aee86 100644 --- a/starlette/middleware/cors.py +++ b/starlette/middleware/cors.py @@ -7,7 +7,7 @@ from starlette.responses import PlainTextResponse, Response from starlette.types import ASGIApp, Message, Receive, Scope, Send ALL_METHODS = ("DELETE", "GET", "OPTIONS", "PATCH", "POST", "PUT") -SAFELISTED_HEADERS = {"accept", "accept-language", "content-language", "content-type"} +SAFELISTED_HEADERS = {"Accept", "Accept-Language", "Content-Language", "Content-Type"} class CORSMiddleware: @@ -49,7 +49,7 @@ class CORSMiddleware: "Access-Control-Max-Age": str(max_age), } ) - allow_headers = SAFELISTED_HEADERS | set([h.lower for h in allow_headers]) + allow_headers = sorted(SAFELISTED_HEADERS | set(allow_headers)) if allow_headers and "*" not in allow_headers: preflight_headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers) if allow_credentials: @@ -58,7 +58,7 @@ class CORSMiddleware: self.app = app self.allow_origins = allow_origins self.allow_methods = allow_methods - self.allow_headers = allow_headers + self.allow_headers = [h.lower() for h in allow_headers] self.allow_all_origins = "*" in allow_origins self.allow_all_headers = "*" in allow_headers self.allow_origin_regex = compiled_allow_origin_regex diff --git a/tests/middleware/test_cors.py b/tests/middleware/test_cors.py index e8bf72fd..a5b6e623 100644 --- a/tests/middleware/test_cors.py +++ b/tests/middleware/test_cors.py @@ -74,7 +74,9 @@ def test_cors_allow_specific_origin(): assert response.status_code == 200 assert response.text == "OK" assert response.headers["access-control-allow-origin"] == "https://example.org" - assert response.headers["access-control-allow-headers"] == "X-Example, Content-Type" + assert response.headers["access-control-allow-headers"] == ( + "Accept, Accept-Language, Content-Language, Content-Type, X-Example" + ) # Test standard response headers = {"Origin": "https://example.org"} @@ -157,7 +159,9 @@ def test_cors_allow_origin_regex(): assert response.status_code == 200 assert response.text == "OK" assert response.headers["access-control-allow-origin"] == "https://another.com" - assert response.headers["access-control-allow-headers"] == "X-Example, Content-Type" + assert response.headers["access-control-allow-headers"] == ( + "Accept, Accept-Language, Content-Language, Content-Type, X-Example" + ) # Test disallowed pre-flight response headers = { -- 2.47.2