]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:pushpin: Upgrade Starlette version (#1057)
authorSebastián Ramírez <tiangolo@gmail.com>
Sat, 29 Feb 2020 20:28:23 +0000 (21:28 +0100)
committerGitHub <noreply@github.com>
Sat, 29 Feb 2020 20:28:23 +0000 (21:28 +0100)
fastapi/applications.py
fastapi/routing.py
pyproject.toml
tests/test_empty_router.py
tests/test_router_events.py [new file with mode: 0644]

index ff2eb52c8606543ad2e5e9e2590f4093751e3700..8270e54fdf4d0a8b3b0529adaf3a00aacc2c8423 100644 (file)
@@ -18,8 +18,8 @@ from fastapi.params import Depends
 from fastapi.utils import warning_response_model_skip_defaults_deprecated
 from starlette.applications import Starlette
 from starlette.datastructures import State
-from starlette.exceptions import ExceptionMiddleware, HTTPException
-from starlette.middleware.errors import ServerErrorMiddleware
+from starlette.exceptions import HTTPException
+from starlette.middleware import Middleware
 from starlette.requests import Request
 from starlette.responses import HTMLResponse, JSONResponse, Response
 from starlette.routing import BaseRoute
@@ -29,9 +29,9 @@ from starlette.types import Receive, Scope, Send
 class FastAPI(Starlette):
     def __init__(
         self,
+        *,
         debug: bool = False,
         routes: List[BaseRoute] = None,
-        template_directory: str = None,
         title: str = "FastAPI",
         description: str = "",
         version: str = "0.1.0",
@@ -42,19 +42,28 @@ class FastAPI(Starlette):
         redoc_url: Optional[str] = "/redoc",
         swagger_ui_oauth2_redirect_url: Optional[str] = "/docs/oauth2-redirect",
         swagger_ui_init_oauth: Optional[dict] = None,
+        middleware: Sequence[Middleware] = None,
+        exception_handlers: Dict[Union[int, Type[Exception]], Callable] = None,
+        on_startup: Sequence[Callable] = None,
+        on_shutdown: Sequence[Callable] = None,
         **extra: Dict[str, Any],
     ) -> None:
         self.default_response_class = default_response_class
         self._debug = debug
         self.state = State()
         self.router: routing.APIRouter = routing.APIRouter(
-            routes, dependency_overrides_provider=self
+            routes,
+            dependency_overrides_provider=self,
+            on_startup=on_startup,
+            on_shutdown=on_shutdown,
         )
-        self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
-        self.error_middleware = ServerErrorMiddleware(
-            self.exception_middleware, debug=debug
+        self.exception_handlers = (
+            {} if exception_handlers is None else dict(exception_handlers)
         )
 
+        self.user_middleware = [] if middleware is None else list(middleware)
+        self.middleware_stack = self.build_middleware_stack()
+
         self.title = title
         self.description = description
         self.version = version
index d5211c4892d6bfb621cadfa15172156b46b07448..7f3dc49774cf593150b21969837d1d73e0fabe8b 100644 (file)
@@ -346,9 +346,15 @@ class APIRouter(routing.Router):
         dependency_overrides_provider: Any = None,
         route_class: Type[APIRoute] = APIRoute,
         default_response_class: Type[Response] = None,
+        on_startup: Sequence[Callable] = None,
+        on_shutdown: Sequence[Callable] = None,
     ) -> None:
         super().__init__(
-            routes=routes, redirect_slashes=redirect_slashes, default=default
+            routes=routes,
+            redirect_slashes=redirect_slashes,
+            default=default,
+            on_startup=on_startup,
+            on_shutdown=on_shutdown,
         )
         self.dependency_overrides_provider = dependency_overrides_provider
         self.route_class = route_class
@@ -552,6 +558,10 @@ class APIRouter(routing.Router):
                 self.add_websocket_route(
                     prefix + route.path, route.endpoint, name=route.name
                 )
+        for handler in router.on_startup:
+            self.add_event_handler("startup", handler)
+        for handler in router.on_shutdown:
+            self.add_event_handler("shutdown", handler)
 
     def get(
         self,
index 7761a54c9fdb10d605cbe460b031691e0c33a73a..3805104ea504d39a85ef940325d85b16d1675081 100644 (file)
@@ -32,7 +32,7 @@ classifiers = [
     "Topic :: Internet :: WWW/HTTP",
 ]
 requires = [
-    "starlette >=0.12.9,<=0.12.9",
+    "starlette ==0.13.2",
     "pydantic >=0.32.2,<2.0.0"
 ]
 description-file = "README.md"
index 57dd006faa03e61af3c185fb3bf3961c5e19073f..c38fae8551e2774c7228a1f0551bcad2ba6f07f6 100644 (file)
@@ -21,10 +21,12 @@ client = TestClient(app)
 def test_use_empty():
     with client:
         response = client.get("/prefix")
+        assert response.status_code == 200
         assert response.json() == ["OK"]
 
         response = client.get("/prefix/")
-        assert response.status_code == 404
+        assert response.status_code == 200
+        assert response.json() == ["OK"]
 
 
 def test_include_empty():
diff --git a/tests/test_router_events.py b/tests/test_router_events.py
new file mode 100644 (file)
index 0000000..3a499b1
--- /dev/null
@@ -0,0 +1,87 @@
+from fastapi import APIRouter, FastAPI
+from pydantic import BaseModel
+from starlette.testclient import TestClient
+
+
+class State(BaseModel):
+    app_startup: bool = False
+    app_shutdown: bool = False
+    router_startup: bool = False
+    router_shutdown: bool = False
+    sub_router_startup: bool = False
+    sub_router_shutdown: bool = False
+
+
+state = State()
+
+app = FastAPI()
+
+
+@app.on_event("startup")
+def app_startup():
+    state.app_startup = True
+
+
+@app.on_event("shutdown")
+def app_shutdown():
+    state.app_shutdown = True
+
+
+router = APIRouter()
+
+
+@router.on_event("startup")
+def router_startup():
+    state.router_startup = True
+
+
+@router.on_event("shutdown")
+def router_shutdown():
+    state.router_shutdown = True
+
+
+sub_router = APIRouter()
+
+
+@sub_router.on_event("startup")
+def sub_router_startup():
+    state.sub_router_startup = True
+
+
+@sub_router.on_event("shutdown")
+def sub_router_shutdown():
+    state.sub_router_shutdown = True
+
+
+@sub_router.get("/")
+def main():
+    return {"message": "Hello World"}
+
+
+router.include_router(sub_router)
+app.include_router(router)
+
+
+def test_router_events():
+    assert state.app_startup is False
+    assert state.router_startup is False
+    assert state.sub_router_startup is False
+    assert state.app_shutdown is False
+    assert state.router_shutdown is False
+    assert state.sub_router_shutdown is False
+    with TestClient(app) as client:
+        assert state.app_startup is True
+        assert state.router_startup is True
+        assert state.sub_router_startup is True
+        assert state.app_shutdown is False
+        assert state.router_shutdown is False
+        assert state.sub_router_shutdown is False
+        response = client.get("/")
+        assert response.status_code == 200
+        assert response.json() == {"message": "Hello World"}
+    assert state.app_startup is True
+    assert state.router_startup is True
+    assert state.sub_router_startup is True
+    assert state.app_shutdown is True
+    assert state.router_shutdown is True
+    assert state.sub_router_shutdown is True