else:
task = self._loop.create_task(coro, context=context)
tasks._set_task_name(task, name)
- # optimization: Immediately call the done callback if the task is
+
+ # Always schedule the done callback even if the task is
# already done (e.g. if the coro was able to complete eagerly),
- # and skip scheduling a done callback
- if task.done():
- self._on_task_done(task)
- else:
- self._tasks.add(task)
- task.add_done_callback(self._on_task_done)
- return task
+ # otherwise if the task completes with an exception then it will cancel
+ # the current task too early. gh-128550, gh-128588
+
+ self._tasks.add(task)
+ task.add_done_callback(self._on_task_done)
+ try:
+ return task
+ finally:
+ # gh-128552: prevent a refcycle of
+ # task.exception().__traceback__->TaskGroup.create_task->task
+ del task
# Since Python 3.8 Tasks propagate all exceptions correctly,
# except for KeyboardInterrupt and SystemExit which are
# Adapted with permission from the EdgeDB project;
# license: PSFL.
+import weakref
+import sys
import gc
import asyncio
import contextvars
return {type(exc) for exc in eg.exceptions}
-class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
+def set_gc_state(enabled):
+ was_enabled = gc.isenabled()
+ if enabled:
+ gc.enable()
+ else:
+ gc.disable()
+ return was_enabled
+
+
+@contextlib.contextmanager
+def disable_gc():
+ was_enabled = set_gc_state(enabled=False)
+ try:
+ yield
+ finally:
+ set_gc_state(enabled=was_enabled)
+
+
+class BaseTestTaskGroup:
async def test_taskgroup_01(self):
self.assertIsInstance(exc, _Done)
self.assertListEqual(gc.get_referrers(exc), [])
+
+ async def test_exception_refcycles_parent_task_wr(self):
+ """Test that TaskGroup deletes self._parent_task and create_task() deletes task"""
+ tg = asyncio.TaskGroup()
+ exc = None
+
+ class _Done(Exception):
+ pass
+
+ async def coro_fn():
+ async with tg:
+ raise _Done
+
+ with disable_gc():
+ try:
+ async with asyncio.TaskGroup() as tg2:
+ task_wr = weakref.ref(tg2.create_task(coro_fn()))
+ except* _Done as excs:
+ exc = excs.exceptions[0].exceptions[0]
+
+ self.assertIsNone(task_wr())
+ self.assertIsInstance(exc, _Done)
+ self.assertListEqual(gc.get_referrers(exc), [])
+
async def test_exception_refcycles_propagate_cancellation_error(self):
"""Test that TaskGroup deletes propagate_cancellation_error"""
tg = asyncio.TaskGroup()
self.assertIsNotNone(exc)
self.assertListEqual(gc.get_referrers(exc), [])
+ async def test_cancels_task_if_created_during_creation(self):
+ # regression test for gh-128550
+ ran = False
+ class MyError(Exception):
+ pass
+
+ exc = None
+ try:
+ async with asyncio.TaskGroup() as tg:
+ async def third_task():
+ raise MyError("third task failed")
+
+ async def second_task():
+ nonlocal ran
+ tg.create_task(third_task())
+ with self.assertRaises(asyncio.CancelledError):
+ await asyncio.sleep(0) # eager tasks cancel here
+ await asyncio.sleep(0) # lazy tasks cancel here
+ ran = True
+
+ tg.create_task(second_task())
+ except* MyError as excs:
+ exc = excs.exceptions[0]
+
+ self.assertTrue(ran)
+ self.assertIsInstance(exc, MyError)
+
+ async def test_cancellation_does_not_leak_out_of_tg(self):
+ class MyError(Exception):
+ pass
+
+ async def throw_error():
+ raise MyError
+
+ try:
+ async with asyncio.TaskGroup() as tg:
+ tg.create_task(throw_error())
+ except* MyError:
+ pass
+ else:
+ self.fail("should have raised one MyError in group")
+
+ # if this test fails this current task will be cancelled
+ # outside the task group and inside unittest internals
+ # we yield to the event loop with sleep(0) so that
+ # cancellation happens here and error is more understandable
+ await asyncio.sleep(0)
+
+
+if sys.platform == "win32":
+ EventLoop = asyncio.ProactorEventLoop
+else:
+ EventLoop = asyncio.SelectorEventLoop
+
+
+class IsolatedAsyncioTestCase(unittest.IsolatedAsyncioTestCase):
+ loop_factory = None
+
+ def _setupAsyncioRunner(self):
+ assert self._asyncioRunner is None, 'asyncio runner is already initialized'
+ runner = asyncio.Runner(debug=True, loop_factory=self.loop_factory)
+ self._asyncioRunner = runner
+
+
+class TestTaskGroup(BaseTestTaskGroup, IsolatedAsyncioTestCase):
+ loop_factory = EventLoop
+
+
+class TestEagerTaskTaskGroup(BaseTestTaskGroup, IsolatedAsyncioTestCase):
+ @staticmethod
+ def loop_factory():
+ loop = EventLoop()
+ loop.set_task_factory(asyncio.eager_task_factory)
+ return loop
+
if __name__ == "__main__":
unittest.main()