]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
[3.12] gh-124958: fix asyncio.TaskGroup and _PyFuture refcycles (#124959) (#125466)
authorThomas Grainger <tagrain@gmail.com>
Thu, 17 Oct 2024 04:45:59 +0000 (05:45 +0100)
committerGitHub <noreply@github.com>
Thu, 17 Oct 2024 04:45:59 +0000 (21:45 -0700)
gh-124958: fix asyncio.TaskGroup and _PyFuture refcycles (#124959)

Lib/asyncio/futures.py
Lib/asyncio/taskgroups.py
Lib/test/test_asyncio/test_futures.py
Lib/test/test_asyncio/test_taskgroups.py
Misc/NEWS.d/next/Library/2024-10-04-08-46-00.gh-issue-124958.rea9-x.rst [new file with mode: 0644]

index fd486f02c67c8e066273d5c30361976bac2b2e0d..0c530bbdbcf2d8e4d6836447c7c2553a900e8a21 100644 (file)
@@ -194,8 +194,7 @@ class Future:
         the future is done and has an exception set, this exception is raised.
         """
         if self._state == _CANCELLED:
-            exc = self._make_cancelled_error()
-            raise exc
+            raise self._make_cancelled_error()
         if self._state != _FINISHED:
             raise exceptions.InvalidStateError('Result is not ready.')
         self.__log_traceback = False
@@ -212,8 +211,7 @@ class Future:
         InvalidStateError.
         """
         if self._state == _CANCELLED:
-            exc = self._make_cancelled_error()
-            raise exc
+            raise self._make_cancelled_error()
         if self._state != _FINISHED:
             raise exceptions.InvalidStateError('Exception is not set.')
         self.__log_traceback = False
index d264e51f1fd4e67d82cd29199fa1c30faf231a5d..aada3ffa8e0f292ad76da4105ebd8db01898ba31 100644 (file)
@@ -66,6 +66,20 @@ class TaskGroup:
         return self
 
     async def __aexit__(self, et, exc, tb):
+        tb = None
+        try:
+            return await self._aexit(et, exc)
+        finally:
+            # Exceptions are heavy objects that can have object
+            # cycles (bad for GC); let's not keep a reference to
+            # a bunch of them. It would be nicer to use a try/finally
+            # in __aexit__ directly but that introduced some diff noise
+            self._parent_task = None
+            self._errors = None
+            self._base_error = None
+            exc = None
+
+    async def _aexit(self, et, exc):
         self._exiting = True
 
         if (exc is not None and
@@ -126,25 +140,34 @@ class TaskGroup:
         assert not self._tasks
 
         if self._base_error is not None:
-            raise self._base_error
+            try:
+                raise self._base_error
+            finally:
+                exc = None
 
         # Propagate CancelledError if there is one, except if there
         # are other errors -- those have priority.
-        if propagate_cancellation_error and not self._errors:
-            raise propagate_cancellation_error
+        try:
+            if propagate_cancellation_error and not self._errors:
+                try:
+                    raise propagate_cancellation_error
+                finally:
+                    exc = None
+        finally:
+            propagate_cancellation_error = None
 
         if et is not None and et is not exceptions.CancelledError:
             self._errors.append(exc)
 
         if self._errors:
-            # Exceptions are heavy objects that can have object
-            # cycles (bad for GC); let's not keep a reference to
-            # a bunch of them.
             try:
-                me = BaseExceptionGroup('unhandled errors in a TaskGroup', self._errors)
-                raise me from None
+                raise BaseExceptionGroup(
+                    'unhandled errors in a TaskGroup',
+                    self._errors,
+                ) from None
             finally:
-                self._errors = None
+                exc = None
+
 
     def create_task(self, coro, *, name=None, context=None):
         """Create a new task in this group and return it.
index 47daa0e9f410a871e3207207f952352acdb38b54..050d33f4fab3ed7e430e27388a759e18afd77045 100644 (file)
@@ -640,6 +640,28 @@ class BaseFutureTests:
             fut = self._new_future(loop=self.loop)
             fut.set_result(Evil())
 
+    def test_future_cancelled_result_refcycles(self):
+        f = self._new_future(loop=self.loop)
+        f.cancel()
+        exc = None
+        try:
+            f.result()
+        except asyncio.CancelledError as e:
+            exc = e
+        self.assertIsNotNone(exc)
+        self.assertListEqual(gc.get_referrers(exc), [])
+
+    def test_future_cancelled_exception_refcycles(self):
+        f = self._new_future(loop=self.loop)
+        f.cancel()
+        exc = None
+        try:
+            f.exception()
+        except asyncio.CancelledError as e:
+            exc = e
+        self.assertIsNotNone(exc)
+        self.assertListEqual(gc.get_referrers(exc), [])
+
 
 @unittest.skipUnless(hasattr(futures, '_CFuture'),
                      'requires the C _asyncio module')
index 7a18362b54e4695a463c84548a682d30a15383f2..236bfaaccf88faeaa0dc2e98a685ed7a5b69bc97 100644 (file)
@@ -1,7 +1,7 @@
 # Adapted with permission from the EdgeDB project;
 # license: PSFL.
 
-
+import gc
 import asyncio
 import contextvars
 import contextlib
@@ -10,7 +10,6 @@ import unittest
 
 from test.test_asyncio.utils import await_without_task
 
-
 # To prevent a warning "test altered the execution environment"
 def tearDownModule():
     asyncio.set_event_loop_policy(None)
@@ -824,6 +823,95 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
         # We still have to await coro to avoid a warning
         await coro
 
+    async def test_exception_refcycles_direct(self):
+        """Test that TaskGroup doesn't keep a reference to the raised ExceptionGroup"""
+        tg = asyncio.TaskGroup()
+        exc = None
+
+        class _Done(Exception):
+            pass
+
+        try:
+            async with tg:
+                raise _Done
+        except ExceptionGroup as e:
+            exc = e
+
+        self.assertIsNotNone(exc)
+        self.assertListEqual(gc.get_referrers(exc), [])
+
+
+    async def test_exception_refcycles_errors(self):
+        """Test that TaskGroup deletes self._errors, and __aexit__ args"""
+        tg = asyncio.TaskGroup()
+        exc = None
+
+        class _Done(Exception):
+            pass
+
+        try:
+            async with tg:
+                raise _Done
+        except* _Done as excs:
+            exc = excs.exceptions[0]
+
+        self.assertIsInstance(exc, _Done)
+        self.assertListEqual(gc.get_referrers(exc), [])
+
+
+    async def test_exception_refcycles_parent_task(self):
+        """Test that TaskGroup deletes self._parent_task"""
+        tg = asyncio.TaskGroup()
+        exc = None
+
+        class _Done(Exception):
+            pass
+
+        async def coro_fn():
+            async with tg:
+                raise _Done
+
+        try:
+            async with asyncio.TaskGroup() as tg2:
+                tg2.create_task(coro_fn())
+        except* _Done as excs:
+            exc = excs.exceptions[0].exceptions[0]
+
+        self.assertIsInstance(exc, _Done)
+        self.assertListEqual(gc.get_referrers(exc), [])
+
+    async def test_exception_refcycles_propagate_cancellation_error(self):
+        """Test that TaskGroup deletes propagate_cancellation_error"""
+        tg = asyncio.TaskGroup()
+        exc = None
+
+        try:
+            async with asyncio.timeout(-1):
+                async with tg:
+                    await asyncio.sleep(0)
+        except TimeoutError as e:
+            exc = e.__cause__
+
+        self.assertIsInstance(exc, asyncio.CancelledError)
+        self.assertListEqual(gc.get_referrers(exc), [])
+
+    async def test_exception_refcycles_base_error(self):
+        """Test that TaskGroup deletes self._base_error"""
+        class MyKeyboardInterrupt(KeyboardInterrupt):
+            pass
+
+        tg = asyncio.TaskGroup()
+        exc = None
+
+        try:
+            async with tg:
+                raise MyKeyboardInterrupt
+        except MyKeyboardInterrupt as e:
+            exc = e
+
+        self.assertIsNotNone(exc)
+        self.assertListEqual(gc.get_referrers(exc), [])
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/Misc/NEWS.d/next/Library/2024-10-04-08-46-00.gh-issue-124958.rea9-x.rst b/Misc/NEWS.d/next/Library/2024-10-04-08-46-00.gh-issue-124958.rea9-x.rst
new file mode 100644 (file)
index 0000000..534d5bb
--- /dev/null
@@ -0,0 +1 @@
+Fix refcycles in exceptions raised from :class:`asyncio.TaskGroup` and the python implementation of :class:`asyncio.Future`