]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Remove routing decorators in test_cors.py (#1498)
authorAmin Alaee <mohammadamin.alaee@gmail.com>
Thu, 10 Feb 2022 09:27:30 +0000 (10:27 +0100)
committerGitHub <noreply@github.com>
Thu, 10 Feb 2022 09:27:30 +0000 (10:27 +0100)
tests/middleware/test_cors.py

index 2f0ca3d34af645b1a4be4a2b5a73080dc198d6cd..910afd9f84f906d5544559b06c8060f4df6d2dd3 100644 (file)
@@ -1,24 +1,28 @@
 from starlette.applications import Starlette
+from starlette.middleware import Middleware
 from starlette.middleware.cors import CORSMiddleware
 from starlette.responses import PlainTextResponse
+from starlette.routing import Route
 
 
 def test_cors_allow_all(test_client_factory):
-    app = Starlette()
-
-    app.add_middleware(
-        CORSMiddleware,
-        allow_origins=["*"],
-        allow_headers=["*"],
-        allow_methods=["*"],
-        expose_headers=["X-Status"],
-        allow_credentials=True,
-    )
-
-    @app.route("/")
     def homepage(request):
         return PlainTextResponse("Homepage", status_code=200)
 
+    app = Starlette(
+        routes=[Route("/", endpoint=homepage)],
+        middleware=[
+            Middleware(
+                CORSMiddleware,
+                allow_origins=["*"],
+                allow_headers=["*"],
+                allow_methods=["*"],
+                expose_headers=["X-Status"],
+                allow_credentials=True,
+            )
+        ],
+    )
+
     client = test_client_factory(app)
 
     # Test pre-flight response
@@ -61,20 +65,22 @@ def test_cors_allow_all(test_client_factory):
 
 
 def test_cors_allow_all_except_credentials(test_client_factory):
-    app = Starlette()
-
-    app.add_middleware(
-        CORSMiddleware,
-        allow_origins=["*"],
-        allow_headers=["*"],
-        allow_methods=["*"],
-        expose_headers=["X-Status"],
-    )
-
-    @app.route("/")
     def homepage(request):
         return PlainTextResponse("Homepage", status_code=200)
 
+    app = Starlette(
+        routes=[Route("/", endpoint=homepage)],
+        middleware=[
+            Middleware(
+                CORSMiddleware,
+                allow_origins=["*"],
+                allow_headers=["*"],
+                allow_methods=["*"],
+                expose_headers=["X-Status"],
+            )
+        ],
+    )
+
     client = test_client_factory(app)
 
     # Test pre-flight response
@@ -108,18 +114,20 @@ def test_cors_allow_all_except_credentials(test_client_factory):
 
 
 def test_cors_allow_specific_origin(test_client_factory):
-    app = Starlette()
-
-    app.add_middleware(
-        CORSMiddleware,
-        allow_origins=["https://example.org"],
-        allow_headers=["X-Example", "Content-Type"],
-    )
-
-    @app.route("/")
     def homepage(request):
         return PlainTextResponse("Homepage", status_code=200)
 
+    app = Starlette(
+        routes=[Route("/", endpoint=homepage)],
+        middleware=[
+            Middleware(
+                CORSMiddleware,
+                allow_origins=["https://example.org"],
+                allow_headers=["X-Example", "Content-Type"],
+            )
+        ],
+    )
+
     client = test_client_factory(app)
 
     # Test pre-flight response
@@ -153,18 +161,20 @@ def test_cors_allow_specific_origin(test_client_factory):
 
 
 def test_cors_disallowed_preflight(test_client_factory):
-    app = Starlette()
-
-    app.add_middleware(
-        CORSMiddleware,
-        allow_origins=["https://example.org"],
-        allow_headers=["X-Example"],
-    )
-
-    @app.route("/")
     def homepage(request):
         pass  # pragma: no cover
 
+    app = Starlette(
+        routes=[Route("/", endpoint=homepage)],
+        middleware=[
+            Middleware(
+                CORSMiddleware,
+                allow_origins=["https://example.org"],
+                allow_headers=["X-Example"],
+            )
+        ],
+    )
+
     client = test_client_factory(app)
 
     # Test pre-flight response
@@ -192,19 +202,21 @@ def test_cors_disallowed_preflight(test_client_factory):
 def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_allowed(
     test_client_factory,
 ):
-    app = Starlette()
-
-    app.add_middleware(
-        CORSMiddleware,
-        allow_origins=["*"],
-        allow_methods=["POST"],
-        allow_credentials=True,
-    )
-
-    @app.route("/")
     def homepage(request):
         return  # pragma: no cover
 
+    app = Starlette(
+        routes=[Route("/", endpoint=homepage)],
+        middleware=[
+            Middleware(
+                CORSMiddleware,
+                allow_origins=["*"],
+                allow_methods=["POST"],
+                allow_credentials=True,
+            )
+        ],
+    )
+
     client = test_client_factory(app)
 
     # Test pre-flight response
@@ -223,18 +235,16 @@ def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_all
 
 
 def test_cors_preflight_allow_all_methods(test_client_factory):
-    app = Starlette()
-
-    app.add_middleware(
-        CORSMiddleware,
-        allow_origins=["*"],
-        allow_methods=["*"],
-    )
-
-    @app.route("/")
     def homepage(request):
         pass  # pragma: no cover
 
+    app = Starlette(
+        routes=[Route("/", endpoint=homepage)],
+        middleware=[
+            Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])
+        ],
+    )
+
     client = test_client_factory(app)
 
     headers = {
@@ -249,20 +259,22 @@ def test_cors_preflight_allow_all_methods(test_client_factory):
 
 
 def test_cors_allow_all_methods(test_client_factory):
-    app = Starlette()
-
-    app.add_middleware(
-        CORSMiddleware,
-        allow_origins=["*"],
-        allow_methods=["*"],
-    )
-
-    @app.route(
-        "/", methods=["delete", "get", "head", "options", "patch", "post", "put"]
-    )
     def homepage(request):
         return PlainTextResponse("Homepage", status_code=200)
 
+    app = Starlette(
+        routes=[
+            Route(
+                "/",
+                endpoint=homepage,
+                methods=["delete", "get", "head", "options", "patch", "post", "put"],
+            )
+        ],
+        middleware=[
+            Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])
+        ],
+    )
+
     client = test_client_factory(app)
 
     headers = {"Origin": "https://example.org"}
@@ -273,19 +285,21 @@ def test_cors_allow_all_methods(test_client_factory):
 
 
 def test_cors_allow_origin_regex(test_client_factory):
-    app = Starlette()
-
-    app.add_middleware(
-        CORSMiddleware,
-        allow_headers=["X-Example", "Content-Type"],
-        allow_origin_regex="https://.*",
-        allow_credentials=True,
-    )
-
-    @app.route("/")
     def homepage(request):
         return PlainTextResponse("Homepage", status_code=200)
 
+    app = Starlette(
+        routes=[Route("/", endpoint=homepage)],
+        middleware=[
+            Middleware(
+                CORSMiddleware,
+                allow_headers=["X-Example", "Content-Type"],
+                allow_origin_regex="https://.*",
+                allow_credentials=True,
+            )
+        ],
+    )
+
     client = test_client_factory(app)
 
     # Test standard response
@@ -341,18 +355,20 @@ def test_cors_allow_origin_regex(test_client_factory):
 
 
 def test_cors_allow_origin_regex_fullmatch(test_client_factory):
-    app = Starlette()
-
-    app.add_middleware(
-        CORSMiddleware,
-        allow_headers=["X-Example", "Content-Type"],
-        allow_origin_regex=r"https://.*\.example.org",
-    )
-
-    @app.route("/")
     def homepage(request):
         return PlainTextResponse("Homepage", status_code=200)
 
+    app = Starlette(
+        routes=[Route("/", endpoint=homepage)],
+        middleware=[
+            Middleware(
+                CORSMiddleware,
+                allow_headers=["X-Example", "Content-Type"],
+                allow_origin_regex=r"https://.*\.example.org",
+            )
+        ],
+    )
+
     client = test_client_factory(app)
 
     # Test standard response
@@ -375,14 +391,13 @@ def test_cors_allow_origin_regex_fullmatch(test_client_factory):
 
 
 def test_cors_credentialed_requests_return_specific_origin(test_client_factory):
-    app = Starlette()
-
-    app.add_middleware(CORSMiddleware, allow_origins=["*"])
-
-    @app.route("/")
     def homepage(request):
         return PlainTextResponse("Homepage", status_code=200)
 
+    app = Starlette(
+        routes=[Route("/", endpoint=homepage)],
+        middleware=[Middleware(CORSMiddleware, allow_origins=["*"])],
+    )
     client = test_client_factory(app)
 
     # Test credentialed request
@@ -395,16 +410,16 @@ def test_cors_credentialed_requests_return_specific_origin(test_client_factory):
 
 
 def test_cors_vary_header_defaults_to_origin(test_client_factory):
-    app = Starlette()
+    def homepage(request):
+        return PlainTextResponse("Homepage", status_code=200)
 
-    app.add_middleware(CORSMiddleware, allow_origins=["https://example.org"])
+    app = Starlette(
+        routes=[Route("/", endpoint=homepage)],
+        middleware=[Middleware(CORSMiddleware, allow_origins=["https://example.org"])],
+    )
 
     headers = {"Origin": "https://example.org"}
 
-    @app.route("/")
-    def homepage(request):
-        return PlainTextResponse("Homepage", status_code=200)
-
     client = test_client_factory(app)
 
     response = client.get("/", headers=headers)
@@ -413,16 +428,15 @@ def test_cors_vary_header_defaults_to_origin(test_client_factory):
 
 
 def test_cors_vary_header_is_not_set_for_non_credentialed_request(test_client_factory):
-    app = Starlette()
-
-    app.add_middleware(CORSMiddleware, allow_origins=["*"])
-
-    @app.route("/")
     def homepage(request):
         return PlainTextResponse(
             "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
         )
 
+    app = Starlette(
+        routes=[Route("/", endpoint=homepage)],
+        middleware=[Middleware(CORSMiddleware, allow_origins=["*"])],
+    )
     client = test_client_factory(app)
 
     response = client.get("/", headers={"Origin": "https://someplace.org"})
@@ -431,16 +445,15 @@ def test_cors_vary_header_is_not_set_for_non_credentialed_request(test_client_fa
 
 
 def test_cors_vary_header_is_properly_set_for_credentialed_request(test_client_factory):
-    app = Starlette()
-
-    app.add_middleware(CORSMiddleware, allow_origins=["*"])
-
-    @app.route("/")
     def homepage(request):
         return PlainTextResponse(
             "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
         )
 
+    app = Starlette(
+        routes=[Route("/", endpoint=homepage)],
+        middleware=[Middleware(CORSMiddleware, allow_origins=["*"])],
+    )
     client = test_client_factory(app)
 
     response = client.get(
@@ -453,16 +466,17 @@ def test_cors_vary_header_is_properly_set_for_credentialed_request(test_client_f
 def test_cors_vary_header_is_properly_set_when_allow_origins_is_not_wildcard(
     test_client_factory,
 ):
-    app = Starlette()
-
-    app.add_middleware(CORSMiddleware, allow_origins=["https://example.org"])
-
-    @app.route("/")
     def homepage(request):
         return PlainTextResponse(
             "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
         )
 
+    app = Starlette(
+        routes=[
+            Route("/", endpoint=homepage),
+        ],
+        middleware=[Middleware(CORSMiddleware, allow_origins=["https://example.org"])],
+    )
     client = test_client_factory(app)
 
     response = client.get("/", headers={"Origin": "https://example.org"})
@@ -473,16 +487,23 @@ def test_cors_vary_header_is_properly_set_when_allow_origins_is_not_wildcard(
 def test_cors_allowed_origin_does_not_leak_between_credentialed_requests(
     test_client_factory,
 ):
-    app = Starlette()
-
-    app.add_middleware(
-        CORSMiddleware, allow_origins=["*"], allow_headers=["*"], allow_methods=["*"]
-    )
-
-    @app.route("/")
     def homepage(request):
         return PlainTextResponse("Homepage", status_code=200)
 
+    app = Starlette(
+        routes=[
+            Route("/", endpoint=homepage),
+        ],
+        middleware=[
+            Middleware(
+                CORSMiddleware,
+                allow_origins=["*"],
+                allow_headers=["*"],
+                allow_methods=["*"],
+            )
+        ],
+    )
+
     client = test_client_factory(app)
     response = client.get("/", headers={"Origin": "https://someplace.org"})
     assert response.headers["access-control-allow-origin"] == "*"