]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Prevents reraising of exception from BaseHttpMiddleware (#2911)
authorRamandeep Singh <ramannanda9@gmail.com>
Sun, 13 Apr 2025 13:38:18 +0000 (06:38 -0700)
committerGitHub <noreply@github.com>
Sun, 13 Apr 2025 13:38:18 +0000 (15:38 +0200)
starlette/middleware/base.py
tests/middleware/test_base.py

index 2a59337e524c7eee2b4e92e2499fb3d2d7979bef..b49ab611f4d25e217e11ee7a53013082b182afce 100644 (file)
@@ -104,6 +104,7 @@ class BaseHTTPMiddleware:
         wrapped_receive = request.wrapped_receive
         response_sent = anyio.Event()
         app_exc: Exception | None = None
+        exception_already_raised = False
 
         async def call_next(request: Request) -> Response:
             async def receive_or_disconnect() -> Message:
@@ -150,6 +151,8 @@ class BaseHTTPMiddleware:
                     message = await recv_stream.receive()
             except anyio.EndOfStream:
                 if app_exc is not None:
+                    nonlocal exception_already_raised
+                    exception_already_raised = True
                     raise app_exc
                 raise RuntimeError("No response returned.")
 
@@ -176,8 +179,7 @@ class BaseHTTPMiddleware:
                 await response(scope, wrapped_receive, send)
                 response_sent.set()
                 recv_stream.close()
-
-        if app_exc is not None:
+        if app_exc is not None and not exception_already_raised:
             raise app_exc
 
     async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
index e4e82077f19659f73efb5c104ac4b0d345c7c148..427ec44ac1f66ae5fd31dc40698febac60b9a96c 100644 (file)
@@ -320,6 +320,27 @@ def test_run_background_tasks_raise_exceptions(test_client_factory: TestClientFa
         client.get("/")
 
 
+def test_exception_can_be_caught(test_client_factory: TestClientFactory) -> None:
+    async def error_endpoint(_: Request) -> None:
+        raise ValueError("TEST")
+
+    async def catches_error(request: Request, call_next: RequestResponseEndpoint) -> Response:
+        try:
+            return await call_next(request)
+        except ValueError as exc:
+            return PlainTextResponse(content=str(exc), status_code=400)
+
+    app = Starlette(
+        middleware=[Middleware(BaseHTTPMiddleware, dispatch=catches_error)],
+        routes=[Route("/", error_endpoint)],
+    )
+
+    client = test_client_factory(app)
+    response = client.get("/")
+    assert response.status_code == 400
+    assert response.text == "TEST"
+
+
 @pytest.mark.anyio
 async def test_do_not_block_on_background_tasks() -> None:
     response_complete = anyio.Event()