]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-128479: fix asyncio staggered race leaking tasks, and logging unhandled exception...
authorThomas Grainger <tagrain@gmail.com>
Thu, 23 Jan 2025 15:53:53 +0000 (15:53 +0000)
committerGitHub <noreply@github.com>
Thu, 23 Jan 2025 15:53:53 +0000 (16:53 +0100)
Co-authored-by: Peter Bierma <zintensitydev@gmail.com>
Lib/asyncio/staggered.py
Lib/test/test_asyncio/test_staggered.py
Misc/NEWS.d/next/Library/2025-01-04-11-10-04.gh-issue-128479.jvOrF-.rst [new file with mode: 0644]

index 0f4df8855a80b91b13c85054e3c8f713b2b690f0..0afed64fdf9c0f23a6e0add752da18b0b02f0f3f 100644 (file)
@@ -66,8 +66,27 @@ async def staggered_race(coro_fns, delay, *, loop=None):
     enum_coro_fns = enumerate(coro_fns)
     winner_result = None
     winner_index = None
+    unhandled_exceptions = []
     exceptions = []
-    running_tasks = []
+    running_tasks = set()
+    on_completed_fut = None
+
+    def task_done(task):
+        running_tasks.discard(task)
+        if (
+            on_completed_fut is not None
+            and not on_completed_fut.done()
+            and not running_tasks
+        ):
+            on_completed_fut.set_result(None)
+
+        if task.cancelled():
+            return
+
+        exc = task.exception()
+        if exc is None:
+            return
+        unhandled_exceptions.append(exc)
 
     async def run_one_coro(ok_to_start, previous_failed) -> None:
         # in eager tasks this waits for the calling task to append this task
@@ -91,11 +110,11 @@ async def staggered_race(coro_fns, delay, *, loop=None):
         this_failed = locks.Event()
         next_ok_to_start = locks.Event()
         next_task = loop.create_task(run_one_coro(next_ok_to_start, this_failed))
-        running_tasks.append(next_task)
+        running_tasks.add(next_task)
+        next_task.add_done_callback(task_done)
         # next_task has been appended to running_tasks so next_task is ok to
         # start.
         next_ok_to_start.set()
-        assert len(running_tasks) == this_index + 2
         # Prepare place to put this coroutine's exceptions if not won
         exceptions.append(None)
         assert len(exceptions) == this_index + 1
@@ -120,31 +139,36 @@ async def staggered_race(coro_fns, delay, *, loop=None):
             # up as done() == True, cancelled() == False, exception() ==
             # asyncio.CancelledError. This behavior is specified in
             # https://bugs.python.org/issue30048
-            for i, t in enumerate(running_tasks):
-                if i != this_index:
+            current_task = tasks.current_task(loop)
+            for t in running_tasks:
+                if t is not current_task:
                     t.cancel()
 
-    ok_to_start = locks.Event()
-    first_task = loop.create_task(run_one_coro(ok_to_start, None))
-    running_tasks.append(first_task)
-    # first_task has been appended to running_tasks so first_task is ok to start.
-    ok_to_start.set()
+    propagate_cancellation_error = None
     try:
-        # Wait for a growing list of tasks to all finish: poor man's version of
-        # curio's TaskGroup or trio's nursery
-        done_count = 0
-        while done_count != len(running_tasks):
-            done, _ = await tasks.wait(running_tasks)
-            done_count = len(done)
+        ok_to_start = locks.Event()
+        first_task = loop.create_task(run_one_coro(ok_to_start, None))
+        running_tasks.add(first_task)
+        first_task.add_done_callback(task_done)
+        # first_task has been appended to running_tasks so first_task is ok to start.
+        ok_to_start.set()
+        propagate_cancellation_error = None
+        # Make sure no tasks are left running if we leave this function
+        while running_tasks:
+            on_completed_fut = loop.create_future()
+            try:
+                await on_completed_fut
+            except exceptions_mod.CancelledError as ex:
+                propagate_cancellation_error = ex
+                for task in running_tasks:
+                    task.cancel(*ex.args)
+            on_completed_fut = None
+        if __debug__ and unhandled_exceptions:
             # If run_one_coro raises an unhandled exception, it's probably a
             # programming error, and I want to see it.
-            if __debug__:
-                for d in done:
-                    if d.done() and not d.cancelled() and d.exception():
-                        raise d.exception()
+            raise ExceptionGroup("staggered race failed", unhandled_exceptions)
+        if propagate_cancellation_error is not None:
+            raise propagate_cancellation_error
         return winner_result, winner_index, exceptions
     finally:
-        del exceptions
-        # Make sure no tasks are left running if we leave this function
-        for t in running_tasks:
-            t.cancel()
+        del exceptions, propagate_cancellation_error, unhandled_exceptions
index 3c81b6296935961cd5394ae7fca4aa6a620e2c4c..ad34aa6da01f546dc97cfd1e0cb4215ebafcc6e2 100644 (file)
@@ -122,3 +122,30 @@ class StaggeredTests(unittest.IsolatedAsyncioTestCase):
         self.assertIsNone(excs[0], None)
         self.assertIsInstance(excs[1], asyncio.CancelledError)
         self.assertIsInstance(excs[2], asyncio.CancelledError)
+
+
+    async def test_cancelled(self):
+        log = []
+        with self.assertRaises(TimeoutError):
+            async with asyncio.timeout(None) as cs_outer, asyncio.timeout(None) as cs_inner:
+                async def coro_fn():
+                    cs_inner.reschedule(-1)
+                    await asyncio.sleep(0)
+                    try:
+                        await asyncio.sleep(0)
+                    except asyncio.CancelledError:
+                        log.append("cancelled 1")
+
+                    cs_outer.reschedule(-1)
+                    await asyncio.sleep(0)
+                    try:
+                        await asyncio.sleep(0)
+                    except asyncio.CancelledError:
+                        log.append("cancelled 2")
+                try:
+                    await staggered_race([coro_fn], delay=None)
+                except asyncio.CancelledError:
+                    log.append("cancelled 3")
+                    raise
+
+        self.assertListEqual(log, ["cancelled 1", "cancelled 2", "cancelled 3"])
diff --git a/Misc/NEWS.d/next/Library/2025-01-04-11-10-04.gh-issue-128479.jvOrF-.rst b/Misc/NEWS.d/next/Library/2025-01-04-11-10-04.gh-issue-128479.jvOrF-.rst
new file mode 100644 (file)
index 0000000..fc3b4d5
--- /dev/null
@@ -0,0 +1 @@
+Fix :func:`!asyncio.staggered.staggered_race` leaking tasks and issuing an unhandled exception.