]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add type hints to `test_cors.py` (#2458)
authorScirlat Danut <danut.scirlat@gmail.com>
Sat, 3 Feb 2024 20:51:19 +0000 (22:51 +0200)
committerGitHub <noreply@github.com>
Sat, 3 Feb 2024 20:51:19 +0000 (13:51 -0700)
Co-authored-by: Scirlat Danut <scirlatdanut@scirlats-mini.lan>
tests/middleware/test_cors.py

index ca3d4f47b0dee237f160d88e3021bdc1878c59cc..09ec9513f3d7f0a17336af522d64f71539b37482 100644 (file)
@@ -1,12 +1,21 @@
+from typing import Callable
+
 from starlette.applications import Starlette
 from starlette.middleware import Middleware
 from starlette.middleware.cors import CORSMiddleware
+from starlette.requests import Request
 from starlette.responses import PlainTextResponse
 from starlette.routing import Route
+from starlette.testclient import TestClient
+from starlette.types import ASGIApp
+
+TestClientFactory = Callable[[ASGIApp], TestClient]
 
 
-def test_cors_allow_all(test_client_factory):
-    def homepage(request):
+def test_cors_allow_all(
+    test_client_factory: TestClientFactory,
+) -> None:
+    def homepage(request: Request) -> PlainTextResponse:
         return PlainTextResponse("Homepage", status_code=200)
 
     app = Starlette(
@@ -64,8 +73,10 @@ def test_cors_allow_all(test_client_factory):
     assert "access-control-allow-origin" not in response.headers
 
 
-def test_cors_allow_all_except_credentials(test_client_factory):
-    def homepage(request):
+def test_cors_allow_all_except_credentials(
+    test_client_factory: TestClientFactory,
+) -> None:
+    def homepage(request: Request) -> PlainTextResponse:
         return PlainTextResponse("Homepage", status_code=200)
 
     app = Starlette(
@@ -113,8 +124,10 @@ def test_cors_allow_all_except_credentials(test_client_factory):
     assert "access-control-allow-origin" not in response.headers
 
 
-def test_cors_allow_specific_origin(test_client_factory):
-    def homepage(request):
+def test_cors_allow_specific_origin(
+    test_client_factory: TestClientFactory,
+) -> None:
+    def homepage(request: Request) -> PlainTextResponse:
         return PlainTextResponse("Homepage", status_code=200)
 
     app = Starlette(
@@ -160,8 +173,10 @@ def test_cors_allow_specific_origin(test_client_factory):
     assert "access-control-allow-origin" not in response.headers
 
 
-def test_cors_disallowed_preflight(test_client_factory):
-    def homepage(request):
+def test_cors_disallowed_preflight(
+    test_client_factory: TestClientFactory,
+) -> None:
+    def homepage(request: Request) -> None:
         pass  # pragma: no cover
 
     app = Starlette(
@@ -200,9 +215,9 @@ def test_cors_disallowed_preflight(test_client_factory):
 
 
 def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_allowed(
-    test_client_factory,
-):
-    def homepage(request):
+    test_client_factory: TestClientFactory,
+) -> None:
+    def homepage(request: Request) -> None:
         return  # pragma: no cover
 
     app = Starlette(
@@ -234,8 +249,10 @@ def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_all
     assert response.headers["vary"] == "Origin"
 
 
-def test_cors_preflight_allow_all_methods(test_client_factory):
-    def homepage(request):
+def test_cors_preflight_allow_all_methods(
+    test_client_factory: TestClientFactory,
+) -> None:
+    def homepage(request: Request) -> None:
         pass  # pragma: no cover
 
     app = Starlette(
@@ -258,8 +275,10 @@ def test_cors_preflight_allow_all_methods(test_client_factory):
         assert method in response.headers["access-control-allow-methods"]
 
 
-def test_cors_allow_all_methods(test_client_factory):
-    def homepage(request):
+def test_cors_allow_all_methods(
+    test_client_factory: TestClientFactory,
+) -> None:
+    def homepage(request: Request) -> PlainTextResponse:
         return PlainTextResponse("Homepage", status_code=200)
 
     app = Starlette(
@@ -287,8 +306,10 @@ def test_cors_allow_all_methods(test_client_factory):
         assert response.status_code == 200
 
 
-def test_cors_allow_origin_regex(test_client_factory):
-    def homepage(request):
+def test_cors_allow_origin_regex(
+    test_client_factory: TestClientFactory,
+) -> None:
+    def homepage(request: Request) -> PlainTextResponse:
         return PlainTextResponse("Homepage", status_code=200)
 
     app = Starlette(
@@ -357,8 +378,10 @@ def test_cors_allow_origin_regex(test_client_factory):
     assert "access-control-allow-origin" not in response.headers
 
 
-def test_cors_allow_origin_regex_fullmatch(test_client_factory):
-    def homepage(request):
+def test_cors_allow_origin_regex_fullmatch(
+    test_client_factory: TestClientFactory,
+) -> None:
+    def homepage(request: Request) -> PlainTextResponse:
         return PlainTextResponse("Homepage", status_code=200)
 
     app = Starlette(
@@ -393,8 +416,10 @@ def test_cors_allow_origin_regex_fullmatch(test_client_factory):
     assert "access-control-allow-origin" not in response.headers
 
 
-def test_cors_credentialed_requests_return_specific_origin(test_client_factory):
-    def homepage(request):
+def test_cors_credentialed_requests_return_specific_origin(
+    test_client_factory: TestClientFactory,
+) -> None:
+    def homepage(request: Request) -> PlainTextResponse:
         return PlainTextResponse("Homepage", status_code=200)
 
     app = Starlette(
@@ -412,8 +437,10 @@ def test_cors_credentialed_requests_return_specific_origin(test_client_factory):
     assert "access-control-allow-credentials" not in response.headers
 
 
-def test_cors_vary_header_defaults_to_origin(test_client_factory):
-    def homepage(request):
+def test_cors_vary_header_defaults_to_origin(
+    test_client_factory: TestClientFactory,
+) -> None:
+    def homepage(request: Request) -> PlainTextResponse:
         return PlainTextResponse("Homepage", status_code=200)
 
     app = Starlette(
@@ -430,8 +457,10 @@ def test_cors_vary_header_defaults_to_origin(test_client_factory):
     assert response.headers["vary"] == "Origin"
 
 
-def test_cors_vary_header_is_not_set_for_non_credentialed_request(test_client_factory):
-    def homepage(request):
+def test_cors_vary_header_is_not_set_for_non_credentialed_request(
+    test_client_factory: TestClientFactory,
+) -> None:
+    def homepage(request: Request) -> PlainTextResponse:
         return PlainTextResponse(
             "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
         )
@@ -447,8 +476,10 @@ def test_cors_vary_header_is_not_set_for_non_credentialed_request(test_client_fa
     assert response.headers["vary"] == "Accept-Encoding"
 
 
-def test_cors_vary_header_is_properly_set_for_credentialed_request(test_client_factory):
-    def homepage(request):
+def test_cors_vary_header_is_properly_set_for_credentialed_request(
+    test_client_factory: TestClientFactory,
+) -> None:
+    def homepage(request: Request) -> PlainTextResponse:
         return PlainTextResponse(
             "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
         )
@@ -467,9 +498,9 @@ def test_cors_vary_header_is_properly_set_for_credentialed_request(test_client_f
 
 
 def test_cors_vary_header_is_properly_set_when_allow_origins_is_not_wildcard(
-    test_client_factory,
-):
-    def homepage(request):
+    test_client_factory: TestClientFactory,
+) -> None:
+    def homepage(request: Request) -> PlainTextResponse:
         return PlainTextResponse(
             "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
         )
@@ -488,9 +519,9 @@ def test_cors_vary_header_is_properly_set_when_allow_origins_is_not_wildcard(
 
 
 def test_cors_allowed_origin_does_not_leak_between_credentialed_requests(
-    test_client_factory,
-):
-    def homepage(request):
+    test_client_factory: TestClientFactory,
+) -> None:
+    def homepage(request: Request) -> PlainTextResponse:
         return PlainTextResponse("Homepage", status_code=200)
 
     app = Starlette(