From: Marcelo Trylesinski Date: Sat, 12 Jul 2025 06:55:12 +0000 (+0200) Subject: only state allowed X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=396bbdac66fc6340770fd4c47abd7e91023de50e;p=thirdparty%2Fstarlette.git only state allowed --- diff --git a/starlette/requests.py b/starlette/requests.py index 64c08477..0cce43e8 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -69,7 +69,7 @@ class ClientDisconnect(Exception): pass -_LifespanStateT = TypeVar("_LifespanStateT", default=State) +_LifespanStateT = TypeVar("_LifespanStateT", bound=State, default=State) class HTTPConnection(Mapping[str, Any], Generic[_LifespanStateT]): diff --git a/starlette/websockets.py b/starlette/websockets.py index 6f15daeb..04260595 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -26,7 +26,7 @@ class WebSocketDisconnect(Exception): self.reason = reason or "" -_LifespanStateT = TypeVar("_LifespanStateT", default=State) +_LifespanStateT = TypeVar("_LifespanStateT", bound=State, default=State) class WebSocket(HTTPConnection[_LifespanStateT]): diff --git a/tests/test_routing.py b/tests/test_routing.py index 7c4930bc..c2ac7bed 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -5,11 +5,12 @@ import functools import json import uuid from collections.abc import AsyncGenerator, AsyncIterator, Generator -from typing import Callable, TypedDict +from typing import Callable import pytest from starlette.applications import Starlette +from starlette.datastructures import State from starlette.exceptions import HTTPException from starlette.middleware import Middleware from starlette.requests import Request @@ -753,32 +754,32 @@ def test_lifespan_state_async_cm(test_client_factory: TestClientFactory) -> None startup_complete = False shutdown_complete = False - class State(TypedDict): + class LifespanState(State): count: int items: list[int] - async def hello_world(request: Request[State]) -> Response: + async def hello_world(request: Request[LifespanState]) -> Response: # from version 0.46.3, the state object is only immutable if defined by the user - assert request.state["count"] in (0, 1) + assert request.state.count in (0, 1) # modify the state, this should not leak to the lifespan or other requests - request.state["count"] += 1 + request.state.count += 1 # since state.items is a mutable object this modification _will_ leak across # requests and to the lifespan - request.state["items"].append(1) + request.state.items.append(1) return PlainTextResponse("hello, world") @contextlib.asynccontextmanager - async def lifespan(app: Starlette) -> AsyncIterator[State]: + async def lifespan(app: Starlette) -> AsyncIterator[LifespanState]: nonlocal startup_complete, shutdown_complete startup_complete = True - state = State(count=0, items=[]) + state = LifespanState({"count": 0, "items": []}) yield state shutdown_complete = True # from version 0.46.3, objects from the state are mutable if the state itself is mutable - assert state["count"] == 2 + assert state.count == 2 # unless of course the request mutates a mutable object that is referenced # via state - assert state["items"] == [1, 1] + assert state.items == [1, 1] app = Router( lifespan=lifespan,