From: Thomas Grainger Date: Sun, 29 Dec 2024 14:08:08 +0000 (+0000) Subject: collapse only one level of excg X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=a4687d7c5fce21a164925629838b5bccf4f3acdd;p=thirdparty%2Fstarlette.git collapse only one level of excg --- diff --git a/starlette/_utils.py b/starlette/_utils.py index 0c389dcb..e9325016 100644 --- a/starlette/_utils.py +++ b/starlette/_utils.py @@ -13,12 +13,14 @@ if sys.version_info >= (3, 10): # pragma: no cover else: # pragma: no cover from typing_extensions import TypeGuard -has_exceptiongroups = True if sys.version_info < (3, 11): # pragma: no cover try: from exceptiongroup import BaseExceptionGroup # type: ignore[unused-ignore,import-not-found] except ImportError: - has_exceptiongroups = False + + class BaseExceptionGroup(BaseException): # type: ignore[no-redef] + pass + T = typing.TypeVar("T") AwaitableCallable = typing.Callable[..., typing.Awaitable[T]] @@ -74,12 +76,23 @@ class AwaitableOrContextManagerWrapper(typing.Generic[SupportsAsyncCloseType]): def collapse_excgroups() -> typing.Generator[None, None, None]: try: yield - except BaseException as exc: - if has_exceptiongroups: # pragma: no cover - while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1: - exc = exc.exceptions[0] - - raise exc + except BaseExceptionGroup as excs: + if len(excs.exceptions) != 1: + raise + + exc = excs.exceptions[0] + context = exc.__context__ + tb = exc.__traceback__ + cause = exc.__cause__ + sc = exc.__suppress_context__ + try: + raise exc + finally: + exc.__traceback__ = tb + exc.__context__ = context + exc.__cause__ = cause + exc.__suppress_context__ = sc + del exc, cause, tb, context def get_route_path(scope: Scope) -> str: diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 7232cfd1..c2cecf48 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -8,6 +8,7 @@ from typing import Any import anyio import pytest from anyio.abc import TaskStatus +from exceptiongroup import ExceptionGroup from starlette.applications import Starlette from starlette.background import BackgroundTask @@ -41,6 +42,10 @@ def exc(request: Request) -> None: raise Exception("Exc") +def eg(request: Request) -> None: + raise ExceptionGroup("my exception group", [ValueError("TEST")]) + + def exc_stream(request: Request) -> StreamingResponse: return StreamingResponse(_generate_faulty_stream()) @@ -76,6 +81,7 @@ app = Starlette( routes=[ Route("/", endpoint=homepage), Route("/exc", endpoint=exc), + Route("/eg", endpoint=eg), Route("/exc-stream", endpoint=exc_stream), Route("/no-response", endpoint=NoResponse), WebSocketRoute("/ws", endpoint=websocket_endpoint), @@ -89,13 +95,16 @@ def test_custom_middleware(test_client_factory: TestClientFactory) -> None: response = client.get("/") assert response.headers["Custom-Header"] == "Example" - with pytest.raises(Exception) as ctx: + with pytest.raises(Exception) as ctx1: response = client.get("/exc") - assert str(ctx.value) == "Exc" + assert str(ctx1.value) == "Exc" - with pytest.raises(Exception) as ctx: + with pytest.raises(Exception) as ctx2: response = client.get("/exc-stream") - assert str(ctx.value) == "Faulty Stream" + assert str(ctx2.value) == "Faulty Stream" + + with pytest.raises(ExceptionGroup, match=r"my exception group \(1 sub-exception\)"): + client.get("/eg") with pytest.raises(RuntimeError): response = client.get("/no-response")