From e1117f75505bbdb2d42321a009dbf26c9c2b8b6d Mon Sep 17 00:00:00 2001 From: Savannah Ostrowski Date: Sat, 6 Dec 2025 04:21:57 -0800 Subject: [PATCH] =?utf8?q?=F0=9F=9A=B8=20=20Improve=20tracebacks=20by=20ad?= =?utf8?q?ding=20endpoint=20metadata=20(#14306)?= MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sebastián Ramírez --- fastapi/exceptions.py | 75 +++++++++-- fastapi/routing.py | 62 ++++++++- tests/test_validation_error_context.py | 168 +++++++++++++++++++++++++ 3 files changed, 289 insertions(+), 16 deletions(-) create mode 100644 tests/test_validation_error_context.py diff --git a/fastapi/exceptions.py b/fastapi/exceptions.py index 0620428be..a46e82350 100644 --- a/fastapi/exceptions.py +++ b/fastapi/exceptions.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Sequence, Type, Union +from typing import Any, Dict, Optional, Sequence, Type, TypedDict, Union from annotated_doc import Doc from pydantic import BaseModel, create_model @@ -7,6 +7,13 @@ from starlette.exceptions import WebSocketException as StarletteWebSocketExcepti from typing_extensions import Annotated +class EndpointContext(TypedDict, total=False): + function: str + path: str + file: str + line: int + + class HTTPException(StarletteHTTPException): """ An HTTP exception you can raise in your own code to show errors to the client. @@ -155,30 +162,72 @@ class DependencyScopeError(FastAPIError): class ValidationException(Exception): - def __init__(self, errors: Sequence[Any]) -> None: + def __init__( + self, + errors: Sequence[Any], + *, + endpoint_ctx: Optional[EndpointContext] = None, + ) -> None: self._errors = errors + self.endpoint_ctx = endpoint_ctx + + ctx = endpoint_ctx or {} + self.endpoint_function = ctx.get("function") + self.endpoint_path = ctx.get("path") + self.endpoint_file = ctx.get("file") + self.endpoint_line = ctx.get("line") def errors(self) -> Sequence[Any]: return self._errors + def _format_endpoint_context(self) -> str: + if not (self.endpoint_file and self.endpoint_line and self.endpoint_function): + if self.endpoint_path: + return f"\n Endpoint: {self.endpoint_path}" + return "" + + context = f'\n File "{self.endpoint_file}", line {self.endpoint_line}, in {self.endpoint_function}' + if self.endpoint_path: + context += f"\n {self.endpoint_path}" + return context + + def __str__(self) -> str: + message = f"{len(self._errors)} validation error{'s' if len(self._errors) != 1 else ''}:\n" + for err in self._errors: + message += f" {err}\n" + message += self._format_endpoint_context() + return message.rstrip() + class RequestValidationError(ValidationException): - def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None: - super().__init__(errors) + def __init__( + self, + errors: Sequence[Any], + *, + body: Any = None, + endpoint_ctx: Optional[EndpointContext] = None, + ) -> None: + super().__init__(errors, endpoint_ctx=endpoint_ctx) self.body = body class WebSocketRequestValidationError(ValidationException): - pass + def __init__( + self, + errors: Sequence[Any], + *, + endpoint_ctx: Optional[EndpointContext] = None, + ) -> None: + super().__init__(errors, endpoint_ctx=endpoint_ctx) class ResponseValidationError(ValidationException): - def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None: - super().__init__(errors) + def __init__( + self, + errors: Sequence[Any], + *, + body: Any = None, + endpoint_ctx: Optional[EndpointContext] = None, + ) -> None: + super().__init__(errors, endpoint_ctx=endpoint_ctx) self.body = body - - def __str__(self) -> str: - message = f"{len(self._errors)} validation errors:\n" - for err in self._errors: - message += f" {err}\n" - return message diff --git a/fastapi/routing.py b/fastapi/routing.py index c10175b16..9be2b44bc 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -46,6 +46,7 @@ from fastapi.dependencies.utils import ( ) from fastapi.encoders import jsonable_encoder from fastapi.exceptions import ( + EndpointContext, FastAPIError, RequestValidationError, ResponseValidationError, @@ -212,6 +213,33 @@ def _merge_lifespan_context( return merged_lifespan # type: ignore[return-value] +# Cache for endpoint context to avoid re-extracting on every request +_endpoint_context_cache: Dict[int, EndpointContext] = {} + + +def _extract_endpoint_context(func: Any) -> EndpointContext: + """Extract endpoint context with caching to avoid repeated file I/O.""" + func_id = id(func) + + if func_id in _endpoint_context_cache: + return _endpoint_context_cache[func_id] + + try: + ctx: EndpointContext = {} + + if (source_file := inspect.getsourcefile(func)) is not None: + ctx["file"] = source_file + if (line_number := inspect.getsourcelines(func)[1]) is not None: + ctx["line"] = line_number + if (func_name := getattr(func, "__name__", None)) is not None: + ctx["function"] = func_name + except Exception: + ctx = EndpointContext() + + _endpoint_context_cache[func_id] = ctx + return ctx + + async def serialize_response( *, field: Optional[ModelField] = None, @@ -223,6 +251,7 @@ async def serialize_response( exclude_defaults: bool = False, exclude_none: bool = False, is_coroutine: bool = True, + endpoint_ctx: Optional[EndpointContext] = None, ) -> Any: if field: errors = [] @@ -245,8 +274,11 @@ async def serialize_response( elif errors_: errors.append(errors_) if errors: + ctx = endpoint_ctx or EndpointContext() raise ResponseValidationError( - errors=_normalize_errors(errors), body=response_content + errors=_normalize_errors(errors), + body=response_content, + endpoint_ctx=ctx, ) if hasattr(field, "serialize"): @@ -318,6 +350,18 @@ def get_request_handler( "fastapi_middleware_astack not found in request scope" ) + # Extract endpoint context for error messages + endpoint_ctx = ( + _extract_endpoint_context(dependant.call) + if dependant.call + else EndpointContext() + ) + + if dependant.path: + # For mounted sub-apps, include the mount path prefix + mount_path = request.scope.get("root_path", "").rstrip("/") + endpoint_ctx["path"] = f"{request.method} {mount_path}{dependant.path}" + # Read body and auto-close files try: body: Any = None @@ -355,6 +399,7 @@ def get_request_handler( } ], body=e.doc, + endpoint_ctx=endpoint_ctx, ) raise validation_error from e except HTTPException: @@ -414,6 +459,7 @@ def get_request_handler( exclude_defaults=response_model_exclude_defaults, exclude_none=response_model_exclude_none, is_coroutine=is_coroutine, + endpoint_ctx=endpoint_ctx, ) response = actual_response_class(content, **response_args) if not is_body_allowed_for_status_code(response.status_code): @@ -421,7 +467,7 @@ def get_request_handler( response.headers.raw.extend(solved_result.response.headers.raw) if errors: validation_error = RequestValidationError( - _normalize_errors(errors), body=body + _normalize_errors(errors), body=body, endpoint_ctx=endpoint_ctx ) raise validation_error @@ -438,6 +484,15 @@ def get_websocket_app( embed_body_fields: bool = False, ) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]: async def app(websocket: WebSocket) -> None: + endpoint_ctx = ( + _extract_endpoint_context(dependant.call) + if dependant.call + else EndpointContext() + ) + if dependant.path: + # For mounted sub-apps, include the mount path prefix + mount_path = websocket.scope.get("root_path", "").rstrip("/") + endpoint_ctx["path"] = f"WS {mount_path}{dependant.path}" async_exit_stack = websocket.scope.get("fastapi_inner_astack") assert isinstance(async_exit_stack, AsyncExitStack), ( "fastapi_inner_astack not found in request scope" @@ -451,7 +506,8 @@ def get_websocket_app( ) if solved_result.errors: raise WebSocketRequestValidationError( - _normalize_errors(solved_result.errors) + _normalize_errors(solved_result.errors), + endpoint_ctx=endpoint_ctx, ) assert dependant.call is not None, "dependant.call must be a function" await dependant.call(**solved_result.values) diff --git a/tests/test_validation_error_context.py b/tests/test_validation_error_context.py new file mode 100644 index 000000000..844b8a64f --- /dev/null +++ b/tests/test_validation_error_context.py @@ -0,0 +1,168 @@ +from fastapi import FastAPI, Request, WebSocket +from fastapi.exceptions import ( + RequestValidationError, + ResponseValidationError, + WebSocketRequestValidationError, +) +from fastapi.testclient import TestClient +from pydantic import BaseModel + + +class Item(BaseModel): + id: int + name: str + + +class ExceptionCapture: + def __init__(self): + self.exception = None + + def capture(self, exc): + self.exception = exc + return exc + + +app = FastAPI() +sub_app = FastAPI() +captured_exception = ExceptionCapture() + +app.mount(path="/sub", app=sub_app) + + +@app.exception_handler(RequestValidationError) +@sub_app.exception_handler(RequestValidationError) +async def request_validation_handler(request: Request, exc: RequestValidationError): + captured_exception.capture(exc) + raise exc + + +@app.exception_handler(ResponseValidationError) +@sub_app.exception_handler(ResponseValidationError) +async def response_validation_handler(_: Request, exc: ResponseValidationError): + captured_exception.capture(exc) + raise exc + + +@app.exception_handler(WebSocketRequestValidationError) +@sub_app.exception_handler(WebSocketRequestValidationError) +async def websocket_validation_handler( + websocket: WebSocket, exc: WebSocketRequestValidationError +): + captured_exception.capture(exc) + raise exc + + +@app.get("/users/{user_id}") +def get_user(user_id: int): + return {"user_id": user_id} # pragma: no cover + + +@app.get("/items/", response_model=Item) +def get_item(): + return {"name": "Widget"} + + +@sub_app.get("/items/", response_model=Item) +def get_sub_item(): + return {"name": "Widget"} # pragma: no cover + + +@app.websocket("/ws/{item_id}") +async def websocket_endpoint(websocket: WebSocket, item_id: int): + await websocket.accept() # pragma: no cover + await websocket.send_text(f"Item: {item_id}") # pragma: no cover + await websocket.close() # pragma: no cover + + +@sub_app.websocket("/ws/{item_id}") +async def subapp_websocket_endpoint(websocket: WebSocket, item_id: int): + await websocket.accept() # pragma: no cover + await websocket.send_text(f"Item: {item_id}") # pragma: no cover + await websocket.close() # pragma: no cover + + +client = TestClient(app) + + +def test_request_validation_error_includes_endpoint_context(): + captured_exception.exception = None + try: + client.get("/users/invalid") + except Exception: + pass + + assert captured_exception.exception is not None + error_str = str(captured_exception.exception) + assert "get_user" in error_str + assert "/users/" in error_str + + +def test_response_validation_error_includes_endpoint_context(): + captured_exception.exception = None + try: + client.get("/items/") + except Exception: + pass + + assert captured_exception.exception is not None + error_str = str(captured_exception.exception) + assert "get_item" in error_str + assert "/items/" in error_str + + +def test_websocket_validation_error_includes_endpoint_context(): + captured_exception.exception = None + try: + with client.websocket_connect("/ws/invalid"): + pass # pragma: no cover + except Exception: + pass + + assert captured_exception.exception is not None + error_str = str(captured_exception.exception) + assert "websocket_endpoint" in error_str + assert "/ws/" in error_str + + +def test_subapp_request_validation_error_includes_endpoint_context(): + captured_exception.exception = None + try: + client.get("/sub/items/") + except Exception: + pass + + assert captured_exception.exception is not None + error_str = str(captured_exception.exception) + assert "get_sub_item" in error_str + assert "/sub/items/" in error_str + + +def test_subapp_websocket_validation_error_includes_endpoint_context(): + captured_exception.exception = None + try: + with client.websocket_connect("/sub/ws/invalid"): + pass # pragma: no cover + except Exception: + pass + + assert captured_exception.exception is not None + error_str = str(captured_exception.exception) + assert "subapp_websocket_endpoint" in error_str + assert "/sub/ws/" in error_str + + +def test_validation_error_with_only_path(): + errors = [{"type": "missing", "loc": ("body", "name"), "msg": "Field required"}] + exc = RequestValidationError(errors, endpoint_ctx={"path": "GET /api/test"}) + error_str = str(exc) + assert "Endpoint: GET /api/test" in error_str + assert 'File "' not in error_str + + +def test_validation_error_with_no_context(): + errors = [{"type": "missing", "loc": ("body", "name"), "msg": "Field required"}] + exc = RequestValidationError(errors, endpoint_ctx={}) + error_str = str(exc) + assert "1 validation error:" in error_str + assert "Endpoint" not in error_str + assert 'File "' not in error_str -- 2.47.3