import contextlib
+from . import events
+from . import exceptions as exceptions_mod
from . import locks
from . import tasks
-from . import taskgroups
-class _Done(Exception):
- pass
async def staggered_race(coro_fns, delay, *, loop=None):
"""Run coroutines with staggered start times and take the first to finish.
delay: amount of time, in seconds, between starting coroutines. If
``None``, the coroutines will run sequentially.
+ loop: the event loop to use.
+
Returns:
tuple *(winner_result, winner_index, exceptions)* where
"""
# TODO: when we have aiter() and anext(), allow async iterables in coro_fns.
+ loop = loop or events.get_running_loop()
+ enum_coro_fns = enumerate(coro_fns)
winner_result = None
winner_index = None
exceptions = []
+ running_tasks = []
+
+ async def run_one_coro(previous_failed) -> None:
+ # Wait for the previous task to finish, or for delay seconds
+ if previous_failed is not None:
+ with contextlib.suppress(exceptions_mod.TimeoutError):
+ # Use asyncio.wait_for() instead of asyncio.wait() here, so
+ # that if we get cancelled at this point, Event.wait() is also
+ # cancelled, otherwise there will be a "Task destroyed but it is
+ # pending" later.
+ await tasks.wait_for(previous_failed.wait(), delay)
+ # Get the next coroutine to run
+ try:
+ this_index, coro_fn = next(enum_coro_fns)
+ except StopIteration:
+ return
+ # Start task that will run the next coroutine
+ this_failed = locks.Event()
+ next_task = loop.create_task(run_one_coro(this_failed))
+ running_tasks.append(next_task)
+ assert len(running_tasks) == this_index + 2
+ # Prepare place to put this coroutine's exceptions if not won
+ exceptions.append(None)
+ assert len(exceptions) == this_index + 1
- async def run_one_coro(this_index, coro_fn, this_failed):
try:
result = await coro_fn()
except (SystemExit, KeyboardInterrupt):
assert winner_index is None
winner_index = this_index
winner_result = result
- raise _Done
-
+ # Cancel all other tasks. We take care to not cancel the current
+ # task as well. If we do so, then since there is no `await` after
+ # here and CancelledError are usually thrown at one, we will
+ # encounter a curious corner case where the current task will end
+ # up as done() == True, cancelled() == False, exception() ==
+ # asyncio.CancelledError. This behavior is specified in
+ # https://bugs.python.org/issue30048
+ for i, t in enumerate(running_tasks):
+ if i != this_index:
+ t.cancel()
+
+ first_task = loop.create_task(run_one_coro(None))
+ running_tasks.append(first_task)
try:
- tg = taskgroups.TaskGroup()
- # Intentionally override the loop in the TaskGroup to avoid
- # using the running loop, preserving backwards compatibility
- # TaskGroup only starts using `_loop` after `__aenter__`
- # so overriding it here is safe.
- tg._loop = loop
- async with tg:
- for this_index, coro_fn in enumerate(coro_fns):
- this_failed = locks.Event()
- exceptions.append(None)
- tg.create_task(run_one_coro(this_index, coro_fn, this_failed))
- with contextlib.suppress(TimeoutError):
- await tasks.wait_for(this_failed.wait(), delay)
- except* _Done:
- pass
-
- return winner_result, winner_index, exceptions
+ # Wait for a growing list of tasks to all finish: poor man's version of
+ # curio's TaskGroup or trio's nursery
+ done_count = 0
+ while done_count != len(running_tasks):
+ done, _ = await tasks.wait(running_tasks)
+ done_count = len(done)
+ # If run_one_coro raises an unhandled exception, it's probably a
+ # programming error, and I want to see it.
+ if __debug__:
+ for d in done:
+ if d.done() and not d.cancelled() and d.exception():
+ raise d.exception()
+ return winner_result, winner_index, exceptions
+ finally:
+ # Make sure no tasks are left running if we leave this function
+ for t in running_tasks:
+ t.cancel()
self.run_coro(run())
- def test_staggered_race_with_eager_tasks(self):
- # See https://github.com/python/cpython/issues/124309
-
- async def fail():
- await asyncio.sleep(0)
- raise ValueError("no good")
-
- async def run():
- winner, index, excs = await asyncio.staggered.staggered_race(
- [
- lambda: asyncio.sleep(2, result="sleep2"),
- lambda: asyncio.sleep(1, result="sleep1"),
- lambda: fail()
- ],
- delay=0.25
- )
- self.assertEqual(winner, 'sleep1')
- self.assertEqual(index, 1)
- self.assertIsNone(excs[index])
- self.assertIsInstance(excs[0], asyncio.CancelledError)
- self.assertIsInstance(excs[2], ValueError)
-
- self.run_coro(run())
-
- def test_staggered_race_with_eager_tasks_no_delay(self):
- # See https://github.com/python/cpython/issues/124309
- async def fail():
- raise ValueError("no good")
-
- async def run():
- winner, index, excs = await asyncio.staggered.staggered_race(
- [
- lambda: fail(),
- lambda: asyncio.sleep(1, result="sleep1"),
- lambda: asyncio.sleep(0, result="sleep0"),
- ],
- delay=None
- )
- self.assertEqual(winner, 'sleep1')
- self.assertEqual(index, 1)
- self.assertIsNone(excs[index])
- self.assertIsInstance(excs[0], ValueError)
- self.assertEqual(len(excs), 2)
-
- self.run_coro(run())
-
-
class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
Task = tasks._PyTask
async def coro(index):
raise ValueError(index)
- for delay in [None, 0, 0.1, 1]:
- with self.subTest(delay=delay):
- winner, index, excs = await staggered_race(
- [
- lambda: coro(0),
- lambda: coro(1),
- ],
- delay=delay,
- )
-
- self.assertIs(winner, None)
- self.assertIs(index, None)
- self.assertEqual(len(excs), 2)
- self.assertIsInstance(excs[0], ValueError)
- self.assertIsInstance(excs[1], ValueError)
-
- async def test_long_delay_early_failure(self):
- async def coro(index):
- await asyncio.sleep(0) # Dummy coroutine for the 1 case
- if index == 0:
- await asyncio.sleep(0.1) # Dummy coroutine
- raise ValueError(index)
-
- return f'Res: {index}'
-
winner, index, excs = await staggered_race(
[
lambda: coro(0),
lambda: coro(1),
],
- delay=10,
+ delay=None,
)
- self.assertEqual(winner, 'Res: 1')
- self.assertEqual(index, 1)
+ self.assertIs(winner, None)
+ self.assertIs(index, None)
self.assertEqual(len(excs), 2)
self.assertIsInstance(excs[0], ValueError)
- self.assertIsNone(excs[1])
-
- def test_loop_argument(self):
- loop = asyncio.new_event_loop()
- async def coro():
- self.assertEqual(loop, asyncio.get_running_loop())
- return 'coro'
-
- async def main():
- winner, index, excs = await staggered_race(
- [coro],
- delay=0.1,
- loop=loop
- )
-
- self.assertEqual(winner, 'coro')
- self.assertEqual(index, 0)
-
- loop.run_until_complete(main())
- loop.close()
-
-
-if __name__ == "__main__":
- unittest.main()
+ self.assertIsInstance(excs[1], ValueError)