+from typing import Callable
+
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
+from starlette.requests import Request
from starlette.responses import PlainTextResponse
from starlette.routing import Route
+from starlette.testclient import TestClient
+from starlette.types import ASGIApp
+
+TestClientFactory = Callable[[ASGIApp], TestClient]
-def test_cors_allow_all(test_client_factory):
- def homepage(request):
+def test_cors_allow_all(
+ test_client_factory: TestClientFactory,
+) -> None:
+ def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)
app = Starlette(
assert "access-control-allow-origin" not in response.headers
-def test_cors_allow_all_except_credentials(test_client_factory):
- def homepage(request):
+def test_cors_allow_all_except_credentials(
+ test_client_factory: TestClientFactory,
+) -> None:
+ def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)
app = Starlette(
assert "access-control-allow-origin" not in response.headers
-def test_cors_allow_specific_origin(test_client_factory):
- def homepage(request):
+def test_cors_allow_specific_origin(
+ test_client_factory: TestClientFactory,
+) -> None:
+ def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)
app = Starlette(
assert "access-control-allow-origin" not in response.headers
-def test_cors_disallowed_preflight(test_client_factory):
- def homepage(request):
+def test_cors_disallowed_preflight(
+ test_client_factory: TestClientFactory,
+) -> None:
+ def homepage(request: Request) -> None:
pass # pragma: no cover
app = Starlette(
def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_allowed(
- test_client_factory,
-):
- def homepage(request):
+ test_client_factory: TestClientFactory,
+) -> None:
+ def homepage(request: Request) -> None:
return # pragma: no cover
app = Starlette(
assert response.headers["vary"] == "Origin"
-def test_cors_preflight_allow_all_methods(test_client_factory):
- def homepage(request):
+def test_cors_preflight_allow_all_methods(
+ test_client_factory: TestClientFactory,
+) -> None:
+ def homepage(request: Request) -> None:
pass # pragma: no cover
app = Starlette(
assert method in response.headers["access-control-allow-methods"]
-def test_cors_allow_all_methods(test_client_factory):
- def homepage(request):
+def test_cors_allow_all_methods(
+ test_client_factory: TestClientFactory,
+) -> None:
+ def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)
app = Starlette(
assert response.status_code == 200
-def test_cors_allow_origin_regex(test_client_factory):
- def homepage(request):
+def test_cors_allow_origin_regex(
+ test_client_factory: TestClientFactory,
+) -> None:
+ def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)
app = Starlette(
assert "access-control-allow-origin" not in response.headers
-def test_cors_allow_origin_regex_fullmatch(test_client_factory):
- def homepage(request):
+def test_cors_allow_origin_regex_fullmatch(
+ test_client_factory: TestClientFactory,
+) -> None:
+ def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)
app = Starlette(
assert "access-control-allow-origin" not in response.headers
-def test_cors_credentialed_requests_return_specific_origin(test_client_factory):
- def homepage(request):
+def test_cors_credentialed_requests_return_specific_origin(
+ test_client_factory: TestClientFactory,
+) -> None:
+ def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)
app = Starlette(
assert "access-control-allow-credentials" not in response.headers
-def test_cors_vary_header_defaults_to_origin(test_client_factory):
- def homepage(request):
+def test_cors_vary_header_defaults_to_origin(
+ test_client_factory: TestClientFactory,
+) -> None:
+ def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)
app = Starlette(
assert response.headers["vary"] == "Origin"
-def test_cors_vary_header_is_not_set_for_non_credentialed_request(test_client_factory):
- def homepage(request):
+def test_cors_vary_header_is_not_set_for_non_credentialed_request(
+ test_client_factory: TestClientFactory,
+) -> None:
+ def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse(
"Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
)
assert response.headers["vary"] == "Accept-Encoding"
-def test_cors_vary_header_is_properly_set_for_credentialed_request(test_client_factory):
- def homepage(request):
+def test_cors_vary_header_is_properly_set_for_credentialed_request(
+ test_client_factory: TestClientFactory,
+) -> None:
+ def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse(
"Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
)
def test_cors_vary_header_is_properly_set_when_allow_origins_is_not_wildcard(
- test_client_factory,
-):
- def homepage(request):
+ test_client_factory: TestClientFactory,
+) -> None:
+ def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse(
"Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
)
def test_cors_allowed_origin_does_not_leak_between_credentialed_requests(
- test_client_factory,
-):
- def homepage(request):
+ test_client_factory: TestClientFactory,
+) -> None:
+ def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)
app = Starlette(