import email.message
import inspect
import json
-from contextlib import AsyncExitStack
+from contextlib import AsyncExitStack, asynccontextmanager
from enum import Enum, IntEnum
from typing import (
Any,
+ AsyncIterator,
Callable,
Coroutine,
Dict,
List,
+ Mapping,
Optional,
Sequence,
Set,
websocket_session,
)
from starlette.routing import Mount as Mount # noqa
-from starlette.types import ASGIApp, Lifespan, Scope
+from starlette.types import AppType, ASGIApp, Lifespan, Scope
from starlette.websockets import WebSocket
from typing_extensions import Annotated, Doc, deprecated
return res
+def _merge_lifespan_context(
+ original_context: Lifespan[Any], nested_context: Lifespan[Any]
+) -> Lifespan[Any]:
+ @asynccontextmanager
+ async def merged_lifespan(
+ app: AppType,
+ ) -> AsyncIterator[Optional[Mapping[str, Any]]]:
+ async with original_context(app) as maybe_original_state:
+ async with nested_context(app) as maybe_nested_state:
+ if maybe_nested_state is None and maybe_original_state is None:
+ yield None # old ASGI compatibility
+ else:
+ yield {**(maybe_nested_state or {}), **(maybe_original_state or {})}
+
+ return merged_lifespan # type: ignore[return-value]
+
+
async def serialize_response(
*,
field: Optional[ModelField] = None,
self.add_event_handler("startup", handler)
for handler in router.on_shutdown:
self.add_event_handler("shutdown", handler)
+ self.lifespan_context = _merge_lifespan_context(
+ self.lifespan_context,
+ router.lifespan_context,
+ )
def get(
self,
from contextlib import asynccontextmanager
-from typing import AsyncGenerator, Dict
+from typing import AsyncGenerator, Dict, Union
import pytest
-from fastapi import APIRouter, FastAPI
+from fastapi import APIRouter, FastAPI, Request
from fastapi.testclient import TestClient
from pydantic import BaseModel
assert response.json() == {"message": "Hello World"}
assert state.app_startup is True
assert state.app_shutdown is True
+
+
+def test_router_nested_lifespan_state(state: State) -> None:
+ @asynccontextmanager
+ async def lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, bool], None]:
+ state.app_startup = True
+ yield {"app": True}
+ state.app_shutdown = True
+
+ @asynccontextmanager
+ async def router_lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, bool], None]:
+ state.router_startup = True
+ yield {"router": True}
+ state.router_shutdown = True
+
+ @asynccontextmanager
+ async def subrouter_lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, bool], None]:
+ state.sub_router_startup = True
+ yield {"sub_router": True}
+ state.sub_router_shutdown = True
+
+ sub_router = APIRouter(lifespan=subrouter_lifespan)
+
+ router = APIRouter(lifespan=router_lifespan)
+ router.include_router(sub_router)
+
+ app = FastAPI(lifespan=lifespan)
+ app.include_router(router)
+
+ @app.get("/")
+ def main(request: Request) -> Dict[str, str]:
+ assert request.state.app
+ assert request.state.router
+ assert request.state.sub_router
+ return {"message": "Hello World"}
+
+ assert state.app_startup is False
+ assert state.router_startup is False
+ assert state.sub_router_startup is False
+ assert state.app_shutdown is False
+ assert state.router_shutdown is False
+ assert state.sub_router_shutdown is False
+
+ with TestClient(app) as client:
+ assert state.app_startup is True
+ assert state.router_startup is True
+ assert state.sub_router_startup is True
+ assert state.app_shutdown is False
+ assert state.router_shutdown is False
+ assert state.sub_router_shutdown is False
+ response = client.get("/")
+ assert response.status_code == 200, response.text
+ assert response.json() == {"message": "Hello World"}
+
+ assert state.app_startup is True
+ assert state.router_startup is True
+ assert state.sub_router_startup is True
+ assert state.app_shutdown is True
+ assert state.router_shutdown is True
+ assert state.sub_router_shutdown is True
+
+
+def test_router_nested_lifespan_state_overriding_by_parent() -> None:
+ @asynccontextmanager
+ async def lifespan(
+ app: FastAPI,
+ ) -> AsyncGenerator[Dict[str, Union[str, bool]], None]:
+ yield {
+ "app_specific": True,
+ "overridden": "app",
+ }
+
+ @asynccontextmanager
+ async def router_lifespan(
+ app: FastAPI,
+ ) -> AsyncGenerator[Dict[str, Union[str, bool]], None]:
+ yield {
+ "router_specific": True,
+ "overridden": "router", # should override parent
+ }
+
+ router = APIRouter(lifespan=router_lifespan)
+ app = FastAPI(lifespan=lifespan)
+ app.include_router(router)
+
+ with TestClient(app) as client:
+ assert client.app_state == {
+ "app_specific": True,
+ "router_specific": True,
+ "overridden": "app",
+ }
+
+
+def test_merged_no_return_lifespans_return_none() -> None:
+ @asynccontextmanager
+ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
+ yield
+
+ @asynccontextmanager
+ async def router_lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
+ yield
+
+ router = APIRouter(lifespan=router_lifespan)
+ app = FastAPI(lifespan=lifespan)
+ app.include_router(router)
+
+ with TestClient(app) as client:
+ assert not client.app_state
+
+
+def test_merged_mixed_state_lifespans() -> None:
+ @asynccontextmanager
+ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
+ yield
+
+ @asynccontextmanager
+ async def router_lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, bool], None]:
+ yield {"router": True}
+
+ @asynccontextmanager
+ async def sub_router_lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
+ yield
+
+ sub_router = APIRouter(lifespan=sub_router_lifespan)
+ router = APIRouter(lifespan=router_lifespan)
+ app = FastAPI(lifespan=lifespan)
+ router.include_router(sub_router)
+ app.include_router(router)
+
+ with TestClient(app) as client:
+ assert client.app_state == {"router": True}