+import contextlib
import email.message
import functools
import inspect
import json
+import types
from collections.abc import (
AsyncIterator,
Awaitable,
Collection,
Coroutine,
+ Generator,
Mapping,
Sequence,
)
-from contextlib import AsyncExitStack, asynccontextmanager
+from contextlib import (
+ AbstractAsyncContextManager,
+ AbstractContextManager,
+ AsyncExitStack,
+ asynccontextmanager,
+)
from enum import Enum, IntEnum
from typing import (
Annotated,
Any,
Callable,
Optional,
+ TypeVar,
Union,
)
return app
+_T = TypeVar("_T")
+
+
+# Vendored from starlette.routing to avoid importing private symbols
+class _AsyncLiftContextManager(AbstractAsyncContextManager[_T]):
+ """
+ Wraps a synchronous context manager to make it async.
+
+ This is vendored from Starlette to avoid importing private symbols.
+ """
+
+ def __init__(self, cm: AbstractContextManager[_T]) -> None:
+ self._cm = cm
+
+ async def __aenter__(self) -> _T:
+ return self._cm.__enter__()
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[type[BaseException]],
+ exc_value: Optional[BaseException],
+ traceback: Optional[types.TracebackType],
+ ) -> Optional[bool]:
+ return self._cm.__exit__(exc_type, exc_value, traceback)
+
+
+# Vendored from starlette.routing to avoid importing private symbols
+def _wrap_gen_lifespan_context(
+ lifespan_context: Callable[[Any], Generator[Any, Any, Any]],
+) -> Callable[[Any], AbstractAsyncContextManager[Any]]:
+ """
+ Wrap a generator-based lifespan context into an async context manager.
+
+ This is vendored from Starlette to avoid importing private symbols.
+ """
+ cmgr = contextlib.contextmanager(lifespan_context)
+
+ @functools.wraps(cmgr)
+ def wrapper(app: Any) -> _AsyncLiftContextManager[Any]:
+ return _AsyncLiftContextManager(cmgr(app))
+
+ return wrapper
+
+
def _merge_lifespan_context(
original_context: Lifespan[Any], nested_context: Lifespan[Any]
) -> Lifespan[Any]:
return merged_lifespan # type: ignore[return-value]
+class _DefaultLifespan:
+ """
+ Default lifespan context manager that runs on_startup and on_shutdown handlers.
+
+ This is a copy of the Starlette _DefaultLifespan class that was removed
+ in Starlette. FastAPI keeps it to maintain backward compatibility with
+ on_startup and on_shutdown event handlers.
+
+ Ref: https://github.com/Kludex/starlette/pull/3117
+ """
+
+ def __init__(self, router: "APIRouter") -> None:
+ self._router = router
+
+ async def __aenter__(self) -> None:
+ await self._router._startup()
+
+ async def __aexit__(self, *exc_info: object) -> None:
+ await self._router._shutdown()
+
+ def __call__(self: _T, app: object) -> _T:
+ return self
+
+
# Cache for endpoint context to avoid re-extracting on every request
_endpoint_context_cache: dict[int, EndpointContext] = {}
),
] = Default(generate_unique_id),
) -> None:
+ # Handle on_startup/on_shutdown locally since Starlette removed support
+ # Ref: https://github.com/Kludex/starlette/pull/3117
+ # TODO: deprecate this once the lifespan (or alternative) interface is improved
+ self.on_startup: list[Callable[[], Any]] = (
+ [] if on_startup is None else list(on_startup)
+ )
+ self.on_shutdown: list[Callable[[], Any]] = (
+ [] if on_shutdown is None else list(on_shutdown)
+ )
+
+ # Determine the lifespan context to use
+ if lifespan is None:
+ # Use the default lifespan that runs on_startup/on_shutdown handlers
+ lifespan_context: Lifespan[Any] = _DefaultLifespan(self)
+ elif inspect.isasyncgenfunction(lifespan):
+ lifespan_context = asynccontextmanager(lifespan)
+ elif inspect.isgeneratorfunction(lifespan):
+ lifespan_context = _wrap_gen_lifespan_context(lifespan)
+ else:
+ lifespan_context = lifespan
+ self.lifespan_context = lifespan_context
+
super().__init__(
routes=routes,
redirect_slashes=redirect_slashes,
default=default,
- on_startup=on_startup,
- on_shutdown=on_shutdown,
- lifespan=lifespan,
+ lifespan=lifespan_context,
)
if prefix:
assert prefix.startswith("/"), "A path prefix must start with '/'"
generate_unique_id_function=generate_unique_id_function,
)
+ # TODO: remove this once the lifespan (or alternative) interface is improved
+ async def _startup(self) -> None:
+ """
+ Run any `.on_startup` event handlers.
+
+ This method is kept for backward compatibility after Starlette removed
+ support for on_startup/on_shutdown handlers.
+
+ Ref: https://github.com/Kludex/starlette/pull/3117
+ """
+ for handler in self.on_startup:
+ if is_async_callable(handler):
+ await handler()
+ else:
+ handler()
+
+ # TODO: remove this once the lifespan (or alternative) interface is improved
+ async def _shutdown(self) -> None:
+ """
+ Run any `.on_shutdown` event handlers.
+
+ This method is kept for backward compatibility after Starlette removed
+ support for on_startup/on_shutdown handlers.
+
+ Ref: https://github.com/Kludex/starlette/pull/3117
+ """
+ for handler in self.on_shutdown:
+ if is_async_callable(handler):
+ await handler()
+ else:
+ handler()
+
+ # TODO: remove this once the lifespan (or alternative) interface is improved
+ def add_event_handler(
+ self,
+ event_type: str,
+ func: Callable[[], Any],
+ ) -> None:
+ """
+ Add an event handler function for startup or shutdown.
+
+ This method is kept for backward compatibility after Starlette removed
+ support for on_startup/on_shutdown handlers.
+
+ Ref: https://github.com/Kludex/starlette/pull/3117
+ """
+ assert event_type in ("startup", "shutdown")
+ if event_type == "startup":
+ self.on_startup.append(func)
+ else:
+ self.on_shutdown.append(func)
+
@deprecated(
"""
on_event is deprecated, use lifespan event handlers instead.
with TestClient(app) as client:
assert client.app_state == {"router": True}
+
+
+@pytest.mark.filterwarnings(
+ r"ignore:\s*on_event is deprecated, use lifespan event handlers instead.*:DeprecationWarning"
+)
+def test_router_async_shutdown_handler(state: State) -> None:
+ """Test that async on_shutdown event handlers are called correctly, for coverage."""
+ app = FastAPI()
+
+ @app.get("/")
+ def main() -> dict[str, str]:
+ return {"message": "Hello World"}
+
+ @app.on_event("shutdown")
+ async def app_shutdown() -> None:
+ state.app_shutdown = True
+
+ assert state.app_shutdown is False
+ with TestClient(app) as client:
+ assert state.app_shutdown is False
+ response = client.get("/")
+ assert response.status_code == 200, response.text
+ assert state.app_shutdown is True
+
+
+def test_router_sync_generator_lifespan(state: State) -> None:
+ """Test that a sync generator lifespan works via _wrap_gen_lifespan_context."""
+ from collections.abc import Generator
+
+ def lifespan(app: FastAPI) -> Generator[None, None, None]:
+ state.app_startup = True
+ yield
+ state.app_shutdown = True
+
+ app = FastAPI(lifespan=lifespan) # type: ignore[arg-type]
+
+ @app.get("/")
+ def main() -> dict[str, str]:
+ return {"message": "Hello World"}
+
+ assert state.app_startup is False
+ assert state.app_shutdown is False
+ with TestClient(app) as client:
+ assert state.app_startup is True
+ assert state.app_shutdown is False
+ response = client.get("/")
+ assert response.status_code == 200, response.text
+ assert response.json() == {"message": "Hello World"}
+ assert state.app_startup is True
+ assert state.app_shutdown is True
+
+
+def test_router_async_generator_lifespan(state: State) -> None:
+ """Test that an async generator lifespan (not wrapped) works."""
+
+ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
+ state.app_startup = True
+ yield
+ state.app_shutdown = True
+
+ app = FastAPI(lifespan=lifespan) # type: ignore[arg-type]
+
+ @app.get("/")
+ def main() -> dict[str, str]:
+ return {"message": "Hello World"}
+
+ assert state.app_startup is False
+ assert state.app_shutdown is False
+ with TestClient(app) as client:
+ assert state.app_startup is True
+ assert state.app_shutdown is False
+ response = client.get("/")
+ assert response.status_code == 200, response.text
+ assert response.json() == {"message": "Hello World"}
+ assert state.app_startup is True
+ assert state.app_shutdown is True