]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
🚸 Improve tracebacks by adding endpoint metadata (#14306)
authorSavannah Ostrowski <savannah@python.org>
Sat, 6 Dec 2025 12:21:57 +0000 (04:21 -0800)
committerGitHub <noreply@github.com>
Sat, 6 Dec 2025 12:21:57 +0000 (12:21 +0000)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
fastapi/exceptions.py
fastapi/routing.py
tests/test_validation_error_context.py [new file with mode: 0644]

index 0620428be749bf5f5be47ba950d68c3536735fb0..a46e823506c64b22a666868d76f00ff934f149d1 100644 (file)
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional, Sequence, Type, Union
+from typing import Any, Dict, Optional, Sequence, Type, TypedDict, Union
 
 from annotated_doc import Doc
 from pydantic import BaseModel, create_model
@@ -7,6 +7,13 @@ from starlette.exceptions import WebSocketException as StarletteWebSocketExcepti
 from typing_extensions import Annotated
 
 
+class EndpointContext(TypedDict, total=False):
+    function: str
+    path: str
+    file: str
+    line: int
+
+
 class HTTPException(StarletteHTTPException):
     """
     An HTTP exception you can raise in your own code to show errors to the client.
@@ -155,30 +162,72 @@ class DependencyScopeError(FastAPIError):
 
 
 class ValidationException(Exception):
-    def __init__(self, errors: Sequence[Any]) -> None:
+    def __init__(
+        self,
+        errors: Sequence[Any],
+        *,
+        endpoint_ctx: Optional[EndpointContext] = None,
+    ) -> None:
         self._errors = errors
+        self.endpoint_ctx = endpoint_ctx
+
+        ctx = endpoint_ctx or {}
+        self.endpoint_function = ctx.get("function")
+        self.endpoint_path = ctx.get("path")
+        self.endpoint_file = ctx.get("file")
+        self.endpoint_line = ctx.get("line")
 
     def errors(self) -> Sequence[Any]:
         return self._errors
 
+    def _format_endpoint_context(self) -> str:
+        if not (self.endpoint_file and self.endpoint_line and self.endpoint_function):
+            if self.endpoint_path:
+                return f"\n  Endpoint: {self.endpoint_path}"
+            return ""
+
+        context = f'\n  File "{self.endpoint_file}", line {self.endpoint_line}, in {self.endpoint_function}'
+        if self.endpoint_path:
+            context += f"\n    {self.endpoint_path}"
+        return context
+
+    def __str__(self) -> str:
+        message = f"{len(self._errors)} validation error{'s' if len(self._errors) != 1 else ''}:\n"
+        for err in self._errors:
+            message += f"  {err}\n"
+        message += self._format_endpoint_context()
+        return message.rstrip()
+
 
 class RequestValidationError(ValidationException):
-    def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
-        super().__init__(errors)
+    def __init__(
+        self,
+        errors: Sequence[Any],
+        *,
+        body: Any = None,
+        endpoint_ctx: Optional[EndpointContext] = None,
+    ) -> None:
+        super().__init__(errors, endpoint_ctx=endpoint_ctx)
         self.body = body
 
 
 class WebSocketRequestValidationError(ValidationException):
-    pass
+    def __init__(
+        self,
+        errors: Sequence[Any],
+        *,
+        endpoint_ctx: Optional[EndpointContext] = None,
+    ) -> None:
+        super().__init__(errors, endpoint_ctx=endpoint_ctx)
 
 
 class ResponseValidationError(ValidationException):
-    def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
-        super().__init__(errors)
+    def __init__(
+        self,
+        errors: Sequence[Any],
+        *,
+        body: Any = None,
+        endpoint_ctx: Optional[EndpointContext] = None,
+    ) -> None:
+        super().__init__(errors, endpoint_ctx=endpoint_ctx)
         self.body = body
-
-    def __str__(self) -> str:
-        message = f"{len(self._errors)} validation errors:\n"
-        for err in self._errors:
-            message += f"  {err}\n"
-        return message
index c10175b1614291471f505158d2db88822c802a23..9be2b44bc1beb2c2ff3b67c93733439ab6688450 100644 (file)
@@ -46,6 +46,7 @@ from fastapi.dependencies.utils import (
 )
 from fastapi.encoders import jsonable_encoder
 from fastapi.exceptions import (
+    EndpointContext,
     FastAPIError,
     RequestValidationError,
     ResponseValidationError,
@@ -212,6 +213,33 @@ def _merge_lifespan_context(
     return merged_lifespan  # type: ignore[return-value]
 
 
+# Cache for endpoint context to avoid re-extracting on every request
+_endpoint_context_cache: Dict[int, EndpointContext] = {}
+
+
+def _extract_endpoint_context(func: Any) -> EndpointContext:
+    """Extract endpoint context with caching to avoid repeated file I/O."""
+    func_id = id(func)
+
+    if func_id in _endpoint_context_cache:
+        return _endpoint_context_cache[func_id]
+
+    try:
+        ctx: EndpointContext = {}
+
+        if (source_file := inspect.getsourcefile(func)) is not None:
+            ctx["file"] = source_file
+        if (line_number := inspect.getsourcelines(func)[1]) is not None:
+            ctx["line"] = line_number
+        if (func_name := getattr(func, "__name__", None)) is not None:
+            ctx["function"] = func_name
+    except Exception:
+        ctx = EndpointContext()
+
+    _endpoint_context_cache[func_id] = ctx
+    return ctx
+
+
 async def serialize_response(
     *,
     field: Optional[ModelField] = None,
@@ -223,6 +251,7 @@ async def serialize_response(
     exclude_defaults: bool = False,
     exclude_none: bool = False,
     is_coroutine: bool = True,
+    endpoint_ctx: Optional[EndpointContext] = None,
 ) -> Any:
     if field:
         errors = []
@@ -245,8 +274,11 @@ async def serialize_response(
         elif errors_:
             errors.append(errors_)
         if errors:
+            ctx = endpoint_ctx or EndpointContext()
             raise ResponseValidationError(
-                errors=_normalize_errors(errors), body=response_content
+                errors=_normalize_errors(errors),
+                body=response_content,
+                endpoint_ctx=ctx,
             )
 
         if hasattr(field, "serialize"):
@@ -318,6 +350,18 @@ def get_request_handler(
             "fastapi_middleware_astack not found in request scope"
         )
 
+        # Extract endpoint context for error messages
+        endpoint_ctx = (
+            _extract_endpoint_context(dependant.call)
+            if dependant.call
+            else EndpointContext()
+        )
+
+        if dependant.path:
+            # For mounted sub-apps, include the mount path prefix
+            mount_path = request.scope.get("root_path", "").rstrip("/")
+            endpoint_ctx["path"] = f"{request.method} {mount_path}{dependant.path}"
+
         # Read body and auto-close files
         try:
             body: Any = None
@@ -355,6 +399,7 @@ def get_request_handler(
                     }
                 ],
                 body=e.doc,
+                endpoint_ctx=endpoint_ctx,
             )
             raise validation_error from e
         except HTTPException:
@@ -414,6 +459,7 @@ def get_request_handler(
                     exclude_defaults=response_model_exclude_defaults,
                     exclude_none=response_model_exclude_none,
                     is_coroutine=is_coroutine,
+                    endpoint_ctx=endpoint_ctx,
                 )
                 response = actual_response_class(content, **response_args)
                 if not is_body_allowed_for_status_code(response.status_code):
@@ -421,7 +467,7 @@ def get_request_handler(
                 response.headers.raw.extend(solved_result.response.headers.raw)
         if errors:
             validation_error = RequestValidationError(
-                _normalize_errors(errors), body=body
+                _normalize_errors(errors), body=body, endpoint_ctx=endpoint_ctx
             )
             raise validation_error
 
@@ -438,6 +484,15 @@ def get_websocket_app(
     embed_body_fields: bool = False,
 ) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
     async def app(websocket: WebSocket) -> None:
+        endpoint_ctx = (
+            _extract_endpoint_context(dependant.call)
+            if dependant.call
+            else EndpointContext()
+        )
+        if dependant.path:
+            # For mounted sub-apps, include the mount path prefix
+            mount_path = websocket.scope.get("root_path", "").rstrip("/")
+            endpoint_ctx["path"] = f"WS {mount_path}{dependant.path}"
         async_exit_stack = websocket.scope.get("fastapi_inner_astack")
         assert isinstance(async_exit_stack, AsyncExitStack), (
             "fastapi_inner_astack not found in request scope"
@@ -451,7 +506,8 @@ def get_websocket_app(
         )
         if solved_result.errors:
             raise WebSocketRequestValidationError(
-                _normalize_errors(solved_result.errors)
+                _normalize_errors(solved_result.errors),
+                endpoint_ctx=endpoint_ctx,
             )
         assert dependant.call is not None, "dependant.call must be a function"
         await dependant.call(**solved_result.values)
diff --git a/tests/test_validation_error_context.py b/tests/test_validation_error_context.py
new file mode 100644 (file)
index 0000000..844b8a6
--- /dev/null
@@ -0,0 +1,168 @@
+from fastapi import FastAPI, Request, WebSocket
+from fastapi.exceptions import (
+    RequestValidationError,
+    ResponseValidationError,
+    WebSocketRequestValidationError,
+)
+from fastapi.testclient import TestClient
+from pydantic import BaseModel
+
+
+class Item(BaseModel):
+    id: int
+    name: str
+
+
+class ExceptionCapture:
+    def __init__(self):
+        self.exception = None
+
+    def capture(self, exc):
+        self.exception = exc
+        return exc
+
+
+app = FastAPI()
+sub_app = FastAPI()
+captured_exception = ExceptionCapture()
+
+app.mount(path="/sub", app=sub_app)
+
+
+@app.exception_handler(RequestValidationError)
+@sub_app.exception_handler(RequestValidationError)
+async def request_validation_handler(request: Request, exc: RequestValidationError):
+    captured_exception.capture(exc)
+    raise exc
+
+
+@app.exception_handler(ResponseValidationError)
+@sub_app.exception_handler(ResponseValidationError)
+async def response_validation_handler(_: Request, exc: ResponseValidationError):
+    captured_exception.capture(exc)
+    raise exc
+
+
+@app.exception_handler(WebSocketRequestValidationError)
+@sub_app.exception_handler(WebSocketRequestValidationError)
+async def websocket_validation_handler(
+    websocket: WebSocket, exc: WebSocketRequestValidationError
+):
+    captured_exception.capture(exc)
+    raise exc
+
+
+@app.get("/users/{user_id}")
+def get_user(user_id: int):
+    return {"user_id": user_id}  # pragma: no cover
+
+
+@app.get("/items/", response_model=Item)
+def get_item():
+    return {"name": "Widget"}
+
+
+@sub_app.get("/items/", response_model=Item)
+def get_sub_item():
+    return {"name": "Widget"}  # pragma: no cover
+
+
+@app.websocket("/ws/{item_id}")
+async def websocket_endpoint(websocket: WebSocket, item_id: int):
+    await websocket.accept()  # pragma: no cover
+    await websocket.send_text(f"Item: {item_id}")  # pragma: no cover
+    await websocket.close()  # pragma: no cover
+
+
+@sub_app.websocket("/ws/{item_id}")
+async def subapp_websocket_endpoint(websocket: WebSocket, item_id: int):
+    await websocket.accept()  # pragma: no cover
+    await websocket.send_text(f"Item: {item_id}")  # pragma: no cover
+    await websocket.close()  # pragma: no cover
+
+
+client = TestClient(app)
+
+
+def test_request_validation_error_includes_endpoint_context():
+    captured_exception.exception = None
+    try:
+        client.get("/users/invalid")
+    except Exception:
+        pass
+
+    assert captured_exception.exception is not None
+    error_str = str(captured_exception.exception)
+    assert "get_user" in error_str
+    assert "/users/" in error_str
+
+
+def test_response_validation_error_includes_endpoint_context():
+    captured_exception.exception = None
+    try:
+        client.get("/items/")
+    except Exception:
+        pass
+
+    assert captured_exception.exception is not None
+    error_str = str(captured_exception.exception)
+    assert "get_item" in error_str
+    assert "/items/" in error_str
+
+
+def test_websocket_validation_error_includes_endpoint_context():
+    captured_exception.exception = None
+    try:
+        with client.websocket_connect("/ws/invalid"):
+            pass  # pragma: no cover
+    except Exception:
+        pass
+
+    assert captured_exception.exception is not None
+    error_str = str(captured_exception.exception)
+    assert "websocket_endpoint" in error_str
+    assert "/ws/" in error_str
+
+
+def test_subapp_request_validation_error_includes_endpoint_context():
+    captured_exception.exception = None
+    try:
+        client.get("/sub/items/")
+    except Exception:
+        pass
+
+    assert captured_exception.exception is not None
+    error_str = str(captured_exception.exception)
+    assert "get_sub_item" in error_str
+    assert "/sub/items/" in error_str
+
+
+def test_subapp_websocket_validation_error_includes_endpoint_context():
+    captured_exception.exception = None
+    try:
+        with client.websocket_connect("/sub/ws/invalid"):
+            pass  # pragma: no cover
+    except Exception:
+        pass
+
+    assert captured_exception.exception is not None
+    error_str = str(captured_exception.exception)
+    assert "subapp_websocket_endpoint" in error_str
+    assert "/sub/ws/" in error_str
+
+
+def test_validation_error_with_only_path():
+    errors = [{"type": "missing", "loc": ("body", "name"), "msg": "Field required"}]
+    exc = RequestValidationError(errors, endpoint_ctx={"path": "GET /api/test"})
+    error_str = str(exc)
+    assert "Endpoint: GET /api/test" in error_str
+    assert 'File "' not in error_str
+
+
+def test_validation_error_with_no_context():
+    errors = [{"type": "missing", "loc": ("body", "name"), "msg": "Field required"}]
+    exc = RequestValidationError(errors, endpoint_ctx={})
+    error_str = str(exc)
+    assert "1 validation error:" in error_str
+    assert "Endpoint" not in error_str
+    assert 'File "' not in error_str