]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
[3.12] gh-124309: fix staggered race on eager tasks (GH-124847) (#125340)
authorMiss Islington (bot) <31488909+miss-islington@users.noreply.github.com>
Sat, 12 Oct 2024 03:12:11 +0000 (05:12 +0200)
committerGitHub <noreply@github.com>
Sat, 12 Oct 2024 03:12:11 +0000 (20:12 -0700)
gh-124309: fix staggered race on eager tasks (GH-124847)

This patch is entirely by Thomas and Peter

(cherry picked from commit 979c0df7c0adfb744159a5fc184043dc733d8534)

Co-authored-by: Thomas Grainger <tagrain@gmail.com>
Co-authored-by: Peter Bierma <zintensitydev@gmail.com>
Lib/asyncio/staggered.py
Lib/test/test_asyncio/test_eager_task_factory.py
Lib/test/test_asyncio/test_staggered.py
Misc/NEWS.d/next/Library/2024-10-01-13-46-58.gh-issue-124390.dK1Zcm.rst [new file with mode: 0644]

index c3a7441a7b091d3ac9967488828dc814948ad935..7aafcea4d885eb5819f7b9efdb7f00bb8e984487 100644 (file)
@@ -69,7 +69,11 @@ async def staggered_race(coro_fns, delay, *, loop=None):
     exceptions = []
     running_tasks = []
 
-    async def run_one_coro(previous_failed) -> None:
+    async def run_one_coro(ok_to_start, previous_failed) -> None:
+        # in eager tasks this waits for the calling task to append this task
+        # to running_tasks, in regular tasks this wait is a no-op that does
+        # not yield a future. See gh-124309.
+        await ok_to_start.wait()
         # Wait for the previous task to finish, or for delay seconds
         if previous_failed is not None:
             with contextlib.suppress(exceptions_mod.TimeoutError):
@@ -85,8 +89,12 @@ async def staggered_race(coro_fns, delay, *, loop=None):
             return
         # Start task that will run the next coroutine
         this_failed = locks.Event()
-        next_task = loop.create_task(run_one_coro(this_failed))
+        next_ok_to_start = locks.Event()
+        next_task = loop.create_task(run_one_coro(next_ok_to_start, this_failed))
         running_tasks.append(next_task)
+        # next_task has been appended to running_tasks so next_task is ok to
+        # start.
+        next_ok_to_start.set()
         assert len(running_tasks) == this_index + 2
         # Prepare place to put this coroutine's exceptions if not won
         exceptions.append(None)
@@ -116,8 +124,11 @@ async def staggered_race(coro_fns, delay, *, loop=None):
                 if i != this_index:
                     t.cancel()
 
-    first_task = loop.create_task(run_one_coro(None))
+    ok_to_start = locks.Event()
+    first_task = loop.create_task(run_one_coro(ok_to_start, None))
     running_tasks.append(first_task)
+    # first_task has been appended to running_tasks so first_task is ok to start.
+    ok_to_start.set()
     try:
         # Wait for a growing list of tasks to all finish: poor man's version of
         # curio's TaskGroup or trio's nursery
index 58c06287bc3c5de35df8a3b4a9cc9fb38653c01d..b06832e02f00d6a545eccf6e107a632274063bf2 100644 (file)
@@ -218,6 +218,52 @@ class EagerTaskFactoryLoopTests:
 
         self.run_coro(run())
 
+    def test_staggered_race_with_eager_tasks(self):
+        # See https://github.com/python/cpython/issues/124309
+
+        async def fail():
+            await asyncio.sleep(0)
+            raise ValueError("no good")
+
+        async def run():
+            winner, index, excs = await asyncio.staggered.staggered_race(
+                [
+                    lambda: asyncio.sleep(2, result="sleep2"),
+                    lambda: asyncio.sleep(1, result="sleep1"),
+                    lambda: fail()
+                ],
+                delay=0.25
+            )
+            self.assertEqual(winner, 'sleep1')
+            self.assertEqual(index, 1)
+            self.assertIsNone(excs[index])
+            self.assertIsInstance(excs[0], asyncio.CancelledError)
+            self.assertIsInstance(excs[2], ValueError)
+
+        self.run_coro(run())
+
+    def test_staggered_race_with_eager_tasks_no_delay(self):
+        # See https://github.com/python/cpython/issues/124309
+        async def fail():
+            raise ValueError("no good")
+
+        async def run():
+            winner, index, excs = await asyncio.staggered.staggered_race(
+                [
+                    lambda: fail(),
+                    lambda: asyncio.sleep(1, result="sleep1"),
+                    lambda: asyncio.sleep(0, result="sleep0"),
+                ],
+                delay=None
+            )
+            self.assertEqual(winner, 'sleep1')
+            self.assertEqual(index, 1)
+            self.assertIsNone(excs[index])
+            self.assertIsInstance(excs[0], ValueError)
+            self.assertEqual(len(excs), 2)
+
+        self.run_coro(run())
+
 
 class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
     Task = tasks._PyTask
index e6e32f7dbbbcba016c71df8b7e404ce153971ad6..74941f704c48901c62e7e5304f62b5f196289312 100644 (file)
@@ -95,3 +95,30 @@ class StaggeredTests(unittest.IsolatedAsyncioTestCase):
         self.assertEqual(len(excs), 2)
         self.assertIsInstance(excs[0], ValueError)
         self.assertIsInstance(excs[1], ValueError)
+
+
+    async def test_multiple_winners(self):
+        event = asyncio.Event()
+
+        async def coro(index):
+            await event.wait()
+            return index
+
+        async def do_set():
+            event.set()
+            await asyncio.Event().wait()
+
+        winner, index, excs = await staggered_race(
+            [
+                lambda: coro(0),
+                lambda: coro(1),
+                do_set,
+            ],
+            delay=0.1,
+        )
+        self.assertIs(winner, 0)
+        self.assertIs(index, 0)
+        self.assertEqual(len(excs), 3)
+        self.assertIsNone(excs[0], None)
+        self.assertIsInstance(excs[1], asyncio.CancelledError)
+        self.assertIsInstance(excs[2], asyncio.CancelledError)
diff --git a/Misc/NEWS.d/next/Library/2024-10-01-13-46-58.gh-issue-124390.dK1Zcm.rst b/Misc/NEWS.d/next/Library/2024-10-01-13-46-58.gh-issue-124390.dK1Zcm.rst
new file mode 100644 (file)
index 0000000..89610fa
--- /dev/null
@@ -0,0 +1 @@
+Fixed :exc:`AssertionError` when using :func:`!asyncio.staggered.staggered_race` with :attr:`asyncio.eager_task_factory`.