]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-124309: Revert eager task factory fix to prevent breaking downstream (#124810)
authorPeter Bierma <zintensitydev@gmail.com>
Tue, 1 Oct 2024 01:37:27 +0000 (21:37 -0400)
committerGitHub <noreply@github.com>
Tue, 1 Oct 2024 01:37:27 +0000 (18:37 -0700)
* Revert "GH-124639: add back loop param to staggered_race (#124700)"

This reverts commit e0a41a5dd12cb6e9277b05abebac5c70be684dd7.

* Revert "gh-124309: Modernize the `staggered_race` implementation to support eager task factories (#124390)"

This reverts commit de929f353c413459834a2a37b2d9b0240673d874.

Lib/asyncio/base_events.py
Lib/asyncio/staggered.py
Lib/test/test_asyncio/test_eager_task_factory.py
Lib/test/test_asyncio/test_staggered.py
Misc/NEWS.d/next/Library/2024-09-23-18-18-23.gh-issue-124309.iFcarA.rst [deleted file]

index ffcc0174e1e245e1ab6f33328a0315b9d6483aca..000647f57dd9e30c5f24ce14e406bcd82ff54f08 100644 (file)
@@ -1144,7 +1144,7 @@ class BaseEventLoop(events.AbstractEventLoop):
                     (functools.partial(self._connect_sock,
                                        exceptions, addrinfo, laddr_infos)
                      for addrinfo in infos),
-                    happy_eyeballs_delay)
+                    happy_eyeballs_delay, loop=self)
 
             if sock is None:
                 exceptions = [exc for sub in exceptions for exc in sub]
index 6ccf5c3c269ff0578283b28f847503ba170480cc..c3a7441a7b091d3ac9967488828dc814948ad935 100644 (file)
@@ -4,12 +4,11 @@ __all__ = 'staggered_race',
 
 import contextlib
 
+from . import events
+from . import exceptions as exceptions_mod
 from . import locks
 from . import tasks
-from . import taskgroups
 
-class _Done(Exception):
-    pass
 
 async def staggered_race(coro_fns, delay, *, loop=None):
     """Run coroutines with staggered start times and take the first to finish.
@@ -43,6 +42,8 @@ async def staggered_race(coro_fns, delay, *, loop=None):
         delay: amount of time, in seconds, between starting coroutines. If
             ``None``, the coroutines will run sequentially.
 
+        loop: the event loop to use.
+
     Returns:
         tuple *(winner_result, winner_index, exceptions)* where
 
@@ -61,11 +62,36 @@ async def staggered_race(coro_fns, delay, *, loop=None):
 
     """
     # TODO: when we have aiter() and anext(), allow async iterables in coro_fns.
+    loop = loop or events.get_running_loop()
+    enum_coro_fns = enumerate(coro_fns)
     winner_result = None
     winner_index = None
     exceptions = []
+    running_tasks = []
+
+    async def run_one_coro(previous_failed) -> None:
+        # Wait for the previous task to finish, or for delay seconds
+        if previous_failed is not None:
+            with contextlib.suppress(exceptions_mod.TimeoutError):
+                # Use asyncio.wait_for() instead of asyncio.wait() here, so
+                # that if we get cancelled at this point, Event.wait() is also
+                # cancelled, otherwise there will be a "Task destroyed but it is
+                # pending" later.
+                await tasks.wait_for(previous_failed.wait(), delay)
+        # Get the next coroutine to run
+        try:
+            this_index, coro_fn = next(enum_coro_fns)
+        except StopIteration:
+            return
+        # Start task that will run the next coroutine
+        this_failed = locks.Event()
+        next_task = loop.create_task(run_one_coro(this_failed))
+        running_tasks.append(next_task)
+        assert len(running_tasks) == this_index + 2
+        # Prepare place to put this coroutine's exceptions if not won
+        exceptions.append(None)
+        assert len(exceptions) == this_index + 1
 
-    async def run_one_coro(this_index, coro_fn, this_failed):
         try:
             result = await coro_fn()
         except (SystemExit, KeyboardInterrupt):
@@ -79,23 +105,34 @@ async def staggered_race(coro_fns, delay, *, loop=None):
             assert winner_index is None
             winner_index = this_index
             winner_result = result
-            raise _Done
-
+            # Cancel all other tasks. We take care to not cancel the current
+            # task as well. If we do so, then since there is no `await` after
+            # here and CancelledError are usually thrown at one, we will
+            # encounter a curious corner case where the current task will end
+            # up as done() == True, cancelled() == False, exception() ==
+            # asyncio.CancelledError. This behavior is specified in
+            # https://bugs.python.org/issue30048
+            for i, t in enumerate(running_tasks):
+                if i != this_index:
+                    t.cancel()
+
+    first_task = loop.create_task(run_one_coro(None))
+    running_tasks.append(first_task)
     try:
-        tg = taskgroups.TaskGroup()
-        # Intentionally override the loop in the TaskGroup to avoid
-        # using the running loop, preserving backwards compatibility
-        # TaskGroup only starts using `_loop` after `__aenter__`
-        # so overriding it here is safe.
-        tg._loop = loop
-        async with tg:
-            for this_index, coro_fn in enumerate(coro_fns):
-                this_failed = locks.Event()
-                exceptions.append(None)
-                tg.create_task(run_one_coro(this_index, coro_fn, this_failed))
-                with contextlib.suppress(TimeoutError):
-                    await tasks.wait_for(this_failed.wait(), delay)
-    except* _Done:
-        pass
-
-    return winner_result, winner_index, exceptions
+        # Wait for a growing list of tasks to all finish: poor man's version of
+        # curio's TaskGroup or trio's nursery
+        done_count = 0
+        while done_count != len(running_tasks):
+            done, _ = await tasks.wait(running_tasks)
+            done_count = len(done)
+            # If run_one_coro raises an unhandled exception, it's probably a
+            # programming error, and I want to see it.
+            if __debug__:
+                for d in done:
+                    if d.done() and not d.cancelled() and d.exception():
+                        raise d.exception()
+        return winner_result, winner_index, exceptions
+    finally:
+        # Make sure no tasks are left running if we leave this function
+        for t in running_tasks:
+            t.cancel()
index 1579ad1188d725e827946f649bdbd896944da5f6..0777f39b57248605964fed372055d3391fa4415a 100644 (file)
@@ -213,53 +213,6 @@ class EagerTaskFactoryLoopTests:
 
         self.run_coro(run())
 
-    def test_staggered_race_with_eager_tasks(self):
-        # See https://github.com/python/cpython/issues/124309
-
-        async def fail():
-            await asyncio.sleep(0)
-            raise ValueError("no good")
-
-        async def run():
-            winner, index, excs = await asyncio.staggered.staggered_race(
-                [
-                    lambda: asyncio.sleep(2, result="sleep2"),
-                    lambda: asyncio.sleep(1, result="sleep1"),
-                    lambda: fail()
-                ],
-                delay=0.25
-            )
-            self.assertEqual(winner, 'sleep1')
-            self.assertEqual(index, 1)
-            self.assertIsNone(excs[index])
-            self.assertIsInstance(excs[0], asyncio.CancelledError)
-            self.assertIsInstance(excs[2], ValueError)
-
-        self.run_coro(run())
-
-    def test_staggered_race_with_eager_tasks_no_delay(self):
-        # See https://github.com/python/cpython/issues/124309
-        async def fail():
-            raise ValueError("no good")
-
-        async def run():
-            winner, index, excs = await asyncio.staggered.staggered_race(
-                [
-                    lambda: fail(),
-                    lambda: asyncio.sleep(1, result="sleep1"),
-                    lambda: asyncio.sleep(0, result="sleep0"),
-                ],
-                delay=None
-            )
-            self.assertEqual(winner, 'sleep1')
-            self.assertEqual(index, 1)
-            self.assertIsNone(excs[index])
-            self.assertIsInstance(excs[0], ValueError)
-            self.assertEqual(len(excs), 2)
-
-        self.run_coro(run())
-
-
 
 class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
     Task = tasks._PyTask
index 8cd98394aea8f8a49f24b0224231cfc1bfb7921b..e6e32f7dbbbcba016c71df8b7e404ce153971ad6 100644 (file)
@@ -82,64 +82,16 @@ class StaggeredTests(unittest.IsolatedAsyncioTestCase):
         async def coro(index):
             raise ValueError(index)
 
-        for delay in [None, 0, 0.1, 1]:
-            with self.subTest(delay=delay):
-                winner, index, excs = await staggered_race(
-                    [
-                        lambda: coro(0),
-                        lambda: coro(1),
-                    ],
-                    delay=delay,
-                )
-
-                self.assertIs(winner, None)
-                self.assertIs(index, None)
-                self.assertEqual(len(excs), 2)
-                self.assertIsInstance(excs[0], ValueError)
-                self.assertIsInstance(excs[1], ValueError)
-
-    async def test_long_delay_early_failure(self):
-        async def coro(index):
-            await asyncio.sleep(0)  # Dummy coroutine for the 1 case
-            if index == 0:
-                await asyncio.sleep(0.1)  # Dummy coroutine
-                raise ValueError(index)
-
-            return f'Res: {index}'
-
         winner, index, excs = await staggered_race(
             [
                 lambda: coro(0),
                 lambda: coro(1),
             ],
-            delay=10,
+            delay=None,
         )
 
-        self.assertEqual(winner, 'Res: 1')
-        self.assertEqual(index, 1)
+        self.assertIs(winner, None)
+        self.assertIs(index, None)
         self.assertEqual(len(excs), 2)
         self.assertIsInstance(excs[0], ValueError)
-        self.assertIsNone(excs[1])
-
-    def test_loop_argument(self):
-        loop = asyncio.new_event_loop()
-        async def coro():
-            self.assertEqual(loop, asyncio.get_running_loop())
-            return 'coro'
-
-        async def main():
-            winner, index, excs = await staggered_race(
-                [coro],
-                delay=0.1,
-                loop=loop
-            )
-
-            self.assertEqual(winner, 'coro')
-            self.assertEqual(index, 0)
-
-        loop.run_until_complete(main())
-        loop.close()
-
-
-if __name__ == "__main__":
-    unittest.main()
+        self.assertIsInstance(excs[1], ValueError)
diff --git a/Misc/NEWS.d/next/Library/2024-09-23-18-18-23.gh-issue-124309.iFcarA.rst b/Misc/NEWS.d/next/Library/2024-09-23-18-18-23.gh-issue-124309.iFcarA.rst
deleted file mode 100644 (file)
index 89610fa..0000000
+++ /dev/null
@@ -1 +0,0 @@
-Fixed :exc:`AssertionError` when using :func:`!asyncio.staggered.staggered_race` with :attr:`asyncio.eager_task_factory`.