```python
import contextlib
-from typing import AsyncIterator, TypedDict
+from collections.abc import AsyncIterator
+from dataclasses import dataclass
import httpx
+
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from starlette.routing import Route
-class State(TypedDict):
- http_client: httpx.AsyncClient
+@dataclass(frozen=True)
+class LifespanState:
+ client: httpx.AsyncClient
@contextlib.asynccontextmanager
-async def lifespan(app: Starlette) -> AsyncIterator[State]:
+async def lifespan(app: Starlette) -> AsyncIterator[LifespanState]:
async with httpx.AsyncClient() as client:
- yield {"http_client": client}
+ yield LifespanState(client=client)
-async def homepage(request: Request) -> PlainTextResponse:
- client = request.state.http_client
- response = await client.get("https://www.example.com")
+async def homepage(request: Request[LifespanState]) -> PlainTextResponse:
+ client = request.state.client
+ response = await client.get("http://localhost:8001")
return PlainTextResponse(response.text)
-app = Starlette(
- lifespan=lifespan,
- routes=[Route("/", homepage)]
-)
+app = Starlette(lifespan=lifespan, routes=[Route("/", homepage)])
```
The `state` received on the requests is a **shallow** copy of the state received on the
lifespan handler.
+!!! warning
+ From version 0.46.3, the state object is not immutable by default.
+
+ As a user you should make sure the lifespan object is immutable if you don't want changes to
+ the lifespan state to be spread through requests.
+
## Running lifespan in tests
You should use `TestClient` as a context manager, to ensure that the lifespan is called.
import json
from collections.abc import AsyncGenerator, Iterator, Mapping
from http import cookies as http_cookies
-from typing import TYPE_CHECKING, Any, NoReturn, cast
+from typing import TYPE_CHECKING, Any, Generic, NoReturn, cast
import anyio
+from typing_extensions import TypeVar
from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper
from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
pass
-class HTTPConnection(Mapping[str, Any]):
+_LifespanStateT = TypeVar("_LifespanStateT", default=State)
+
+
+class HTTPConnection(Mapping[str, Any], Generic[_LifespanStateT]):
"""
A base class for incoming HTTP connections, that is used to provide
any functionality that is common to both `Request` and `WebSocket`.
"""
+ _state: _LifespanStateT
+
def __init__(self, scope: Scope, receive: Receive | None = None) -> None:
assert scope["type"] in ("http", "websocket")
self.scope = scope
return self.scope["user"]
@property
- def state(self) -> State:
+ def state(self) -> _LifespanStateT:
if not hasattr(self, "_state"):
# Ensure 'state' has an empty dict if it's not already populated.
self.scope.setdefault("state", {})
- # Create a state instance with a reference to the dict in which it should
- # store info
- self._state = State(self.scope["state"])
+ # for backwards compatibility, if the user didn't define a state, then it should create a default State().
+ self.scope["state"].setdefault("starlette.lifespan_state", State())
+ # Create a state instance with a reference to the dict in which it should store info
+ self._state = self.scope["state"]["starlette.lifespan_state"]
return self._state
def url_for(self, name: str, /, **path_params: Any) -> URL:
raise RuntimeError("Send channel has not been made available")
-class Request(HTTPConnection):
+class Request(HTTPConnection[_LifespanStateT]):
_form: FormData | None
def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send):
"use an @contextlib.asynccontextmanager function instead",
DeprecationWarning,
)
- self.lifespan_context = asynccontextmanager(
- lifespan,
- )
+ self.lifespan_context = asynccontextmanager(lifespan)
elif inspect.isgeneratorfunction(lifespan):
warnings.warn(
"generator function lifespans are deprecated, use an @contextlib.asynccontextmanager function instead",
DeprecationWarning,
)
- self.lifespan_context = _wrap_gen_lifespan_context(
- lifespan,
- )
+ self.lifespan_context = _wrap_gen_lifespan_context(lifespan)
else:
self.lifespan_context = lifespan
if maybe_state is not None:
if "state" not in scope:
raise RuntimeError('The server does not support "state" in the lifespan scope.')
- scope["state"].update(maybe_state)
+ scope["state"].update({"starlette.lifespan_state": maybe_state})
await send({"type": "lifespan.startup.complete"})
started = True
await receive()
from concurrent.futures import Future
from contextlib import AbstractContextManager
from types import GeneratorType
-from typing import (
- Any,
- Callable,
- Literal,
- TypedDict,
- Union,
- cast,
-)
+from typing import Any, Callable, Literal, TypedDict, Union, cast
from urllib.parse import unquote, urljoin
import anyio
-from collections.abc import Awaitable, Mapping, MutableMapping
+from collections.abc import Awaitable, MutableMapping
from contextlib import AbstractAsyncContextManager
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union
ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]]
StatelessLifespan = Callable[[AppType], AbstractAsyncContextManager[None]]
-StatefulLifespan = Callable[[AppType], AbstractAsyncContextManager[Mapping[str, Any]]]
+StatefulLifespan = Callable[[AppType], AbstractAsyncContextManager[Any]]
Lifespan = Union[StatelessLifespan[AppType], StatefulLifespan[AppType]]
HTTPExceptionHandler = Callable[["Request", Exception], "Response | Awaitable[Response]"]
from collections.abc import AsyncIterator, Iterable
from typing import Any, cast
+from typing_extensions import TypeVar
+
+from starlette.datastructures import State
from starlette.requests import HTTPConnection
from starlette.responses import Response
from starlette.types import Message, Receive, Scope, Send
self.reason = reason or ""
-class WebSocket(HTTPConnection):
+_LifespanStateT = TypeVar("_LifespanStateT", default=State)
+
+
+class WebSocket(HTTPConnection[_LifespanStateT]):
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
super().__init__(scope)
assert scope["type"] == "websocket"
count: int
items: list[int]
- async def hello_world(request: Request) -> Response:
- # modifications to the state should not leak across requests
- assert request.state.count == 0
+ async def hello_world(request: Request[State]) -> Response:
+ # from version 0.46.3, the state object is only immutable if defined by the user
+ assert request.state["count"] in (0, 1)
# modify the state, this should not leak to the lifespan or other requests
- request.state.count += 1
+ request.state["count"] += 1
# since state.items is a mutable object this modification _will_ leak across
# requests and to the lifespan
- request.state.items.append(1)
+ request.state["items"].append(1)
return PlainTextResponse("hello, world")
@contextlib.asynccontextmanager
state = State(count=0, items=[])
yield state
shutdown_complete = True
- # modifications made to the state from a request do not leak to the lifespan
- assert state["count"] == 0
+ # from version 0.46.3, objects from the state are mutable if the state itself is mutable
+ assert state["count"] == 2
# unless of course the request mutates a mutable object that is referenced
# via state
assert state["items"] == [1, 1]