]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Turn `State` into a `Mapping`
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Thu, 9 Oct 2025 08:47:19 +0000 (09:47 +0100)
committerMarcelo Trylesinski <marcelotryle@gmail.com>
Thu, 9 Oct 2025 08:47:19 +0000 (09:47 +0100)
starlette/datastructures.py
starlette/requests.py

index 38eabec52743cb04dde53b17fdc19fe074a12841..3a12527ab8aa30d212774c90d17e63ef332d0acb 100644 (file)
@@ -690,3 +690,18 @@ class State:
 
     def __delattr__(self, key: Any) -> None:
         del self._state[key]
+
+    def __getitem__(self, key: str) -> Any:
+        return self._state[key]
+
+    def __setitem__(self, key: str, value: Any) -> None:
+        self._state[key] = value
+
+    def __delitem__(self, key: str) -> None:
+        del self._state[key]
+
+    def __iter__(self) -> Iterator[str]:
+        return iter(self._state)
+
+    def __len__(self) -> int:
+        return len(self._state)
index 99e8b9e193df5f6b7de2d4cb9bbf627ab4c6ecc9..041ef25a21cc63afd0b1475e2c3c79503b7471c6 100644 (file)
@@ -3,9 +3,10 @@ from __future__ import annotations
 import json
 from collections.abc import AsyncGenerator, Iterator, Mapping
 from http import cookies as http_cookies
-from typing import TYPE_CHECKING, Any, NoReturn, cast
+from typing import TYPE_CHECKING, Any, Generic, NoReturn, cast
 
 import anyio
+from typing_extensions import TypeVar
 
 from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper
 from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
@@ -68,7 +69,10 @@ class ClientDisconnect(Exception):
     pass
 
 
-class HTTPConnection(Mapping[str, Any]):
+StateT = TypeVar("StateT", bound=Mapping[str, Any] | State, default=State)
+
+
+class HTTPConnection(Mapping[str, Any], Generic[StateT]):
     """
     A base class for incoming HTTP connections, that is used to provide
     any functionality that is common to both `Request` and `WebSocket`.
@@ -172,14 +176,14 @@ class HTTPConnection(Mapping[str, Any]):
         return self.scope["user"]
 
     @property
-    def state(self) -> State:
+    def state(self) -> StateT:
         if not hasattr(self, "_state"):
             # Ensure 'state' has an empty dict if it's not already populated.
             self.scope.setdefault("state", {})
             # Create a state instance with a reference to the dict in which it should
             # store info
             self._state = State(self.scope["state"])
-        return self._state
+        return cast(StateT, self._state)
 
     def url_for(self, name: str, /, **path_params: Any) -> URL:
         url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app")
@@ -197,7 +201,7 @@ async def empty_send(message: Message) -> NoReturn:
     raise RuntimeError("Send channel has not been made available")
 
 
-class Request(HTTPConnection):
+class Request(HTTPConnection[StateT]):
     _form: FormData | None
 
     def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send):