async def __aexit__(self, et, exc, tb):
self._exiting = True
- propagate_cancellation_error = None
if (exc is not None and
self._is_base_error(exc) and
self._base_error is None):
self._base_error = exc
- if et is not None:
- if et is exceptions.CancelledError:
- if self._parent_cancel_requested and not self._parent_task.uncancel():
- # Do nothing, i.e. swallow the error.
- pass
- else:
- propagate_cancellation_error = exc
+ propagate_cancellation_error = \
+ exc if et is exceptions.CancelledError else None
+ if self._parent_cancel_requested:
+ # If this flag is set we *must* call uncancel().
+ if self._parent_task.uncancel() == 0:
+ # If there are no pending cancellations left,
+ # don't propagate CancelledError.
+ propagate_cancellation_error = None
+ if et is not None:
if not self._aborting:
# Our parent task is being cancelled:
#
import asyncio
import contextvars
-
+import contextlib
from asyncio import taskgroups
import unittest
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
+ async def test_taskgroup_context_manager_exit_raises(self):
+ # See https://github.com/python/cpython/issues/95289
+ class CustomException(Exception):
+ pass
+
+ async def raise_exc():
+ raise CustomException
+
+ @contextlib.asynccontextmanager
+ async def database():
+ try:
+ yield
+ finally:
+ raise CustomException
+
+ async def main():
+ task = asyncio.current_task()
+ try:
+ async with taskgroups.TaskGroup() as tg:
+ async with database():
+ tg.create_task(raise_exc())
+ await asyncio.sleep(1)
+ except* CustomException as err:
+ self.assertEqual(task.cancelling(), 0)
+ self.assertEqual(len(err.exceptions), 2)
+
+ else:
+ self.fail('CustomException not raised')
+
+ await asyncio.create_task(main())
+
if __name__ == "__main__":
unittest.main()