]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
✨ Add `iter_route_contexts()` for advanced use cases that used to use `router.routes...
authorSebastián Ramírez <tiangolo@gmail.com>
Thu, 18 Jun 2026 06:49:38 +0000 (08:49 +0200)
committerGitHub <noreply@github.com>
Thu, 18 Jun 2026 06:49:38 +0000 (08:49 +0200)
fastapi/openapi/utils.py
fastapi/routing.py
tests/test_router_include_context.py

index ab4543d346a6a4fff304fb870a3814d0531bd1ac..2e0aca118788b22d5556cd2e6f9df7d95423bce1 100644 (file)
@@ -479,26 +479,22 @@ def get_openapi_path(
 
 
 def _get_api_route_for_openapi(
-    route: BaseRoute, route_context: routing._EffectiveRouteContext | None
+    route_context: routing.RouteContext,
 ) -> routing._APIRouteLike | None:
-    if route_context is not None and isinstance(
-        route_context.original_route, routing.APIRoute
-    ):
+    if isinstance(route_context.original_route, routing.APIRoute):
         return cast(routing._APIRouteLike, route_context)
-    if isinstance(route, routing.APIRoute):
-        return cast(routing._APIRouteLike, route)
     return None
 
 
 def get_fields_from_routes(
-    routes: Sequence[BaseRoute],
+    routes: Sequence[BaseRoute | routing.RouteContext],
 ) -> list[ModelField]:
     body_fields_from_routes: list[ModelField] = []
     responses_from_routes: list[ModelField] = []
     request_fields_from_routes: list[ModelField] = []
     callback_flat_models: list[ModelField] = []
-    for route, route_context in routing._iter_routes_with_context(routes):
-        api_route = _get_api_route_for_openapi(route, route_context)
+    for route_context in routing.iter_route_contexts(routes):
+        api_route = _get_api_route_for_openapi(route_context)
         if api_route is None:
             continue
         if api_route.include_in_schema:
@@ -531,8 +527,8 @@ def get_openapi(
     openapi_version: str = "3.1.0",
     summary: str | None = None,
     description: str | None = None,
-    routes: Sequence[BaseRoute],
-    webhooks: Sequence[BaseRoute] | None = None,
+    routes: Sequence[BaseRoute | routing.RouteContext],
+    webhooks: Sequence[BaseRoute | routing.RouteContext] | None = None,
     tags: list[dict[str, Any]] | None = None,
     servers: list[dict[str, str | Any]] | None = None,
     terms_of_service: str | None = None,
@@ -567,8 +563,8 @@ def get_openapi(
         model_name_map=model_name_map,
         separate_input_output_schemas=separate_input_output_schemas,
     )
-    for route, route_context in routing._iter_routes_with_context(routes):
-        api_route = _get_api_route_for_openapi(route, route_context)
+    for route_context in routing.iter_route_contexts(routes):
+        api_route = _get_api_route_for_openapi(route_context)
         if api_route is not None:
             result = get_openapi_path(
                 route=api_route,
@@ -587,8 +583,8 @@ def get_openapi(
                     )
                 if path_definitions:
                     definitions.update(path_definitions)
-    for webhook, webhook_context in routing._iter_routes_with_context(webhooks or []):
-        api_webhook = _get_api_route_for_openapi(webhook, webhook_context)
+    for webhook_context in routing.iter_route_contexts(webhooks or []):
+        api_webhook = _get_api_route_for_openapi(webhook_context)
         if api_webhook is not None:
             result = get_openapi_path(
                 route=api_webhook,
index 48c0c21535ee561b38399bec321313bf7bebd83c..4a55fda8a8922c73a340f8b919eba19ff0013ed4 100644 (file)
@@ -1454,6 +1454,47 @@ class _EffectiveRouteContext:
         return URLPath(path=path, protocol="http")
 
 
+@dataclass(frozen=True)
+class RouteContext:
+    route: BaseRoute
+    _route_context: _EffectiveRouteContext | None = field(default=None, repr=False)
+
+    @property
+    def original_route(self) -> BaseRoute:
+        if self._route_context is not None:
+            return self._route_context.original_route
+        return self.route
+
+    @property
+    def _effective_route(self) -> BaseRoute | _EffectiveRouteContext:
+        if self._route_context is not None:
+            return self._route_context
+        return self.route
+
+    @property
+    def path(self) -> str | None:
+        return getattr(self._effective_route, "path", None)
+
+    @property
+    def path_format(self) -> str | None:
+        return getattr(self._effective_route, "path_format", None)
+
+    @property
+    def name(self) -> str | None:
+        return getattr(self._effective_route, "name", None)
+
+    @property
+    def methods(self) -> set[str] | None:
+        return getattr(self._effective_route, "methods", None)
+
+    @property
+    def endpoint(self) -> Callable[..., Any] | None:
+        return getattr(self._effective_route, "endpoint", None)
+
+    def __getattr__(self, name: str) -> Any:
+        return getattr(self._effective_route, name)
+
+
 @dataclass
 class _IncludedRouter(BaseRoute):
     original_router: "APIRouter"
@@ -1654,6 +1695,20 @@ def _iter_included_route_candidates(routes: Sequence[BaseRoute]) -> Iterator[Bas
             yield route
 
 
+def iter_route_contexts(
+    routes: Sequence[BaseRoute | RouteContext],
+) -> Iterator[RouteContext]:
+    for route in routes:
+        if isinstance(route, RouteContext):
+            yield route
+            continue
+        for original_route, route_context in _iter_routes_with_context([route]):
+            if route_context is None:
+                yield RouteContext(original_route)
+            else:
+                yield RouteContext(original_route, route_context)
+
+
 def _iter_routes_with_context(
     routes: Sequence[BaseRoute],
 ) -> Iterator[tuple[BaseRoute, _EffectiveRouteContext | None]]:
index c2679aa1176d77adb9250dc154752066bf77d336..cb8dc81fa913c2632f3571469772306c1141803f 100644 (file)
@@ -1,16 +1,21 @@
 from typing import Annotated, cast
 
 import pytest
-from fastapi import APIRouter, Body, Depends, FastAPI, Request
+from fastapi import APIRouter, Body, Depends, FastAPI, Request, Security
 from fastapi.exceptions import FastAPIError
+from fastapi.openapi.utils import get_openapi
 from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse
 from fastapi.routing import (
     APIRoute,
+    RouteContext,
     _IncludedRouter,
     _iter_included_route_candidates,
     _restore_fastapi_scope_key,
+    iter_route_contexts,
 )
+from fastapi.security import HTTPBearer
 from fastapi.testclient import TestClient
+from pydantic import BaseModel
 from starlette.routing import BaseRoute, Host, Match, Mount, NoMatchFound, Route, Router
 
 
@@ -30,6 +35,104 @@ def unique_id_b(route: APIRoute) -> str:
     return f"b_{route.name}"
 
 
+def test_iter_route_contexts_returns_direct_route_context():
+    router = APIRouter()
+
+    @router.get("/items/{item_id}")
+    def read_item(item_id: str):  # pragma: no cover
+        return {"item_id": item_id}
+
+    contexts = list(iter_route_contexts(router.routes))
+
+    assert len(contexts) == 1
+    assert isinstance(contexts[0], RouteContext)
+    assert contexts[0].original_route is router.routes[0]
+    assert contexts[0].path == "/items/{item_id}"
+    assert contexts[0].path_format == "/items/{item_id}"
+    assert contexts[0].methods == {"GET"}
+    assert contexts[0].endpoint is read_item
+
+
+def test_iter_route_contexts_supports_nested_conflict_detection():
+    existing_router = APIRouter()
+    nested_router = APIRouter()
+
+    @nested_router.get("/{username}")
+    def read_user(username: str):  # pragma: no cover
+        return {"username": username}
+
+    existing_router.include_router(nested_router, prefix="/auth/user")
+
+    new_router = APIRouter()
+
+    @new_router.get("/auth/user/{username}")
+    def read_user_again(username: str):  # pragma: no cover
+        return {"username": username}
+
+    existing_paths = {
+        context.path for context in iter_route_contexts(existing_router.routes)
+    }
+    new_paths = {context.path for context in iter_route_contexts(new_router.routes)}
+
+    assert existing_paths & new_paths == {"/auth/user/{username}"}
+
+
+def test_get_openapi_accepts_filtered_route_contexts_with_effective_paths():
+    router = APIRouter()
+    bearer_scheme = HTTPBearer()
+
+    @router.get("/public", tags=["public"])
+    def read_public(token: Annotated[str, Security(bearer_scheme)]):  # pragma: no cover
+        return {"public": True}
+
+    @router.get("/private", tags=["private"])
+    def read_private():  # pragma: no cover
+        return {"private": True}
+
+    app = FastAPI()
+    app.include_router(router, prefix="/api")
+
+    public_routes = [
+        context
+        for context in iter_route_contexts(app.routes)
+        if "public" in getattr(context, "tags", [])
+    ]
+    schema = get_openapi(
+        title="Public API",
+        version="1.0.0",
+        routes=public_routes,
+    )
+
+    assert set(schema["paths"]) == {"/api/public"}
+    assert "HTTPBearer" in schema["components"]["securitySchemes"]
+
+
+def test_get_openapi_accepts_webhook_route_contexts():
+    app = FastAPI()
+    bearer_scheme = HTTPBearer()
+
+    class Subscription(BaseModel):
+        username: str
+
+    @app.webhooks.post("new-subscription")
+    def new_subscription(
+        body: Subscription, token: Annotated[str, Security(bearer_scheme)]
+    ):  # pragma: no cover
+        return None
+
+    webhook_contexts = list(iter_route_contexts(app.webhooks.routes))
+    schema = get_openapi(
+        title="Webhook API",
+        version="1.0.0",
+        routes=[],
+        webhooks=webhook_contexts,
+    )
+
+    assert set(schema["webhooks"]) == {"new-subscription"}
+    assert "HTTPBearer" in schema["components"]["securitySchemes"]
+    assert "Subscription" in schema["components"]["schemas"]
+
+
 def test_router_include_context_matches_flattened_include_metadata():
     callback_router = APIRouter()