"""Synchronization primitives."""
-__all__ = ('Lock', 'Event', 'Condition', 'Semaphore', 'BoundedSemaphore')
+__all__ = ('Lock', 'Event', 'Condition', 'Semaphore',
+ 'BoundedSemaphore', 'Barrier')
import collections
+import enum
from . import exceptions
from . import mixins
from . import tasks
-
class _ContextManagerMixin:
async def __aenter__(self):
await self.acquire()
if self._value >= self._bound_value:
raise ValueError('BoundedSemaphore released too many times')
super().release()
+
+
+
+class _BarrierState(enum.Enum):
+ FILLING = 'filling'
+ DRAINING = 'draining'
+ RESETTING = 'resetting'
+ BROKEN = 'broken'
+
+
+class Barrier(mixins._LoopBoundMixin):
+ """Asyncio equivalent to threading.Barrier
+
+ Implements a Barrier primitive.
+ Useful for synchronizing a fixed number of tasks at known synchronization
+ points. Tasks block on 'wait()' and are simultaneously awoken once they
+ have all made their call.
+ """
+
+ def __init__(self, parties):
+ """Create a barrier, initialised to 'parties' tasks."""
+ if parties < 1:
+ raise ValueError('parties must be > 0')
+
+ self._cond = Condition() # notify all tasks when state changes
+
+ self._parties = parties
+ self._state = _BarrierState.FILLING
+ self._count = 0 # count tasks in Barrier
+
+ def __repr__(self):
+ res = super().__repr__()
+ extra = f'{self._state.value}'
+ if not self.broken:
+ extra += f', waiters:{self.n_waiting}/{self.parties}'
+ return f'<{res[1:-1]} [{extra}]>'
+
+ async def __aenter__(self):
+ # wait for the barrier reaches the parties number
+ # when start draining release and return index of waited task
+ return await self.wait()
+
+ async def __aexit__(self, *args):
+ pass
+
+ async def wait(self):
+ """Wait for the barrier.
+
+ When the specified number of tasks have started waiting, they are all
+ simultaneously awoken.
+ Returns an unique and individual index number from 0 to 'parties-1'.
+ """
+ async with self._cond:
+ await self._block() # Block while the barrier drains or resets.
+ try:
+ index = self._count
+ self._count += 1
+ if index + 1 == self._parties:
+ # We release the barrier
+ await self._release()
+ else:
+ await self._wait()
+ return index
+ finally:
+ self._count -= 1
+ # Wake up any tasks waiting for barrier to drain.
+ self._exit()
+
+ async def _block(self):
+ # Block until the barrier is ready for us,
+ # or raise an exception if it is broken.
+ #
+ # It is draining or resetting, wait until done
+ # unless a CancelledError occurs
+ await self._cond.wait_for(
+ lambda: self._state not in (
+ _BarrierState.DRAINING, _BarrierState.RESETTING
+ )
+ )
+
+ # see if the barrier is in a broken state
+ if self._state is _BarrierState.BROKEN:
+ raise exceptions.BrokenBarrierError("Barrier aborted")
+
+ async def _release(self):
+ # Release the tasks waiting in the barrier.
+
+ # Enter draining state.
+ # Next waiting tasks will be blocked until the end of draining.
+ self._state = _BarrierState.DRAINING
+ self._cond.notify_all()
+
+ async def _wait(self):
+ # Wait in the barrier until we are released. Raise an exception
+ # if the barrier is reset or broken.
+
+ # wait for end of filling
+ # unless a CancelledError occurs
+ await self._cond.wait_for(lambda: self._state is not _BarrierState.FILLING)
+
+ if self._state in (_BarrierState.BROKEN, _BarrierState.RESETTING):
+ raise exceptions.BrokenBarrierError("Abort or reset of barrier")
+
+ def _exit(self):
+ # If we are the last tasks to exit the barrier, signal any tasks
+ # waiting for the barrier to drain.
+ if self._count == 0:
+ if self._state in (_BarrierState.RESETTING, _BarrierState.DRAINING):
+ self._state = _BarrierState.FILLING
+ self._cond.notify_all()
+
+ async def reset(self):
+ """Reset the barrier to the initial state.
+
+ Any tasks currently waiting will get the BrokenBarrier exception
+ raised.
+ """
+ async with self._cond:
+ if self._count > 0:
+ if self._state is not _BarrierState.RESETTING:
+ #reset the barrier, waking up tasks
+ self._state = _BarrierState.RESETTING
+ else:
+ self._state = _BarrierState.FILLING
+ self._cond.notify_all()
+
+ async def abort(self):
+ """Place the barrier into a 'broken' state.
+
+ Useful in case of error. Any currently waiting tasks and tasks
+ attempting to 'wait()' will have BrokenBarrierError raised.
+ """
+ async with self._cond:
+ self._state = _BarrierState.BROKEN
+ self._cond.notify_all()
+
+ @property
+ def parties(self):
+ """Return the number of tasks required to trip the barrier."""
+ return self._parties
+
+ @property
+ def n_waiting(self):
+ """Return the number of tasks currently waiting at the barrier."""
+ if self._state is _BarrierState.FILLING:
+ return self._count
+ return 0
+
+ @property
+ def broken(self):
+ """Return True if the barrier is in a broken state."""
+ return self._state is _BarrierState.BROKEN
-"""Tests for lock.py"""
+"""Tests for locks.py"""
import unittest
from unittest import mock
STR_RGX_REPR = (
r'^<(?P<class>.*?) object at (?P<address>.*?)'
r'\[(?P<extras>'
- r'(set|unset|locked|unlocked)(, value:\d)?(, waiters:\d+)?'
+ r'(set|unset|locked|unlocked|filling|draining|resetting|broken)'
+ r'(, value:\d)?'
+ r'(, waiters:\d+)?'
+ r'(, waiters:\d+\/\d+)?' # barrier
r')\]>\Z'
)
RGX_REPR = re.compile(STR_RGX_REPR)
)
+class BarrierTests(unittest.IsolatedAsyncioTestCase):
+
+ async def asyncSetUp(self):
+ await super().asyncSetUp()
+ self.N = 5
+
+ def make_tasks(self, n, coro):
+ tasks = [asyncio.create_task(coro()) for _ in range(n)]
+ return tasks
+
+ async def gather_tasks(self, n, coro):
+ tasks = self.make_tasks(n, coro)
+ res = await asyncio.gather(*tasks)
+ return res, tasks
+
+ async def test_barrier(self):
+ barrier = asyncio.Barrier(self.N)
+ self.assertIn("filling", repr(barrier))
+ with self.assertRaisesRegex(
+ TypeError,
+ "object Barrier can't be used in 'await' expression",
+ ):
+ await barrier
+
+ self.assertIn("filling", repr(barrier))
+
+ async def test_repr(self):
+ barrier = asyncio.Barrier(self.N)
+
+ self.assertTrue(RGX_REPR.match(repr(barrier)))
+ self.assertIn("filling", repr(barrier))
+
+ waiters = []
+ async def wait(barrier):
+ await barrier.wait()
+
+ incr = 2
+ for i in range(incr):
+ waiters.append(asyncio.create_task(wait(barrier)))
+ await asyncio.sleep(0)
+
+ self.assertTrue(RGX_REPR.match(repr(barrier)))
+ self.assertTrue(f"waiters:{incr}/{self.N}" in repr(barrier))
+ self.assertIn("filling", repr(barrier))
+
+ # create missing waiters
+ for i in range(barrier.parties - barrier.n_waiting):
+ waiters.append(asyncio.create_task(wait(barrier)))
+ await asyncio.sleep(0)
+
+ self.assertTrue(RGX_REPR.match(repr(barrier)))
+ self.assertIn("draining", repr(barrier))
+
+ # add a part of waiters
+ for i in range(incr):
+ waiters.append(asyncio.create_task(wait(barrier)))
+ await asyncio.sleep(0)
+ # and reset
+ await barrier.reset()
+
+ self.assertTrue(RGX_REPR.match(repr(barrier)))
+ self.assertIn("resetting", repr(barrier))
+
+ # add a part of waiters again
+ for i in range(incr):
+ waiters.append(asyncio.create_task(wait(barrier)))
+ await asyncio.sleep(0)
+ # and abort
+ await barrier.abort()
+
+ self.assertTrue(RGX_REPR.match(repr(barrier)))
+ self.assertIn("broken", repr(barrier))
+ self.assertTrue(barrier.broken)
+
+ # suppress unhandled exceptions
+ await asyncio.gather(*waiters, return_exceptions=True)
+
+ async def test_barrier_parties(self):
+ self.assertRaises(ValueError, lambda: asyncio.Barrier(0))
+ self.assertRaises(ValueError, lambda: asyncio.Barrier(-4))
+
+ self.assertIsInstance(asyncio.Barrier(self.N), asyncio.Barrier)
+
+ async def test_context_manager(self):
+ self.N = 3
+ barrier = asyncio.Barrier(self.N)
+ results = []
+
+ async def coro():
+ async with barrier as i:
+ results.append(i)
+
+ await self.gather_tasks(self.N, coro)
+
+ self.assertListEqual(sorted(results), list(range(self.N)))
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertFalse(barrier.broken)
+
+ async def test_filling_one_task(self):
+ barrier = asyncio.Barrier(1)
+
+ async def f():
+ async with barrier as i:
+ return True
+
+ ret = await f()
+
+ self.assertTrue(ret)
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertFalse(barrier.broken)
+
+ async def test_filling_one_task_twice(self):
+ barrier = asyncio.Barrier(1)
+
+ t1 = asyncio.create_task(barrier.wait())
+ await asyncio.sleep(0)
+ self.assertEqual(barrier.n_waiting, 0)
+
+ t2 = asyncio.create_task(barrier.wait())
+ await asyncio.sleep(0)
+
+ self.assertEqual(t1.result(), t2.result())
+ self.assertEqual(t1.done(), t2.done())
+
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertFalse(barrier.broken)
+
+ async def test_filling_task_by_task(self):
+ self.N = 3
+ barrier = asyncio.Barrier(self.N)
+
+ t1 = asyncio.create_task(barrier.wait())
+ await asyncio.sleep(0)
+ self.assertEqual(barrier.n_waiting, 1)
+ self.assertIn("filling", repr(barrier))
+
+ t2 = asyncio.create_task(barrier.wait())
+ await asyncio.sleep(0)
+ self.assertEqual(barrier.n_waiting, 2)
+ self.assertIn("filling", repr(barrier))
+
+ t3 = asyncio.create_task(barrier.wait())
+ await asyncio.sleep(0)
+
+ await asyncio.wait([t1, t2, t3])
+
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertFalse(barrier.broken)
+
+ async def test_filling_tasks_wait_twice(self):
+ barrier = asyncio.Barrier(self.N)
+ results = []
+
+ async def coro():
+ async with barrier:
+ results.append(True)
+
+ async with barrier:
+ results.append(False)
+
+ await self.gather_tasks(self.N, coro)
+
+ self.assertEqual(len(results), self.N*2)
+ self.assertEqual(results.count(True), self.N)
+ self.assertEqual(results.count(False), self.N)
+
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertFalse(barrier.broken)
+
+ async def test_filling_tasks_check_return_value(self):
+ barrier = asyncio.Barrier(self.N)
+ results1 = []
+ results2 = []
+
+ async def coro():
+ async with barrier:
+ results1.append(True)
+
+ async with barrier as i:
+ results2.append(True)
+ return i
+
+ res, _ = await self.gather_tasks(self.N, coro)
+
+ self.assertEqual(len(results1), self.N)
+ self.assertTrue(all(results1))
+ self.assertEqual(len(results2), self.N)
+ self.assertTrue(all(results2))
+ self.assertListEqual(sorted(res), list(range(self.N)))
+
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertFalse(barrier.broken)
+
+ async def test_draining_state(self):
+ barrier = asyncio.Barrier(self.N)
+ results = []
+
+ async def coro():
+ async with barrier:
+ # barrier state change to filling for the last task release
+ results.append("draining" in repr(barrier))
+
+ await self.gather_tasks(self.N, coro)
+
+ self.assertEqual(len(results), self.N)
+ self.assertEqual(results[-1], False)
+ self.assertTrue(all(results[:self.N-1]))
+
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertFalse(barrier.broken)
+
+ async def test_blocking_tasks_while_draining(self):
+ rewait = 2
+ barrier = asyncio.Barrier(self.N)
+ barrier_nowaiting = asyncio.Barrier(self.N - rewait)
+ results = []
+ rewait_n = rewait
+ counter = 0
+
+ async def coro():
+ nonlocal rewait_n
+
+ # first time waiting
+ await barrier.wait()
+
+ # after wainting once for all tasks
+ if rewait_n > 0:
+ rewait_n -= 1
+ # wait again only for rewait tasks
+ await barrier.wait()
+ else:
+ # wait for end of draining state`
+ await barrier_nowaiting.wait()
+ # wait for other waiting tasks
+ await barrier.wait()
+
+ # a success means that barrier_nowaiting
+ # was waited for exactly N-rewait=3 times
+ await self.gather_tasks(self.N, coro)
+
+ async def test_filling_tasks_cancel_one(self):
+ self.N = 3
+ barrier = asyncio.Barrier(self.N)
+ results = []
+
+ async def coro():
+ await barrier.wait()
+ results.append(True)
+
+ t1 = asyncio.create_task(coro())
+ await asyncio.sleep(0)
+ self.assertEqual(barrier.n_waiting, 1)
+
+ t2 = asyncio.create_task(coro())
+ await asyncio.sleep(0)
+ self.assertEqual(barrier.n_waiting, 2)
+
+ t1.cancel()
+ await asyncio.sleep(0)
+ self.assertEqual(barrier.n_waiting, 1)
+ with self.assertRaises(asyncio.CancelledError):
+ await t1
+ self.assertTrue(t1.cancelled())
+
+ t3 = asyncio.create_task(coro())
+ await asyncio.sleep(0)
+ self.assertEqual(barrier.n_waiting, 2)
+
+ t4 = asyncio.create_task(coro())
+ await asyncio.gather(t2, t3, t4)
+
+ self.assertEqual(len(results), self.N)
+ self.assertTrue(all(results))
+
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertFalse(barrier.broken)
+
+ async def test_reset_barrier(self):
+ barrier = asyncio.Barrier(1)
+
+ asyncio.create_task(barrier.reset())
+ await asyncio.sleep(0)
+
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertFalse(barrier.broken)
+
+ async def test_reset_barrier_while_tasks_waiting(self):
+ barrier = asyncio.Barrier(self.N)
+ results = []
+
+ async def coro():
+ try:
+ await barrier.wait()
+ except asyncio.BrokenBarrierError:
+ results.append(True)
+
+ async def coro_reset():
+ await barrier.reset()
+
+ # N-1 tasks waiting on barrier with N parties
+ tasks = self.make_tasks(self.N-1, coro)
+ await asyncio.sleep(0)
+
+ # reset the barrier
+ asyncio.create_task(coro_reset())
+ await asyncio.gather(*tasks)
+
+ self.assertEqual(len(results), self.N-1)
+ self.assertTrue(all(results))
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertNotIn("resetting", repr(barrier))
+ self.assertFalse(barrier.broken)
+
+ async def test_reset_barrier_when_tasks_half_draining(self):
+ barrier = asyncio.Barrier(self.N)
+ results1 = []
+ rest_of_tasks = self.N//2
+
+ async def coro():
+ try:
+ await barrier.wait()
+ except asyncio.BrokenBarrierError:
+ # catch here waiting tasks
+ results1.append(True)
+ else:
+ # here drained task ouside the barrier
+ if rest_of_tasks == barrier._count:
+ # tasks outside the barrier
+ await barrier.reset()
+
+ await self.gather_tasks(self.N, coro)
+
+ self.assertEqual(results1, [True]*rest_of_tasks)
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertNotIn("resetting", repr(barrier))
+ self.assertFalse(barrier.broken)
+
+ async def test_reset_barrier_when_tasks_half_draining_half_blocking(self):
+ barrier = asyncio.Barrier(self.N)
+ results1 = []
+ results2 = []
+ blocking_tasks = self.N//2
+ count = 0
+
+ async def coro():
+ nonlocal count
+ try:
+ await barrier.wait()
+ except asyncio.BrokenBarrierError:
+ # here catch still waiting tasks
+ results1.append(True)
+
+ # so now waiting again to reach nb_parties
+ await barrier.wait()
+ else:
+ count += 1
+ if count > blocking_tasks:
+ # reset now: raise asyncio.BrokenBarrierError for waiting tasks
+ await barrier.reset()
+
+ # so now waiting again to reach nb_parties
+ await barrier.wait()
+ else:
+ try:
+ await barrier.wait()
+ except asyncio.BrokenBarrierError:
+ # here no catch - blocked tasks go to wait
+ results2.append(True)
+
+ await self.gather_tasks(self.N, coro)
+
+ self.assertEqual(results1, [True]*blocking_tasks)
+ self.assertEqual(results2, [])
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertNotIn("resetting", repr(barrier))
+ self.assertFalse(barrier.broken)
+
+ async def test_reset_barrier_while_tasks_waiting_and_waiting_again(self):
+ barrier = asyncio.Barrier(self.N)
+ results1 = []
+ results2 = []
+
+ async def coro1():
+ try:
+ await barrier.wait()
+ except asyncio.BrokenBarrierError:
+ results1.append(True)
+ finally:
+ await barrier.wait()
+ results2.append(True)
+
+ async def coro2():
+ async with barrier:
+ results2.append(True)
+
+ tasks = self.make_tasks(self.N-1, coro1)
+
+ # reset barrier, N-1 waiting tasks raise an BrokenBarrierError
+ asyncio.create_task(barrier.reset())
+ await asyncio.sleep(0)
+
+ # complete waiting tasks in the `finally`
+ asyncio.create_task(coro2())
+
+ await asyncio.gather(*tasks)
+
+ self.assertFalse(barrier.broken)
+ self.assertEqual(len(results1), self.N-1)
+ self.assertTrue(all(results1))
+ self.assertEqual(len(results2), self.N)
+ self.assertTrue(all(results2))
+
+ self.assertEqual(barrier.n_waiting, 0)
+
+
+ async def test_reset_barrier_while_tasks_draining(self):
+ barrier = asyncio.Barrier(self.N)
+ results1 = []
+ results2 = []
+ results3 = []
+ count = 0
+
+ async def coro():
+ nonlocal count
+
+ i = await barrier.wait()
+ count += 1
+ if count == self.N:
+ # last task exited from barrier
+ await barrier.reset()
+
+ # wit here to reach the `parties`
+ await barrier.wait()
+ else:
+ try:
+ # second waiting
+ await barrier.wait()
+
+ # N-1 tasks here
+ results1.append(True)
+ except Exception as e:
+ # never goes here
+ results2.append(True)
+
+ # Now, pass the barrier again
+ # last wait, must be completed
+ k = await barrier.wait()
+ results3.append(True)
+
+ await self.gather_tasks(self.N, coro)
+
+ self.assertFalse(barrier.broken)
+ self.assertTrue(all(results1))
+ self.assertEqual(len(results1), self.N-1)
+ self.assertEqual(len(results2), 0)
+ self.assertEqual(len(results3), self.N)
+ self.assertTrue(all(results3))
+
+ self.assertEqual(barrier.n_waiting, 0)
+
+ async def test_abort_barrier(self):
+ barrier = asyncio.Barrier(1)
+
+ asyncio.create_task(barrier.abort())
+ await asyncio.sleep(0)
+
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertTrue(barrier.broken)
+
+ async def test_abort_barrier_when_tasks_half_draining_half_blocking(self):
+ barrier = asyncio.Barrier(self.N)
+ results1 = []
+ results2 = []
+ blocking_tasks = self.N//2
+ count = 0
+
+ async def coro():
+ nonlocal count
+ try:
+ await barrier.wait()
+ except asyncio.BrokenBarrierError:
+ # here catch tasks waiting to drain
+ results1.append(True)
+ else:
+ count += 1
+ if count > blocking_tasks:
+ # abort now: raise asyncio.BrokenBarrierError for all tasks
+ await barrier.abort()
+ else:
+ try:
+ await barrier.wait()
+ except asyncio.BrokenBarrierError:
+ # here catch blocked tasks (already drained)
+ results2.append(True)
+
+ await self.gather_tasks(self.N, coro)
+
+ self.assertTrue(barrier.broken)
+ self.assertEqual(results1, [True]*blocking_tasks)
+ self.assertEqual(results2, [True]*(self.N-blocking_tasks-1))
+ self.assertEqual(barrier.n_waiting, 0)
+ self.assertNotIn("resetting", repr(barrier))
+
+ async def test_abort_barrier_when_exception(self):
+ # test from threading.Barrier: see `lock_tests.test_reset`
+ barrier = asyncio.Barrier(self.N)
+ results1 = []
+ results2 = []
+
+ async def coro():
+ try:
+ async with barrier as i :
+ if i == self.N//2:
+ raise RuntimeError
+ async with barrier:
+ results1.append(True)
+ except asyncio.BrokenBarrierError:
+ results2.append(True)
+ except RuntimeError:
+ await barrier.abort()
+
+ await self.gather_tasks(self.N, coro)
+
+ self.assertTrue(barrier.broken)
+ self.assertEqual(len(results1), 0)
+ self.assertEqual(len(results2), self.N-1)
+ self.assertTrue(all(results2))
+ self.assertEqual(barrier.n_waiting, 0)
+
+ async def test_abort_barrier_when_exception_then_resetting(self):
+ # test from threading.Barrier: see `lock_tests.test_abort_and_reset``
+ barrier1 = asyncio.Barrier(self.N)
+ barrier2 = asyncio.Barrier(self.N)
+ results1 = []
+ results2 = []
+ results3 = []
+
+ async def coro():
+ try:
+ i = await barrier1.wait()
+ if i == self.N//2:
+ raise RuntimeError
+ await barrier1.wait()
+ results1.append(True)
+ except asyncio.BrokenBarrierError:
+ results2.append(True)
+ except RuntimeError:
+ await barrier1.abort()
+
+ # Synchronize and reset the barrier. Must synchronize first so
+ # that everyone has left it when we reset, and after so that no
+ # one enters it before the reset.
+ i = await barrier2.wait()
+ if i == self.N//2:
+ await barrier1.reset()
+ await barrier2.wait()
+ await barrier1.wait()
+ results3.append(True)
+
+ await self.gather_tasks(self.N, coro)
+
+ self.assertFalse(barrier1.broken)
+ self.assertEqual(len(results1), 0)
+ self.assertEqual(len(results2), self.N-1)
+ self.assertTrue(all(results2))
+ self.assertEqual(len(results3), self.N)
+ self.assertTrue(all(results3))
+
+ self.assertEqual(barrier1.n_waiting, 0)
+
+
if __name__ == '__main__':
unittest.main()