]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
collapse only one level of excg
authorThomas Grainger <tagrain@gmail.com>
Sun, 29 Dec 2024 14:08:08 +0000 (14:08 +0000)
committerThomas Grainger <tagrain@gmail.com>
Sun, 29 Dec 2024 14:10:09 +0000 (14:10 +0000)
starlette/_utils.py
tests/middleware/test_base.py

index 0c389dcb29a0b021e013c5a79d1da8e0eb9eb47f..e9325016310ee5a786ea39047bc8011c5471ac01 100644 (file)
@@ -13,12 +13,14 @@ if sys.version_info >= (3, 10):  # pragma: no cover
 else:  # pragma: no cover
     from typing_extensions import TypeGuard
 
-has_exceptiongroups = True
 if sys.version_info < (3, 11):  # pragma: no cover
     try:
         from exceptiongroup import BaseExceptionGroup  # type: ignore[unused-ignore,import-not-found]
     except ImportError:
-        has_exceptiongroups = False
+
+        class BaseExceptionGroup(BaseException):  # type: ignore[no-redef]
+            pass
+
 
 T = typing.TypeVar("T")
 AwaitableCallable = typing.Callable[..., typing.Awaitable[T]]
@@ -74,12 +76,23 @@ class AwaitableOrContextManagerWrapper(typing.Generic[SupportsAsyncCloseType]):
 def collapse_excgroups() -> typing.Generator[None, None, None]:
     try:
         yield
-    except BaseException as exc:
-        if has_exceptiongroups:  # pragma: no cover
-            while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1:
-                exc = exc.exceptions[0]
-
-        raise exc
+    except BaseExceptionGroup as excs:
+        if len(excs.exceptions) != 1:
+            raise
+
+        exc = excs.exceptions[0]
+        context = exc.__context__
+        tb = exc.__traceback__
+        cause = exc.__cause__
+        sc = exc.__suppress_context__
+        try:
+            raise exc
+        finally:
+            exc.__traceback__ = tb
+            exc.__context__ = context
+            exc.__cause__ = cause
+            exc.__suppress_context__ = sc
+            del exc, cause, tb, context
 
 
 def get_route_path(scope: Scope) -> str:
index 7232cfd18de02a33056988ffe835a8afe774415d..c2cecf48a87fc7fd94bd2b2718493b6169081a0a 100644 (file)
@@ -8,6 +8,7 @@ from typing import Any
 import anyio
 import pytest
 from anyio.abc import TaskStatus
+from exceptiongroup import ExceptionGroup
 
 from starlette.applications import Starlette
 from starlette.background import BackgroundTask
@@ -41,6 +42,10 @@ def exc(request: Request) -> None:
     raise Exception("Exc")
 
 
+def eg(request: Request) -> None:
+    raise ExceptionGroup("my exception group", [ValueError("TEST")])
+
+
 def exc_stream(request: Request) -> StreamingResponse:
     return StreamingResponse(_generate_faulty_stream())
 
@@ -76,6 +81,7 @@ app = Starlette(
     routes=[
         Route("/", endpoint=homepage),
         Route("/exc", endpoint=exc),
+        Route("/eg", endpoint=eg),
         Route("/exc-stream", endpoint=exc_stream),
         Route("/no-response", endpoint=NoResponse),
         WebSocketRoute("/ws", endpoint=websocket_endpoint),
@@ -89,13 +95,16 @@ def test_custom_middleware(test_client_factory: TestClientFactory) -> None:
     response = client.get("/")
     assert response.headers["Custom-Header"] == "Example"
 
-    with pytest.raises(Exception) as ctx:
+    with pytest.raises(Exception) as ctx1:
         response = client.get("/exc")
-    assert str(ctx.value) == "Exc"
+    assert str(ctx1.value) == "Exc"
 
-    with pytest.raises(Exception) as ctx:
+    with pytest.raises(Exception) as ctx2:
         response = client.get("/exc-stream")
-    assert str(ctx.value) == "Faulty Stream"
+    assert str(ctx2.value) == "Faulty Stream"
+
+    with pytest.raises(ExceptionGroup, match=r"my exception group \(1 sub-exception\)"):
+        client.get("/eg")
 
     with pytest.raises(RuntimeError):
         response = client.get("/no-response")