]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Do not pollute exception context in Middleware (#2976)
authorAdam Sikora <42934708+adam-sikora@users.noreply.github.com>
Tue, 28 Oct 2025 08:03:33 +0000 (09:03 +0100)
committerGitHub <noreply@github.com>
Tue, 28 Oct 2025 08:03:33 +0000 (08:03 +0000)
starlette/middleware/base.py
tests/middleware/test_base.py

index 577918eb91b18183018b5e48078257e4ae695502..dc353a26dd3873aefc9435b92acfb0e1145ae233 100644 (file)
@@ -156,7 +156,16 @@ class BaseHTTPMiddleware:
                 if app_exc is not None:
                     nonlocal exception_already_raised
                     exception_already_raised = True
-                    raise app_exc
+                    # Prevent `anyio.EndOfStream` from polluting app exception context.
+                    # If both cause and context are None then the context is suppressed
+                    # and `anyio.EndOfStream` is not present in the exception traceback.
+                    # If exception cause is not None then it is propagated with
+                    # reraising here.
+                    # If exception has no cause but has context set then the context is
+                    # propagated as a cause with the reraise. This is necessary in order
+                    # to prevent `anyio.EndOfStream` from polluting the exception
+                    # context.
+                    raise app_exc from app_exc.__cause__ or app_exc.__context__
                 raise RuntimeError("No response returned.")
 
             assert message["type"] == "http.response.start"
index d4548e66b679f0e22528972a81928225ea28718b..1b0b9476039a7b5f4fba723a6ddc9c7bcbb16642 100644 (file)
@@ -1243,3 +1243,62 @@ async def test_asgi_pathsend_events(tmpdir: Path) -> None:
     assert len(events) == 2
     assert events[0]["type"] == "http.response.start"
     assert events[1]["type"] == "http.response.pathsend"
+
+
+def test_error_context_propagation(test_client_factory: TestClientFactory) -> None:
+    class PassthroughMiddleware(BaseHTTPMiddleware):
+        async def dispatch(
+            self,
+            request: Request,
+            call_next: RequestResponseEndpoint,
+        ) -> Response:
+            return await call_next(request)
+
+    def exception_without_context(request: Request) -> None:
+        raise Exception("Exception")
+
+    def exception_with_context(request: Request) -> None:
+        try:
+            raise Exception("Inner exception")
+        except Exception:
+            raise Exception("Outer exception")
+
+    def exception_with_cause(request: Request) -> None:
+        try:
+            raise Exception("Inner exception")
+        except Exception as e:
+            raise Exception("Outer exception") from e
+
+    app = Starlette(
+        routes=[
+            Route("/exception-without-context", endpoint=exception_without_context),
+            Route("/exception-with-context", endpoint=exception_with_context),
+            Route("/exception-with-cause", endpoint=exception_with_cause),
+        ],
+        middleware=[Middleware(PassthroughMiddleware)],
+    )
+    client = test_client_factory(app)
+
+    # For exceptions without context the context is filled with the `anyio.EndOfStream`
+    # but it is suppressed therefore not propagated to traceback.
+    with pytest.raises(Exception) as ctx:
+        client.get("/exception-without-context")
+    assert str(ctx.value) == "Exception"
+    assert ctx.value.__cause__ is None
+    assert ctx.value.__context__ is not None
+    assert ctx.value.__suppress_context__ is True
+
+    # For exceptions with context the context is propagated as a cause to avoid
+    # `anyio.EndOfStream` error from overwriting it.
+    with pytest.raises(Exception) as ctx:
+        client.get("/exception-with-context")
+    assert str(ctx.value) == "Outer exception"
+    assert ctx.value.__cause__ is not None
+    assert str(ctx.value.__cause__) == "Inner exception"
+
+    # For exceptions with cause check that it gets correctly propagated.
+    with pytest.raises(Exception) as ctx:
+        client.get("/exception-with-cause")
+    assert str(ctx.value) == "Outer exception"
+    assert ctx.value.__cause__ is not None
+    assert str(ctx.value.__cause__) == "Inner exception"