-from typing import Any, Dict, Optional, Sequence, Type, Union
+from typing import Any, Dict, Optional, Sequence, Type, TypedDict, Union
from annotated_doc import Doc
from pydantic import BaseModel, create_model
from typing_extensions import Annotated
+class EndpointContext(TypedDict, total=False):
+ function: str
+ path: str
+ file: str
+ line: int
+
+
class HTTPException(StarletteHTTPException):
"""
An HTTP exception you can raise in your own code to show errors to the client.
class ValidationException(Exception):
- def __init__(self, errors: Sequence[Any]) -> None:
+ def __init__(
+ self,
+ errors: Sequence[Any],
+ *,
+ endpoint_ctx: Optional[EndpointContext] = None,
+ ) -> None:
self._errors = errors
+ self.endpoint_ctx = endpoint_ctx
+
+ ctx = endpoint_ctx or {}
+ self.endpoint_function = ctx.get("function")
+ self.endpoint_path = ctx.get("path")
+ self.endpoint_file = ctx.get("file")
+ self.endpoint_line = ctx.get("line")
def errors(self) -> Sequence[Any]:
return self._errors
+ def _format_endpoint_context(self) -> str:
+ if not (self.endpoint_file and self.endpoint_line and self.endpoint_function):
+ if self.endpoint_path:
+ return f"\n Endpoint: {self.endpoint_path}"
+ return ""
+
+ context = f'\n File "{self.endpoint_file}", line {self.endpoint_line}, in {self.endpoint_function}'
+ if self.endpoint_path:
+ context += f"\n {self.endpoint_path}"
+ return context
+
+ def __str__(self) -> str:
+ message = f"{len(self._errors)} validation error{'s' if len(self._errors) != 1 else ''}:\n"
+ for err in self._errors:
+ message += f" {err}\n"
+ message += self._format_endpoint_context()
+ return message.rstrip()
+
class RequestValidationError(ValidationException):
- def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
- super().__init__(errors)
+ def __init__(
+ self,
+ errors: Sequence[Any],
+ *,
+ body: Any = None,
+ endpoint_ctx: Optional[EndpointContext] = None,
+ ) -> None:
+ super().__init__(errors, endpoint_ctx=endpoint_ctx)
self.body = body
class WebSocketRequestValidationError(ValidationException):
- pass
+ def __init__(
+ self,
+ errors: Sequence[Any],
+ *,
+ endpoint_ctx: Optional[EndpointContext] = None,
+ ) -> None:
+ super().__init__(errors, endpoint_ctx=endpoint_ctx)
class ResponseValidationError(ValidationException):
- def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
- super().__init__(errors)
+ def __init__(
+ self,
+ errors: Sequence[Any],
+ *,
+ body: Any = None,
+ endpoint_ctx: Optional[EndpointContext] = None,
+ ) -> None:
+ super().__init__(errors, endpoint_ctx=endpoint_ctx)
self.body = body
-
- def __str__(self) -> str:
- message = f"{len(self._errors)} validation errors:\n"
- for err in self._errors:
- message += f" {err}\n"
- return message
)
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import (
+ EndpointContext,
FastAPIError,
RequestValidationError,
ResponseValidationError,
return merged_lifespan # type: ignore[return-value]
+# Cache for endpoint context to avoid re-extracting on every request
+_endpoint_context_cache: Dict[int, EndpointContext] = {}
+
+
+def _extract_endpoint_context(func: Any) -> EndpointContext:
+ """Extract endpoint context with caching to avoid repeated file I/O."""
+ func_id = id(func)
+
+ if func_id in _endpoint_context_cache:
+ return _endpoint_context_cache[func_id]
+
+ try:
+ ctx: EndpointContext = {}
+
+ if (source_file := inspect.getsourcefile(func)) is not None:
+ ctx["file"] = source_file
+ if (line_number := inspect.getsourcelines(func)[1]) is not None:
+ ctx["line"] = line_number
+ if (func_name := getattr(func, "__name__", None)) is not None:
+ ctx["function"] = func_name
+ except Exception:
+ ctx = EndpointContext()
+
+ _endpoint_context_cache[func_id] = ctx
+ return ctx
+
+
async def serialize_response(
*,
field: Optional[ModelField] = None,
exclude_defaults: bool = False,
exclude_none: bool = False,
is_coroutine: bool = True,
+ endpoint_ctx: Optional[EndpointContext] = None,
) -> Any:
if field:
errors = []
elif errors_:
errors.append(errors_)
if errors:
+ ctx = endpoint_ctx or EndpointContext()
raise ResponseValidationError(
- errors=_normalize_errors(errors), body=response_content
+ errors=_normalize_errors(errors),
+ body=response_content,
+ endpoint_ctx=ctx,
)
if hasattr(field, "serialize"):
"fastapi_middleware_astack not found in request scope"
)
+ # Extract endpoint context for error messages
+ endpoint_ctx = (
+ _extract_endpoint_context(dependant.call)
+ if dependant.call
+ else EndpointContext()
+ )
+
+ if dependant.path:
+ # For mounted sub-apps, include the mount path prefix
+ mount_path = request.scope.get("root_path", "").rstrip("/")
+ endpoint_ctx["path"] = f"{request.method} {mount_path}{dependant.path}"
+
# Read body and auto-close files
try:
body: Any = None
}
],
body=e.doc,
+ endpoint_ctx=endpoint_ctx,
)
raise validation_error from e
except HTTPException:
exclude_defaults=response_model_exclude_defaults,
exclude_none=response_model_exclude_none,
is_coroutine=is_coroutine,
+ endpoint_ctx=endpoint_ctx,
)
response = actual_response_class(content, **response_args)
if not is_body_allowed_for_status_code(response.status_code):
response.headers.raw.extend(solved_result.response.headers.raw)
if errors:
validation_error = RequestValidationError(
- _normalize_errors(errors), body=body
+ _normalize_errors(errors), body=body, endpoint_ctx=endpoint_ctx
)
raise validation_error
embed_body_fields: bool = False,
) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
async def app(websocket: WebSocket) -> None:
+ endpoint_ctx = (
+ _extract_endpoint_context(dependant.call)
+ if dependant.call
+ else EndpointContext()
+ )
+ if dependant.path:
+ # For mounted sub-apps, include the mount path prefix
+ mount_path = websocket.scope.get("root_path", "").rstrip("/")
+ endpoint_ctx["path"] = f"WS {mount_path}{dependant.path}"
async_exit_stack = websocket.scope.get("fastapi_inner_astack")
assert isinstance(async_exit_stack, AsyncExitStack), (
"fastapi_inner_astack not found in request scope"
)
if solved_result.errors:
raise WebSocketRequestValidationError(
- _normalize_errors(solved_result.errors)
+ _normalize_errors(solved_result.errors),
+ endpoint_ctx=endpoint_ctx,
)
assert dependant.call is not None, "dependant.call must be a function"
await dependant.call(**solved_result.values)
--- /dev/null
+from fastapi import FastAPI, Request, WebSocket
+from fastapi.exceptions import (
+ RequestValidationError,
+ ResponseValidationError,
+ WebSocketRequestValidationError,
+)
+from fastapi.testclient import TestClient
+from pydantic import BaseModel
+
+
+class Item(BaseModel):
+ id: int
+ name: str
+
+
+class ExceptionCapture:
+ def __init__(self):
+ self.exception = None
+
+ def capture(self, exc):
+ self.exception = exc
+ return exc
+
+
+app = FastAPI()
+sub_app = FastAPI()
+captured_exception = ExceptionCapture()
+
+app.mount(path="/sub", app=sub_app)
+
+
+@app.exception_handler(RequestValidationError)
+@sub_app.exception_handler(RequestValidationError)
+async def request_validation_handler(request: Request, exc: RequestValidationError):
+ captured_exception.capture(exc)
+ raise exc
+
+
+@app.exception_handler(ResponseValidationError)
+@sub_app.exception_handler(ResponseValidationError)
+async def response_validation_handler(_: Request, exc: ResponseValidationError):
+ captured_exception.capture(exc)
+ raise exc
+
+
+@app.exception_handler(WebSocketRequestValidationError)
+@sub_app.exception_handler(WebSocketRequestValidationError)
+async def websocket_validation_handler(
+ websocket: WebSocket, exc: WebSocketRequestValidationError
+):
+ captured_exception.capture(exc)
+ raise exc
+
+
+@app.get("/users/{user_id}")
+def get_user(user_id: int):
+ return {"user_id": user_id} # pragma: no cover
+
+
+@app.get("/items/", response_model=Item)
+def get_item():
+ return {"name": "Widget"}
+
+
+@sub_app.get("/items/", response_model=Item)
+def get_sub_item():
+ return {"name": "Widget"} # pragma: no cover
+
+
+@app.websocket("/ws/{item_id}")
+async def websocket_endpoint(websocket: WebSocket, item_id: int):
+ await websocket.accept() # pragma: no cover
+ await websocket.send_text(f"Item: {item_id}") # pragma: no cover
+ await websocket.close() # pragma: no cover
+
+
+@sub_app.websocket("/ws/{item_id}")
+async def subapp_websocket_endpoint(websocket: WebSocket, item_id: int):
+ await websocket.accept() # pragma: no cover
+ await websocket.send_text(f"Item: {item_id}") # pragma: no cover
+ await websocket.close() # pragma: no cover
+
+
+client = TestClient(app)
+
+
+def test_request_validation_error_includes_endpoint_context():
+ captured_exception.exception = None
+ try:
+ client.get("/users/invalid")
+ except Exception:
+ pass
+
+ assert captured_exception.exception is not None
+ error_str = str(captured_exception.exception)
+ assert "get_user" in error_str
+ assert "/users/" in error_str
+
+
+def test_response_validation_error_includes_endpoint_context():
+ captured_exception.exception = None
+ try:
+ client.get("/items/")
+ except Exception:
+ pass
+
+ assert captured_exception.exception is not None
+ error_str = str(captured_exception.exception)
+ assert "get_item" in error_str
+ assert "/items/" in error_str
+
+
+def test_websocket_validation_error_includes_endpoint_context():
+ captured_exception.exception = None
+ try:
+ with client.websocket_connect("/ws/invalid"):
+ pass # pragma: no cover
+ except Exception:
+ pass
+
+ assert captured_exception.exception is not None
+ error_str = str(captured_exception.exception)
+ assert "websocket_endpoint" in error_str
+ assert "/ws/" in error_str
+
+
+def test_subapp_request_validation_error_includes_endpoint_context():
+ captured_exception.exception = None
+ try:
+ client.get("/sub/items/")
+ except Exception:
+ pass
+
+ assert captured_exception.exception is not None
+ error_str = str(captured_exception.exception)
+ assert "get_sub_item" in error_str
+ assert "/sub/items/" in error_str
+
+
+def test_subapp_websocket_validation_error_includes_endpoint_context():
+ captured_exception.exception = None
+ try:
+ with client.websocket_connect("/sub/ws/invalid"):
+ pass # pragma: no cover
+ except Exception:
+ pass
+
+ assert captured_exception.exception is not None
+ error_str = str(captured_exception.exception)
+ assert "subapp_websocket_endpoint" in error_str
+ assert "/sub/ws/" in error_str
+
+
+def test_validation_error_with_only_path():
+ errors = [{"type": "missing", "loc": ("body", "name"), "msg": "Field required"}]
+ exc = RequestValidationError(errors, endpoint_ctx={"path": "GET /api/test"})
+ error_str = str(exc)
+ assert "Endpoint: GET /api/test" in error_str
+ assert 'File "' not in error_str
+
+
+def test_validation_error_with_no_context():
+ errors = [{"type": "missing", "loc": ("body", "name"), "msg": "Field required"}]
+ exc = RequestValidationError(errors, endpoint_ctx={})
+ error_str = str(exc)
+ assert "1 validation error:" in error_str
+ assert "Endpoint" not in error_str
+ assert 'File "' not in error_str