]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
🐛 Ensure that `app.include_router` merges nested lifespans (#9630)
authorPastukhov Nikita <diementros@yandex.ru>
Sat, 24 Aug 2024 19:09:52 +0000 (22:09 +0300)
committerGitHub <noreply@github.com>
Sat, 24 Aug 2024 19:09:52 +0000 (14:09 -0500)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
fastapi/routing.py
tests/test_router_events.py

index 2e7959f3dfcc02f31493878e76b89017b575a4aa..49f1b60138f3de9a36567a155a6d2799e0249f67 100644 (file)
@@ -3,14 +3,16 @@ import dataclasses
 import email.message
 import inspect
 import json
-from contextlib import AsyncExitStack
+from contextlib import AsyncExitStack, asynccontextmanager
 from enum import Enum, IntEnum
 from typing import (
     Any,
+    AsyncIterator,
     Callable,
     Coroutine,
     Dict,
     List,
+    Mapping,
     Optional,
     Sequence,
     Set,
@@ -67,7 +69,7 @@ from starlette.routing import (
     websocket_session,
 )
 from starlette.routing import Mount as Mount  # noqa
-from starlette.types import ASGIApp, Lifespan, Scope
+from starlette.types import AppType, ASGIApp, Lifespan, Scope
 from starlette.websockets import WebSocket
 from typing_extensions import Annotated, Doc, deprecated
 
@@ -119,6 +121,23 @@ def _prepare_response_content(
     return res
 
 
+def _merge_lifespan_context(
+    original_context: Lifespan[Any], nested_context: Lifespan[Any]
+) -> Lifespan[Any]:
+    @asynccontextmanager
+    async def merged_lifespan(
+        app: AppType,
+    ) -> AsyncIterator[Optional[Mapping[str, Any]]]:
+        async with original_context(app) as maybe_original_state:
+            async with nested_context(app) as maybe_nested_state:
+                if maybe_nested_state is None and maybe_original_state is None:
+                    yield None  # old ASGI compatibility
+                else:
+                    yield {**(maybe_nested_state or {}), **(maybe_original_state or {})}
+
+    return merged_lifespan  # type: ignore[return-value]
+
+
 async def serialize_response(
     *,
     field: Optional[ModelField] = None,
@@ -1308,6 +1327,10 @@ class APIRouter(routing.Router):
             self.add_event_handler("startup", handler)
         for handler in router.on_shutdown:
             self.add_event_handler("shutdown", handler)
+        self.lifespan_context = _merge_lifespan_context(
+            self.lifespan_context,
+            router.lifespan_context,
+        )
 
     def get(
         self,
index 1b9de18aea26c7d333bee17ed6a1be493187f7bc..dd7ff3314b8777d70062036f16b42a4c6c9796d4 100644 (file)
@@ -1,8 +1,8 @@
 from contextlib import asynccontextmanager
-from typing import AsyncGenerator, Dict
+from typing import AsyncGenerator, Dict, Union
 
 import pytest
-from fastapi import APIRouter, FastAPI
+from fastapi import APIRouter, FastAPI, Request
 from fastapi.testclient import TestClient
 from pydantic import BaseModel
 
@@ -109,3 +109,134 @@ def test_app_lifespan_state(state: State) -> None:
         assert response.json() == {"message": "Hello World"}
     assert state.app_startup is True
     assert state.app_shutdown is True
+
+
+def test_router_nested_lifespan_state(state: State) -> None:
+    @asynccontextmanager
+    async def lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, bool], None]:
+        state.app_startup = True
+        yield {"app": True}
+        state.app_shutdown = True
+
+    @asynccontextmanager
+    async def router_lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, bool], None]:
+        state.router_startup = True
+        yield {"router": True}
+        state.router_shutdown = True
+
+    @asynccontextmanager
+    async def subrouter_lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, bool], None]:
+        state.sub_router_startup = True
+        yield {"sub_router": True}
+        state.sub_router_shutdown = True
+
+    sub_router = APIRouter(lifespan=subrouter_lifespan)
+
+    router = APIRouter(lifespan=router_lifespan)
+    router.include_router(sub_router)
+
+    app = FastAPI(lifespan=lifespan)
+    app.include_router(router)
+
+    @app.get("/")
+    def main(request: Request) -> Dict[str, str]:
+        assert request.state.app
+        assert request.state.router
+        assert request.state.sub_router
+        return {"message": "Hello World"}
+
+    assert state.app_startup is False
+    assert state.router_startup is False
+    assert state.sub_router_startup is False
+    assert state.app_shutdown is False
+    assert state.router_shutdown is False
+    assert state.sub_router_shutdown is False
+
+    with TestClient(app) as client:
+        assert state.app_startup is True
+        assert state.router_startup is True
+        assert state.sub_router_startup is True
+        assert state.app_shutdown is False
+        assert state.router_shutdown is False
+        assert state.sub_router_shutdown is False
+        response = client.get("/")
+        assert response.status_code == 200, response.text
+        assert response.json() == {"message": "Hello World"}
+
+    assert state.app_startup is True
+    assert state.router_startup is True
+    assert state.sub_router_startup is True
+    assert state.app_shutdown is True
+    assert state.router_shutdown is True
+    assert state.sub_router_shutdown is True
+
+
+def test_router_nested_lifespan_state_overriding_by_parent() -> None:
+    @asynccontextmanager
+    async def lifespan(
+        app: FastAPI,
+    ) -> AsyncGenerator[Dict[str, Union[str, bool]], None]:
+        yield {
+            "app_specific": True,
+            "overridden": "app",
+        }
+
+    @asynccontextmanager
+    async def router_lifespan(
+        app: FastAPI,
+    ) -> AsyncGenerator[Dict[str, Union[str, bool]], None]:
+        yield {
+            "router_specific": True,
+            "overridden": "router",  # should override parent
+        }
+
+    router = APIRouter(lifespan=router_lifespan)
+    app = FastAPI(lifespan=lifespan)
+    app.include_router(router)
+
+    with TestClient(app) as client:
+        assert client.app_state == {
+            "app_specific": True,
+            "router_specific": True,
+            "overridden": "app",
+        }
+
+
+def test_merged_no_return_lifespans_return_none() -> None:
+    @asynccontextmanager
+    async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
+        yield
+
+    @asynccontextmanager
+    async def router_lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
+        yield
+
+    router = APIRouter(lifespan=router_lifespan)
+    app = FastAPI(lifespan=lifespan)
+    app.include_router(router)
+
+    with TestClient(app) as client:
+        assert not client.app_state
+
+
+def test_merged_mixed_state_lifespans() -> None:
+    @asynccontextmanager
+    async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
+        yield
+
+    @asynccontextmanager
+    async def router_lifespan(app: FastAPI) -> AsyncGenerator[Dict[str, bool], None]:
+        yield {"router": True}
+
+    @asynccontextmanager
+    async def sub_router_lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
+        yield
+
+    sub_router = APIRouter(lifespan=sub_router_lifespan)
+    router = APIRouter(lifespan=router_lifespan)
+    app = FastAPI(lifespan=lifespan)
+    router.include_router(sub_router)
+    app.include_router(router)
+
+    with TestClient(app) as client:
+        assert client.app_state == {"router": True}