]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
♻️ Refactor include_router to mount sub-routers
authorSebastián Ramírez <tiangolo@gmail.com>
Wed, 13 Apr 2022 11:01:42 +0000 (13:01 +0200)
committerSebastián Ramírez <tiangolo@gmail.com>
Wed, 13 Apr 2022 11:01:42 +0000 (13:01 +0200)
fastapi/applications.py
fastapi/openapi/utils.py
fastapi/routing.py
fastapi/utils.py
tests/test_custom_route_class.py

index 132a94c9a1c7cfdab758309d9f7edb57cce0fb9f..1f60d02855f21425c3298020623945eccb7fa8a2 100644 (file)
@@ -137,6 +137,10 @@ class FastAPI(Starlette):
         self.middleware_stack: ASGIApp = self.build_middleware_stack()
         self.setup()
 
+    @property
+    def routes(self) -> List[BaseRoute]:
+        return list(self.router.iter_all_routes())
+
     def build_middleware_stack(self) -> ASGIApp:
         # Duplicate/override from Starlette to add AsyncExitStackMiddleware
         # inside of ExceptionMiddleware, inside of custom user middlewares
index 58a748d04919facccfeed8e5f74a460d1633ba1b..bba0c4bcb4c06962dd37705df73524d240ce7664 100644 (file)
@@ -152,7 +152,7 @@ def generate_operation_id(
     )
     if route.operation_id:
         return route.operation_id
-    path: str = route.path_format
+    path: str = route._route_full_path_format
     return generate_operation_id_for_path(name=route.name, path=path, method=method)
 
 
@@ -243,7 +243,7 @@ def get_openapi_path(
                             model_name_map=model_name_map,
                             operation_ids=operation_ids,
                         )
-                        callbacks[callback.name] = {callback.path: cb_path}
+                        callbacks[callback.name] = {callback._route_full_path: cb_path}
                 operation["callbacks"] = callbacks
             if route.status_code is not None:
                 status_code = str(route.status_code)
@@ -422,7 +422,7 @@ def get_openapi(
             if result:
                 path, security_schemes, path_definitions = result
                 if path:
-                    paths.setdefault(route.path_format, {}).update(path)
+                    paths.setdefault(route._route_full_path_format, {}).update(path)
                 if security_schemes:
                     components.setdefault("securitySchemes", {}).update(
                         security_schemes
index 0f416ac42e1df52e716c4a7be862bd0271d2b23f..8f27656d8b8cbd2d6d7aec43bfbdcbab4027810e 100644 (file)
@@ -9,12 +9,14 @@ from typing import (
     Callable,
     Coroutine,
     Dict,
+    Iterator,
     List,
     Optional,
     Sequence,
     Set,
     Tuple,
     Type,
+    TypeVar,
     Union,
 )
 
@@ -57,6 +59,10 @@ from starlette.status import WS_1008_POLICY_VIOLATION
 from starlette.types import ASGIApp, Scope
 from starlette.websockets import WebSocket
 
+APIRouteType = TypeVar("APIRouteType", bound="APIRoute")
+APIRouterType = TypeVar("APIRouterType", bound="APIRouter")
+APIMountType = TypeVar("APIMountType", bound="APIMount")
+
 
 def _prepare_response_content(
     res: Any,
@@ -338,13 +344,13 @@ class APIRoute(routing.Route):
         generate_unique_id_function: Union[
             Callable[["APIRoute"], str], DefaultPlaceholder
         ] = Default(generate_unique_id),
+        router: Optional["APIRouter"] = None,
     ) -> None:
         self.path = path
         self.endpoint = endpoint
         self.response_model = response_model
         self.summary = summary
         self.response_description = response_description
-        self.deprecated = deprecated
         self.operation_id = operation_id
         self.response_model_include = response_model_include
         self.response_model_exclude = response_model_exclude
@@ -352,34 +358,128 @@ class APIRoute(routing.Route):
         self.response_model_exclude_unset = response_model_exclude_unset
         self.response_model_exclude_defaults = response_model_exclude_defaults
         self.response_model_exclude_none = response_model_exclude_none
-        self.include_in_schema = include_in_schema
-        self.response_class = response_class
         self.dependency_overrides_provider = dependency_overrides_provider
-        self.callbacks = callbacks
         self.openapi_extra = openapi_extra
-        self.generate_unique_id_function = generate_unique_id_function
-        self.tags = tags or []
-        self.responses = responses or {}
+        self.router = router
+
         self.name = get_name(endpoint) if name is None else name
-        self.path_regex, self.path_format, self.param_convertors = compile_path(path)
+        # normalize enums e.g. http.HTTPStatus
+        if isinstance(status_code, IntEnum):
+            status_code = int(status_code)
+        self.status_code = status_code
         if methods is None:
             methods = ["GET"]
         self.methods: Set[str] = set([method.upper() for method in methods])
-        if isinstance(generate_unique_id_function, DefaultPlaceholder):
-            current_generate_unique_id: Callable[
+
+        self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
+        # if a "form feed" character (page break) is found in the description text,
+        # truncate description text to the content preceding the first "form feed"
+        self.description = self.description.split("\f")[0]
+
+        assert callable(endpoint), "An endpoint must be a callable"
+
+        self.path_regex, self.path_format, self.param_convertors = compile_path(
+            self.path
+        )
+
+        # Attributes set in route used to compute resolved attributes
+        self._route_deprecated = deprecated
+        self._route_include_in_schema = include_in_schema
+        self._route_response_class = response_class
+        self._route_callbacks = callbacks
+        self._route_generate_unique_id_function = generate_unique_id_function
+        self._route_tags = tags or []
+        self._route_responses = responses or {}
+        if dependencies:
+            self._route_dependencies = dependencies
+        else:
+            self._route_dependencies = []
+
+        self.setup()
+
+    def setup(self) -> None:
+        # setup full path
+        self._route_full_path = self.path
+        if self.router:
+            self._route_full_path = self.router._router_full_path + self.path
+
+        # setup dependencies
+        self.dependencies: List[params.Depends] = []
+        if self.router:
+            self.dependencies.extend(self.router.dependencies)
+        self.dependencies.extend(self._route_dependencies)
+
+        # setup generate_unique_id
+        generate_unique_id_functions: List[
+            Union[Callable[[APIRoute], str], DefaultPlaceholder]
+        ] = [self._route_generate_unique_id_function]
+        if self.router:
+            generate_unique_id_functions.append(self.router.generate_unique_id_function)
+        current_generate_unique_id_function = get_value_or_default(
+            *generate_unique_id_functions
+        )
+        self.generate_unique_id_function: Union[
+            Callable[[APIRoute], str], DefaultPlaceholder
+        ] = current_generate_unique_id_function
+
+        # setup responses
+        responses: Dict[Union[int, str], Dict[str, Any]] = {}
+        if self.router:
+            responses.update(self.router.responses)
+        responses.update(self._route_responses)
+        self.responses: Dict[Union[int, str], Dict[str, Any]] = responses
+
+        # setup default_response_class
+        default_response_classes: List[Union[Type[Response], DefaultPlaceholder]] = [
+            self._route_response_class
+        ]
+        if self.router:
+            default_response_classes.append(self.router.default_response_class)
+        current_default_response_class = get_value_or_default(*default_response_classes)
+        self.response_class: Union[
+            Type[Response], DefaultPlaceholder
+        ] = current_default_response_class
+
+        # setup tags
+        self.tags: List[Union[str, Enum]] = []
+        if self.router:
+            self.tags.extend(self.router.tags)
+        self.tags.extend(self._route_tags)
+
+        # setup callbacks
+        callbacks: List[BaseRoute] = []
+        if self.router:
+            callbacks.extend(self.router.callbacks)
+        if self._route_callbacks:
+            callbacks.extend(self._route_callbacks)
+        self.callbacks = callbacks
+
+        # setup deprecated
+        self.deprecated = self._route_deprecated
+        if self.router:
+            self.deprecated = self._route_deprecated or self.router.deprecated
+
+        # setup include_in_schema
+        self.include_in_schema = self._route_include_in_schema
+        if self.router:
+            self.include_in_schema = (
+                self._route_include_in_schema and self.router.include_in_schema
+            )
+
+        _, self._route_full_path_format, _ = compile_path(self._route_full_path)
+
+        if isinstance(self.generate_unique_id_function, DefaultPlaceholder):
+            resolved_generate_unique_id: Callable[
                 ["APIRoute"], str
-            ] = generate_unique_id_function.value
+            ] = self.generate_unique_id_function.value
         else:
-            current_generate_unique_id = generate_unique_id_function
-        self.unique_id = self.operation_id or current_generate_unique_id(self)
-        # normalize enums e.g. http.HTTPStatus
-        if isinstance(status_code, IntEnum):
-            status_code = int(status_code)
-        self.status_code = status_code
+            resolved_generate_unique_id = self.generate_unique_id_function
+        self.unique_id = self.operation_id or resolved_generate_unique_id(self)
+
         if self.response_model:
             assert (
-                status_code not in STATUS_CODES_WITH_NO_BODY
-            ), f"Status code {status_code} must not have a response body"
+                self.status_code not in STATUS_CODES_WITH_NO_BODY
+            ), f"Status code {self.status_code} must not have a response body"
             response_name = "Response_" + self.unique_id
             self.response_field = create_response_field(
                 name=response_name, type_=self.response_model
@@ -397,14 +497,7 @@ class APIRoute(routing.Route):
         else:
             self.response_field = None  # type: ignore
             self.secure_cloned_response_field = None
-        if dependencies:
-            self.dependencies = list(dependencies)
-        else:
-            self.dependencies = []
-        self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
-        # if a "form feed" character (page break) is found in the description text,
-        # truncate description text to the content preceding the first "form feed"
-        self.description = self.description.split("\f")[0]
+
         response_fields = {}
         for additional_status_code, response in self.responses.items():
             assert isinstance(response, dict), "An additional response must be a dict"
@@ -421,16 +514,50 @@ class APIRoute(routing.Route):
         else:
             self.response_fields = {}
 
-        assert callable(endpoint), "An endpoint must be a callable"
-        self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
+        self.dependant = get_dependant(
+            path=self._route_full_path_format, call=self.endpoint
+        )
         for depends in self.dependencies[::-1]:
             self.dependant.dependencies.insert(
                 0,
-                get_parameterless_sub_dependant(depends=depends, path=self.path_format),
+                get_parameterless_sub_dependant(
+                    depends=depends, path=self._route_full_path_format
+                ),
             )
         self.body_field = get_body_field(dependant=self.dependant, name=self.unique_id)
         self.app = request_response(self.get_route_handler())
 
+    def copy(self: APIRouteType) -> APIRouteType:
+        return type(self)(
+            path=self.path,
+            endpoint=self.endpoint,
+            response_model=self.response_model,
+            status_code=self.status_code,
+            tags=self._route_tags,
+            dependencies=self._route_dependencies,
+            summary=self.summary,
+            description=self.description,
+            response_description=self.response_description,
+            responses=self._route_responses,
+            deprecated=self._route_deprecated,
+            name=self.name,
+            methods=self.methods,
+            operation_id=self.operation_id,
+            response_model_include=self.response_model_include,
+            response_model_exclude=self.response_model_exclude,
+            response_model_by_alias=self.response_model_by_alias,
+            response_model_exclude_unset=self.response_model_exclude_unset,
+            response_model_exclude_defaults=self.response_model_exclude_defaults,
+            response_model_exclude_none=self.response_model_exclude_none,
+            include_in_schema=self._route_include_in_schema,
+            response_class=self._route_response_class,
+            dependency_overrides_provider=self.dependency_overrides_provider,
+            callbacks=self._route_callbacks,
+            openapi_extra=self.openapi_extra,
+            generate_unique_id_function=self._route_generate_unique_id_function,
+            router=self.router,
+        )
+
     def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
         return get_request_handler(
             dependant=self.dependant,
@@ -476,6 +603,7 @@ class APIRouter(routing.Router):
         generate_unique_id_function: Callable[[APIRoute], str] = Default(
             generate_unique_id
         ),
+        parent_router: Optional["APIRouter"] = None,
     ) -> None:
         super().__init__(
             routes=routes,  # type: ignore # in Starlette
@@ -490,16 +618,151 @@ class APIRouter(routing.Router):
                 "/"
             ), "A path prefix must not end with '/', as the routes will start with '/'"
         self.prefix = prefix
-        self.tags: List[Union[str, Enum]] = tags or []
-        self.dependencies = list(dependencies or []) or []
-        self.deprecated = deprecated
-        self.include_in_schema = include_in_schema
-        self.responses = responses or {}
-        self.callbacks = callbacks or []
         self.dependency_overrides_provider = dependency_overrides_provider
         self.route_class = route_class
-        self.default_response_class = default_response_class
-        self.generate_unique_id_function = generate_unique_id_function
+
+        self.parent_router = parent_router
+
+        # Attributes set in router used to compute resolved attributes
+        self._router_dependencies = list(dependencies or []) or []
+        self._router_generate_unique_id_function = generate_unique_id_function
+        self._router_responses = responses or {}
+        self._router_default_response_class = default_response_class
+        self._router_tags: List[Union[str, Enum]] = tags or []
+        self._router_callbacks = callbacks or []
+        self._router_deprecated = deprecated
+        self._router_include_in_schema = include_in_schema
+        self._router_has_empty_route = False
+        self._router_has_root_route = False
+        self.setup()
+
+    def setup(self) -> None:
+        # setup full path
+        self._router_full_path = self.prefix
+        if self.parent_router:
+            self._router_full_path = self.parent_router._router_full_path + self.prefix
+        # setup dependencies
+        self.dependencies: List[params.Depends] = []
+        if self.parent_router:
+            self.dependencies.extend(self.parent_router.dependencies)
+        self.dependencies.extend(self._router_dependencies)
+
+        # setup generate_unique_id
+        generate_unique_id_functions: List[
+            Union[Callable[[APIRoute], str], DefaultPlaceholder]
+        ] = [self._router_generate_unique_id_function]
+        if self.parent_router:
+            generate_unique_id_functions.append(
+                self.parent_router.generate_unique_id_function
+            )
+        current_generate_unique_id_function = get_value_or_default(
+            *generate_unique_id_functions
+        )
+        self.generate_unique_id_function: Union[
+            Callable[[APIRoute], str], DefaultPlaceholder
+        ] = current_generate_unique_id_function
+
+        # setup responses
+        responses: Dict[Union[int, str], Dict[str, Any]] = {}
+        if self.parent_router:
+            responses.update(self.parent_router.responses)
+        responses.update(self._router_responses)
+        self.responses: Dict[Union[int, str], Dict[str, Any]] = responses
+
+        # setup default_response_class
+        default_response_classes: List[Union[Type[Response], DefaultPlaceholder]] = [
+            self._router_default_response_class
+        ]
+        if self.parent_router:
+            default_response_classes.append(self.parent_router.default_response_class)
+        current_default_response_class = get_value_or_default(*default_response_classes)
+        self.default_response_class: Union[
+            Type[Response], DefaultPlaceholder
+        ] = current_default_response_class
+
+        # setup tags
+        self.tags: List[Union[str, Enum]] = []
+        if self.parent_router:
+            self.tags.extend(self.parent_router.tags)
+        self.tags.extend(self._router_tags)
+
+        # setup callbacks
+        self.callbacks: List[BaseRoute] = []
+        if self.parent_router:
+            self.callbacks.extend(self.parent_router.callbacks)
+        self.callbacks.extend(self._router_callbacks)
+
+        # setup deprecated
+        self.deprecated = self._router_deprecated
+        if self.parent_router:
+            self.deprecated = self._router_deprecated or self.parent_router.deprecated
+
+        # setup include_in_schema
+        self.include_in_schema = self._router_include_in_schema
+        if self.parent_router:
+            self.include_in_schema = (
+                self._router_include_in_schema and self.parent_router.include_in_schema
+            )
+
+        # setup routes
+        for route in self.routes:
+            if isinstance(route, APIRoute):
+                route.router = self
+                route.setup()
+            elif isinstance(route, APIMount):
+                route.parent_router = self
+                route.setup()
+
+    def copy(self: APIRouterType) -> APIRouterType:
+        routes: List[routing.BaseRoute] = []
+        for route in self.routes:
+            if isinstance(route, APIRoute):
+                routes.append(route.copy())
+            elif isinstance(route, APIMount):
+                routes.append(route.copy())
+            else:
+                routes.append(route)
+        copied_router = type(self)(
+            prefix=self.prefix,
+            tags=self._router_tags,
+            dependencies=self._router_dependencies,
+            default_response_class=self._router_default_response_class,
+            responses=self._router_responses,
+            callbacks=self._router_callbacks,
+            routes=routes,
+            redirect_slashes=self.redirect_slashes,
+            default=self.default,
+            dependency_overrides_provider=self.dependency_overrides_provider,
+            route_class=self.route_class,
+            on_startup=self.on_startup,
+            on_shutdown=self.on_shutdown,
+            deprecated=self._router_deprecated,
+            include_in_schema=self._router_include_in_schema,
+            generate_unique_id_function=self._router_generate_unique_id_function,
+            parent_router=self.parent_router,
+        )
+        copied_router._router_has_empty_route = self._router_has_empty_route
+        copied_router._router_has_root_route = self._router_has_root_route
+        for route in copied_router.routes:
+            if isinstance(route, APIRoute):
+                route.router = copied_router
+                route.setup()
+            elif isinstance(route, Mount):
+                if isinstance(route.app, APIRouter):
+                    route.app.setup()
+        return copied_router
+
+    def iter_all_routes(self) -> Iterator[routing.BaseRoute]:
+        for route in self.routes:
+            if isinstance(route, Mount):
+                if isinstance(route.app, APIRouter):
+                    yield from route.app.iter_all_routes()
+            else:
+                yield route
+
+    def api_mount(self, router: "APIRouter", name: Optional[str] = None) -> None:
+        route = APIMount(router=router, name=name, parent_router=self)
+        self.routes.append(route)
 
     def add_api_route(
         self,
@@ -537,34 +800,18 @@ class APIRouter(routing.Router):
     ) -> None:
         route_class = route_class_override or self.route_class
         responses = responses or {}
-        combined_responses = {**self.responses, **responses}
-        current_response_class = get_value_or_default(
-            response_class, self.default_response_class
-        )
-        current_tags = self.tags.copy()
-        if tags:
-            current_tags.extend(tags)
-        current_dependencies = self.dependencies.copy()
-        if dependencies:
-            current_dependencies.extend(dependencies)
-        current_callbacks = self.callbacks.copy()
-        if callbacks:
-            current_callbacks.extend(callbacks)
-        current_generate_unique_id = get_value_or_default(
-            generate_unique_id_function, self.generate_unique_id_function
-        )
         route = route_class(
-            self.prefix + path,
+            path,
             endpoint=endpoint,
             response_model=response_model,
             status_code=status_code,
-            tags=current_tags,
-            dependencies=current_dependencies,
+            tags=tags,
+            dependencies=dependencies,
             summary=summary,
             description=description,
             response_description=response_description,
-            responses=combined_responses,
-            deprecated=deprecated or self.deprecated,
+            responses=responses,
+            deprecated=deprecated,
             methods=methods,
             operation_id=operation_id,
             response_model_include=response_model_include,
@@ -573,15 +820,20 @@ class APIRouter(routing.Router):
             response_model_exclude_unset=response_model_exclude_unset,
             response_model_exclude_defaults=response_model_exclude_defaults,
             response_model_exclude_none=response_model_exclude_none,
-            include_in_schema=include_in_schema and self.include_in_schema,
-            response_class=current_response_class,
+            include_in_schema=include_in_schema,
+            response_class=response_class,
             name=name,
             dependency_overrides_provider=self.dependency_overrides_provider,
-            callbacks=current_callbacks,
+            callbacks=callbacks,
             openapi_extra=openapi_extra,
-            generate_unique_id_function=current_generate_unique_id,
+            generate_unique_id_function=generate_unique_id_function,
+            router=self,
         )
         self.routes.append(route)
+        if not path:
+            self._router_has_empty_route = True
+        if path == "/":
+            self._router_has_root_route = True
 
     def api_route(
         self,
@@ -680,103 +932,197 @@ class APIRouter(routing.Router):
         generate_unique_id_function: Callable[[APIRoute], str] = Default(
             generate_unique_id
         ),
+        copy_flat_routes: Optional[bool] = None,
     ) -> None:
         if prefix:
             assert prefix.startswith("/"), "A path prefix must start with '/'"
             assert not prefix.endswith(
                 "/"
             ), "A path prefix must not end with '/', as the routes will start with '/'"
+        resolved_copy_flat_routes = copy_flat_routes
+        if resolved_copy_flat_routes is None:
+            resolved_copy_flat_routes = not (prefix or router.prefix)
+        if not resolved_copy_flat_routes:
+            included_router = router.copy()
+            if (
+                prefix
+                or tags
+                or dependencies
+                or not isinstance(default_response_class, DefaultPlaceholder)
+                or responses
+                or callbacks
+                or deprecated is not None
+                or include_in_schema is not True
+                or not isinstance(generate_unique_id_function, DefaultPlaceholder)
+            ):
+                current_router = type(self)(
+                    prefix=prefix,
+                    tags=tags,
+                    dependencies=dependencies,
+                    default_response_class=default_response_class,
+                    responses=responses,
+                    callbacks=callbacks,
+                    deprecated=deprecated,
+                    include_in_schema=include_in_schema,
+                    generate_unique_id_function=generate_unique_id_function,
+                    parent_router=self,
+                )
+                # current_router.api_mount(included_router)
+                current_router.include_router(included_router)
+                if included_router._router_has_empty_route and not self.prefix:
+                    current_router._router_has_empty_route = True
+                    current_router._router_has_root_route = (
+                        included_router._router_has_root_route
+                    )
+                self.api_mount(current_router)
+                included_router.parent_router = current_router
+            else:
+                self.api_mount(included_router)
+                included_router.parent_router = self
+
+            included_router.setup()
         else:
-            for r in router.routes:
-                path = getattr(r, "path")
-                name = getattr(r, "name", "unknown")
-                if path is not None and not path:
-                    raise Exception(
-                        f"Prefix and path cannot be both empty (path operation: {name})"
+            # TODO: remove this and its test, as a subrouter can mount another
+            # subrouter (done automatically of other things are overwritten) and both
+            # can omit a prefix, this would error out
+            # for r in router.routes:
+            #     path = getattr(r, "path")
+            #     name = getattr(r, "name", "unknown")
+            #     if path is not None and not path:
+            #         raise Exception(
+            #             f"Prefix and path cannot be both empty (path operation: {name})"
+            #         )
+            if responses is None:
+                responses = {}
+            for route in router.routes:
+                if isinstance(route, APIRoute):
+                    combined_responses = {}
+                    if route.router:
+                        combined_responses.update(route.router.responses)
+                    combined_responses.update(responses)
+                    combined_responses.update(route.responses)
+
+                    response_classes: List[
+                        Union[Type[Response], DefaultPlaceholder]
+                    ] = []
+                    if route.router:
+                        response_classes.append(route.router.default_response_class)
+                    response_classes.extend(
+                        [
+                            route.response_class,
+                            router.default_response_class,
+                            default_response_class,
+                            self.default_response_class,
+                        ]
                     )
-        if responses is None:
-            responses = {}
-        for route in router.routes:
-            if isinstance(route, APIRoute):
-                combined_responses = {**responses, **route.responses}
-                use_response_class = get_value_or_default(
-                    route.response_class,
-                    router.default_response_class,
-                    default_response_class,
-                    self.default_response_class,
-                )
-                current_tags = []
-                if tags:
-                    current_tags.extend(tags)
-                if route.tags:
-                    current_tags.extend(route.tags)
-                current_dependencies: List[params.Depends] = []
-                if dependencies:
-                    current_dependencies.extend(dependencies)
-                if route.dependencies:
-                    current_dependencies.extend(route.dependencies)
-                current_callbacks = []
-                if callbacks:
-                    current_callbacks.extend(callbacks)
-                if route.callbacks:
-                    current_callbacks.extend(route.callbacks)
-                current_generate_unique_id = get_value_or_default(
-                    route.generate_unique_id_function,
-                    router.generate_unique_id_function,
-                    generate_unique_id_function,
-                    self.generate_unique_id_function,
-                )
-                self.add_api_route(
-                    prefix + route.path,
-                    route.endpoint,
-                    response_model=route.response_model,
-                    status_code=route.status_code,
-                    tags=current_tags,
-                    dependencies=current_dependencies,
-                    summary=route.summary,
-                    description=route.description,
-                    response_description=route.response_description,
-                    responses=combined_responses,
-                    deprecated=route.deprecated or deprecated or self.deprecated,
-                    methods=route.methods,
-                    operation_id=route.operation_id,
-                    response_model_include=route.response_model_include,
-                    response_model_exclude=route.response_model_exclude,
-                    response_model_by_alias=route.response_model_by_alias,
-                    response_model_exclude_unset=route.response_model_exclude_unset,
-                    response_model_exclude_defaults=route.response_model_exclude_defaults,
-                    response_model_exclude_none=route.response_model_exclude_none,
-                    include_in_schema=route.include_in_schema
-                    and self.include_in_schema
-                    and include_in_schema,
-                    response_class=use_response_class,
-                    name=route.name,
-                    route_class_override=type(route),
-                    callbacks=current_callbacks,
-                    openapi_extra=route.openapi_extra,
-                    generate_unique_id_function=current_generate_unique_id,
-                )
-            elif isinstance(route, routing.Route):
-                methods = list(route.methods or [])  # type: ignore # in Starlette
-                self.add_route(
-                    prefix + route.path,
-                    route.endpoint,
-                    methods=methods,
-                    include_in_schema=route.include_in_schema,
-                    name=route.name,
-                )
-            elif isinstance(route, APIWebSocketRoute):
-                self.add_api_websocket_route(
-                    prefix + route.path, route.endpoint, name=route.name
-                )
-            elif isinstance(route, routing.WebSocketRoute):
-                self.add_websocket_route(
-                    prefix + route.path, route.endpoint, name=route.name
-                )
-        for handler in router.on_startup:
-            self.add_event_handler("startup", handler)
-        for handler in router.on_shutdown:
-            self.add_event_handler("shutdown", handler)
+                    use_response_class = get_value_or_default(*response_classes)
+                    current_tags = []
+                    if route.router:
+                        current_tags.extend(route.router.tags)
+                    if tags:
+                        current_tags.extend(tags)
+                    if route.tags:
+                        current_tags.extend(route.tags)
+                    current_dependencies: List[params.Depends] = []
+                    if route.router:
+                        current_dependencies.extend(route.router.dependencies)
+                    if dependencies:
+                        current_dependencies.extend(dependencies)
+                    if route.dependencies:
+                        current_dependencies.extend(route.dependencies)
+                    current_callbacks = []
+                    if route.router:
+                        current_callbacks.extend(route.router.callbacks)
+                    if callbacks:
+                        current_callbacks.extend(callbacks)
+                    if route.callbacks:
+                        current_callbacks.extend(route.callbacks)
+
+                    generate_unique_id_functions: List[
+                        Union[Callable[[APIRoute], str], DefaultPlaceholder]
+                    ] = []
+                    if route.router:
+                        generate_unique_id_functions.append(
+                            route.router.generate_unique_id_function
+                        )
+                    generate_unique_id_functions.extend(
+                        [
+                            route.generate_unique_id_function,
+                            router.generate_unique_id_function,
+                            generate_unique_id_function,
+                            self.generate_unique_id_function,
+                        ]
+                    )
+                    current_generate_unique_id_function = get_value_or_default(
+                        *generate_unique_id_functions
+                    )
+                    path = prefix + route.path
+                    if route.router:
+                        path = prefix + route.router.prefix + path
+                    self.add_api_route(
+                        path,
+                        route.endpoint,
+                        response_model=route.response_model,
+                        status_code=route.status_code,
+                        tags=current_tags,
+                        dependencies=current_dependencies,
+                        summary=route.summary,
+                        description=route.description,
+                        response_description=route.response_description,
+                        responses=combined_responses,
+                        deprecated=route.deprecated or deprecated or self.deprecated,
+                        methods=route.methods,
+                        operation_id=route.operation_id,
+                        response_model_include=route.response_model_include,
+                        response_model_exclude=route.response_model_exclude,
+                        response_model_by_alias=route.response_model_by_alias,
+                        response_model_exclude_unset=route.response_model_exclude_unset,
+                        response_model_exclude_defaults=route.response_model_exclude_defaults,
+                        response_model_exclude_none=route.response_model_exclude_none,
+                        include_in_schema=route.include_in_schema
+                        and self.include_in_schema
+                        and include_in_schema,
+                        response_class=use_response_class,
+                        name=route.name,
+                        route_class_override=type(route),
+                        callbacks=current_callbacks,
+                        openapi_extra=route.openapi_extra,
+                        generate_unique_id_function=current_generate_unique_id_function,
+                    )
+                elif isinstance(route, APIMount):
+                    self.include_router(
+                        route.app,
+                        prefix=prefix,
+                        tags=tags,
+                        dependencies=dependencies,
+                        default_response_class=default_response_class,
+                        responses=responses,
+                        callbacks=callbacks,
+                        deprecated=deprecated,
+                        include_in_schema=include_in_schema,
+                        generate_unique_id_function=generate_unique_id_function,
+                    )
+                elif isinstance(route, routing.Route):
+                    methods = list(route.methods or [])  # type: ignore # in Starlette
+                    self.add_route(
+                        prefix + route.path,
+                        route.endpoint,
+                        methods=methods,
+                        include_in_schema=route.include_in_schema,
+                        name=route.name,
+                    )
+                elif isinstance(route, APIWebSocketRoute):
+                    self.add_api_websocket_route(
+                        prefix + route.path, route.endpoint, name=route.name
+                    )
+                elif isinstance(route, routing.WebSocketRoute):
+                    self.add_websocket_route(
+                        prefix + route.path, route.endpoint, name=route.name
+                    )
+            for handler in router.on_startup:
+                self.add_event_handler("startup", handler)
+            for handler in router.on_shutdown:
+                self.add_event_handler("shutdown", handler)
 
     def get(
         self,
@@ -1226,3 +1572,100 @@ class APIRouter(routing.Router):
             openapi_extra=openapi_extra,
             generate_unique_id_function=generate_unique_id_function,
         )
+
+
+class APIMount(routing.Mount):
+    def __init__(
+        self,
+        router: APIRouter,
+        *,
+        name: Optional[str] = None,
+        parent_router: Optional[APIRouter] = None,
+    ) -> None:
+        self.name = name  # type: ignore # in Starlette
+        self.parent_router = parent_router
+        self.router = router
+
+        self.setup()
+
+    def setup(self) -> None:
+        self.app: APIRouter = self.router.copy()
+        if self.parent_router:
+            self.app.parent_router = self.parent_router
+            self.app.setup()
+        self.path = self.app.prefix
+        self.path_regex, self.path_format, self.param_convertors = compile_path(
+            self.path + "/{path:path}"
+        )
+
+        # Add custom additional root without trailing slash for compatibility with
+        # include_router and possibly app migrations
+        # Ref: https://github.com/tiangolo/fastapi/issues/414
+        (
+            self._root_path_regex,
+            self._root_path_format,
+            self._root_param_convertors,
+        ) = compile_path(self.path)
+        (
+            self._root_path_regex_trailing,
+            self._root_path_format_trailing,
+            self._root_param_convertors_trailing,
+        ) = compile_path(self.path + "/")
+
+    def copy(self: APIMountType) -> APIMountType:
+        return type(self)(
+            router=self.router.copy(),
+            name=self.name,
+            parent_router=self.parent_router,
+        )
+
+    def matches(self, scope: Scope) -> Tuple[Match, Scope]:
+        if scope["type"] in ("http", "websocket"):
+            path = scope["path"]
+            if self.app._router_has_empty_route:
+                # Custom logic to support paths without trailing slash
+                # Ref: https://github.com/tiangolo/fastapi/issues/414
+                # This mixes the code in
+                # starlette.routing.Route.matches() and starlette.routing.Mount.matches()
+                match = self._root_path_regex.match(path)
+                if match:
+                    matched_params = match.groupdict()
+                    for key, value in matched_params.items():
+                        matched_params[key] = self.param_convertors[key].convert(value)
+                    path_params = dict(scope.get("path_params", {}))
+                    path_params.update(matched_params)
+                    root_path = scope.get("root_path", "")
+                    child_scope = {
+                        "path_params": path_params,
+                        "app_root_path": scope.get("app_root_path", root_path),
+                        "root_path": root_path,
+                        "path": "",
+                        "endpoint": self.app,
+                    }
+                    return Match.FULL, child_scope
+                if not self.app._router_has_root_route:
+                    match = self._root_path_regex_trailing.match(path)
+                    if match:
+                        return Match.NONE, {}
+                # End of custom logic
+            # Duplicated code from Starlette
+            match = self.path_regex.match(path)
+            if match:
+                matched_params = match.groupdict()
+                for key, value in matched_params.items():
+                    matched_params[key] = self.param_convertors[key].convert(value)
+                remaining_path = "/" + matched_params.pop("path")
+                matched_path = path[: -len(remaining_path)]
+                path_params = dict(scope.get("path_params", {}))
+                path_params.update(matched_params)
+                root_path = scope.get("root_path", "")
+                child_scope = {
+                    "path_params": path_params,
+                    "app_root_path": scope.get("app_root_path", root_path),
+                    "root_path": root_path + matched_path,
+                    "path": remaining_path,
+                    "endpoint": self.app,
+                }
+                return Match.FULL, child_scope
+        return Match.NONE, {}
+        # End of duplicated code from Starlette
index b9301499a27a70f612fc62bb6b6b9fd102b9cd36..9f832a6d4072835cae4cb79817ada7f37a2f0500 100644 (file)
@@ -139,7 +139,7 @@ def generate_operation_id_for_path(
 
 
 def generate_unique_id(route: "APIRoute") -> str:
-    operation_id = route.name + route.path_format
+    operation_id = route.name + route._route_full_path_format
     operation_id = re.sub("[^0-9a-zA-Z_]", "_", operation_id)
     assert route.methods
     operation_id = operation_id + "_" + list(route.methods)[0].lower()
index 1a9ea7199ad5431c44937900001e2691d4234aef..7d8b5f141d17ff7201b86d29d42d1fcfdb6365ee 100644 (file)
@@ -107,9 +107,9 @@ def test_get_path(path, expected_status, expected_response):
 
 def test_route_classes():
     routes = {}
-    for r in app.router.routes:
-        assert isinstance(r, Route)
-        routes[r.path] = r
+    for r in app.router.iter_all_routes():
+        if isinstance(r, APIRoute):
+            routes[r._route_full_path_format] = r
     assert getattr(routes["/a/"], "x_type") == "A"
     assert getattr(routes["/a/b/"], "x_type") == "B"
     assert getattr(routes["/a/b/c/"], "x_type") == "C"