]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
only state allowed lifespan-state-user-friendly 2944/head
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Sat, 12 Jul 2025 06:55:12 +0000 (08:55 +0200)
committerMarcelo Trylesinski <marcelotryle@gmail.com>
Sat, 12 Jul 2025 06:55:12 +0000 (08:55 +0200)
starlette/requests.py
starlette/websockets.py
tests/test_routing.py

index 64c08477aba5e7c2e8776245f16714e869ce91b7..0cce43e8621128431322049b13fe4f6fb527b302 100644 (file)
@@ -69,7 +69,7 @@ class ClientDisconnect(Exception):
     pass
 
 
-_LifespanStateT = TypeVar("_LifespanStateT", default=State)
+_LifespanStateT = TypeVar("_LifespanStateT", bound=State, default=State)
 
 
 class HTTPConnection(Mapping[str, Any], Generic[_LifespanStateT]):
index 6f15daeb25094f5dd7254923d40a33bd3c9ffc74..04260595ec97b26e1f5f926682e03e5aa7438ff3 100644 (file)
@@ -26,7 +26,7 @@ class WebSocketDisconnect(Exception):
         self.reason = reason or ""
 
 
-_LifespanStateT = TypeVar("_LifespanStateT", default=State)
+_LifespanStateT = TypeVar("_LifespanStateT", bound=State, default=State)
 
 
 class WebSocket(HTTPConnection[_LifespanStateT]):
index 7c4930bc3054fc28fc449f2d1a622669d79ab15b..c2ac7bede6b47416405d745277dace163f7eefa9 100644 (file)
@@ -5,11 +5,12 @@ import functools
 import json
 import uuid
 from collections.abc import AsyncGenerator, AsyncIterator, Generator
-from typing import Callable, TypedDict
+from typing import Callable
 
 import pytest
 
 from starlette.applications import Starlette
+from starlette.datastructures import State
 from starlette.exceptions import HTTPException
 from starlette.middleware import Middleware
 from starlette.requests import Request
@@ -753,32 +754,32 @@ def test_lifespan_state_async_cm(test_client_factory: TestClientFactory) -> None
     startup_complete = False
     shutdown_complete = False
 
-    class State(TypedDict):
+    class LifespanState(State):
         count: int
         items: list[int]
 
-    async def hello_world(request: Request[State]) -> Response:
+    async def hello_world(request: Request[LifespanState]) -> Response:
         # from version 0.46.3, the state object is only immutable if defined by the user
-        assert request.state["count"] in (0, 1)
+        assert request.state.count in (0, 1)
         # modify the state, this should not leak to the lifespan or other requests
-        request.state["count"] += 1
+        request.state.count += 1
         # since state.items is a mutable object this modification _will_ leak across
         # requests and to the lifespan
-        request.state["items"].append(1)
+        request.state.items.append(1)
         return PlainTextResponse("hello, world")
 
     @contextlib.asynccontextmanager
-    async def lifespan(app: Starlette) -> AsyncIterator[State]:
+    async def lifespan(app: Starlette) -> AsyncIterator[LifespanState]:
         nonlocal startup_complete, shutdown_complete
         startup_complete = True
-        state = State(count=0, items=[])
+        state = LifespanState({"count": 0, "items": []})
         yield state
         shutdown_complete = True
         # from version 0.46.3, objects from the state are mutable if the state itself is mutable
-        assert state["count"] == 2
+        assert state.count == 2
         # unless of course the request mutates a mutable object that is referenced
         # via state
-        assert state["items"] == [1, 1]
+        assert state.items == [1, 1]
 
     app = Router(
         lifespan=lifespan,