from fastapi.utils import warning_response_model_skip_defaults_deprecated
from starlette.applications import Starlette
from starlette.datastructures import State
-from starlette.exceptions import ExceptionMiddleware, HTTPException
-from starlette.middleware.errors import ServerErrorMiddleware
+from starlette.exceptions import HTTPException
+from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import HTMLResponse, JSONResponse, Response
from starlette.routing import BaseRoute
class FastAPI(Starlette):
def __init__(
self,
+ *,
debug: bool = False,
routes: List[BaseRoute] = None,
- template_directory: str = None,
title: str = "FastAPI",
description: str = "",
version: str = "0.1.0",
redoc_url: Optional[str] = "/redoc",
swagger_ui_oauth2_redirect_url: Optional[str] = "/docs/oauth2-redirect",
swagger_ui_init_oauth: Optional[dict] = None,
+ middleware: Sequence[Middleware] = None,
+ exception_handlers: Dict[Union[int, Type[Exception]], Callable] = None,
+ on_startup: Sequence[Callable] = None,
+ on_shutdown: Sequence[Callable] = None,
**extra: Dict[str, Any],
) -> None:
self.default_response_class = default_response_class
self._debug = debug
self.state = State()
self.router: routing.APIRouter = routing.APIRouter(
- routes, dependency_overrides_provider=self
+ routes,
+ dependency_overrides_provider=self,
+ on_startup=on_startup,
+ on_shutdown=on_shutdown,
)
- self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
- self.error_middleware = ServerErrorMiddleware(
- self.exception_middleware, debug=debug
+ self.exception_handlers = (
+ {} if exception_handlers is None else dict(exception_handlers)
)
+ self.user_middleware = [] if middleware is None else list(middleware)
+ self.middleware_stack = self.build_middleware_stack()
+
self.title = title
self.description = description
self.version = version
dependency_overrides_provider: Any = None,
route_class: Type[APIRoute] = APIRoute,
default_response_class: Type[Response] = None,
+ on_startup: Sequence[Callable] = None,
+ on_shutdown: Sequence[Callable] = None,
) -> None:
super().__init__(
- routes=routes, redirect_slashes=redirect_slashes, default=default
+ routes=routes,
+ redirect_slashes=redirect_slashes,
+ default=default,
+ on_startup=on_startup,
+ on_shutdown=on_shutdown,
)
self.dependency_overrides_provider = dependency_overrides_provider
self.route_class = route_class
self.add_websocket_route(
prefix + route.path, route.endpoint, name=route.name
)
+ for handler in router.on_startup:
+ self.add_event_handler("startup", handler)
+ for handler in router.on_shutdown:
+ self.add_event_handler("shutdown", handler)
def get(
self,
"Topic :: Internet :: WWW/HTTP",
]
requires = [
- "starlette >=0.12.9,<=0.12.9",
+ "starlette ==0.13.2",
"pydantic >=0.32.2,<2.0.0"
]
description-file = "README.md"
def test_use_empty():
with client:
response = client.get("/prefix")
+ assert response.status_code == 200
assert response.json() == ["OK"]
response = client.get("/prefix/")
- assert response.status_code == 404
+ assert response.status_code == 200
+ assert response.json() == ["OK"]
def test_include_empty():
--- /dev/null
+from fastapi import APIRouter, FastAPI
+from pydantic import BaseModel
+from starlette.testclient import TestClient
+
+
+class State(BaseModel):
+ app_startup: bool = False
+ app_shutdown: bool = False
+ router_startup: bool = False
+ router_shutdown: bool = False
+ sub_router_startup: bool = False
+ sub_router_shutdown: bool = False
+
+
+state = State()
+
+app = FastAPI()
+
+
+@app.on_event("startup")
+def app_startup():
+ state.app_startup = True
+
+
+@app.on_event("shutdown")
+def app_shutdown():
+ state.app_shutdown = True
+
+
+router = APIRouter()
+
+
+@router.on_event("startup")
+def router_startup():
+ state.router_startup = True
+
+
+@router.on_event("shutdown")
+def router_shutdown():
+ state.router_shutdown = True
+
+
+sub_router = APIRouter()
+
+
+@sub_router.on_event("startup")
+def sub_router_startup():
+ state.sub_router_startup = True
+
+
+@sub_router.on_event("shutdown")
+def sub_router_shutdown():
+ state.sub_router_shutdown = True
+
+
+@sub_router.get("/")
+def main():
+ return {"message": "Hello World"}
+
+
+router.include_router(sub_router)
+app.include_router(router)
+
+
+def test_router_events():
+ assert state.app_startup is False
+ assert state.router_startup is False
+ assert state.sub_router_startup is False
+ assert state.app_shutdown is False
+ assert state.router_shutdown is False
+ assert state.sub_router_shutdown is False
+ with TestClient(app) as client:
+ assert state.app_startup is True
+ assert state.router_startup is True
+ assert state.sub_router_startup is True
+ assert state.app_shutdown is False
+ assert state.router_shutdown is False
+ assert state.sub_router_shutdown is False
+ response = client.get("/")
+ assert response.status_code == 200
+ assert response.json() == {"message": "Hello World"}
+ assert state.app_startup is True
+ assert state.router_startup is True
+ assert state.sub_router_startup is True
+ assert state.app_shutdown is True
+ assert state.router_shutdown is True
+ assert state.sub_router_shutdown is True