from __future__ import absolute_import, division, print_function, with_statement
-__all__ = ['Condition', 'Event']
+__all__ = ['Condition', 'Event', 'Semaphore']
import collections
+import contextlib
from tornado import gen, ioloop
from tornado.concurrent import Future
return self._future
else:
return gen.with_timeout(timeout, self._future)
+
+
+class _ContextManagerFuture(Future):
+ """A Future that can be used with the "with" statement.
+
+ When a coroutine yields this Future, the return value is a context manager
+ that can be used like:
+
+ with (yield future):
+ pass
+
+ At the end of the block, the Future's exit callback is run. Used for
+ Lock.acquire and Semaphore.acquire.
+ """
+ def __init__(self, wrapped, exit_callback):
+ super(_ContextManagerFuture, self).__init__()
+ gen.chain_future(wrapped, self)
+ self.exit_callback = exit_callback
+
+ def result(self, timeout=None):
+ if self.exception():
+ raise self.exception()
+
+ # Otherwise return a context manager that cleans up after the block.
+ @contextlib.contextmanager
+ def f():
+ try:
+ yield
+ finally:
+ self.exit_callback()
+ return f()
+
+
+class Semaphore(object):
+ """A lock that can be acquired a fixed number of times before blocking.
+
+ A Semaphore manages a counter representing the number of `.release` calls
+ minus the number of `.acquire` calls, plus an initial value. The `.acquire`
+ method blocks if necessary until it can return without making the counter
+ negative.
+
+ `.acquire` supports the context manager protocol:
+
+ >>> from tornado import gen, locks
+ >>> semaphore = locks.Semaphore()
+ >>> @gen.coroutine
+ ... def f():
+ ... with (yield semaphore.acquire()):
+ ... assert semaphore.locked()
+ ...
+ ... assert not semaphore.locked()
+
+ .. note:: Unlike the standard `threading.Semaphore`, a Tornado `.Semaphore`
+ can tell you the current value of its `.counter`, because code in a
+ single-threaded Tornado application can check this value and act upon
+ it without fear of interruption from another thread.
+ """
+ def __init__(self, value=1):
+ if value < 0:
+ raise ValueError('semaphore initial value must be >= 0')
+
+ self.io_loop = ioloop.IOLoop.current()
+ self._value = value
+ self._waiters = collections.deque()
+
+ def __repr__(self):
+ res = super(Semaphore, self).__repr__()
+ extra = 'locked' if self.locked() else 'unlocked,value:{0}'.format(
+ self._value)
+ if self._waiters:
+ extra = '{0},waiters:{1}'.format(extra, len(self._waiters))
+ return '<{0} [{1}]>'.format(res[1:-1], extra)
+
+ @property
+ def counter(self):
+ """An integer, the current semaphore value."""
+ return self._value
+
+ def locked(self):
+ """True if the semaphore cannot be acquired immediately."""
+ return self._value == 0
+
+ def release(self):
+ """Increment `.counter` and wake one waiter."""
+ self._value += 1
+ for waiter in self._waiters:
+ if not waiter.done():
+ self._value -= 1
+ waiter.set_result(None)
+ break
+
+ def acquire(self, timeout=None):
+ """Decrement `.counter`. Returns a Future.
+
+ Block if the counter is zero and wait for a `.release`. The Future
+ raises `.TimeoutError` after the deadline.
+ """
+ if self._value > 0:
+ self._value -= 1
+ future = gen._null_future
+ else:
+ waiter = Future()
+ self._waiters.append(waiter)
+ if timeout:
+ future = gen.with_timeout(timeout, waiter, self.io_loop,
+ quiet_exceptions=gen.TimeoutError)
+
+ # Set waiter's exception after the deadline.
+ gen.chain_future(future, waiter)
+ else:
+ future = waiter
+ return _ContextManagerFuture(future, self.release)
+
+ def __enter__(self):
+ raise RuntimeError(
+ "Use Semaphore like 'with (yield semaphore.acquire())', not like"
+ " 'with semaphore'")
+
+ __exit__ = __enter__
self.assertTrue(f1.done())
+class SemaphoreTest(AsyncTestCase):
+ def test_negative_value(self):
+ self.assertRaises(ValueError, locks.Semaphore, value=-1)
+
+ def test_str(self):
+ sem = locks.Semaphore()
+ self.assertIn('Semaphore', str(sem))
+ self.assertIn('unlocked,value:1', str(sem))
+ sem.acquire()
+ self.assertIn('locked', str(sem))
+ self.assertNotIn('waiters', str(sem))
+ sem.acquire()
+ self.assertIn('waiters', str(sem))
+
+ def test_acquire(self):
+ sem = locks.Semaphore()
+ self.assertFalse(sem.locked())
+ f0 = sem.acquire()
+ self.assertTrue(f0.done())
+ self.assertTrue(sem.locked())
+
+ # Wait for release().
+ f1 = sem.acquire()
+ f2 = sem.acquire()
+ sem.release()
+ self.assertTrue(f1.done())
+ self.assertFalse(f2.done())
+ sem.release()
+ self.assertTrue(f2.done())
+
+ sem.release()
+ # Now acquire() is instant.
+ self.assertTrue(sem.acquire().done())
+
+ @gen_test
+ def test_acquire_timeout(self):
+ sem = locks.Semaphore(2)
+ yield sem.acquire()
+ yield sem.acquire()
+ with self.assertRaises(gen.TimeoutError):
+ yield sem.acquire(timedelta(seconds=0.01))
+
+ f = sem.acquire()
+ sem.release()
+ self.assertTrue(f.done())
+
+ def test_release_unacquired(self):
+ # Unbounded releases are allowed, and increment the semaphore's value.
+ sem = locks.Semaphore()
+ sem.release()
+ sem.release()
+ self.assertEqual(3, sem.counter)
+
+
+class SemaphoreContextManagerTest(AsyncTestCase):
+ @gen_test
+ def test_context_manager(self):
+ sem = locks.Semaphore()
+ with (yield sem.acquire()) as yielded:
+ self.assertTrue(sem.locked())
+ self.assertTrue(yielded is None)
+
+ self.assertFalse(sem.locked())
+
+ @gen_test
+ def test_context_manager_exception(self):
+ sem = locks.Semaphore()
+ with self.assertRaises(ZeroDivisionError):
+ with (yield sem.acquire()):
+ 1 / 0
+
+ # Context manager released semaphore.
+ self.assertFalse(sem.locked())
+
+ @gen_test
+ def test_context_manager_timeout(self):
+ sem = locks.Semaphore(value=0)
+ with self.assertRaises(gen.TimeoutError):
+ with (yield sem.acquire(timedelta(seconds=0.01))):
+ pass
+
+ @gen_test
+ def test_context_manager_contended(self):
+ sem = locks.Semaphore()
+ history = []
+
+ @gen.coroutine
+ def f(index):
+ with (yield sem.acquire()):
+ history.append('acquired %d' % index)
+ yield gen.sleep(0.01)
+ history.append('release %d' % index)
+
+ yield [f(i) for i in range(2)]
+
+ expected_history = []
+ for i in range(2):
+ expected_history.extend(['acquired %d' % i, 'release %d' % i])
+
+ self.assertEqual(expected_history, history)
+
+ @gen_test
+ def test_yield_sem(self):
+ # Ensure we catch a "with (yield sem)", which should be
+ # "with (yield sem.acquire())".
+ with self.assertRaises(gen.BadYieldError):
+ with (yield locks.Semaphore()):
+ pass
+
+ def test_context_manager_misuse(self):
+ # Ensure we catch a "with sem", which should be
+ # "with (yield sem.acquire())".
+ with self.assertRaises(RuntimeError):
+ with locks.Semaphore():
+ pass
+
+
if __name__ == '__main__':
unittest.main()