]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
✨ Update internal `AsyncExitStack` to fix context for dependencies with `yield` ...
authorSebastián Ramírez <tiangolo@gmail.com>
Thu, 17 Feb 2022 12:40:12 +0000 (13:40 +0100)
committerGitHub <noreply@github.com>
Thu, 17 Feb 2022 12:40:12 +0000 (13:40 +0100)
docs/en/docs/tutorial/dependencies/dependencies-with-yield.md
fastapi/applications.py
fastapi/middleware/asyncexitstack.py [new file with mode: 0644]
tests/test_dependency_contextmanager.py
tests/test_dependency_contextvars.py [new file with mode: 0644]
tests/test_dependency_normal_exceptions.py [new file with mode: 0644]
tests/test_exception_handlers.py

index 82553afae52dd22ae9ff7823c45801cffd1b823a..ac2e9cb8cb4c4a8a043d98548f65a13e4e308e64 100644 (file)
@@ -99,7 +99,7 @@ You saw that you can use dependencies with `yield` and have `try` blocks that ca
 
 It might be tempting to raise an `HTTPException` or similar in the exit code, after the `yield`. But **it won't work**.
 
-The exit code in dependencies with `yield` is executed *after* [Exception Handlers](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank}. There's nothing catching exceptions thrown by your dependencies in the exit code (after the `yield`).
+The exit code in dependencies with `yield` is executed *after* the response is sent, so [Exception Handlers](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank} will have already run. There's nothing catching exceptions thrown by your dependencies in the exit code (after the `yield`).
 
 So, if you raise an `HTTPException` after the `yield`, the default (or any custom) exception handler that catches `HTTPException`s and returns an HTTP 400 response won't be there to catch that exception anymore.
 
@@ -138,9 +138,11 @@ participant tasks as Background tasks
     end
     dep ->> operation: Run dependency, e.g. DB session
     opt raise
-        operation -->> handler: Raise HTTPException
+        operation -->> dep: Raise HTTPException
+        dep -->> handler: Auto forward exception
         handler -->> client: HTTP error response
         operation -->> dep: Raise other exception
+        dep -->> handler: Auto forward exception
     end
     operation ->> client: Return response to client
     Note over client,operation: Response is already sent, can't change it anymore
@@ -162,9 +164,9 @@ participant tasks as Background tasks
     After one of those responses is sent, no other response can be sent.
 
 !!! tip
-    This diagram shows `HTTPException`, but you could also raise any other exception for which you create a [Custom Exception Handler](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank}. And that exception would be handled by that custom exception handler instead of the dependency exit code.
+    This diagram shows `HTTPException`, but you could also raise any other exception for which you create a [Custom Exception Handler](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank}.
 
-    But if you raise an exception that is not handled by the exception handlers, it will be handled by the exit code of the dependency.
+    If you raise any exception, it will be passed to the dependencies with yield, including `HTTPException`, and then **again** to the exception handlers. If there's no exception handler for that exception, it will then be handled by the default internal `ServerErrorMiddleware`, returning a 500 HTTP status code, to let the client know that there was an error in the server.
 
 ## Context Managers
 
index dbfd76fb9f3cd4a774bcf0fa5867c3d4a1b50108..9fb78719c31c40c80b42f2b740b10aaf59950c48 100644 (file)
@@ -2,7 +2,6 @@ from enum import Enum
 from typing import Any, Callable, Coroutine, Dict, List, Optional, Sequence, Type, Union
 
 from fastapi import routing
-from fastapi.concurrency import AsyncExitStack
 from fastapi.datastructures import Default, DefaultPlaceholder
 from fastapi.encoders import DictIntStrAny, SetIntStr
 from fastapi.exception_handlers import (
@@ -11,6 +10,7 @@ from fastapi.exception_handlers import (
 )
 from fastapi.exceptions import RequestValidationError
 from fastapi.logger import logger
+from fastapi.middleware.asyncexitstack import AsyncExitStackMiddleware
 from fastapi.openapi.docs import (
     get_redoc_html,
     get_swagger_ui_html,
@@ -21,8 +21,9 @@ from fastapi.params import Depends
 from fastapi.types import DecoratedCallable
 from starlette.applications import Starlette
 from starlette.datastructures import State
-from starlette.exceptions import HTTPException
+from starlette.exceptions import ExceptionMiddleware, HTTPException
 from starlette.middleware import Middleware
+from starlette.middleware.errors import ServerErrorMiddleware
 from starlette.requests import Request
 from starlette.responses import HTMLResponse, JSONResponse, Response
 from starlette.routing import BaseRoute
@@ -134,6 +135,55 @@ class FastAPI(Starlette):
         self.openapi_schema: Optional[Dict[str, Any]] = None
         self.setup()
 
+    def build_middleware_stack(self) -> ASGIApp:
+        # Duplicate/override from Starlette to add AsyncExitStackMiddleware
+        # inside of ExceptionMiddleware, inside of custom user middlewares
+        debug = self.debug
+        error_handler = None
+        exception_handlers = {}
+
+        for key, value in self.exception_handlers.items():
+            if key in (500, Exception):
+                error_handler = value
+            else:
+                exception_handlers[key] = value
+
+        middleware = (
+            [Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
+            + self.user_middleware
+            + [
+                Middleware(
+                    ExceptionMiddleware, handlers=exception_handlers, debug=debug
+                ),
+                # Add FastAPI-specific AsyncExitStackMiddleware for dependencies with
+                # contextvars.
+                # This needs to happen after user middlewares because those create a
+                # new contextvars context copy by using a new AnyIO task group.
+                # The initial part of dependencies with yield is executed in the
+                # FastAPI code, inside all the middlewares, but the teardown part
+                # (after yield) is executed in the AsyncExitStack in this middleware,
+                # if the AsyncExitStack lived outside of the custom middlewares and
+                # contextvars were set in a dependency with yield in that internal
+                # contextvars context, the values would not be available in the
+                # outside context of the AsyncExitStack.
+                # By putting the middleware and the AsyncExitStack here, inside all
+                # user middlewares, the code before and after yield in dependencies
+                # with yield is executed in the same contextvars context, so all values
+                # set in contextvars before yield is still available after yield as
+                # would be expected.
+                # Additionally, by having this AsyncExitStack here, after the
+                # ExceptionMiddleware, now dependencies can catch handled exceptions,
+                # e.g. HTTPException, to customize the teardown code (e.g. DB session
+                # rollback).
+                Middleware(AsyncExitStackMiddleware),
+            ]
+        )
+
+        app = self.router
+        for cls, options in reversed(middleware):
+            app = cls(app=app, **options)
+        return app
+
     def openapi(self) -> Dict[str, Any]:
         if not self.openapi_schema:
             self.openapi_schema = get_openapi(
@@ -206,12 +256,7 @@ class FastAPI(Starlette):
     async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
         if self.root_path:
             scope["root_path"] = self.root_path
-        if AsyncExitStack:
-            async with AsyncExitStack() as stack:
-                scope["fastapi_astack"] = stack
-                await super().__call__(scope, receive, send)
-        else:
-            await super().__call__(scope, receive, send)  # pragma: no cover
+        await super().__call__(scope, receive, send)
 
     def add_api_route(
         self,
diff --git a/fastapi/middleware/asyncexitstack.py b/fastapi/middleware/asyncexitstack.py
new file mode 100644 (file)
index 0000000..503a68a
--- /dev/null
@@ -0,0 +1,28 @@
+from typing import Optional
+
+from fastapi.concurrency import AsyncExitStack
+from starlette.types import ASGIApp, Receive, Scope, Send
+
+
+class AsyncExitStackMiddleware:
+    def __init__(self, app: ASGIApp, context_name: str = "fastapi_astack") -> None:
+        self.app = app
+        self.context_name = context_name
+
+    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+        if AsyncExitStack:
+            dependency_exception: Optional[Exception] = None
+            async with AsyncExitStack() as stack:
+                scope[self.context_name] = stack
+                try:
+                    await self.app(scope, receive, send)
+                except Exception as e:
+                    dependency_exception = e
+                    raise e
+            if dependency_exception:
+                # This exception was possibly handled by the dependency but it should
+                # still bubble up so that the ServerErrorMiddleware can return a 500
+                # or the ExceptionMiddleware can catch and handle any other exceptions
+                raise dependency_exception
+        else:
+            await self.app(scope, receive, send)  # pragma: no cover
index 3e42b47f7edf3a6e011f4e1fa6ad873af1cee6e1..03ef56c4d7e5bfedf4f4cdf94f18f3ec8b95af40 100644 (file)
@@ -235,7 +235,16 @@ def test_sync_raise_other():
     assert "/sync_raise" not in errors
 
 
-def test_async_raise():
+def test_async_raise_raises():
+    with pytest.raises(AsyncDependencyError):
+        client.get("/async_raise")
+    assert state["/async_raise"] == "asyncgen raise finalized"
+    assert "/async_raise" in errors
+    errors.clear()
+
+
+def test_async_raise_server_error():
+    client = TestClient(app, raise_server_exceptions=False)
     response = client.get("/async_raise")
     assert response.status_code == 500, response.text
     assert state["/async_raise"] == "asyncgen raise finalized"
@@ -270,7 +279,16 @@ def test_background_tasks():
     assert state["bg"] == "bg set - b: started b - a: started a"
 
 
-def test_sync_raise():
+def test_sync_raise_raises():
+    with pytest.raises(SyncDependencyError):
+        client.get("/sync_raise")
+    assert state["/sync_raise"] == "generator raise finalized"
+    assert "/sync_raise" in errors
+    errors.clear()
+
+
+def test_sync_raise_server_error():
+    client = TestClient(app, raise_server_exceptions=False)
     response = client.get("/sync_raise")
     assert response.status_code == 500, response.text
     assert state["/sync_raise"] == "generator raise finalized"
@@ -306,7 +324,16 @@ def test_sync_sync_raise_other():
     assert "/sync_raise" not in errors
 
 
-def test_sync_async_raise():
+def test_sync_async_raise_raises():
+    with pytest.raises(AsyncDependencyError):
+        client.get("/sync_async_raise")
+    assert state["/async_raise"] == "asyncgen raise finalized"
+    assert "/async_raise" in errors
+    errors.clear()
+
+
+def test_sync_async_raise_server_error():
+    client = TestClient(app, raise_server_exceptions=False)
     response = client.get("/sync_async_raise")
     assert response.status_code == 500, response.text
     assert state["/async_raise"] == "asyncgen raise finalized"
@@ -314,7 +341,16 @@ def test_sync_async_raise():
     errors.clear()
 
 
-def test_sync_sync_raise():
+def test_sync_sync_raise_raises():
+    with pytest.raises(SyncDependencyError):
+        client.get("/sync_sync_raise")
+    assert state["/sync_raise"] == "generator raise finalized"
+    assert "/sync_raise" in errors
+    errors.clear()
+
+
+def test_sync_sync_raise_server_error():
+    client = TestClient(app, raise_server_exceptions=False)
     response = client.get("/sync_sync_raise")
     assert response.status_code == 500, response.text
     assert state["/sync_raise"] == "generator raise finalized"
diff --git a/tests/test_dependency_contextvars.py b/tests/test_dependency_contextvars.py
new file mode 100644 (file)
index 0000000..076802d
--- /dev/null
@@ -0,0 +1,51 @@
+from contextvars import ContextVar
+from typing import Any, Awaitable, Callable, Dict, Optional
+
+from fastapi import Depends, FastAPI, Request, Response
+from fastapi.testclient import TestClient
+
+legacy_request_state_context_var: ContextVar[Optional[Dict[str, Any]]] = ContextVar(
+    "legacy_request_state_context_var", default=None
+)
+
+app = FastAPI()
+
+
+async def set_up_request_state_dependency():
+    request_state = {"user": "deadpond"}
+    contextvar_token = legacy_request_state_context_var.set(request_state)
+    yield request_state
+    legacy_request_state_context_var.reset(contextvar_token)
+
+
+@app.middleware("http")
+async def custom_middleware(
+    request: Request, call_next: Callable[[Request], Awaitable[Response]]
+):
+    response = await call_next(request)
+    response.headers["custom"] = "foo"
+    return response
+
+
+@app.get("/user", dependencies=[Depends(set_up_request_state_dependency)])
+def get_user():
+    request_state = legacy_request_state_context_var.get()
+    assert request_state
+    return request_state["user"]
+
+
+client = TestClient(app)
+
+
+def test_dependency_contextvars():
+    """
+    Check that custom middlewares don't affect the contextvar context for dependencies.
+
+    The code before yield and the code after yield should be run in the same contextvar
+    context, so that request_state_context_var.reset(contextvar_token).
+
+    If they are run in a different context, that raises an error.
+    """
+    response = client.get("/user")
+    assert response.json() == "deadpond"
+    assert response.headers["custom"] == "foo"
diff --git a/tests/test_dependency_normal_exceptions.py b/tests/test_dependency_normal_exceptions.py
new file mode 100644 (file)
index 0000000..49a19f4
--- /dev/null
@@ -0,0 +1,71 @@
+import pytest
+from fastapi import Body, Depends, FastAPI, HTTPException
+from fastapi.testclient import TestClient
+
+initial_fake_database = {"rick": "Rick Sanchez"}
+
+fake_database = initial_fake_database.copy()
+
+initial_state = {"except": False, "finally": False}
+
+state = initial_state.copy()
+
+app = FastAPI()
+
+
+async def get_database():
+    temp_database = fake_database.copy()
+    try:
+        yield temp_database
+        fake_database.update(temp_database)
+    except HTTPException:
+        state["except"] = True
+    finally:
+        state["finally"] = True
+
+
+@app.put("/invalid-user/{user_id}")
+def put_invalid_user(
+    user_id: str, name: str = Body(...), db: dict = Depends(get_database)
+):
+    db[user_id] = name
+    raise HTTPException(status_code=400, detail="Invalid user")
+
+
+@app.put("/user/{user_id}")
+def put_user(user_id: str, name: str = Body(...), db: dict = Depends(get_database)):
+    db[user_id] = name
+    return {"message": "OK"}
+
+
+@pytest.fixture(autouse=True)
+def reset_state_and_db():
+    global fake_database
+    global state
+    fake_database = initial_fake_database.copy()
+    state = initial_state.copy()
+
+
+client = TestClient(app)
+
+
+def test_dependency_gets_exception():
+    assert state["except"] is False
+    assert state["finally"] is False
+    response = client.put("/invalid-user/rick", json="Morty")
+    assert response.status_code == 400, response.text
+    assert response.json() == {"detail": "Invalid user"}
+    assert state["except"] is True
+    assert state["finally"] is True
+    assert fake_database["rick"] == "Rick Sanchez"
+
+
+def test_dependency_no_exception():
+    assert state["except"] is False
+    assert state["finally"] is False
+    response = client.put("/user/rick", json="Morty")
+    assert response.status_code == 200, response.text
+    assert response.json() == {"message": "OK"}
+    assert state["except"] is False
+    assert state["finally"] is True
+    assert fake_database["rick"] == "Morty"
index 6153f7ab925ed8108166eb44b2f177306358593e..67a4becec7adcd86ba48961949dd088a63f5fddf 100644 (file)
@@ -1,3 +1,4 @@
+import pytest
 from fastapi import FastAPI, HTTPException
 from fastapi.exceptions import RequestValidationError
 from fastapi.testclient import TestClient
@@ -12,10 +13,15 @@ def request_validation_exception_handler(request, exception):
     return JSONResponse({"exception": "request-validation"})
 
 
+def server_error_exception_handler(request, exception):
+    return JSONResponse(status_code=500, content={"exception": "server-error"})
+
+
 app = FastAPI(
     exception_handlers={
         HTTPException: http_exception_handler,
         RequestValidationError: request_validation_exception_handler,
+        Exception: server_error_exception_handler,
     }
 )
 
@@ -32,6 +38,11 @@ def route_with_request_validation_exception(param: int):
     pass  # pragma: no cover
 
 
+@app.get("/server-error")
+def route_with_server_error():
+    raise RuntimeError("Oops!")
+
+
 def test_override_http_exception():
     response = client.get("/http-exception")
     assert response.status_code == 200
@@ -42,3 +53,15 @@ def test_override_request_validation_exception():
     response = client.get("/request-validation/invalid")
     assert response.status_code == 200
     assert response.json() == {"exception": "request-validation"}
+
+
+def test_override_server_error_exception_raises():
+    with pytest.raises(RuntimeError):
+        client.get("/server-error")
+
+
+def test_override_server_error_exception_response():
+    client = TestClient(app, raise_server_exceptions=False)
+    response = client.get("/server-error")
+    assert response.status_code == 500
+    assert response.json() == {"exception": "server-error"}