wrapped_receive = request.wrapped_receive
response_sent = anyio.Event()
app_exc: Exception | None = None
+ exception_already_raised = False
async def call_next(request: Request) -> Response:
async def receive_or_disconnect() -> Message:
message = await recv_stream.receive()
except anyio.EndOfStream:
if app_exc is not None:
+ nonlocal exception_already_raised
+ exception_already_raised = True
raise app_exc
raise RuntimeError("No response returned.")
await response(scope, wrapped_receive, send)
response_sent.set()
recv_stream.close()
-
- if app_exc is not None:
+ if app_exc is not None and not exception_already_raised:
raise app_exc
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
client.get("/")
+def test_exception_can_be_caught(test_client_factory: TestClientFactory) -> None:
+ async def error_endpoint(_: Request) -> None:
+ raise ValueError("TEST")
+
+ async def catches_error(request: Request, call_next: RequestResponseEndpoint) -> Response:
+ try:
+ return await call_next(request)
+ except ValueError as exc:
+ return PlainTextResponse(content=str(exc), status_code=400)
+
+ app = Starlette(
+ middleware=[Middleware(BaseHTTPMiddleware, dispatch=catches_error)],
+ routes=[Route("/", error_endpoint)],
+ )
+
+ client = test_client_factory(app)
+ response = client.get("/")
+ assert response.status_code == 400
+ assert response.text == "TEST"
+
+
@pytest.mark.anyio
async def test_do_not_block_on_background_tasks() -> None:
response_complete = anyio.Event()