]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
[3.12] gh-124309: Modernize the `staggered_race` implementation to support e… (#124574)
authorKumar Aditya <kumaraditya@python.org>
Thu, 26 Sep 2024 05:39:46 +0000 (11:09 +0530)
committerGitHub <noreply@github.com>
Thu, 26 Sep 2024 05:39:46 +0000 (05:39 +0000)
gh-124309: Modernize the `staggered_race` implementation to support eager task factories (#124390)

Co-authored-by: Thomas Grainger <tagrain@gmail.com>
Co-authored-by: Jelle Zijlstra <jelle.zijlstra@gmail.com>
Co-authored-by: Carol Willing <carolcode@willingconsulting.com>
Co-authored-by: Kumar Aditya <kumaraditya@python.org>
(cherry picked from commit de929f353c413459834a2a37b2d9b0240673d874)

Co-authored-by: Peter Bierma <zintensitydev@gmail.com>
Lib/asyncio/base_events.py
Lib/asyncio/staggered.py
Lib/test/test_asyncio/test_eager_task_factory.py
Lib/test/test_asyncio/test_staggered.py [new file with mode: 0644]
Misc/NEWS.d/next/Library/2024-09-23-18-18-23.gh-issue-124309.iFcarA.rst [new file with mode: 0644]

index cb037fd472c5aa3bed4b64cfa401cc9889768402..02b900891949b380e3070b6d0c07cc16e0d87efe 100644 (file)
@@ -1110,7 +1110,7 @@ class BaseEventLoop(events.AbstractEventLoop):
                     (functools.partial(self._connect_sock,
                                        exceptions, addrinfo, laddr_infos)
                      for addrinfo in infos),
-                    happy_eyeballs_delay, loop=self)
+                    happy_eyeballs_delay)
 
             if sock is None:
                 exceptions = [exc for sub in exceptions for exc in sub]
index 451a53a16f3831611b6fc9b55f9eff5468924793..4458d01dece0e6c99027bbfec06b585cde167925 100644 (file)
@@ -3,24 +3,15 @@
 __all__ = 'staggered_race',
 
 import contextlib
-import typing
 
-from . import events
-from . import exceptions as exceptions_mod
 from . import locks
 from . import tasks
+from . import taskgroups
 
+class _Done(Exception):
+    pass
 
-async def staggered_race(
-        coro_fns: typing.Iterable[typing.Callable[[], typing.Awaitable]],
-        delay: typing.Optional[float],
-        *,
-        loop: events.AbstractEventLoop = None,
-) -> typing.Tuple[
-    typing.Any,
-    typing.Optional[int],
-    typing.List[typing.Optional[Exception]]
-]:
+async def staggered_race(coro_fns, delay):
     """Run coroutines with staggered start times and take the first to finish.
 
     This method takes an iterable of coroutine functions. The first one is
@@ -52,8 +43,6 @@ async def staggered_race(
         delay: amount of time, in seconds, between starting coroutines. If
             ``None``, the coroutines will run sequentially.
 
-        loop: the event loop to use.
-
     Returns:
         tuple *(winner_result, winner_index, exceptions)* where
 
@@ -72,37 +61,11 @@ async def staggered_race(
 
     """
     # TODO: when we have aiter() and anext(), allow async iterables in coro_fns.
-    loop = loop or events.get_running_loop()
-    enum_coro_fns = enumerate(coro_fns)
     winner_result = None
     winner_index = None
     exceptions = []
-    running_tasks = []
-
-    async def run_one_coro(
-            previous_failed: typing.Optional[locks.Event]) -> None:
-        # Wait for the previous task to finish, or for delay seconds
-        if previous_failed is not None:
-            with contextlib.suppress(exceptions_mod.TimeoutError):
-                # Use asyncio.wait_for() instead of asyncio.wait() here, so
-                # that if we get cancelled at this point, Event.wait() is also
-                # cancelled, otherwise there will be a "Task destroyed but it is
-                # pending" later.
-                await tasks.wait_for(previous_failed.wait(), delay)
-        # Get the next coroutine to run
-        try:
-            this_index, coro_fn = next(enum_coro_fns)
-        except StopIteration:
-            return
-        # Start task that will run the next coroutine
-        this_failed = locks.Event()
-        next_task = loop.create_task(run_one_coro(this_failed))
-        running_tasks.append(next_task)
-        assert len(running_tasks) == this_index + 2
-        # Prepare place to put this coroutine's exceptions if not won
-        exceptions.append(None)
-        assert len(exceptions) == this_index + 1
 
+    async def run_one_coro(this_index, coro_fn, this_failed):
         try:
             result = await coro_fn()
         except (SystemExit, KeyboardInterrupt):
@@ -116,34 +79,17 @@ async def staggered_race(
             assert winner_index is None
             winner_index = this_index
             winner_result = result
-            # Cancel all other tasks. We take care to not cancel the current
-            # task as well. If we do so, then since there is no `await` after
-            # here and CancelledError are usually thrown at one, we will
-            # encounter a curious corner case where the current task will end
-            # up as done() == True, cancelled() == False, exception() ==
-            # asyncio.CancelledError. This behavior is specified in
-            # https://bugs.python.org/issue30048
-            for i, t in enumerate(running_tasks):
-                if i != this_index:
-                    t.cancel()
-
-    first_task = loop.create_task(run_one_coro(None))
-    running_tasks.append(first_task)
+            raise _Done
+
     try:
-        # Wait for a growing list of tasks to all finish: poor man's version of
-        # curio's TaskGroup or trio's nursery
-        done_count = 0
-        while done_count != len(running_tasks):
-            done, _ = await tasks.wait(running_tasks)
-            done_count = len(done)
-            # If run_one_coro raises an unhandled exception, it's probably a
-            # programming error, and I want to see it.
-            if __debug__:
-                for d in done:
-                    if d.done() and not d.cancelled() and d.exception():
-                        raise d.exception()
-        return winner_result, winner_index, exceptions
-    finally:
-        # Make sure no tasks are left running if we leave this function
-        for t in running_tasks:
-            t.cancel()
+        async with taskgroups.TaskGroup() as tg:
+            for this_index, coro_fn in enumerate(coro_fns):
+                this_failed = locks.Event()
+                exceptions.append(None)
+                tg.create_task(run_one_coro(this_index, coro_fn, this_failed))
+                with contextlib.suppress(TimeoutError):
+                    await tasks.wait_for(this_failed.wait(), delay)
+    except* _Done:
+        pass
+
+    return winner_result, winner_index, exceptions
index 58c06287bc3c5de35df8a3b4a9cc9fb38653c01d..ed74c6ecbd83f436470f5aaed3c7a4e73040db9c 100644 (file)
@@ -218,6 +218,53 @@ class EagerTaskFactoryLoopTests:
 
         self.run_coro(run())
 
+    def test_staggered_race_with_eager_tasks(self):
+        # See https://github.com/python/cpython/issues/124309
+
+        async def fail():
+            await asyncio.sleep(0)
+            raise ValueError("no good")
+
+        async def run():
+            winner, index, excs = await asyncio.staggered.staggered_race(
+                [
+                    lambda: asyncio.sleep(2, result="sleep2"),
+                    lambda: asyncio.sleep(1, result="sleep1"),
+                    lambda: fail()
+                ],
+                delay=0.25
+            )
+            self.assertEqual(winner, 'sleep1')
+            self.assertEqual(index, 1)
+            self.assertIsNone(excs[index])
+            self.assertIsInstance(excs[0], asyncio.CancelledError)
+            self.assertIsInstance(excs[2], ValueError)
+
+        self.run_coro(run())
+
+    def test_staggered_race_with_eager_tasks_no_delay(self):
+        # See https://github.com/python/cpython/issues/124309
+        async def fail():
+            raise ValueError("no good")
+
+        async def run():
+            winner, index, excs = await asyncio.staggered.staggered_race(
+                [
+                    lambda: fail(),
+                    lambda: asyncio.sleep(1, result="sleep1"),
+                    lambda: asyncio.sleep(0, result="sleep0"),
+                ],
+                delay=None
+            )
+            self.assertEqual(winner, 'sleep1')
+            self.assertEqual(index, 1)
+            self.assertIsNone(excs[index])
+            self.assertIsInstance(excs[0], ValueError)
+            self.assertEqual(len(excs), 2)
+
+        self.run_coro(run())
+
+
 
 class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
     Task = tasks._PyTask
diff --git a/Lib/test/test_asyncio/test_staggered.py b/Lib/test/test_asyncio/test_staggered.py
new file mode 100644 (file)
index 0000000..21a39b3
--- /dev/null
@@ -0,0 +1,126 @@
+import asyncio
+import unittest
+from asyncio.staggered import staggered_race
+
+from test import support
+
+support.requires_working_socket(module=True)
+
+
+def tearDownModule():
+    asyncio.set_event_loop_policy(None)
+
+
+class StaggeredTests(unittest.IsolatedAsyncioTestCase):
+    async def test_empty(self):
+        winner, index, excs = await staggered_race(
+            [],
+            delay=None,
+        )
+
+        self.assertIs(winner, None)
+        self.assertIs(index, None)
+        self.assertEqual(excs, [])
+
+    async def test_one_successful(self):
+        async def coro(index):
+            return f'Res: {index}'
+
+        winner, index, excs = await staggered_race(
+            [
+                lambda: coro(0),
+                lambda: coro(1),
+            ],
+            delay=None,
+        )
+
+        self.assertEqual(winner, 'Res: 0')
+        self.assertEqual(index, 0)
+        self.assertEqual(excs, [None])
+
+    async def test_first_error_second_successful(self):
+        async def coro(index):
+            if index == 0:
+                raise ValueError(index)
+            return f'Res: {index}'
+
+        winner, index, excs = await staggered_race(
+            [
+                lambda: coro(0),
+                lambda: coro(1),
+            ],
+            delay=None,
+        )
+
+        self.assertEqual(winner, 'Res: 1')
+        self.assertEqual(index, 1)
+        self.assertEqual(len(excs), 2)
+        self.assertIsInstance(excs[0], ValueError)
+        self.assertIs(excs[1], None)
+
+    async def test_first_timeout_second_successful(self):
+        async def coro(index):
+            if index == 0:
+                await asyncio.sleep(10)  # much bigger than delay
+            return f'Res: {index}'
+
+        winner, index, excs = await staggered_race(
+            [
+                lambda: coro(0),
+                lambda: coro(1),
+            ],
+            delay=0.1,
+        )
+
+        self.assertEqual(winner, 'Res: 1')
+        self.assertEqual(index, 1)
+        self.assertEqual(len(excs), 2)
+        self.assertIsInstance(excs[0], asyncio.CancelledError)
+        self.assertIs(excs[1], None)
+
+    async def test_none_successful(self):
+        async def coro(index):
+            raise ValueError(index)
+
+        for delay in [None, 0, 0.1, 1]:
+            with self.subTest(delay=delay):
+                winner, index, excs = await staggered_race(
+                    [
+                        lambda: coro(0),
+                        lambda: coro(1),
+                    ],
+                    delay=delay,
+                )
+
+                self.assertIs(winner, None)
+                self.assertIs(index, None)
+                self.assertEqual(len(excs), 2)
+                self.assertIsInstance(excs[0], ValueError)
+                self.assertIsInstance(excs[1], ValueError)
+
+    async def test_long_delay_early_failure(self):
+        async def coro(index):
+            await asyncio.sleep(0)  # Dummy coroutine for the 1 case
+            if index == 0:
+                await asyncio.sleep(0.1)  # Dummy coroutine
+                raise ValueError(index)
+
+            return f'Res: {index}'
+
+        winner, index, excs = await staggered_race(
+            [
+                lambda: coro(0),
+                lambda: coro(1),
+            ],
+            delay=10,
+        )
+
+        self.assertEqual(winner, 'Res: 1')
+        self.assertEqual(index, 1)
+        self.assertEqual(len(excs), 2)
+        self.assertIsInstance(excs[0], ValueError)
+        self.assertIsNone(excs[1])
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/Misc/NEWS.d/next/Library/2024-09-23-18-18-23.gh-issue-124309.iFcarA.rst b/Misc/NEWS.d/next/Library/2024-09-23-18-18-23.gh-issue-124309.iFcarA.rst
new file mode 100644 (file)
index 0000000..89610fa
--- /dev/null
@@ -0,0 +1 @@
+Fixed :exc:`AssertionError` when using :func:`!asyncio.staggered.staggered_race` with :attr:`asyncio.eager_task_factory`.