]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
♻️ Re-implement `on_event` in FastAPI for compatibility with the next Starlette,...
authorSebastián Ramírez <tiangolo@gmail.com>
Fri, 6 Feb 2026 15:18:30 +0000 (07:18 -0800)
committerGitHub <noreply@github.com>
Fri, 6 Feb 2026 15:18:30 +0000 (16:18 +0100)
fastapi/routing.py
tests/test_router_events.py

index 9ca2f46732c9eeef6430ca387dbacd581bc17d01..c95f624bdfea8d2d6ef3b8aa1ce09e429a1c0f3e 100644 (file)
@@ -1,22 +1,31 @@
+import contextlib
 import email.message
 import functools
 import inspect
 import json
+import types
 from collections.abc import (
     AsyncIterator,
     Awaitable,
     Collection,
     Coroutine,
+    Generator,
     Mapping,
     Sequence,
 )
-from contextlib import AsyncExitStack, asynccontextmanager
+from contextlib import (
+    AbstractAsyncContextManager,
+    AbstractContextManager,
+    AsyncExitStack,
+    asynccontextmanager,
+)
 from enum import Enum, IntEnum
 from typing import (
     Annotated,
     Any,
     Callable,
     Optional,
+    TypeVar,
     Union,
 )
 
@@ -143,6 +152,50 @@ def websocket_session(
     return app
 
 
+_T = TypeVar("_T")
+
+
+# Vendored from starlette.routing to avoid importing private symbols
+class _AsyncLiftContextManager(AbstractAsyncContextManager[_T]):
+    """
+    Wraps a synchronous context manager to make it async.
+
+    This is vendored from Starlette to avoid importing private symbols.
+    """
+
+    def __init__(self, cm: AbstractContextManager[_T]) -> None:
+        self._cm = cm
+
+    async def __aenter__(self) -> _T:
+        return self._cm.__enter__()
+
+    async def __aexit__(
+        self,
+        exc_type: Optional[type[BaseException]],
+        exc_value: Optional[BaseException],
+        traceback: Optional[types.TracebackType],
+    ) -> Optional[bool]:
+        return self._cm.__exit__(exc_type, exc_value, traceback)
+
+
+# Vendored from starlette.routing to avoid importing private symbols
+def _wrap_gen_lifespan_context(
+    lifespan_context: Callable[[Any], Generator[Any, Any, Any]],
+) -> Callable[[Any], AbstractAsyncContextManager[Any]]:
+    """
+    Wrap a generator-based lifespan context into an async context manager.
+
+    This is vendored from Starlette to avoid importing private symbols.
+    """
+    cmgr = contextlib.contextmanager(lifespan_context)
+
+    @functools.wraps(cmgr)
+    def wrapper(app: Any) -> _AsyncLiftContextManager[Any]:
+        return _AsyncLiftContextManager(cmgr(app))
+
+    return wrapper
+
+
 def _merge_lifespan_context(
     original_context: Lifespan[Any], nested_context: Lifespan[Any]
 ) -> Lifespan[Any]:
@@ -160,6 +213,30 @@ def _merge_lifespan_context(
     return merged_lifespan  # type: ignore[return-value]
 
 
+class _DefaultLifespan:
+    """
+    Default lifespan context manager that runs on_startup and on_shutdown handlers.
+
+    This is a copy of the Starlette _DefaultLifespan class that was removed
+    in Starlette. FastAPI keeps it to maintain backward compatibility with
+    on_startup and on_shutdown event handlers.
+
+    Ref: https://github.com/Kludex/starlette/pull/3117
+    """
+
+    def __init__(self, router: "APIRouter") -> None:
+        self._router = router
+
+    async def __aenter__(self) -> None:
+        await self._router._startup()
+
+    async def __aexit__(self, *exc_info: object) -> None:
+        await self._router._shutdown()
+
+    def __call__(self: _T, app: object) -> _T:
+        return self
+
+
 # Cache for endpoint context to avoid re-extracting on every request
 _endpoint_context_cache: dict[int, EndpointContext] = {}
 
@@ -903,13 +980,33 @@ class APIRouter(routing.Router):
             ),
         ] = Default(generate_unique_id),
     ) -> None:
+        # Handle on_startup/on_shutdown locally since Starlette removed support
+        # Ref: https://github.com/Kludex/starlette/pull/3117
+        # TODO: deprecate this once the lifespan (or alternative) interface is improved
+        self.on_startup: list[Callable[[], Any]] = (
+            [] if on_startup is None else list(on_startup)
+        )
+        self.on_shutdown: list[Callable[[], Any]] = (
+            [] if on_shutdown is None else list(on_shutdown)
+        )
+
+        # Determine the lifespan context to use
+        if lifespan is None:
+            # Use the default lifespan that runs on_startup/on_shutdown handlers
+            lifespan_context: Lifespan[Any] = _DefaultLifespan(self)
+        elif inspect.isasyncgenfunction(lifespan):
+            lifespan_context = asynccontextmanager(lifespan)
+        elif inspect.isgeneratorfunction(lifespan):
+            lifespan_context = _wrap_gen_lifespan_context(lifespan)
+        else:
+            lifespan_context = lifespan
+        self.lifespan_context = lifespan_context
+
         super().__init__(
             routes=routes,
             redirect_slashes=redirect_slashes,
             default=default,
-            on_startup=on_startup,
-            on_shutdown=on_shutdown,
-            lifespan=lifespan,
+            lifespan=lifespan_context,
         )
         if prefix:
             assert prefix.startswith("/"), "A path prefix must start with '/'"
@@ -4473,6 +4570,58 @@ class APIRouter(routing.Router):
             generate_unique_id_function=generate_unique_id_function,
         )
 
+    # TODO: remove this once the lifespan (or alternative) interface is improved
+    async def _startup(self) -> None:
+        """
+        Run any `.on_startup` event handlers.
+
+        This method is kept for backward compatibility after Starlette removed
+        support for on_startup/on_shutdown handlers.
+
+        Ref: https://github.com/Kludex/starlette/pull/3117
+        """
+        for handler in self.on_startup:
+            if is_async_callable(handler):
+                await handler()
+            else:
+                handler()
+
+    # TODO: remove this once the lifespan (or alternative) interface is improved
+    async def _shutdown(self) -> None:
+        """
+        Run any `.on_shutdown` event handlers.
+
+        This method is kept for backward compatibility after Starlette removed
+        support for on_startup/on_shutdown handlers.
+
+        Ref: https://github.com/Kludex/starlette/pull/3117
+        """
+        for handler in self.on_shutdown:
+            if is_async_callable(handler):
+                await handler()
+            else:
+                handler()
+
+    # TODO: remove this once the lifespan (or alternative) interface is improved
+    def add_event_handler(
+        self,
+        event_type: str,
+        func: Callable[[], Any],
+    ) -> None:
+        """
+        Add an event handler function for startup or shutdown.
+
+        This method is kept for backward compatibility after Starlette removed
+        support for on_startup/on_shutdown handlers.
+
+        Ref: https://github.com/Kludex/starlette/pull/3117
+        """
+        assert event_type in ("startup", "shutdown")
+        if event_type == "startup":
+            self.on_startup.append(func)
+        else:
+            self.on_shutdown.append(func)
+
     @deprecated(
         """
         on_event is deprecated, use lifespan event handlers instead.
index 9df299cdaa6098d5211cd4f86a5d3f62825a6741..65f2f521c186e6fccfc5365265cdba8cbd62b9d1 100644 (file)
@@ -241,3 +241,79 @@ def test_merged_mixed_state_lifespans() -> None:
 
     with TestClient(app) as client:
         assert client.app_state == {"router": True}
+
+
+@pytest.mark.filterwarnings(
+    r"ignore:\s*on_event is deprecated, use lifespan event handlers instead.*:DeprecationWarning"
+)
+def test_router_async_shutdown_handler(state: State) -> None:
+    """Test that async on_shutdown event handlers are called correctly, for coverage."""
+    app = FastAPI()
+
+    @app.get("/")
+    def main() -> dict[str, str]:
+        return {"message": "Hello World"}
+
+    @app.on_event("shutdown")
+    async def app_shutdown() -> None:
+        state.app_shutdown = True
+
+    assert state.app_shutdown is False
+    with TestClient(app) as client:
+        assert state.app_shutdown is False
+        response = client.get("/")
+        assert response.status_code == 200, response.text
+    assert state.app_shutdown is True
+
+
+def test_router_sync_generator_lifespan(state: State) -> None:
+    """Test that a sync generator lifespan works via _wrap_gen_lifespan_context."""
+    from collections.abc import Generator
+
+    def lifespan(app: FastAPI) -> Generator[None, None, None]:
+        state.app_startup = True
+        yield
+        state.app_shutdown = True
+
+    app = FastAPI(lifespan=lifespan)  # type: ignore[arg-type]
+
+    @app.get("/")
+    def main() -> dict[str, str]:
+        return {"message": "Hello World"}
+
+    assert state.app_startup is False
+    assert state.app_shutdown is False
+    with TestClient(app) as client:
+        assert state.app_startup is True
+        assert state.app_shutdown is False
+        response = client.get("/")
+        assert response.status_code == 200, response.text
+        assert response.json() == {"message": "Hello World"}
+    assert state.app_startup is True
+    assert state.app_shutdown is True
+
+
+def test_router_async_generator_lifespan(state: State) -> None:
+    """Test that an async generator lifespan (not wrapped) works."""
+
+    async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
+        state.app_startup = True
+        yield
+        state.app_shutdown = True
+
+    app = FastAPI(lifespan=lifespan)  # type: ignore[arg-type]
+
+    @app.get("/")
+    def main() -> dict[str, str]:
+        return {"message": "Hello World"}
+
+    assert state.app_startup is False
+    assert state.app_shutdown is False
+    with TestClient(app) as client:
+        assert state.app_startup is True
+        assert state.app_shutdown is False
+        response = client.get("/")
+        assert response.status_code == 200, response.text
+        assert response.json() == {"message": "Hello World"}
+    assert state.app_startup is True
+    assert state.app_shutdown is True