]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Allow to pass any object as lifespan state
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Tue, 27 May 2025 13:05:43 +0000 (15:05 +0200)
committerMarcelo Trylesinski <marcelotryle@gmail.com>
Tue, 27 May 2025 13:05:43 +0000 (15:05 +0200)
docs/lifespan.md
starlette/requests.py
starlette/routing.py
starlette/testclient.py
starlette/types.py
starlette/websockets.py
tests/test_routing.py

index a5a766251bdf6c31bd4dfc94a5b4f3099d528d17..3c094c485a357a0223cb117bd844c59928a36047 100644 (file)
@@ -39,40 +39,46 @@ can be used to share the objects between the lifespan, and the requests.
 
 ```python
 import contextlib
-from typing import AsyncIterator, TypedDict
+from collections.abc import AsyncIterator
+from dataclasses import dataclass
 
 import httpx
+
 from starlette.applications import Starlette
 from starlette.requests import Request
 from starlette.responses import PlainTextResponse
 from starlette.routing import Route
 
 
-class State(TypedDict):
-    http_client: httpx.AsyncClient
+@dataclass(frozen=True)
+class LifespanState:
+    client: httpx.AsyncClient
 
 
 @contextlib.asynccontextmanager
-async def lifespan(app: Starlette) -> AsyncIterator[State]:
+async def lifespan(app: Starlette) -> AsyncIterator[LifespanState]:
     async with httpx.AsyncClient() as client:
-        yield {"http_client": client}
+        yield LifespanState(client=client)
 
 
-async def homepage(request: Request) -> PlainTextResponse:
-    client = request.state.http_client
-    response = await client.get("https://www.example.com")
+async def homepage(request: Request[LifespanState]) -> PlainTextResponse:
+    client = request.state.client
+    response = await client.get("http://localhost:8001")
     return PlainTextResponse(response.text)
 
 
-app = Starlette(
-    lifespan=lifespan,
-    routes=[Route("/", homepage)]
-)
+app = Starlette(lifespan=lifespan, routes=[Route("/", homepage)])
 ```
 
 The `state` received on the requests is a **shallow** copy of the state received on the
 lifespan handler.
 
+!!! warning
+    From version 0.46.3, the state object is not immutable by default.
+
+    As a user you should make sure the lifespan object is immutable if you don't want changes to
+    the lifespan state to be spread through requests.
+
 ## Running lifespan in tests
 
 You should use `TestClient` as a context manager, to ensure that the lifespan is called.
index 628358d15feb8cd83900f846b1656ea7998ff6a6..64c08477aba5e7c2e8776245f16714e869ce91b7 100644 (file)
@@ -3,9 +3,10 @@ from __future__ import annotations
 import json
 from collections.abc import AsyncGenerator, Iterator, Mapping
 from http import cookies as http_cookies
-from typing import TYPE_CHECKING, Any, NoReturn, cast
+from typing import TYPE_CHECKING, Any, Generic, NoReturn, cast
 
 import anyio
+from typing_extensions import TypeVar
 
 from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper
 from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
@@ -68,12 +69,17 @@ class ClientDisconnect(Exception):
     pass
 
 
-class HTTPConnection(Mapping[str, Any]):
+_LifespanStateT = TypeVar("_LifespanStateT", default=State)
+
+
+class HTTPConnection(Mapping[str, Any], Generic[_LifespanStateT]):
     """
     A base class for incoming HTTP connections, that is used to provide
     any functionality that is common to both `Request` and `WebSocket`.
     """
 
+    _state: _LifespanStateT
+
     def __init__(self, scope: Scope, receive: Receive | None = None) -> None:
         assert scope["type"] in ("http", "websocket")
         self.scope = scope
@@ -171,13 +177,14 @@ class HTTPConnection(Mapping[str, Any]):
         return self.scope["user"]
 
     @property
-    def state(self) -> State:
+    def state(self) -> _LifespanStateT:
         if not hasattr(self, "_state"):
             # Ensure 'state' has an empty dict if it's not already populated.
             self.scope.setdefault("state", {})
-            # Create a state instance with a reference to the dict in which it should
-            # store info
-            self._state = State(self.scope["state"])
+            # for backwards compatibility, if the user didn't define a state, then it should create a default State().
+            self.scope["state"].setdefault("starlette.lifespan_state", State())
+            # Create a state instance with a reference to the dict in which it should store info
+            self._state = self.scope["state"]["starlette.lifespan_state"]
         return self._state
 
     def url_for(self, name: str, /, **path_params: Any) -> URL:
@@ -196,7 +203,7 @@ async def empty_send(message: Message) -> NoReturn:
     raise RuntimeError("Send channel has not been made available")
 
 
-class Request(HTTPConnection):
+class Request(HTTPConnection[_LifespanStateT]):
     _form: FormData | None
 
     def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send):
index 6eb57f487bf456b6563480f8ff9bc71f4d2ab6b6..2ce93c364a7175e8f46520df26ab8b5f73081195 100644 (file)
@@ -620,17 +620,13 @@ class Router:
                 "use an @contextlib.asynccontextmanager function instead",
                 DeprecationWarning,
             )
-            self.lifespan_context = asynccontextmanager(
-                lifespan,
-            )
+            self.lifespan_context = asynccontextmanager(lifespan)
         elif inspect.isgeneratorfunction(lifespan):
             warnings.warn(
                 "generator function lifespans are deprecated, use an @contextlib.asynccontextmanager function instead",
                 DeprecationWarning,
             )
-            self.lifespan_context = _wrap_gen_lifespan_context(
-                lifespan,
-            )
+            self.lifespan_context = _wrap_gen_lifespan_context(lifespan)
         else:
             self.lifespan_context = lifespan
 
@@ -695,7 +691,7 @@ class Router:
                 if maybe_state is not None:
                     if "state" not in scope:
                         raise RuntimeError('The server does not support "state" in the lifespan scope.')
-                    scope["state"].update(maybe_state)
+                    scope["state"].update({"starlette.lifespan_state": maybe_state})
                 await send({"type": "lifespan.startup.complete"})
                 started = True
                 await receive()
index df8e113807bf155ef54fc559bb83f7222e4e5fff..2ce89fcd3f8967b23bdf51fe6519d5b94ad11a45 100644 (file)
@@ -11,14 +11,7 @@ from collections.abc import Awaitable, Generator, Iterable, Mapping, MutableMapp
 from concurrent.futures import Future
 from contextlib import AbstractContextManager
 from types import GeneratorType
-from typing import (
-    Any,
-    Callable,
-    Literal,
-    TypedDict,
-    Union,
-    cast,
-)
+from typing import Any, Callable, Literal, TypedDict, Union, cast
 from urllib.parse import unquote, urljoin
 
 import anyio
index e1f478d78741df2af89e0d839a40bb1412a51bd4..be16269a7c6bbfb7abdd8b017b3550942753f4d4 100644 (file)
@@ -1,4 +1,4 @@
-from collections.abc import Awaitable, Mapping, MutableMapping
+from collections.abc import Awaitable, MutableMapping
 from contextlib import AbstractAsyncContextManager
 from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union
 
@@ -18,7 +18,7 @@ Send = Callable[[Message], Awaitable[None]]
 ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]]
 
 StatelessLifespan = Callable[[AppType], AbstractAsyncContextManager[None]]
-StatefulLifespan = Callable[[AppType], AbstractAsyncContextManager[Mapping[str, Any]]]
+StatefulLifespan = Callable[[AppType], AbstractAsyncContextManager[Any]]
 Lifespan = Union[StatelessLifespan[AppType], StatefulLifespan[AppType]]
 
 HTTPExceptionHandler = Callable[["Request", Exception], "Response | Awaitable[Response]"]
index fb76361c8a2210dc6c494549b0c8c3b9c2e01a18..6f15daeb25094f5dd7254923d40a33bd3c9ffc74 100644 (file)
@@ -5,6 +5,9 @@ import json
 from collections.abc import AsyncIterator, Iterable
 from typing import Any, cast
 
+from typing_extensions import TypeVar
+
+from starlette.datastructures import State
 from starlette.requests import HTTPConnection
 from starlette.responses import Response
 from starlette.types import Message, Receive, Scope, Send
@@ -23,7 +26,10 @@ class WebSocketDisconnect(Exception):
         self.reason = reason or ""
 
 
-class WebSocket(HTTPConnection):
+_LifespanStateT = TypeVar("_LifespanStateT", default=State)
+
+
+class WebSocket(HTTPConnection[_LifespanStateT]):
     def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
         super().__init__(scope)
         assert scope["type"] == "websocket"
index 041aab103e389d749a619e735f4c8b919b89bbde..7c4930bc3054fc28fc449f2d1a622669d79ab15b 100644 (file)
@@ -757,14 +757,14 @@ def test_lifespan_state_async_cm(test_client_factory: TestClientFactory) -> None
         count: int
         items: list[int]
 
-    async def hello_world(request: Request) -> Response:
-        # modifications to the state should not leak across requests
-        assert request.state.count == 0
+    async def hello_world(request: Request[State]) -> Response:
+        # from version 0.46.3, the state object is only immutable if defined by the user
+        assert request.state["count"] in (0, 1)
         # modify the state, this should not leak to the lifespan or other requests
-        request.state.count += 1
+        request.state["count"] += 1
         # since state.items is a mutable object this modification _will_ leak across
         # requests and to the lifespan
-        request.state.items.append(1)
+        request.state["items"].append(1)
         return PlainTextResponse("hello, world")
 
     @contextlib.asynccontextmanager
@@ -774,8 +774,8 @@ def test_lifespan_state_async_cm(test_client_factory: TestClientFactory) -> None
         state = State(count=0, items=[])
         yield state
         shutdown_complete = True
-        # modifications made to the state from a request do not leak to the lifespan
-        assert state["count"] == 0
+        # from version 0.46.3, objects from the state are mutable if the state itself is mutable
+        assert state["count"] == 2
         # unless of course the request mutates a mutable object that is referenced
         # via state
         assert state["items"] == [1, 1]