From: Miss Islington (bot) <31488909+miss-islington@users.noreply.github.com> Date: Sun, 29 Sep 2024 03:40:41 +0000 (+0200) Subject: [3.12] GH-124639: add back loop param to staggered_race (GH-124700) (#124744) X-Git-Tag: v3.12.7~17 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=65103adadd5ba23a3591b6a0de227889a330ebdb;p=thirdparty%2FPython%2Fcpython.git [3.12] GH-124639: add back loop param to staggered_race (GH-124700) (#124744) GH-124639: add back loop param to staggered_race (GH-124700) (cherry picked from commit e0a41a5dd12cb6e9277b05abebac5c70be684dd7) Co-authored-by: Kumar Aditya --- diff --git a/Lib/asyncio/staggered.py b/Lib/asyncio/staggered.py index 4458d01dece0..6ccf5c3c269f 100644 --- a/Lib/asyncio/staggered.py +++ b/Lib/asyncio/staggered.py @@ -11,7 +11,7 @@ from . import taskgroups class _Done(Exception): pass -async def staggered_race(coro_fns, delay): +async def staggered_race(coro_fns, delay, *, loop=None): """Run coroutines with staggered start times and take the first to finish. This method takes an iterable of coroutine functions. The first one is @@ -82,7 +82,13 @@ async def staggered_race(coro_fns, delay): raise _Done try: - async with taskgroups.TaskGroup() as tg: + tg = taskgroups.TaskGroup() + # Intentionally override the loop in the TaskGroup to avoid + # using the running loop, preserving backwards compatibility + # TaskGroup only starts using `_loop` after `__aenter__` + # so overriding it here is safe. + tg._loop = loop + async with tg: for this_index, coro_fn in enumerate(coro_fns): this_failed = locks.Event() exceptions.append(None) diff --git a/Lib/test/test_asyncio/test_staggered.py b/Lib/test/test_asyncio/test_staggered.py index 21a39b3f9117..8cd98394aea8 100644 --- a/Lib/test/test_asyncio/test_staggered.py +++ b/Lib/test/test_asyncio/test_staggered.py @@ -121,6 +121,25 @@ class StaggeredTests(unittest.IsolatedAsyncioTestCase): self.assertIsInstance(excs[0], ValueError) self.assertIsNone(excs[1]) + def test_loop_argument(self): + loop = asyncio.new_event_loop() + async def coro(): + self.assertEqual(loop, asyncio.get_running_loop()) + return 'coro' + + async def main(): + winner, index, excs = await staggered_race( + [coro], + delay=0.1, + loop=loop + ) + + self.assertEqual(winner, 'coro') + self.assertEqual(index, 0) + + loop.run_until_complete(main()) + loop.close() + if __name__ == "__main__": unittest.main()