]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Enforce `__future__.annotations` (#2483)
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Thu, 29 Feb 2024 10:16:42 +0000 (11:16 +0100)
committerGitHub <noreply@github.com>
Thu, 29 Feb 2024 10:16:42 +0000 (10:16 +0000)
28 files changed:
pyproject.toml
starlette/applications.py
starlette/authentication.py
starlette/config.py
starlette/convertors.py
starlette/datastructures.py
starlette/formparsers.py
starlette/middleware/authentication.py
starlette/middleware/base.py
starlette/middleware/cors.py
starlette/middleware/errors.py
starlette/middleware/exceptions.py
starlette/middleware/sessions.py
starlette/middleware/trustedhost.py
starlette/middleware/wsgi.py
starlette/requests.py
starlette/routing.py
starlette/templating.py
starlette/websockets.py
tests/conftest.py
tests/middleware/test_base.py
tests/test_authentication.py
tests/test_formparsers.py
tests/test_requests.py
tests/test_responses.py
tests/test_routing.py
tests/test_templates.py
tests/test_websockets.py

index 3cf7c0d407e2cd534c123e3d97ec1060affd8809..679deaade174427a550f7c8ea49a73208e0d51af 100644 (file)
@@ -9,9 +9,7 @@ description = "The little ASGI library that shines."
 readme = "README.md"
 license = "BSD-3-Clause"
 requires-python = ">=3.8"
-authors = [
-    { name = "Tom Christie", email = "tom@tomchristie.com" },
-]
+authors = [{ name = "Tom Christie", email = "tom@tomchristie.com" }]
 classifiers = [
     "Development Status :: 3 - Alpha",
     "Environment :: Web Environment",
@@ -52,7 +50,7 @@ Source = "https://github.com/encode/starlette"
 path = "starlette/__init__.py"
 
 [tool.ruff.lint]
-select = ["E", "F", "I"]
+select = ["E", "F", "I", "FA", "UP"]
 
 [tool.ruff.lint.isort]
 combine-as-imports = true
@@ -83,10 +81,7 @@ filterwarnings = [
 ]
 
 [tool.coverage.run]
-source_pkgs = [
-    "starlette",
-    "tests",
-]
+source_pkgs = ["starlette", "tests"]
 
 [tool.coverage.report]
 exclude_lines = [
index 1a4e3d264fa3e51ad13508a2957fb6e8e2b29994..913fd4c9dbdfad47a51ad850c087503f554d1868 100644 (file)
@@ -79,7 +79,7 @@ class Starlette:
             {} if exception_handlers is None else dict(exception_handlers)
         )
         self.user_middleware = [] if middleware is None else list(middleware)
-        self.middleware_stack: typing.Optional[ASGIApp] = None
+        self.middleware_stack: ASGIApp | None = None
 
     def build_middleware_stack(self) -> ASGIApp:
         debug = self.debug
@@ -133,7 +133,7 @@ class Starlette:
 
     def add_middleware(
         self,
-        middleware_class: typing.Type[_MiddlewareClass[P]],
+        middleware_class: type[_MiddlewareClass[P]],
         *args: P.args,
         **kwargs: P.kwargs,
     ) -> None:
@@ -143,7 +143,7 @@ class Starlette:
 
     def add_exception_handler(
         self,
-        exc_class_or_status_code: int | typing.Type[Exception],
+        exc_class_or_status_code: int | type[Exception],
         handler: ExceptionHandler,
     ) -> None:  # pragma: no cover
         self.exception_handlers[exc_class_or_status_code] = handler
@@ -159,8 +159,8 @@ class Starlette:
         self,
         path: str,
         route: typing.Callable[[Request], typing.Awaitable[Response] | Response],
-        methods: typing.Optional[typing.List[str]] = None,
-        name: typing.Optional[str] = None,
+        methods: list[str] | None = None,
+        name: str | None = None,
         include_in_schema: bool = True,
     ) -> None:  # pragma: no cover
         self.router.add_route(
@@ -176,7 +176,7 @@ class Starlette:
         self.router.add_websocket_route(path, route, name=name)
 
     def exception_handler(
-        self, exc_class_or_status_code: int | typing.Type[Exception]
+        self, exc_class_or_status_code: int | type[Exception]
     ) -> typing.Callable:  # type: ignore[type-arg]
         warnings.warn(
             "The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. "  # noqa: E501
index e26a8a388110657c3f3461115ddac60933f0b3ba..f2586a042735f9afc011b104f7d6d48c58c0dfce 100644 (file)
@@ -75,10 +75,7 @@ def requires(
                 if not has_required_scope(request, scopes_list):
                     if redirect is not None:
                         orig_request_qparam = urlencode({"next": str(request.url)})
-                        next_url = "{redirect_path}?{orig_request}".format(
-                            redirect_path=request.url_for(redirect),
-                            orig_request=orig_request_qparam,
-                        )
+                        next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
                         return RedirectResponse(url=next_url, status_code=303)
                     raise HTTPException(status_code=status_code)
                 return await func(*args, **kwargs)
@@ -95,10 +92,7 @@ def requires(
                 if not has_required_scope(request, scopes_list):
                     if redirect is not None:
                         orig_request_qparam = urlencode({"next": str(request.url)})
-                        next_url = "{redirect_path}?{orig_request}".format(
-                            redirect_path=request.url_for(redirect),
-                            orig_request=orig_request_qparam,
-                        )
+                        next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
                         return RedirectResponse(url=next_url, status_code=303)
                     raise HTTPException(status_code=status_code)
                 return func(*args, **kwargs)
index d222a0a6278c1f5598d014f4225061409fdbab88..5b9813beac7397e09b993bc9d0dc959c85c0dc90 100644 (file)
@@ -17,7 +17,7 @@ class EnvironError(Exception):
 class Environ(typing.MutableMapping[str, str]):
     def __init__(self, environ: typing.MutableMapping[str, str] = os.environ):
         self._environ = environ
-        self._has_been_read: typing.Set[str] = set()
+        self._has_been_read: set[str] = set()
 
     def __getitem__(self, key: str) -> str:
         self._has_been_read.add(key)
@@ -60,7 +60,7 @@ class Config:
     ) -> None:
         self.environ = environ
         self.env_prefix = env_prefix
-        self.file_values: typing.Dict[str, str] = {}
+        self.file_values: dict[str, str] = {}
         if env_file is not None:
             if not os.path.isfile(env_file):
                 warnings.warn(f"Config file '{env_file}' not found.")
@@ -118,7 +118,7 @@ class Config:
         raise KeyError(f"Config '{key}' is missing, and has no default.")
 
     def _read_file(self, file_name: str | Path) -> dict[str, str]:
-        file_values: typing.Dict[str, str] = {}
+        file_values: dict[str, str] = {}
         with open(file_name) as input_file:
             for line in input_file.readlines():
                 line = line.strip()
index 3b12ac7a0ced2f5a4f6bffb8e88000dd3eac8e0b..2d8ab53beb637891a435b277414947aacae433f6 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import math
 import typing
 import uuid
@@ -74,7 +76,7 @@ class UUIDConvertor(Convertor[uuid.UUID]):
         return str(value)
 
 
-CONVERTOR_TYPES: typing.Dict[str, Convertor[typing.Any]] = {
+CONVERTOR_TYPES: dict[str, Convertor[typing.Any]] = {
     "str": StringConvertor(),
     "path": PathConvertor(),
     "int": IntegerConvertor(),
index e430d09b6b6f300dded255c3c519bcb1e4f632c3..54b5e54f3bf8a19a36a81ad4e5aee970a512153b 100644 (file)
@@ -150,7 +150,7 @@ class URL:
         query = urlencode([(str(key), str(value)) for key, value in kwargs.items()])
         return self.replace(query=query)
 
-    def remove_query_params(self, keys: str | typing.Sequence[str]) -> "URL":
+    def remove_query_params(self, keys: str | typing.Sequence[str]) -> URL:
         if isinstance(keys, str):
             keys = [keys]
         params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
@@ -178,7 +178,7 @@ class URLPath(str):
     Used by the routing to return `url_path_for` matches.
     """
 
-    def __new__(cls, path: str, protocol: str = "", host: str = "") -> "URLPath":
+    def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath:
         assert protocol in ("http", "websocket", "")
         return str.__new__(cls, path)
 
@@ -251,13 +251,13 @@ class CommaSeparatedStrings(typing.Sequence[str]):
 
 
 class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]):
-    _dict: typing.Dict[_KeyType, _CovariantValueType]
+    _dict: dict[_KeyType, _CovariantValueType]
 
     def __init__(
         self,
         *args: ImmutableMultiDict[_KeyType, _CovariantValueType]
         | typing.Mapping[_KeyType, _CovariantValueType]
-        | typing.Iterable[typing.Tuple[_KeyType, _CovariantValueType]],
+        | typing.Iterable[tuple[_KeyType, _CovariantValueType]],
         **kwargs: typing.Any,
     ) -> None:
         assert len(args) < 2, "Too many arguments."
@@ -599,7 +599,7 @@ class MutableHeaders(Headers):
         set_key = key.lower().encode("latin-1")
         set_value = value.encode("latin-1")
 
-        found_indexes: "typing.List[int]" = []
+        found_indexes: list[int] = []
         for idx, (item_key, item_value) in enumerate(self._list):
             if item_key == set_key:
                 found_indexes.append(idx)
@@ -619,7 +619,7 @@ class MutableHeaders(Headers):
         """
         del_key = key.lower().encode("latin-1")
 
-        pop_indexes: "typing.List[int]" = []
+        pop_indexes: list[int] = []
         for idx, (item_key, item_value) in enumerate(self._list):
             if item_key == del_key:
                 pop_indexes.append(idx)
index e2a95e53fe77e5a9d60d1799f861df6ab69149ca..2e12c7faac8fe3837afffc331622e9a4efb0a919 100644 (file)
@@ -91,7 +91,7 @@ class FormParser:
         field_name = b""
         field_value = b""
 
-        items: list[tuple[str, typing.Union[str, UploadFile]]] = []
+        items: list[tuple[str, str | UploadFile]] = []
 
         # Feed the parser with data from the request.
         async for chunk in self.stream:
index 21f097434352e133a5fbc2b345ac76f247804dc4..966c639bb6360d3047f254f3c8a4142dada043cf 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import typing
 
 from starlette.authentication import (
@@ -16,9 +18,8 @@ class AuthenticationMiddleware:
         self,
         app: ASGIApp,
         backend: AuthenticationBackend,
-        on_error: typing.Optional[
-            typing.Callable[[HTTPConnection, AuthenticationError], Response]
-        ] = None,
+        on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response]
+        | None = None,
     ) -> None:
         self.app = app
         self.backend = backend
index ad3ffcfeefc0f20f7b95da0423c0d514ff0dbb86..4e5054d7a294223754b0c1f054d7f56d9248380e 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import typing
 
 import anyio
@@ -92,9 +94,7 @@ class _CachedRequest(Request):
 
 
 class BaseHTTPMiddleware:
-    def __init__(
-        self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None
-    ) -> None:
+    def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None) -> None:
         self.app = app
         self.dispatch_func = self.dispatch if dispatch is None else dispatch
 
@@ -108,7 +108,7 @@ class BaseHTTPMiddleware:
         response_sent = anyio.Event()
 
         async def call_next(request: Request) -> Response:
-            app_exc: typing.Optional[Exception] = None
+            app_exc: Exception | None = None
             send_stream: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
             recv_stream: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
             send_stream, recv_stream = anyio.create_memory_object_stream()
@@ -203,10 +203,10 @@ class _StreamingResponse(StreamingResponse):
         self,
         content: ContentStream,
         status_code: int = 200,
-        headers: typing.Optional[typing.Mapping[str, str]] = None,
-        media_type: typing.Optional[str] = None,
-        background: typing.Optional[BackgroundTask] = None,
-        info: typing.Optional[typing.Mapping[str, typing.Any]] = None,
+        headers: typing.Mapping[str, str] | None = None,
+        media_type: str | None = None,
+        background: BackgroundTask | None = None,
+        info: typing.Mapping[str, typing.Any] | None = None,
     ) -> None:
         self._info = info
         super().__init__(content, status_code, headers, media_type, background)
index 5c9bfa68401de824d0f1eaa9e976cd7c8c140406..4b8e97bc9dc844e9b36e31f06abd7840d2e58de0 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import functools
 import re
 import typing
@@ -18,7 +20,7 @@ class CORSMiddleware:
         allow_methods: typing.Sequence[str] = ("GET",),
         allow_headers: typing.Sequence[str] = (),
         allow_credentials: bool = False,
-        allow_origin_regex: typing.Optional[str] = None,
+        allow_origin_regex: str | None = None,
         expose_headers: typing.Sequence[str] = (),
         max_age: int = 600,
     ) -> None:
index c6336160ca79be4e919bcf958d71d094b1be7bbd..e9eba62b0bdf9ad0f2941e8eac420f80185da53e 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import html
 import inspect
 import traceback
@@ -137,9 +139,7 @@ class ServerErrorMiddleware:
     def __init__(
         self,
         app: ASGIApp,
-        handler: typing.Optional[
-            typing.Callable[[Request, Exception], typing.Any]
-        ] = None,
+        handler: typing.Callable[[Request, Exception], typing.Any] | None = None,
         debug: bool = False,
     ) -> None:
         self.app = app
index 0124f5c8f3b3452b4ef671c83f7dfe258e5aa802..b2bf88dbfe7bc40492fd5bbce4d85de6c431d806 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import typing
 
 from starlette._exception_handler import (
@@ -16,9 +18,10 @@ class ExceptionMiddleware:
     def __init__(
         self,
         app: ASGIApp,
-        handlers: typing.Optional[
-            typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]]
-        ] = None,
+        handlers: typing.Mapping[
+            typing.Any, typing.Callable[[Request, Exception], Response]
+        ]
+        | None = None,
         debug: bool = False,
     ) -> None:
         self.app = app
@@ -34,7 +37,7 @@ class ExceptionMiddleware:
 
     def add_exception_handler(
         self,
-        exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
+        exc_class_or_status_code: int | type[Exception],
         handler: typing.Callable[[Request, Exception], Response],
     ) -> None:
         if isinstance(exc_class_or_status_code, int):
@@ -53,7 +56,7 @@ class ExceptionMiddleware:
             self._status_handlers,
         )
 
-        conn: typing.Union[Request, WebSocket]
+        conn: Request | WebSocket
         if scope["type"] == "http":
             conn = Request(scope, receive, send)
         else:
index 1093717b4354ace6bdf1a2a61c661dffce30721c..5855912cac339b1f57bfa1be310b2bf57128c508 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import json
 import typing
 from base64 import b64decode, b64encode
@@ -14,13 +16,13 @@ class SessionMiddleware:
     def __init__(
         self,
         app: ASGIApp,
-        secret_key: typing.Union[str, Secret],
+        secret_key: str | Secret,
         session_cookie: str = "session",
-        max_age: typing.Optional[int] = 14 * 24 * 60 * 60,  # 14 days, in seconds
+        max_age: int | None = 14 * 24 * 60 * 60,  # 14 days, in seconds
         path: str = "/",
         same_site: typing.Literal["lax", "strict", "none"] = "lax",
         https_only: bool = False,
-        domain: typing.Optional[str] = None,
+        domain: str | None = None,
     ) -> None:
         self.app = app
         self.signer = itsdangerous.TimestampSigner(str(secret_key))
index e84e6876a03a53c546c3d54d0799575da2caeac8..59e527363348272320b761950bfd3aadf0076c92 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import typing
 
 from starlette.datastructures import URL, Headers
@@ -11,7 +13,7 @@ class TrustedHostMiddleware:
     def __init__(
         self,
         app: ASGIApp,
-        allowed_hosts: typing.Optional[typing.Sequence[str]] = None,
+        allowed_hosts: typing.Sequence[str] | None = None,
         www_redirect: bool = True,
     ) -> None:
         if allowed_hosts is None:
index 2ce83b0740e7c0e9200485a9af294f1638df7cdd..c9a7e132814684916bc013aad65f8b826d0f685c 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import io
 import math
 import sys
@@ -16,7 +18,7 @@ warnings.warn(
 )
 
 
-def build_environ(scope: Scope, body: bytes) -> typing.Dict[str, typing.Any]:
+def build_environ(scope: Scope, body: bytes) -> dict[str, typing.Any]:
     """
     Builds a scope and request body into a WSGI environ object.
     """
@@ -117,7 +119,7 @@ class WSGIResponder:
     def start_response(
         self,
         status: str,
-        response_headers: typing.List[typing.Tuple[str, str]],
+        response_headers: list[tuple[str, str]],
         exc_info: typing.Any = None,
     ) -> None:
         self.exc_info = exc_info
@@ -140,7 +142,7 @@ class WSGIResponder:
 
     def wsgi(
         self,
-        environ: typing.Dict[str, typing.Any],
+        environ: dict[str, typing.Any],
         start_response: typing.Callable[..., typing.Any],
     ) -> None:
         for chunk in self.app(environ, start_response):
index 4af63bfc1ffbb2020c6288173734e9ffb1b7ea3e..b27e8e1e269f1bdbb43daf6d03ea0ca13e470d4a 100644 (file)
@@ -43,7 +43,7 @@ def cookie_parser(cookie_string: str) -> dict[str, str]:
     Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based
     on an outdated spec and will fail on lots of input we want to support
     """
-    cookie_dict: typing.Dict[str, str] = {}
+    cookie_dict: dict[str, str] = {}
     for chunk in cookie_string.split(";"):
         if "=" in chunk:
             key, val = chunk.split("=", 1)
@@ -135,7 +135,7 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
     @property
     def cookies(self) -> dict[str, str]:
         if not hasattr(self, "_cookies"):
-            cookies: typing.Dict[str, str] = {}
+            cookies: dict[str, str] = {}
             cookie_header = self.headers.get("cookie")
 
             if cookie_header:
@@ -197,7 +197,7 @@ async def empty_send(message: Message) -> typing.NoReturn:
 
 
 class Request(HTTPConnection):
-    _form: typing.Optional[FormData]
+    _form: FormData | None
 
     def __init__(
         self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send
@@ -240,7 +240,7 @@ class Request(HTTPConnection):
 
     async def body(self) -> bytes:
         if not hasattr(self, "_body"):
-            chunks: "typing.List[bytes]" = []
+            chunks: list[bytes] = []
             async for chunk in self.stream():
                 chunks.append(chunk)
             self._body = b"".join(chunks)
@@ -309,7 +309,7 @@ class Request(HTTPConnection):
 
     async def send_push_promise(self, path: str) -> None:
         if "http.response.push" in self.scope.get("extensions", {}):
-            raw_headers: "typing.List[typing.Tuple[bytes, bytes]]" = []
+            raw_headers: list[tuple[bytes, bytes]] = []
             for name in SERVER_PUSH_HEADERS_TO_COPY:
                 for value in self.headers.getlist(name):
                     raw_headers.append(
index b5467bb05c6008c39ddeb3a293488489f6b95cc5..92cdf2be8be6527806d7215aa6ec1f28d7ed5702 100644 (file)
@@ -57,9 +57,7 @@ def iscoroutinefunction_or_partial(obj: typing.Any) -> bool:  # pragma: no cover
 
 
 def request_response(
-    func: typing.Callable[
-        [Request], typing.Union[typing.Awaitable[Response], Response]
-    ],
+    func: typing.Callable[[Request], typing.Awaitable[Response] | Response],
 ) -> ASGIApp:
     """
     Takes a function or coroutine `func(request) -> response`,
@@ -255,7 +253,7 @@ class Route(BaseRoute):
         self.path_regex, self.path_format, self.param_convertors = compile_path(path)
 
     def matches(self, scope: Scope) -> tuple[Match, Scope]:
-        path_params: "typing.Dict[str, typing.Any]"
+        path_params: dict[str, typing.Any]
         if scope["type"] == "http":
             route_path = get_route_path(scope)
             match = self.path_regex.match(route_path)
@@ -344,7 +342,7 @@ class WebSocketRoute(BaseRoute):
         self.path_regex, self.path_format, self.param_convertors = compile_path(path)
 
     def matches(self, scope: Scope) -> tuple[Match, Scope]:
-        path_params: "typing.Dict[str, typing.Any]"
+        path_params: dict[str, typing.Any]
         if scope["type"] == "websocket":
             route_path = get_route_path(scope)
             match = self.path_regex.match(route_path)
@@ -417,8 +415,8 @@ class Mount(BaseRoute):
     def routes(self) -> list[BaseRoute]:
         return getattr(self._base_app, "routes", [])
 
-    def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
-        path_params: "typing.Dict[str, typing.Any]"
+    def matches(self, scope: Scope) -> tuple[Match, Scope]:
+        path_params: dict[str, typing.Any]
         if scope["type"] in ("http", "websocket"):
             root_path = scope.get("root_path", "")
             route_path = get_route_path(scope)
index fe31ab5ee4230cd1ade9a6f5052697dcb366a9a4..2dc3a5930d2a28a372eb7d8b642191038b427750 100644 (file)
@@ -129,7 +129,7 @@ class Jinja2Templates:
     def _setup_env_defaults(self, env: jinja2.Environment) -> None:
         @pass_context
         def url_for(
-            context: typing.Dict[str, typing.Any],
+            context: dict[str, typing.Any],
             name: str,
             /,
             **path_params: typing.Any,
index 955063fa179d5bb859e81f256a15540d35a6d43c..53ab5a70c8f1111cb06c56e927bfc2a1b21003b5 100644 (file)
@@ -17,7 +17,7 @@ class WebSocketState(enum.Enum):
 
 
 class WebSocketDisconnect(Exception):
-    def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None:
+    def __init__(self, code: int = 1000, reason: str | None = None) -> None:
         self.code = code
         self.reason = reason or ""
 
@@ -95,7 +95,7 @@ class WebSocket(HTTPConnection):
                 self.application_state = WebSocketState.DISCONNECTED
             try:
                 await self._send(message)
-            except IOError:
+            except OSError:
                 self.application_state = WebSocketState.DISCONNECTED
                 raise WebSocketDisconnect(code=1006)
         elif self.application_state == WebSocketState.RESPONSE:
index 1a61664d17c58cbcc75a2992b4eb3f7b74a0c754..724ca65d3cda199de3d1f4dad8cc5d841b46a931 100644 (file)
@@ -1,5 +1,7 @@
+from __future__ import annotations
+
 import functools
-from typing import Any, Callable, Dict, Literal
+from typing import Any, Callable, Literal
 
 import pytest
 
@@ -11,7 +13,7 @@ TestClientFactory = Callable[..., TestClient]
 @pytest.fixture
 def test_client_factory(
     anyio_backend_name: Literal["asyncio", "trio"],
-    anyio_backend_options: Dict[str, Any],
+    anyio_backend_options: dict[str, Any],
 ) -> TestClientFactory:
     # anyio_backend_name defined by:
     # https://anyio.readthedocs.io/en/stable/testing.html#specifying-the-backends-to-run-on
index 6e5e42b94423bbbab8e708e410069fc8fb6ca398..2176404d825b9140c1a55d043737c3c31903755c 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import contextvars
 from contextlib import AsyncExitStack
 from typing import (
@@ -5,9 +7,6 @@ from typing import (
     AsyncGenerator,
     Callable,
     Generator,
-    List,
-    Type,
-    Union,
 )
 
 import anyio
@@ -241,7 +240,7 @@ class CustomMiddlewareUsingBaseHTTPMiddleware(BaseHTTPMiddleware):
 )
 def test_contextvars(
     test_client_factory: TestClientFactory,
-    middleware_cls: Type[_MiddlewareClass[Any]],
+    middleware_cls: type[_MiddlewareClass[Any]],
 ) -> None:
     # this has to be an async endpoint because Starlette calls run_in_threadpool
     # on sync endpoints which has it's own set of peculiarities w.r.t propagating
@@ -318,7 +317,7 @@ async def test_run_background_tasks_even_if_client_disconnects() -> None:
 async def test_do_not_block_on_background_tasks() -> None:
     request_body_sent = False
     response_complete = anyio.Event()
-    events: List[Union[str, Message]] = []
+    events: list[str | Message] = []
 
     async def sleep_and_set() -> None:
         events.append("Background task started")
@@ -766,7 +765,7 @@ async def test_read_request_stream_in_dispatch_wrapping_app_calls_body() -> None
             call_next: RequestResponseEndpoint,
         ) -> Response:
             expected = b"1"
-            response: Union[Response, None] = None
+            response: Response | None = None
             async for chunk in request.stream():
                 assert chunk == expected
                 if expected == b"1":
@@ -783,7 +782,7 @@ async def test_read_request_stream_in_dispatch_wrapping_app_calls_body() -> None
         yield {"type": "http.request", "body": b"3"}
         await anyio.sleep(float("inf"))
 
-    sent: List[Message] = []
+    sent: list[Message] = []
 
     async def send(msg: Message) -> None:
         sent.append(msg)
@@ -1000,7 +999,7 @@ def test_pr_1519_comment_1236166180_example() -> None:
     """
     https://github.com/encode/starlette/pull/1519#issuecomment-1236166180
     """
-    bodies: List[bytes] = []
+    bodies: list[bytes] = []
 
     class LogRequestBodySize(BaseHTTPMiddleware):
         async def dispatch(
index 27b0337620d594a4d7ca3345c93ece12bbad0fc1..ecddda75ed52a7523327b3add7b900a51d20033e 100644 (file)
@@ -1,6 +1,8 @@
+from __future__ import annotations
+
 import base64
 import binascii
-from typing import Any, Awaitable, Callable, Optional, Tuple
+from typing import Any, Awaitable, Callable
 from urllib.parse import urlencode
 
 import pytest
@@ -31,7 +33,7 @@ class BasicAuth(AuthenticationBackend):
     async def authenticate(
         self,
         request: HTTPConnection,
-    ) -> Optional[Tuple[AuthCredentials, SimpleUser]]:
+    ) -> tuple[AuthCredentials, SimpleUser] | None:
         if "Authorization" not in request.headers:
             return None
 
index 4f0cd430d35b7c625ee01b691029fc2a7a7bc1f1..ed2226878bec5daaa175a941c615c7924986a4b8 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import os
 import typing
 from contextlib import nullcontext as does_not_raise
@@ -29,7 +31,7 @@ FORCE_MULTIPART = ForceMultipartDict()
 async def app(scope: Scope, receive: Receive, send: Send) -> None:
     request = Request(scope, receive)
     data = await request.form()
-    output: typing.Dict[str, typing.Any] = {}
+    output: dict[str, typing.Any] = {}
     for key, value in data.items():
         if isinstance(value, UploadFile):
             content = await value.read()
@@ -49,7 +51,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
 async def multi_items_app(scope: Scope, receive: Receive, send: Send) -> None:
     request = Request(scope, receive)
     data = await request.form()
-    output: typing.Dict[str, typing.List[typing.Any]] = {}
+    output: dict[str, list[typing.Any]] = {}
     for key, value in data.multi_items():
         if key not in output:
             output[key] = []
@@ -73,7 +75,7 @@ async def multi_items_app(scope: Scope, receive: Receive, send: Send) -> None:
 async def app_with_headers(scope: Scope, receive: Receive, send: Send) -> None:
     request = Request(scope, receive)
     data = await request.form()
-    output: typing.Dict[str, typing.Any] = {}
+    output: dict[str, typing.Any] = {}
     for key, value in data.items():
         if isinstance(value, UploadFile):
             content = await value.read()
@@ -108,7 +110,7 @@ def make_app_max_parts(max_files: int = 1000, max_fields: int = 1000) -> ASGIApp
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         request = Request(scope, receive)
         data = await request.form(max_files=max_files, max_fields=max_fields)
-        output: typing.Dict[str, typing.Any] = {}
+        output: dict[str, typing.Any] = {}
         for key, value in data.items():
             if isinstance(value, UploadFile):
                 content = await value.read()
index b3ce3a04add67609a026d5bf45904c69225e03fb..d8e2e94773868e0a2861f1480d3bc8f865561a07 100644 (file)
@@ -1,5 +1,7 @@
+from __future__ import annotations
+
 import sys
-from typing import Any, Callable, Dict, Iterator, List, Optional
+from typing import Any, Callable, Iterator
 
 import anyio
 import pytest
@@ -72,7 +74,7 @@ def test_request_headers(test_client_factory: TestClientFactory) -> None:
         ({}, None),
     ],
 )
-def test_request_client(scope: Scope, expected_client: Optional[Address]) -> None:
+def test_request_client(scope: Scope, expected_client: Address | None) -> None:
     scope.update({"type": "http"})  # required by Request's constructor
     client = Request(scope).client
     assert client == expected_client
@@ -239,7 +241,7 @@ def test_request_without_setting_receive(
 
 def test_request_disconnect(
     anyio_backend_name: str,
-    anyio_backend_options: Dict[str, Any],
+    anyio_backend_options: dict[str, Any],
 ) -> None:
     """
     If a client disconnect occurs while reading request body
@@ -391,7 +393,7 @@ def test_cookie_lenient_parsing(test_client_factory: TestClientFactory) -> None:
 )
 def test_cookies_edge_cases(
     set_cookie: str,
-    expected: Dict[str, str],
+    expected: dict[str, str],
     test_client_factory: TestClientFactory,
 ) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
@@ -430,7 +432,7 @@ def test_cookies_edge_cases(
 )
 def test_cookies_invalid(
     set_cookie: str,
-    expected: Dict[str, str],
+    expected: dict[str, str],
     test_client_factory: TestClientFactory,
 ) -> None:
     """
@@ -542,7 +544,7 @@ def test_request_send_push_promise_without_setting_send(
     ],
 )
 @pytest.mark.anyio
-async def test_request_rcv(messages: List[Message]) -> None:
+async def test_request_rcv(messages: list[Message]) -> None:
     messages = messages.copy()
 
     async def rcv() -> Message:
@@ -557,7 +559,7 @@ async def test_request_rcv(messages: List[Message]) -> None:
 
 @pytest.mark.anyio
 async def test_request_stream_called_twice() -> None:
-    messages: List[Message] = [
+    messages: list[Message] = [
         {"type": "http.request", "body": b"1", "more_body": True},
         {"type": "http.request", "body": b"2", "more_body": True},
         {"type": "http.request", "body": b"3"},
index 57a594901815b9bb283affabe8af320b545cee9b..a3cdcadcf287dc7cafac292d0158407c2429e5a5 100644 (file)
@@ -1,9 +1,11 @@
+from __future__ import annotations
+
 import datetime as dt
 import os
 import time
 from http.cookies import SimpleCookie
 from pathlib import Path
-from typing import AsyncIterator, Callable, Iterator, Union
+from typing import AsyncIterator, Callable, Iterator
 
 import anyio
 import pytest
@@ -160,7 +162,7 @@ def test_streaming_response_custom_iterable(
 ) -> None:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         class CustomAsyncIterable:
-            async def __aiter__(self) -> AsyncIterator[Union[str, bytes]]:
+            async def __aiter__(self) -> AsyncIterator[str | bytes]:
                 for i in range(5):
                     yield str(i + 1)
 
index 8c3f16639a9e3ebd427cd752ad06881361a4acb1..b75fc47f02bdde12380fc0d6a87bcb22427aa161 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import contextlib
 import functools
 import json
@@ -762,7 +764,7 @@ def test_lifespan_state_unsupported(
     @contextlib.asynccontextmanager
     async def lifespan(
         app: ASGIApp,
-    ) -> typing.AsyncGenerator[typing.Dict[str, str], None]:
+    ) -> typing.AsyncGenerator[dict[str, str], None]:
         yield {"foo": "bar"}
 
     app = Router(
@@ -787,7 +789,7 @@ def test_lifespan_state_async_cm(test_client_factory: TestClientFactory) -> None
 
     class State(typing.TypedDict):
         count: int
-        items: typing.List[int]
+        items: list[int]
 
     async def hello_world(request: Request) -> Response:
         # modifications to the state should not leak across requests
index ab0b38a91000d7a949cd8baa602ce544aec7360e..95e392ed5a7e988c935b9d2dc594ab4f50b1266d 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import os
 import typing
 from pathlib import Path
@@ -46,7 +48,7 @@ def test_calls_context_processors(
     async def homepage(request: Request) -> Response:
         return templates.TemplateResponse(request, "index.html")
 
-    def hello_world_processor(request: Request) -> typing.Dict[str, str]:
+    def hello_world_processor(request: Request) -> dict[str, str]:
         return {"username": "World"}
 
     app = Starlette(
index c4b6c16bdbbd012ac7cbaa1f6a77cb09dc638740..854c269143de5266eb00db3ce0bf5ae7b6e611f7 100644 (file)
@@ -273,7 +273,7 @@ async def test_client_disconnect_on_send() -> None:
             return
         # Simulate the exception the server would send to the application when the
         # client disconnects.
-        raise IOError
+        raise OSError
 
     with pytest.raises(WebSocketDisconnect) as ctx:
         await app({"type": "websocket", "path": "/"}, receive, send)