From: Marcelo Trylesinski Date: Tue, 27 May 2025 13:05:43 +0000 (+0200) Subject: Allow to pass any object as lifespan state X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=544af5606fd138e23bed09d86eb7c5fc2f375529;p=thirdparty%2Fstarlette.git Allow to pass any object as lifespan state --- diff --git a/docs/lifespan.md b/docs/lifespan.md index a5a76625..3c094c48 100644 --- a/docs/lifespan.md +++ b/docs/lifespan.md @@ -39,40 +39,46 @@ can be used to share the objects between the lifespan, and the requests. ```python import contextlib -from typing import AsyncIterator, TypedDict +from collections.abc import AsyncIterator +from dataclasses import dataclass import httpx + from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import PlainTextResponse from starlette.routing import Route -class State(TypedDict): - http_client: httpx.AsyncClient +@dataclass(frozen=True) +class LifespanState: + client: httpx.AsyncClient @contextlib.asynccontextmanager -async def lifespan(app: Starlette) -> AsyncIterator[State]: +async def lifespan(app: Starlette) -> AsyncIterator[LifespanState]: async with httpx.AsyncClient() as client: - yield {"http_client": client} + yield LifespanState(client=client) -async def homepage(request: Request) -> PlainTextResponse: - client = request.state.http_client - response = await client.get("https://www.example.com") +async def homepage(request: Request[LifespanState]) -> PlainTextResponse: + client = request.state.client + response = await client.get("http://localhost:8001") return PlainTextResponse(response.text) -app = Starlette( - lifespan=lifespan, - routes=[Route("/", homepage)] -) +app = Starlette(lifespan=lifespan, routes=[Route("/", homepage)]) ``` The `state` received on the requests is a **shallow** copy of the state received on the lifespan handler. +!!! warning + From version 0.46.3, the state object is not immutable by default. + + As a user you should make sure the lifespan object is immutable if you don't want changes to + the lifespan state to be spread through requests. + ## Running lifespan in tests You should use `TestClient` as a context manager, to ensure that the lifespan is called. diff --git a/starlette/requests.py b/starlette/requests.py index 628358d1..64c08477 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -3,9 +3,10 @@ from __future__ import annotations import json from collections.abc import AsyncGenerator, Iterator, Mapping from http import cookies as http_cookies -from typing import TYPE_CHECKING, Any, NoReturn, cast +from typing import TYPE_CHECKING, Any, Generic, NoReturn, cast import anyio +from typing_extensions import TypeVar from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State @@ -68,12 +69,17 @@ class ClientDisconnect(Exception): pass -class HTTPConnection(Mapping[str, Any]): +_LifespanStateT = TypeVar("_LifespanStateT", default=State) + + +class HTTPConnection(Mapping[str, Any], Generic[_LifespanStateT]): """ A base class for incoming HTTP connections, that is used to provide any functionality that is common to both `Request` and `WebSocket`. """ + _state: _LifespanStateT + def __init__(self, scope: Scope, receive: Receive | None = None) -> None: assert scope["type"] in ("http", "websocket") self.scope = scope @@ -171,13 +177,14 @@ class HTTPConnection(Mapping[str, Any]): return self.scope["user"] @property - def state(self) -> State: + def state(self) -> _LifespanStateT: if not hasattr(self, "_state"): # Ensure 'state' has an empty dict if it's not already populated. self.scope.setdefault("state", {}) - # Create a state instance with a reference to the dict in which it should - # store info - self._state = State(self.scope["state"]) + # for backwards compatibility, if the user didn't define a state, then it should create a default State(). + self.scope["state"].setdefault("starlette.lifespan_state", State()) + # Create a state instance with a reference to the dict in which it should store info + self._state = self.scope["state"]["starlette.lifespan_state"] return self._state def url_for(self, name: str, /, **path_params: Any) -> URL: @@ -196,7 +203,7 @@ async def empty_send(message: Message) -> NoReturn: raise RuntimeError("Send channel has not been made available") -class Request(HTTPConnection): +class Request(HTTPConnection[_LifespanStateT]): _form: FormData | None def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send): diff --git a/starlette/routing.py b/starlette/routing.py index 6eb57f48..2ce93c36 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -620,17 +620,13 @@ class Router: "use an @contextlib.asynccontextmanager function instead", DeprecationWarning, ) - self.lifespan_context = asynccontextmanager( - lifespan, - ) + self.lifespan_context = asynccontextmanager(lifespan) elif inspect.isgeneratorfunction(lifespan): warnings.warn( "generator function lifespans are deprecated, use an @contextlib.asynccontextmanager function instead", DeprecationWarning, ) - self.lifespan_context = _wrap_gen_lifespan_context( - lifespan, - ) + self.lifespan_context = _wrap_gen_lifespan_context(lifespan) else: self.lifespan_context = lifespan @@ -695,7 +691,7 @@ class Router: if maybe_state is not None: if "state" not in scope: raise RuntimeError('The server does not support "state" in the lifespan scope.') - scope["state"].update(maybe_state) + scope["state"].update({"starlette.lifespan_state": maybe_state}) await send({"type": "lifespan.startup.complete"}) started = True await receive() diff --git a/starlette/testclient.py b/starlette/testclient.py index df8e1138..2ce89fcd 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -11,14 +11,7 @@ from collections.abc import Awaitable, Generator, Iterable, Mapping, MutableMapp from concurrent.futures import Future from contextlib import AbstractContextManager from types import GeneratorType -from typing import ( - Any, - Callable, - Literal, - TypedDict, - Union, - cast, -) +from typing import Any, Callable, Literal, TypedDict, Union, cast from urllib.parse import unquote, urljoin import anyio diff --git a/starlette/types.py b/starlette/types.py index e1f478d7..be16269a 100644 --- a/starlette/types.py +++ b/starlette/types.py @@ -1,4 +1,4 @@ -from collections.abc import Awaitable, Mapping, MutableMapping +from collections.abc import Awaitable, MutableMapping from contextlib import AbstractAsyncContextManager from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union @@ -18,7 +18,7 @@ Send = Callable[[Message], Awaitable[None]] ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]] StatelessLifespan = Callable[[AppType], AbstractAsyncContextManager[None]] -StatefulLifespan = Callable[[AppType], AbstractAsyncContextManager[Mapping[str, Any]]] +StatefulLifespan = Callable[[AppType], AbstractAsyncContextManager[Any]] Lifespan = Union[StatelessLifespan[AppType], StatefulLifespan[AppType]] HTTPExceptionHandler = Callable[["Request", Exception], "Response | Awaitable[Response]"] diff --git a/starlette/websockets.py b/starlette/websockets.py index fb76361c..6f15daeb 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -5,6 +5,9 @@ import json from collections.abc import AsyncIterator, Iterable from typing import Any, cast +from typing_extensions import TypeVar + +from starlette.datastructures import State from starlette.requests import HTTPConnection from starlette.responses import Response from starlette.types import Message, Receive, Scope, Send @@ -23,7 +26,10 @@ class WebSocketDisconnect(Exception): self.reason = reason or "" -class WebSocket(HTTPConnection): +_LifespanStateT = TypeVar("_LifespanStateT", default=State) + + +class WebSocket(HTTPConnection[_LifespanStateT]): def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: super().__init__(scope) assert scope["type"] == "websocket" diff --git a/tests/test_routing.py b/tests/test_routing.py index 041aab10..7c4930bc 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -757,14 +757,14 @@ def test_lifespan_state_async_cm(test_client_factory: TestClientFactory) -> None count: int items: list[int] - async def hello_world(request: Request) -> Response: - # modifications to the state should not leak across requests - assert request.state.count == 0 + async def hello_world(request: Request[State]) -> Response: + # from version 0.46.3, the state object is only immutable if defined by the user + assert request.state["count"] in (0, 1) # modify the state, this should not leak to the lifespan or other requests - request.state.count += 1 + request.state["count"] += 1 # since state.items is a mutable object this modification _will_ leak across # requests and to the lifespan - request.state.items.append(1) + request.state["items"].append(1) return PlainTextResponse("hello, world") @contextlib.asynccontextmanager @@ -774,8 +774,8 @@ def test_lifespan_state_async_cm(test_client_factory: TestClientFactory) -> None state = State(count=0, items=[]) yield state shutdown_complete = True - # modifications made to the state from a request do not leak to the lifespan - assert state["count"] == 0 + # from version 0.46.3, objects from the state are mutable if the state itself is mutable + assert state["count"] == 2 # unless of course the request mutates a mutable object that is referenced # via state assert state["items"] == [1, 1]