From: Ramandeep Singh Date: Sun, 13 Apr 2025 13:38:18 +0000 (-0700) Subject: Prevents reraising of exception from BaseHttpMiddleware (#2911) X-Git-Tag: 0.46.2~1 X-Git-Url: http://git.ipfire.org/gitweb/gitweb.cgi?a=commitdiff_plain;h=48f7cd7f28b9e4b9a5a6b19159139205907aec41;p=thirdparty%2Fstarlette.git Prevents reraising of exception from BaseHttpMiddleware (#2911) --- diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 2a59337e..b49ab611 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -104,6 +104,7 @@ class BaseHTTPMiddleware: wrapped_receive = request.wrapped_receive response_sent = anyio.Event() app_exc: Exception | None = None + exception_already_raised = False async def call_next(request: Request) -> Response: async def receive_or_disconnect() -> Message: @@ -150,6 +151,8 @@ class BaseHTTPMiddleware: message = await recv_stream.receive() except anyio.EndOfStream: if app_exc is not None: + nonlocal exception_already_raised + exception_already_raised = True raise app_exc raise RuntimeError("No response returned.") @@ -176,8 +179,7 @@ class BaseHTTPMiddleware: await response(scope, wrapped_receive, send) response_sent.set() recv_stream.close() - - if app_exc is not None: + if app_exc is not None and not exception_already_raised: raise app_exc async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index e4e82077..427ec44a 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -320,6 +320,27 @@ def test_run_background_tasks_raise_exceptions(test_client_factory: TestClientFa client.get("/") +def test_exception_can_be_caught(test_client_factory: TestClientFactory) -> None: + async def error_endpoint(_: Request) -> None: + raise ValueError("TEST") + + async def catches_error(request: Request, call_next: RequestResponseEndpoint) -> Response: + try: + return await call_next(request) + except ValueError as exc: + return PlainTextResponse(content=str(exc), status_code=400) + + app = Starlette( + middleware=[Middleware(BaseHTTPMiddleware, dispatch=catches_error)], + routes=[Route("/", error_endpoint)], + ) + + client = test_client_factory(app) + response = client.get("/") + assert response.status_code == 400 + assert response.text == "TEST" + + @pytest.mark.anyio async def test_do_not_block_on_background_tasks() -> None: response_complete = anyio.Event()