It might be tempting to raise an `HTTPException` or similar in the exit code, after the `yield`. But **it won't work**.
-The exit code in dependencies with `yield` is executed *after* [Exception Handlers](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank}. There's nothing catching exceptions thrown by your dependencies in the exit code (after the `yield`).
+The exit code in dependencies with `yield` is executed *after* the response is sent, so [Exception Handlers](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank} will have already run. There's nothing catching exceptions thrown by your dependencies in the exit code (after the `yield`).
So, if you raise an `HTTPException` after the `yield`, the default (or any custom) exception handler that catches `HTTPException`s and returns an HTTP 400 response won't be there to catch that exception anymore.
end
dep ->> operation: Run dependency, e.g. DB session
opt raise
- operation -->> handler: Raise HTTPException
+ operation -->> dep: Raise HTTPException
+ dep -->> handler: Auto forward exception
handler -->> client: HTTP error response
operation -->> dep: Raise other exception
+ dep -->> handler: Auto forward exception
end
operation ->> client: Return response to client
Note over client,operation: Response is already sent, can't change it anymore
After one of those responses is sent, no other response can be sent.
!!! tip
- This diagram shows `HTTPException`, but you could also raise any other exception for which you create a [Custom Exception Handler](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank}. And that exception would be handled by that custom exception handler instead of the dependency exit code.
+ This diagram shows `HTTPException`, but you could also raise any other exception for which you create a [Custom Exception Handler](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank}.
- But if you raise an exception that is not handled by the exception handlers, it will be handled by the exit code of the dependency.
+ If you raise any exception, it will be passed to the dependencies with yield, including `HTTPException`, and then **again** to the exception handlers. If there's no exception handler for that exception, it will then be handled by the default internal `ServerErrorMiddleware`, returning a 500 HTTP status code, to let the client know that there was an error in the server.
## Context Managers
from typing import Any, Callable, Coroutine, Dict, List, Optional, Sequence, Type, Union
from fastapi import routing
-from fastapi.concurrency import AsyncExitStack
from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.encoders import DictIntStrAny, SetIntStr
from fastapi.exception_handlers import (
)
from fastapi.exceptions import RequestValidationError
from fastapi.logger import logger
+from fastapi.middleware.asyncexitstack import AsyncExitStackMiddleware
from fastapi.openapi.docs import (
get_redoc_html,
get_swagger_ui_html,
from fastapi.types import DecoratedCallable
from starlette.applications import Starlette
from starlette.datastructures import State
-from starlette.exceptions import HTTPException
+from starlette.exceptions import ExceptionMiddleware, HTTPException
from starlette.middleware import Middleware
+from starlette.middleware.errors import ServerErrorMiddleware
from starlette.requests import Request
from starlette.responses import HTMLResponse, JSONResponse, Response
from starlette.routing import BaseRoute
self.openapi_schema: Optional[Dict[str, Any]] = None
self.setup()
+ def build_middleware_stack(self) -> ASGIApp:
+ # Duplicate/override from Starlette to add AsyncExitStackMiddleware
+ # inside of ExceptionMiddleware, inside of custom user middlewares
+ debug = self.debug
+ error_handler = None
+ exception_handlers = {}
+
+ for key, value in self.exception_handlers.items():
+ if key in (500, Exception):
+ error_handler = value
+ else:
+ exception_handlers[key] = value
+
+ middleware = (
+ [Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
+ + self.user_middleware
+ + [
+ Middleware(
+ ExceptionMiddleware, handlers=exception_handlers, debug=debug
+ ),
+ # Add FastAPI-specific AsyncExitStackMiddleware for dependencies with
+ # contextvars.
+ # This needs to happen after user middlewares because those create a
+ # new contextvars context copy by using a new AnyIO task group.
+ # The initial part of dependencies with yield is executed in the
+ # FastAPI code, inside all the middlewares, but the teardown part
+ # (after yield) is executed in the AsyncExitStack in this middleware,
+ # if the AsyncExitStack lived outside of the custom middlewares and
+ # contextvars were set in a dependency with yield in that internal
+ # contextvars context, the values would not be available in the
+ # outside context of the AsyncExitStack.
+ # By putting the middleware and the AsyncExitStack here, inside all
+ # user middlewares, the code before and after yield in dependencies
+ # with yield is executed in the same contextvars context, so all values
+ # set in contextvars before yield is still available after yield as
+ # would be expected.
+ # Additionally, by having this AsyncExitStack here, after the
+ # ExceptionMiddleware, now dependencies can catch handled exceptions,
+ # e.g. HTTPException, to customize the teardown code (e.g. DB session
+ # rollback).
+ Middleware(AsyncExitStackMiddleware),
+ ]
+ )
+
+ app = self.router
+ for cls, options in reversed(middleware):
+ app = cls(app=app, **options)
+ return app
+
def openapi(self) -> Dict[str, Any]:
if not self.openapi_schema:
self.openapi_schema = get_openapi(
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.root_path:
scope["root_path"] = self.root_path
- if AsyncExitStack:
- async with AsyncExitStack() as stack:
- scope["fastapi_astack"] = stack
- await super().__call__(scope, receive, send)
- else:
- await super().__call__(scope, receive, send) # pragma: no cover
+ await super().__call__(scope, receive, send)
def add_api_route(
self,
--- /dev/null
+from typing import Optional
+
+from fastapi.concurrency import AsyncExitStack
+from starlette.types import ASGIApp, Receive, Scope, Send
+
+
+class AsyncExitStackMiddleware:
+ def __init__(self, app: ASGIApp, context_name: str = "fastapi_astack") -> None:
+ self.app = app
+ self.context_name = context_name
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ if AsyncExitStack:
+ dependency_exception: Optional[Exception] = None
+ async with AsyncExitStack() as stack:
+ scope[self.context_name] = stack
+ try:
+ await self.app(scope, receive, send)
+ except Exception as e:
+ dependency_exception = e
+ raise e
+ if dependency_exception:
+ # This exception was possibly handled by the dependency but it should
+ # still bubble up so that the ServerErrorMiddleware can return a 500
+ # or the ExceptionMiddleware can catch and handle any other exceptions
+ raise dependency_exception
+ else:
+ await self.app(scope, receive, send) # pragma: no cover
assert "/sync_raise" not in errors
-def test_async_raise():
+def test_async_raise_raises():
+ with pytest.raises(AsyncDependencyError):
+ client.get("/async_raise")
+ assert state["/async_raise"] == "asyncgen raise finalized"
+ assert "/async_raise" in errors
+ errors.clear()
+
+
+def test_async_raise_server_error():
+ client = TestClient(app, raise_server_exceptions=False)
response = client.get("/async_raise")
assert response.status_code == 500, response.text
assert state["/async_raise"] == "asyncgen raise finalized"
assert state["bg"] == "bg set - b: started b - a: started a"
-def test_sync_raise():
+def test_sync_raise_raises():
+ with pytest.raises(SyncDependencyError):
+ client.get("/sync_raise")
+ assert state["/sync_raise"] == "generator raise finalized"
+ assert "/sync_raise" in errors
+ errors.clear()
+
+
+def test_sync_raise_server_error():
+ client = TestClient(app, raise_server_exceptions=False)
response = client.get("/sync_raise")
assert response.status_code == 500, response.text
assert state["/sync_raise"] == "generator raise finalized"
assert "/sync_raise" not in errors
-def test_sync_async_raise():
+def test_sync_async_raise_raises():
+ with pytest.raises(AsyncDependencyError):
+ client.get("/sync_async_raise")
+ assert state["/async_raise"] == "asyncgen raise finalized"
+ assert "/async_raise" in errors
+ errors.clear()
+
+
+def test_sync_async_raise_server_error():
+ client = TestClient(app, raise_server_exceptions=False)
response = client.get("/sync_async_raise")
assert response.status_code == 500, response.text
assert state["/async_raise"] == "asyncgen raise finalized"
errors.clear()
-def test_sync_sync_raise():
+def test_sync_sync_raise_raises():
+ with pytest.raises(SyncDependencyError):
+ client.get("/sync_sync_raise")
+ assert state["/sync_raise"] == "generator raise finalized"
+ assert "/sync_raise" in errors
+ errors.clear()
+
+
+def test_sync_sync_raise_server_error():
+ client = TestClient(app, raise_server_exceptions=False)
response = client.get("/sync_sync_raise")
assert response.status_code == 500, response.text
assert state["/sync_raise"] == "generator raise finalized"
--- /dev/null
+from contextvars import ContextVar
+from typing import Any, Awaitable, Callable, Dict, Optional
+
+from fastapi import Depends, FastAPI, Request, Response
+from fastapi.testclient import TestClient
+
+legacy_request_state_context_var: ContextVar[Optional[Dict[str, Any]]] = ContextVar(
+ "legacy_request_state_context_var", default=None
+)
+
+app = FastAPI()
+
+
+async def set_up_request_state_dependency():
+ request_state = {"user": "deadpond"}
+ contextvar_token = legacy_request_state_context_var.set(request_state)
+ yield request_state
+ legacy_request_state_context_var.reset(contextvar_token)
+
+
+@app.middleware("http")
+async def custom_middleware(
+ request: Request, call_next: Callable[[Request], Awaitable[Response]]
+):
+ response = await call_next(request)
+ response.headers["custom"] = "foo"
+ return response
+
+
+@app.get("/user", dependencies=[Depends(set_up_request_state_dependency)])
+def get_user():
+ request_state = legacy_request_state_context_var.get()
+ assert request_state
+ return request_state["user"]
+
+
+client = TestClient(app)
+
+
+def test_dependency_contextvars():
+ """
+ Check that custom middlewares don't affect the contextvar context for dependencies.
+
+ The code before yield and the code after yield should be run in the same contextvar
+ context, so that request_state_context_var.reset(contextvar_token).
+
+ If they are run in a different context, that raises an error.
+ """
+ response = client.get("/user")
+ assert response.json() == "deadpond"
+ assert response.headers["custom"] == "foo"
--- /dev/null
+import pytest
+from fastapi import Body, Depends, FastAPI, HTTPException
+from fastapi.testclient import TestClient
+
+initial_fake_database = {"rick": "Rick Sanchez"}
+
+fake_database = initial_fake_database.copy()
+
+initial_state = {"except": False, "finally": False}
+
+state = initial_state.copy()
+
+app = FastAPI()
+
+
+async def get_database():
+ temp_database = fake_database.copy()
+ try:
+ yield temp_database
+ fake_database.update(temp_database)
+ except HTTPException:
+ state["except"] = True
+ finally:
+ state["finally"] = True
+
+
+@app.put("/invalid-user/{user_id}")
+def put_invalid_user(
+ user_id: str, name: str = Body(...), db: dict = Depends(get_database)
+):
+ db[user_id] = name
+ raise HTTPException(status_code=400, detail="Invalid user")
+
+
+@app.put("/user/{user_id}")
+def put_user(user_id: str, name: str = Body(...), db: dict = Depends(get_database)):
+ db[user_id] = name
+ return {"message": "OK"}
+
+
+@pytest.fixture(autouse=True)
+def reset_state_and_db():
+ global fake_database
+ global state
+ fake_database = initial_fake_database.copy()
+ state = initial_state.copy()
+
+
+client = TestClient(app)
+
+
+def test_dependency_gets_exception():
+ assert state["except"] is False
+ assert state["finally"] is False
+ response = client.put("/invalid-user/rick", json="Morty")
+ assert response.status_code == 400, response.text
+ assert response.json() == {"detail": "Invalid user"}
+ assert state["except"] is True
+ assert state["finally"] is True
+ assert fake_database["rick"] == "Rick Sanchez"
+
+
+def test_dependency_no_exception():
+ assert state["except"] is False
+ assert state["finally"] is False
+ response = client.put("/user/rick", json="Morty")
+ assert response.status_code == 200, response.text
+ assert response.json() == {"message": "OK"}
+ assert state["except"] is False
+ assert state["finally"] is True
+ assert fake_database["rick"] == "Morty"
+import pytest
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.testclient import TestClient
return JSONResponse({"exception": "request-validation"})
+def server_error_exception_handler(request, exception):
+ return JSONResponse(status_code=500, content={"exception": "server-error"})
+
+
app = FastAPI(
exception_handlers={
HTTPException: http_exception_handler,
RequestValidationError: request_validation_exception_handler,
+ Exception: server_error_exception_handler,
}
)
pass # pragma: no cover
+@app.get("/server-error")
+def route_with_server_error():
+ raise RuntimeError("Oops!")
+
+
def test_override_http_exception():
response = client.get("/http-exception")
assert response.status_code == 200
response = client.get("/request-validation/invalid")
assert response.status_code == 200
assert response.json() == {"exception": "request-validation"}
+
+
+def test_override_server_error_exception_raises():
+ with pytest.raises(RuntimeError):
+ client.get("/server-error")
+
+
+def test_override_server_error_exception_response():
+ client = TestClient(app, raise_server_exceptions=False)
+ response = client.get("/server-error")
+ assert response.status_code == 500
+ assert response.json() == {"exception": "server-error"}