]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
🐛 Preserve traceback when exception is raised in sync dependency with `yield` (#5823)
authorAbdullah Hashim <sombek990@hotmail.com>
Tue, 3 Dec 2024 22:37:12 +0000 (01:37 +0300)
committerGitHub <noreply@github.com>
Tue, 3 Dec 2024 22:37:12 +0000 (23:37 +0100)
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
fastapi/concurrency.py
tests/test_exception_handlers.py

index 894bd3ed11873b1eecb2a5c0b85ad575b6f7de71..3202c70789ad37f934a67bfe19d3bc8f45efe513 100644 (file)
@@ -1,7 +1,7 @@
 from contextlib import asynccontextmanager as asynccontextmanager
 from typing import AsyncGenerator, ContextManager, TypeVar
 
-import anyio
+import anyio.to_thread
 from anyio import CapacityLimiter
 from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool  # noqa
 from starlette.concurrency import run_in_threadpool as run_in_threadpool  # noqa
@@ -28,7 +28,7 @@ async def contextmanager_in_threadpool(
     except Exception as e:
         ok = bool(
             await anyio.to_thread.run_sync(
-                cm.__exit__, type(e), e, None, limiter=exit_limiter
+                cm.__exit__, type(e), e, e.__traceback__, limiter=exit_limiter
             )
         )
         if not ok:
index 67a4becec7adcd86ba48961949dd088a63f5fddf..6a3cbd830d0bc0fc19449f12b689b598da211460 100644 (file)
@@ -1,5 +1,5 @@
 import pytest
-from fastapi import FastAPI, HTTPException
+from fastapi import Depends, FastAPI, HTTPException
 from fastapi.exceptions import RequestValidationError
 from fastapi.testclient import TestClient
 from starlette.responses import JSONResponse
@@ -28,6 +28,18 @@ app = FastAPI(
 client = TestClient(app)
 
 
+def raise_value_error():
+    raise ValueError()
+
+
+def dependency_with_yield():
+    yield raise_value_error()
+
+
+@app.get("/dependency-with-yield", dependencies=[Depends(dependency_with_yield)])
+def with_yield(): ...
+
+
 @app.get("/http-exception")
 def route_with_http_exception():
     raise HTTPException(status_code=400)
@@ -65,3 +77,12 @@ def test_override_server_error_exception_response():
     response = client.get("/server-error")
     assert response.status_code == 500
     assert response.json() == {"exception": "server-error"}
+
+
+def test_traceback_for_dependency_with_yield():
+    client = TestClient(app, raise_server_exceptions=True)
+    with pytest.raises(ValueError) as exc_info:
+        client.get("/dependency-with-yield")
+    last_frame = exc_info.traceback[-1]
+    assert str(last_frame.path) == __file__
+    assert last_frame.lineno == raise_value_error.__code__.co_firstlineno