Callable,
Coroutine,
Dict,
+ Iterator,
List,
Optional,
Sequence,
Set,
Tuple,
Type,
+ TypeVar,
Union,
)
from starlette.types import ASGIApp, Scope
from starlette.websockets import WebSocket
+APIRouteType = TypeVar("APIRouteType", bound="APIRoute")
+APIRouterType = TypeVar("APIRouterType", bound="APIRouter")
+APIMountType = TypeVar("APIMountType", bound="APIMount")
+
def _prepare_response_content(
res: Any,
generate_unique_id_function: Union[
Callable[["APIRoute"], str], DefaultPlaceholder
] = Default(generate_unique_id),
+ router: Optional["APIRouter"] = None,
) -> None:
self.path = path
self.endpoint = endpoint
self.response_model = response_model
self.summary = summary
self.response_description = response_description
- self.deprecated = deprecated
self.operation_id = operation_id
self.response_model_include = response_model_include
self.response_model_exclude = response_model_exclude
self.response_model_exclude_unset = response_model_exclude_unset
self.response_model_exclude_defaults = response_model_exclude_defaults
self.response_model_exclude_none = response_model_exclude_none
- self.include_in_schema = include_in_schema
- self.response_class = response_class
self.dependency_overrides_provider = dependency_overrides_provider
- self.callbacks = callbacks
self.openapi_extra = openapi_extra
- self.generate_unique_id_function = generate_unique_id_function
- self.tags = tags or []
- self.responses = responses or {}
+ self.router = router
+
self.name = get_name(endpoint) if name is None else name
- self.path_regex, self.path_format, self.param_convertors = compile_path(path)
+ # normalize enums e.g. http.HTTPStatus
+ if isinstance(status_code, IntEnum):
+ status_code = int(status_code)
+ self.status_code = status_code
if methods is None:
methods = ["GET"]
self.methods: Set[str] = set([method.upper() for method in methods])
- if isinstance(generate_unique_id_function, DefaultPlaceholder):
- current_generate_unique_id: Callable[
+
+ self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
+ # if a "form feed" character (page break) is found in the description text,
+ # truncate description text to the content preceding the first "form feed"
+ self.description = self.description.split("\f")[0]
+
+ assert callable(endpoint), "An endpoint must be a callable"
+
+ self.path_regex, self.path_format, self.param_convertors = compile_path(
+ self.path
+ )
+
+ # Attributes set in route used to compute resolved attributes
+ self._route_deprecated = deprecated
+ self._route_include_in_schema = include_in_schema
+ self._route_response_class = response_class
+ self._route_callbacks = callbacks
+ self._route_generate_unique_id_function = generate_unique_id_function
+ self._route_tags = tags or []
+ self._route_responses = responses or {}
+ if dependencies:
+ self._route_dependencies = dependencies
+ else:
+ self._route_dependencies = []
+
+ self.setup()
+
+ def setup(self) -> None:
+ # setup full path
+ self._route_full_path = self.path
+ if self.router:
+ self._route_full_path = self.router._router_full_path + self.path
+
+ # setup dependencies
+ self.dependencies: List[params.Depends] = []
+ if self.router:
+ self.dependencies.extend(self.router.dependencies)
+ self.dependencies.extend(self._route_dependencies)
+
+ # setup generate_unique_id
+ generate_unique_id_functions: List[
+ Union[Callable[[APIRoute], str], DefaultPlaceholder]
+ ] = [self._route_generate_unique_id_function]
+ if self.router:
+ generate_unique_id_functions.append(self.router.generate_unique_id_function)
+ current_generate_unique_id_function = get_value_or_default(
+ *generate_unique_id_functions
+ )
+ self.generate_unique_id_function: Union[
+ Callable[[APIRoute], str], DefaultPlaceholder
+ ] = current_generate_unique_id_function
+
+ # setup responses
+ responses: Dict[Union[int, str], Dict[str, Any]] = {}
+ if self.router:
+ responses.update(self.router.responses)
+ responses.update(self._route_responses)
+ self.responses: Dict[Union[int, str], Dict[str, Any]] = responses
+
+ # setup default_response_class
+ default_response_classes: List[Union[Type[Response], DefaultPlaceholder]] = [
+ self._route_response_class
+ ]
+ if self.router:
+ default_response_classes.append(self.router.default_response_class)
+ current_default_response_class = get_value_or_default(*default_response_classes)
+ self.response_class: Union[
+ Type[Response], DefaultPlaceholder
+ ] = current_default_response_class
+
+ # setup tags
+ self.tags: List[Union[str, Enum]] = []
+ if self.router:
+ self.tags.extend(self.router.tags)
+ self.tags.extend(self._route_tags)
+
+ # setup callbacks
+ callbacks: List[BaseRoute] = []
+ if self.router:
+ callbacks.extend(self.router.callbacks)
+ if self._route_callbacks:
+ callbacks.extend(self._route_callbacks)
+ self.callbacks = callbacks
+
+ # setup deprecated
+ self.deprecated = self._route_deprecated
+ if self.router:
+ self.deprecated = self._route_deprecated or self.router.deprecated
+
+ # setup include_in_schema
+ self.include_in_schema = self._route_include_in_schema
+ if self.router:
+ self.include_in_schema = (
+ self._route_include_in_schema and self.router.include_in_schema
+ )
+
+ _, self._route_full_path_format, _ = compile_path(self._route_full_path)
+
+ if isinstance(self.generate_unique_id_function, DefaultPlaceholder):
+ resolved_generate_unique_id: Callable[
["APIRoute"], str
- ] = generate_unique_id_function.value
+ ] = self.generate_unique_id_function.value
else:
- current_generate_unique_id = generate_unique_id_function
- self.unique_id = self.operation_id or current_generate_unique_id(self)
- # normalize enums e.g. http.HTTPStatus
- if isinstance(status_code, IntEnum):
- status_code = int(status_code)
- self.status_code = status_code
+ resolved_generate_unique_id = self.generate_unique_id_function
+ self.unique_id = self.operation_id or resolved_generate_unique_id(self)
+
if self.response_model:
assert (
- status_code not in STATUS_CODES_WITH_NO_BODY
- ), f"Status code {status_code} must not have a response body"
+ self.status_code not in STATUS_CODES_WITH_NO_BODY
+ ), f"Status code {self.status_code} must not have a response body"
response_name = "Response_" + self.unique_id
self.response_field = create_response_field(
name=response_name, type_=self.response_model
else:
self.response_field = None # type: ignore
self.secure_cloned_response_field = None
- if dependencies:
- self.dependencies = list(dependencies)
- else:
- self.dependencies = []
- self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
- # if a "form feed" character (page break) is found in the description text,
- # truncate description text to the content preceding the first "form feed"
- self.description = self.description.split("\f")[0]
+
response_fields = {}
for additional_status_code, response in self.responses.items():
assert isinstance(response, dict), "An additional response must be a dict"
else:
self.response_fields = {}
- assert callable(endpoint), "An endpoint must be a callable"
- self.dependant = get_dependant(path=self.path_format, call=self.endpoint)
+ self.dependant = get_dependant(
+ path=self._route_full_path_format, call=self.endpoint
+ )
for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert(
0,
- get_parameterless_sub_dependant(depends=depends, path=self.path_format),
+ get_parameterless_sub_dependant(
+ depends=depends, path=self._route_full_path_format
+ ),
)
self.body_field = get_body_field(dependant=self.dependant, name=self.unique_id)
self.app = request_response(self.get_route_handler())
+ def copy(self: APIRouteType) -> APIRouteType:
+ return type(self)(
+ path=self.path,
+ endpoint=self.endpoint,
+ response_model=self.response_model,
+ status_code=self.status_code,
+ tags=self._route_tags,
+ dependencies=self._route_dependencies,
+ summary=self.summary,
+ description=self.description,
+ response_description=self.response_description,
+ responses=self._route_responses,
+ deprecated=self._route_deprecated,
+ name=self.name,
+ methods=self.methods,
+ operation_id=self.operation_id,
+ response_model_include=self.response_model_include,
+ response_model_exclude=self.response_model_exclude,
+ response_model_by_alias=self.response_model_by_alias,
+ response_model_exclude_unset=self.response_model_exclude_unset,
+ response_model_exclude_defaults=self.response_model_exclude_defaults,
+ response_model_exclude_none=self.response_model_exclude_none,
+ include_in_schema=self._route_include_in_schema,
+ response_class=self._route_response_class,
+ dependency_overrides_provider=self.dependency_overrides_provider,
+ callbacks=self._route_callbacks,
+ openapi_extra=self.openapi_extra,
+ generate_unique_id_function=self._route_generate_unique_id_function,
+ router=self.router,
+ )
+
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
return get_request_handler(
dependant=self.dependant,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
generate_unique_id
),
+ parent_router: Optional["APIRouter"] = None,
) -> None:
super().__init__(
routes=routes, # type: ignore # in Starlette
"/"
), "A path prefix must not end with '/', as the routes will start with '/'"
self.prefix = prefix
- self.tags: List[Union[str, Enum]] = tags or []
- self.dependencies = list(dependencies or []) or []
- self.deprecated = deprecated
- self.include_in_schema = include_in_schema
- self.responses = responses or {}
- self.callbacks = callbacks or []
self.dependency_overrides_provider = dependency_overrides_provider
self.route_class = route_class
- self.default_response_class = default_response_class
- self.generate_unique_id_function = generate_unique_id_function
+
+ self.parent_router = parent_router
+
+ # Attributes set in router used to compute resolved attributes
+ self._router_dependencies = list(dependencies or []) or []
+ self._router_generate_unique_id_function = generate_unique_id_function
+ self._router_responses = responses or {}
+ self._router_default_response_class = default_response_class
+ self._router_tags: List[Union[str, Enum]] = tags or []
+ self._router_callbacks = callbacks or []
+ self._router_deprecated = deprecated
+ self._router_include_in_schema = include_in_schema
+ self._router_has_empty_route = False
+ self._router_has_root_route = False
+ self.setup()
+
+ def setup(self) -> None:
+ # setup full path
+ self._router_full_path = self.prefix
+ if self.parent_router:
+ self._router_full_path = self.parent_router._router_full_path + self.prefix
+ # setup dependencies
+ self.dependencies: List[params.Depends] = []
+ if self.parent_router:
+ self.dependencies.extend(self.parent_router.dependencies)
+ self.dependencies.extend(self._router_dependencies)
+
+ # setup generate_unique_id
+ generate_unique_id_functions: List[
+ Union[Callable[[APIRoute], str], DefaultPlaceholder]
+ ] = [self._router_generate_unique_id_function]
+ if self.parent_router:
+ generate_unique_id_functions.append(
+ self.parent_router.generate_unique_id_function
+ )
+ current_generate_unique_id_function = get_value_or_default(
+ *generate_unique_id_functions
+ )
+ self.generate_unique_id_function: Union[
+ Callable[[APIRoute], str], DefaultPlaceholder
+ ] = current_generate_unique_id_function
+
+ # setup responses
+ responses: Dict[Union[int, str], Dict[str, Any]] = {}
+ if self.parent_router:
+ responses.update(self.parent_router.responses)
+ responses.update(self._router_responses)
+ self.responses: Dict[Union[int, str], Dict[str, Any]] = responses
+
+ # setup default_response_class
+ default_response_classes: List[Union[Type[Response], DefaultPlaceholder]] = [
+ self._router_default_response_class
+ ]
+ if self.parent_router:
+ default_response_classes.append(self.parent_router.default_response_class)
+ current_default_response_class = get_value_or_default(*default_response_classes)
+ self.default_response_class: Union[
+ Type[Response], DefaultPlaceholder
+ ] = current_default_response_class
+
+ # setup tags
+ self.tags: List[Union[str, Enum]] = []
+ if self.parent_router:
+ self.tags.extend(self.parent_router.tags)
+ self.tags.extend(self._router_tags)
+
+ # setup callbacks
+ self.callbacks: List[BaseRoute] = []
+ if self.parent_router:
+ self.callbacks.extend(self.parent_router.callbacks)
+ self.callbacks.extend(self._router_callbacks)
+
+ # setup deprecated
+ self.deprecated = self._router_deprecated
+ if self.parent_router:
+ self.deprecated = self._router_deprecated or self.parent_router.deprecated
+
+ # setup include_in_schema
+ self.include_in_schema = self._router_include_in_schema
+ if self.parent_router:
+ self.include_in_schema = (
+ self._router_include_in_schema and self.parent_router.include_in_schema
+ )
+
+ # setup routes
+ for route in self.routes:
+ if isinstance(route, APIRoute):
+ route.router = self
+ route.setup()
+ elif isinstance(route, APIMount):
+ route.parent_router = self
+ route.setup()
+
+ def copy(self: APIRouterType) -> APIRouterType:
+ routes: List[routing.BaseRoute] = []
+ for route in self.routes:
+ if isinstance(route, APIRoute):
+ routes.append(route.copy())
+ elif isinstance(route, APIMount):
+ routes.append(route.copy())
+ else:
+ routes.append(route)
+ copied_router = type(self)(
+ prefix=self.prefix,
+ tags=self._router_tags,
+ dependencies=self._router_dependencies,
+ default_response_class=self._router_default_response_class,
+ responses=self._router_responses,
+ callbacks=self._router_callbacks,
+ routes=routes,
+ redirect_slashes=self.redirect_slashes,
+ default=self.default,
+ dependency_overrides_provider=self.dependency_overrides_provider,
+ route_class=self.route_class,
+ on_startup=self.on_startup,
+ on_shutdown=self.on_shutdown,
+ deprecated=self._router_deprecated,
+ include_in_schema=self._router_include_in_schema,
+ generate_unique_id_function=self._router_generate_unique_id_function,
+ parent_router=self.parent_router,
+ )
+ copied_router._router_has_empty_route = self._router_has_empty_route
+ copied_router._router_has_root_route = self._router_has_root_route
+ for route in copied_router.routes:
+ if isinstance(route, APIRoute):
+ route.router = copied_router
+ route.setup()
+ elif isinstance(route, Mount):
+ if isinstance(route.app, APIRouter):
+ route.app.setup()
+ return copied_router
+
+ def iter_all_routes(self) -> Iterator[routing.BaseRoute]:
+ for route in self.routes:
+ if isinstance(route, Mount):
+ if isinstance(route.app, APIRouter):
+ yield from route.app.iter_all_routes()
+ else:
+ yield route
+
+ def api_mount(self, router: "APIRouter", name: Optional[str] = None) -> None:
+ route = APIMount(router=router, name=name, parent_router=self)
+ self.routes.append(route)
def add_api_route(
self,
) -> None:
route_class = route_class_override or self.route_class
responses = responses or {}
- combined_responses = {**self.responses, **responses}
- current_response_class = get_value_or_default(
- response_class, self.default_response_class
- )
- current_tags = self.tags.copy()
- if tags:
- current_tags.extend(tags)
- current_dependencies = self.dependencies.copy()
- if dependencies:
- current_dependencies.extend(dependencies)
- current_callbacks = self.callbacks.copy()
- if callbacks:
- current_callbacks.extend(callbacks)
- current_generate_unique_id = get_value_or_default(
- generate_unique_id_function, self.generate_unique_id_function
- )
route = route_class(
- self.prefix + path,
+ path,
endpoint=endpoint,
response_model=response_model,
status_code=status_code,
- tags=current_tags,
- dependencies=current_dependencies,
+ tags=tags,
+ dependencies=dependencies,
summary=summary,
description=description,
response_description=response_description,
- responses=combined_responses,
- deprecated=deprecated or self.deprecated,
+ responses=responses,
+ deprecated=deprecated,
methods=methods,
operation_id=operation_id,
response_model_include=response_model_include,
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
- include_in_schema=include_in_schema and self.include_in_schema,
- response_class=current_response_class,
+ include_in_schema=include_in_schema,
+ response_class=response_class,
name=name,
dependency_overrides_provider=self.dependency_overrides_provider,
- callbacks=current_callbacks,
+ callbacks=callbacks,
openapi_extra=openapi_extra,
- generate_unique_id_function=current_generate_unique_id,
+ generate_unique_id_function=generate_unique_id_function,
+ router=self,
)
self.routes.append(route)
+ if not path:
+ self._router_has_empty_route = True
+ if path == "/":
+ self._router_has_root_route = True
def api_route(
self,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
generate_unique_id
),
+ copy_flat_routes: Optional[bool] = None,
) -> None:
if prefix:
assert prefix.startswith("/"), "A path prefix must start with '/'"
assert not prefix.endswith(
"/"
), "A path prefix must not end with '/', as the routes will start with '/'"
+ resolved_copy_flat_routes = copy_flat_routes
+ if resolved_copy_flat_routes is None:
+ resolved_copy_flat_routes = not (prefix or router.prefix)
+ if not resolved_copy_flat_routes:
+ included_router = router.copy()
+ if (
+ prefix
+ or tags
+ or dependencies
+ or not isinstance(default_response_class, DefaultPlaceholder)
+ or responses
+ or callbacks
+ or deprecated is not None
+ or include_in_schema is not True
+ or not isinstance(generate_unique_id_function, DefaultPlaceholder)
+ ):
+ current_router = type(self)(
+ prefix=prefix,
+ tags=tags,
+ dependencies=dependencies,
+ default_response_class=default_response_class,
+ responses=responses,
+ callbacks=callbacks,
+ deprecated=deprecated,
+ include_in_schema=include_in_schema,
+ generate_unique_id_function=generate_unique_id_function,
+ parent_router=self,
+ )
+ # current_router.api_mount(included_router)
+ current_router.include_router(included_router)
+ if included_router._router_has_empty_route and not self.prefix:
+ current_router._router_has_empty_route = True
+ current_router._router_has_root_route = (
+ included_router._router_has_root_route
+ )
+ self.api_mount(current_router)
+ included_router.parent_router = current_router
+ else:
+ self.api_mount(included_router)
+ included_router.parent_router = self
+
+ included_router.setup()
else:
- for r in router.routes:
- path = getattr(r, "path")
- name = getattr(r, "name", "unknown")
- if path is not None and not path:
- raise Exception(
- f"Prefix and path cannot be both empty (path operation: {name})"
+ # TODO: remove this and its test, as a subrouter can mount another
+ # subrouter (done automatically of other things are overwritten) and both
+ # can omit a prefix, this would error out
+ # for r in router.routes:
+ # path = getattr(r, "path")
+ # name = getattr(r, "name", "unknown")
+ # if path is not None and not path:
+ # raise Exception(
+ # f"Prefix and path cannot be both empty (path operation: {name})"
+ # )
+ if responses is None:
+ responses = {}
+ for route in router.routes:
+ if isinstance(route, APIRoute):
+ combined_responses = {}
+ if route.router:
+ combined_responses.update(route.router.responses)
+ combined_responses.update(responses)
+ combined_responses.update(route.responses)
+
+ response_classes: List[
+ Union[Type[Response], DefaultPlaceholder]
+ ] = []
+ if route.router:
+ response_classes.append(route.router.default_response_class)
+ response_classes.extend(
+ [
+ route.response_class,
+ router.default_response_class,
+ default_response_class,
+ self.default_response_class,
+ ]
)
- if responses is None:
- responses = {}
- for route in router.routes:
- if isinstance(route, APIRoute):
- combined_responses = {**responses, **route.responses}
- use_response_class = get_value_or_default(
- route.response_class,
- router.default_response_class,
- default_response_class,
- self.default_response_class,
- )
- current_tags = []
- if tags:
- current_tags.extend(tags)
- if route.tags:
- current_tags.extend(route.tags)
- current_dependencies: List[params.Depends] = []
- if dependencies:
- current_dependencies.extend(dependencies)
- if route.dependencies:
- current_dependencies.extend(route.dependencies)
- current_callbacks = []
- if callbacks:
- current_callbacks.extend(callbacks)
- if route.callbacks:
- current_callbacks.extend(route.callbacks)
- current_generate_unique_id = get_value_or_default(
- route.generate_unique_id_function,
- router.generate_unique_id_function,
- generate_unique_id_function,
- self.generate_unique_id_function,
- )
- self.add_api_route(
- prefix + route.path,
- route.endpoint,
- response_model=route.response_model,
- status_code=route.status_code,
- tags=current_tags,
- dependencies=current_dependencies,
- summary=route.summary,
- description=route.description,
- response_description=route.response_description,
- responses=combined_responses,
- deprecated=route.deprecated or deprecated or self.deprecated,
- methods=route.methods,
- operation_id=route.operation_id,
- response_model_include=route.response_model_include,
- response_model_exclude=route.response_model_exclude,
- response_model_by_alias=route.response_model_by_alias,
- response_model_exclude_unset=route.response_model_exclude_unset,
- response_model_exclude_defaults=route.response_model_exclude_defaults,
- response_model_exclude_none=route.response_model_exclude_none,
- include_in_schema=route.include_in_schema
- and self.include_in_schema
- and include_in_schema,
- response_class=use_response_class,
- name=route.name,
- route_class_override=type(route),
- callbacks=current_callbacks,
- openapi_extra=route.openapi_extra,
- generate_unique_id_function=current_generate_unique_id,
- )
- elif isinstance(route, routing.Route):
- methods = list(route.methods or []) # type: ignore # in Starlette
- self.add_route(
- prefix + route.path,
- route.endpoint,
- methods=methods,
- include_in_schema=route.include_in_schema,
- name=route.name,
- )
- elif isinstance(route, APIWebSocketRoute):
- self.add_api_websocket_route(
- prefix + route.path, route.endpoint, name=route.name
- )
- elif isinstance(route, routing.WebSocketRoute):
- self.add_websocket_route(
- prefix + route.path, route.endpoint, name=route.name
- )
- for handler in router.on_startup:
- self.add_event_handler("startup", handler)
- for handler in router.on_shutdown:
- self.add_event_handler("shutdown", handler)
+ use_response_class = get_value_or_default(*response_classes)
+ current_tags = []
+ if route.router:
+ current_tags.extend(route.router.tags)
+ if tags:
+ current_tags.extend(tags)
+ if route.tags:
+ current_tags.extend(route.tags)
+ current_dependencies: List[params.Depends] = []
+ if route.router:
+ current_dependencies.extend(route.router.dependencies)
+ if dependencies:
+ current_dependencies.extend(dependencies)
+ if route.dependencies:
+ current_dependencies.extend(route.dependencies)
+ current_callbacks = []
+ if route.router:
+ current_callbacks.extend(route.router.callbacks)
+ if callbacks:
+ current_callbacks.extend(callbacks)
+ if route.callbacks:
+ current_callbacks.extend(route.callbacks)
+
+ generate_unique_id_functions: List[
+ Union[Callable[[APIRoute], str], DefaultPlaceholder]
+ ] = []
+ if route.router:
+ generate_unique_id_functions.append(
+ route.router.generate_unique_id_function
+ )
+ generate_unique_id_functions.extend(
+ [
+ route.generate_unique_id_function,
+ router.generate_unique_id_function,
+ generate_unique_id_function,
+ self.generate_unique_id_function,
+ ]
+ )
+ current_generate_unique_id_function = get_value_or_default(
+ *generate_unique_id_functions
+ )
+ path = prefix + route.path
+ if route.router:
+ path = prefix + route.router.prefix + path
+ self.add_api_route(
+ path,
+ route.endpoint,
+ response_model=route.response_model,
+ status_code=route.status_code,
+ tags=current_tags,
+ dependencies=current_dependencies,
+ summary=route.summary,
+ description=route.description,
+ response_description=route.response_description,
+ responses=combined_responses,
+ deprecated=route.deprecated or deprecated or self.deprecated,
+ methods=route.methods,
+ operation_id=route.operation_id,
+ response_model_include=route.response_model_include,
+ response_model_exclude=route.response_model_exclude,
+ response_model_by_alias=route.response_model_by_alias,
+ response_model_exclude_unset=route.response_model_exclude_unset,
+ response_model_exclude_defaults=route.response_model_exclude_defaults,
+ response_model_exclude_none=route.response_model_exclude_none,
+ include_in_schema=route.include_in_schema
+ and self.include_in_schema
+ and include_in_schema,
+ response_class=use_response_class,
+ name=route.name,
+ route_class_override=type(route),
+ callbacks=current_callbacks,
+ openapi_extra=route.openapi_extra,
+ generate_unique_id_function=current_generate_unique_id_function,
+ )
+ elif isinstance(route, APIMount):
+ self.include_router(
+ route.app,
+ prefix=prefix,
+ tags=tags,
+ dependencies=dependencies,
+ default_response_class=default_response_class,
+ responses=responses,
+ callbacks=callbacks,
+ deprecated=deprecated,
+ include_in_schema=include_in_schema,
+ generate_unique_id_function=generate_unique_id_function,
+ )
+ elif isinstance(route, routing.Route):
+ methods = list(route.methods or []) # type: ignore # in Starlette
+ self.add_route(
+ prefix + route.path,
+ route.endpoint,
+ methods=methods,
+ include_in_schema=route.include_in_schema,
+ name=route.name,
+ )
+ elif isinstance(route, APIWebSocketRoute):
+ self.add_api_websocket_route(
+ prefix + route.path, route.endpoint, name=route.name
+ )
+ elif isinstance(route, routing.WebSocketRoute):
+ self.add_websocket_route(
+ prefix + route.path, route.endpoint, name=route.name
+ )
+ for handler in router.on_startup:
+ self.add_event_handler("startup", handler)
+ for handler in router.on_shutdown:
+ self.add_event_handler("shutdown", handler)
def get(
self,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
)
+
+
+class APIMount(routing.Mount):
+ def __init__(
+ self,
+ router: APIRouter,
+ *,
+ name: Optional[str] = None,
+ parent_router: Optional[APIRouter] = None,
+ ) -> None:
+ self.name = name # type: ignore # in Starlette
+ self.parent_router = parent_router
+ self.router = router
+
+ self.setup()
+
+ def setup(self) -> None:
+ self.app: APIRouter = self.router.copy()
+ if self.parent_router:
+ self.app.parent_router = self.parent_router
+ self.app.setup()
+ self.path = self.app.prefix
+ self.path_regex, self.path_format, self.param_convertors = compile_path(
+ self.path + "/{path:path}"
+ )
+
+ # Add custom additional root without trailing slash for compatibility with
+ # include_router and possibly app migrations
+ # Ref: https://github.com/tiangolo/fastapi/issues/414
+ (
+ self._root_path_regex,
+ self._root_path_format,
+ self._root_param_convertors,
+ ) = compile_path(self.path)
+ (
+ self._root_path_regex_trailing,
+ self._root_path_format_trailing,
+ self._root_param_convertors_trailing,
+ ) = compile_path(self.path + "/")
+
+ def copy(self: APIMountType) -> APIMountType:
+ return type(self)(
+ router=self.router.copy(),
+ name=self.name,
+ parent_router=self.parent_router,
+ )
+
+ def matches(self, scope: Scope) -> Tuple[Match, Scope]:
+ if scope["type"] in ("http", "websocket"):
+ path = scope["path"]
+ if self.app._router_has_empty_route:
+ # Custom logic to support paths without trailing slash
+ # Ref: https://github.com/tiangolo/fastapi/issues/414
+ # This mixes the code in
+ # starlette.routing.Route.matches() and starlette.routing.Mount.matches()
+ match = self._root_path_regex.match(path)
+ if match:
+ matched_params = match.groupdict()
+ for key, value in matched_params.items():
+ matched_params[key] = self.param_convertors[key].convert(value)
+ path_params = dict(scope.get("path_params", {}))
+ path_params.update(matched_params)
+ root_path = scope.get("root_path", "")
+ child_scope = {
+ "path_params": path_params,
+ "app_root_path": scope.get("app_root_path", root_path),
+ "root_path": root_path,
+ "path": "",
+ "endpoint": self.app,
+ }
+ return Match.FULL, child_scope
+ if not self.app._router_has_root_route:
+ match = self._root_path_regex_trailing.match(path)
+ if match:
+ return Match.NONE, {}
+ # End of custom logic
+ # Duplicated code from Starlette
+ match = self.path_regex.match(path)
+ if match:
+ matched_params = match.groupdict()
+ for key, value in matched_params.items():
+ matched_params[key] = self.param_convertors[key].convert(value)
+ remaining_path = "/" + matched_params.pop("path")
+ matched_path = path[: -len(remaining_path)]
+ path_params = dict(scope.get("path_params", {}))
+ path_params.update(matched_params)
+ root_path = scope.get("root_path", "")
+ child_scope = {
+ "path_params": path_params,
+ "app_root_path": scope.get("app_root_path", root_path),
+ "root_path": root_path + matched_path,
+ "path": remaining_path,
+ "endpoint": self.app,
+ }
+ return Match.FULL, child_scope
+ return Match.NONE, {}
+ # End of duplicated code from Starlette