# Adapted with permission from the EdgeDB project;
# license: PSFL.
+import weakref
import sys
import gc
import asyncio
return [coro]
-class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
+def set_gc_state(enabled):
+ was_enabled = gc.isenabled()
+ if enabled:
+ gc.enable()
+ else:
+ gc.disable()
+ return was_enabled
+
+
+@contextlib.contextmanager
+def disable_gc():
+ was_enabled = set_gc_state(enabled=False)
+ try:
+ yield
+ finally:
+ set_gc_state(enabled=was_enabled)
+
+
+class BaseTestTaskGroup:
async def test_taskgroup_01(self):
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
tg.create_task(coro)
- def test_coro_closed_when_tg_closed(self):
+ async def test_coro_closed_when_tg_closed(self):
async def run_coro_after_tg_closes():
async with taskgroups.TaskGroup() as tg:
pass
coro = asyncio.sleep(0)
with self.assertRaisesRegex(RuntimeError, "is finished"):
tg.create_task(coro)
- loop = asyncio.get_event_loop()
- loop.run_until_complete(run_coro_after_tg_closes())
+
+ await run_coro_after_tg_closes()
async def test_cancelling_level_preserved(self):
async def raise_after(t, e):
self.assertIsInstance(exc, _Done)
self.assertListEqual(gc.get_referrers(exc), no_other_refs())
+
+ async def test_exception_refcycles_parent_task_wr(self):
+ """Test that TaskGroup deletes self._parent_task and create_task() deletes task"""
+ tg = asyncio.TaskGroup()
+ exc = None
+
+ class _Done(Exception):
+ pass
+
+ async def coro_fn():
+ async with tg:
+ raise _Done
+
+ with disable_gc():
+ try:
+ async with asyncio.TaskGroup() as tg2:
+ task_wr = weakref.ref(tg2.create_task(coro_fn()))
+ except* _Done as excs:
+ exc = excs.exceptions[0].exceptions[0]
+
+ self.assertIsNone(task_wr())
+ self.assertIsInstance(exc, _Done)
+ self.assertListEqual(gc.get_referrers(exc), no_other_refs())
+
async def test_exception_refcycles_propagate_cancellation_error(self):
"""Test that TaskGroup deletes propagate_cancellation_error"""
tg = asyncio.TaskGroup()
self.assertListEqual(gc.get_referrers(exc), no_other_refs())
+class TestTaskGroup(BaseTestTaskGroup, unittest.IsolatedAsyncioTestCase):
+ loop_factory = asyncio.EventLoop
+
+class TestEagerTaskTaskGroup(BaseTestTaskGroup, unittest.IsolatedAsyncioTestCase):
+ @staticmethod
+ def loop_factory():
+ loop = asyncio.EventLoop()
+ loop.set_task_factory(asyncio.eager_task_factory)
+ return loop
+
+
if __name__ == "__main__":
unittest.main()