return response
- except Exception as exc:
+ except BaseException as exc:
response.close()
raise exc
request = next_request
history.append(response)
- except Exception as exc:
+ except BaseException as exc:
response.close()
raise exc
finally:
response.next_request = request
return response
- except Exception as exc:
+ except BaseException as exc:
response.close()
raise exc
return response
- except Exception as exc: # pragma: no cover
+ except BaseException as exc: # pragma: no cover
await response.aclose()
raise exc
request = next_request
history.append(response)
- except Exception as exc:
+ except BaseException as exc:
await response.aclose()
raise exc
finally:
response.next_request = request
return response
- except Exception as exc:
+ except BaseException as exc:
await response.aclose()
raise exc
assert response.text == "Hello, world!"
+@pytest.mark.usefixtures("async_environment")
+async def test_cancellation_during_stream():
+ """
+ If any BaseException is raised during streaming the response, then the
+ stream should be closed.
+
+ This includes:
+
+ * `asyncio.CancelledError` (A subclass of BaseException from Python 3.8 onwards.)
+ * `trio.Cancelled`
+ * `KeyboardInterrupt`
+ * `SystemExit`
+
+ See https://github.com/encode/httpx/issues/2139
+ """
+ stream_was_closed = False
+
+ def response_with_cancel_during_stream(request):
+ class CancelledStream(httpx.AsyncByteStream):
+ async def __aiter__(self) -> typing.AsyncIterator[bytes]:
+ yield b"Hello"
+ raise KeyboardInterrupt()
+ yield b", world" # pragma: nocover
+
+ async def aclose(self) -> None:
+ nonlocal stream_was_closed
+ stream_was_closed = True
+
+ return httpx.Response(
+ 200, headers={"Content-Length": "12"}, stream=CancelledStream()
+ )
+
+ transport = httpx.MockTransport(response_with_cancel_during_stream)
+
+ async with httpx.AsyncClient(transport=transport) as client:
+ with pytest.raises(KeyboardInterrupt):
+ await client.get("https://www.example.com")
+ assert stream_was_closed
+
+
@pytest.mark.usefixtures("async_environment")
async def test_server_extensions(server):
url = server.url