def _get_api_route_for_openapi(
- route: BaseRoute, route_context: routing._EffectiveRouteContext | None
+ route_context: routing.RouteContext,
) -> routing._APIRouteLike | None:
- if route_context is not None and isinstance(
- route_context.original_route, routing.APIRoute
- ):
+ if isinstance(route_context.original_route, routing.APIRoute):
return cast(routing._APIRouteLike, route_context)
- if isinstance(route, routing.APIRoute):
- return cast(routing._APIRouteLike, route)
return None
def get_fields_from_routes(
- routes: Sequence[BaseRoute],
+ routes: Sequence[BaseRoute | routing.RouteContext],
) -> list[ModelField]:
body_fields_from_routes: list[ModelField] = []
responses_from_routes: list[ModelField] = []
request_fields_from_routes: list[ModelField] = []
callback_flat_models: list[ModelField] = []
- for route, route_context in routing._iter_routes_with_context(routes):
- api_route = _get_api_route_for_openapi(route, route_context)
+ for route_context in routing.iter_route_contexts(routes):
+ api_route = _get_api_route_for_openapi(route_context)
if api_route is None:
continue
if api_route.include_in_schema:
openapi_version: str = "3.1.0",
summary: str | None = None,
description: str | None = None,
- routes: Sequence[BaseRoute],
- webhooks: Sequence[BaseRoute] | None = None,
+ routes: Sequence[BaseRoute | routing.RouteContext],
+ webhooks: Sequence[BaseRoute | routing.RouteContext] | None = None,
tags: list[dict[str, Any]] | None = None,
servers: list[dict[str, str | Any]] | None = None,
terms_of_service: str | None = None,
model_name_map=model_name_map,
separate_input_output_schemas=separate_input_output_schemas,
)
- for route, route_context in routing._iter_routes_with_context(routes):
- api_route = _get_api_route_for_openapi(route, route_context)
+ for route_context in routing.iter_route_contexts(routes):
+ api_route = _get_api_route_for_openapi(route_context)
if api_route is not None:
result = get_openapi_path(
route=api_route,
)
if path_definitions:
definitions.update(path_definitions)
- for webhook, webhook_context in routing._iter_routes_with_context(webhooks or []):
- api_webhook = _get_api_route_for_openapi(webhook, webhook_context)
+ for webhook_context in routing.iter_route_contexts(webhooks or []):
+ api_webhook = _get_api_route_for_openapi(webhook_context)
if api_webhook is not None:
result = get_openapi_path(
route=api_webhook,
return URLPath(path=path, protocol="http")
+@dataclass(frozen=True)
+class RouteContext:
+ route: BaseRoute
+ _route_context: _EffectiveRouteContext | None = field(default=None, repr=False)
+
+ @property
+ def original_route(self) -> BaseRoute:
+ if self._route_context is not None:
+ return self._route_context.original_route
+ return self.route
+
+ @property
+ def _effective_route(self) -> BaseRoute | _EffectiveRouteContext:
+ if self._route_context is not None:
+ return self._route_context
+ return self.route
+
+ @property
+ def path(self) -> str | None:
+ return getattr(self._effective_route, "path", None)
+
+ @property
+ def path_format(self) -> str | None:
+ return getattr(self._effective_route, "path_format", None)
+
+ @property
+ def name(self) -> str | None:
+ return getattr(self._effective_route, "name", None)
+
+ @property
+ def methods(self) -> set[str] | None:
+ return getattr(self._effective_route, "methods", None)
+
+ @property
+ def endpoint(self) -> Callable[..., Any] | None:
+ return getattr(self._effective_route, "endpoint", None)
+
+ def __getattr__(self, name: str) -> Any:
+ return getattr(self._effective_route, name)
+
+
@dataclass
class _IncludedRouter(BaseRoute):
original_router: "APIRouter"
yield route
+def iter_route_contexts(
+ routes: Sequence[BaseRoute | RouteContext],
+) -> Iterator[RouteContext]:
+ for route in routes:
+ if isinstance(route, RouteContext):
+ yield route
+ continue
+ for original_route, route_context in _iter_routes_with_context([route]):
+ if route_context is None:
+ yield RouteContext(original_route)
+ else:
+ yield RouteContext(original_route, route_context)
+
+
def _iter_routes_with_context(
routes: Sequence[BaseRoute],
) -> Iterator[tuple[BaseRoute, _EffectiveRouteContext | None]]:
from typing import Annotated, cast
import pytest
-from fastapi import APIRouter, Body, Depends, FastAPI, Request
+from fastapi import APIRouter, Body, Depends, FastAPI, Request, Security
from fastapi.exceptions import FastAPIError
+from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse
from fastapi.routing import (
APIRoute,
+ RouteContext,
_IncludedRouter,
_iter_included_route_candidates,
_restore_fastapi_scope_key,
+ iter_route_contexts,
)
+from fastapi.security import HTTPBearer
from fastapi.testclient import TestClient
+from pydantic import BaseModel
from starlette.routing import BaseRoute, Host, Match, Mount, NoMatchFound, Route, Router
return f"b_{route.name}"
+def test_iter_route_contexts_returns_direct_route_context():
+ router = APIRouter()
+
+ @router.get("/items/{item_id}")
+ def read_item(item_id: str): # pragma: no cover
+ return {"item_id": item_id}
+
+ contexts = list(iter_route_contexts(router.routes))
+
+ assert len(contexts) == 1
+ assert isinstance(contexts[0], RouteContext)
+ assert contexts[0].original_route is router.routes[0]
+ assert contexts[0].path == "/items/{item_id}"
+ assert contexts[0].path_format == "/items/{item_id}"
+ assert contexts[0].methods == {"GET"}
+ assert contexts[0].endpoint is read_item
+
+
+def test_iter_route_contexts_supports_nested_conflict_detection():
+ existing_router = APIRouter()
+ nested_router = APIRouter()
+
+ @nested_router.get("/{username}")
+ def read_user(username: str): # pragma: no cover
+ return {"username": username}
+
+ existing_router.include_router(nested_router, prefix="/auth/user")
+
+ new_router = APIRouter()
+
+ @new_router.get("/auth/user/{username}")
+ def read_user_again(username: str): # pragma: no cover
+ return {"username": username}
+
+ existing_paths = {
+ context.path for context in iter_route_contexts(existing_router.routes)
+ }
+ new_paths = {context.path for context in iter_route_contexts(new_router.routes)}
+
+ assert existing_paths & new_paths == {"/auth/user/{username}"}
+
+
+def test_get_openapi_accepts_filtered_route_contexts_with_effective_paths():
+ router = APIRouter()
+ bearer_scheme = HTTPBearer()
+
+ @router.get("/public", tags=["public"])
+ def read_public(token: Annotated[str, Security(bearer_scheme)]): # pragma: no cover
+ return {"public": True}
+
+ @router.get("/private", tags=["private"])
+ def read_private(): # pragma: no cover
+ return {"private": True}
+
+ app = FastAPI()
+ app.include_router(router, prefix="/api")
+
+ public_routes = [
+ context
+ for context in iter_route_contexts(app.routes)
+ if "public" in getattr(context, "tags", [])
+ ]
+ schema = get_openapi(
+ title="Public API",
+ version="1.0.0",
+ routes=public_routes,
+ )
+
+ assert set(schema["paths"]) == {"/api/public"}
+ assert "HTTPBearer" in schema["components"]["securitySchemes"]
+
+
+def test_get_openapi_accepts_webhook_route_contexts():
+ app = FastAPI()
+ bearer_scheme = HTTPBearer()
+
+ class Subscription(BaseModel):
+ username: str
+
+ @app.webhooks.post("new-subscription")
+ def new_subscription(
+ body: Subscription, token: Annotated[str, Security(bearer_scheme)]
+ ): # pragma: no cover
+ return None
+
+ webhook_contexts = list(iter_route_contexts(app.webhooks.routes))
+ schema = get_openapi(
+ title="Webhook API",
+ version="1.0.0",
+ routes=[],
+ webhooks=webhook_contexts,
+ )
+
+ assert set(schema["webhooks"]) == {"new-subscription"}
+ assert "HTTPBearer" in schema["components"]["securitySchemes"]
+ assert "Subscription" in schema["components"]["schemas"]
+
+
def test_router_include_context_matches_flattened_include_metadata():
callback_router = APIRouter()