From: A. Jesse Jiryu Davis Date: Fri, 20 Feb 2015 03:32:36 +0000 (-0500) Subject: Add tornado.locks.Semaphore. X-Git-Tag: v4.2.0b1~93^2~4 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=10fd949193f752c687ae52b2cc51ad344da3eca0;p=thirdparty%2Ftornado.git Add tornado.locks.Semaphore. --- diff --git a/tornado/locks.py b/tornado/locks.py index 73cc9e2b8..73a06c85b 100644 --- a/tornado/locks.py +++ b/tornado/locks.py @@ -14,9 +14,10 @@ from __future__ import absolute_import, division, print_function, with_statement -__all__ = ['Condition', 'Event'] +__all__ = ['Condition', 'Event', 'Semaphore'] import collections +import contextlib from tornado import gen, ioloop from tornado.concurrent import Future @@ -123,3 +124,122 @@ class Event(object): return self._future else: return gen.with_timeout(timeout, self._future) + + +class _ContextManagerFuture(Future): + """A Future that can be used with the "with" statement. + + When a coroutine yields this Future, the return value is a context manager + that can be used like: + + with (yield future): + pass + + At the end of the block, the Future's exit callback is run. Used for + Lock.acquire and Semaphore.acquire. + """ + def __init__(self, wrapped, exit_callback): + super(_ContextManagerFuture, self).__init__() + gen.chain_future(wrapped, self) + self.exit_callback = exit_callback + + def result(self, timeout=None): + if self.exception(): + raise self.exception() + + # Otherwise return a context manager that cleans up after the block. + @contextlib.contextmanager + def f(): + try: + yield + finally: + self.exit_callback() + return f() + + +class Semaphore(object): + """A lock that can be acquired a fixed number of times before blocking. + + A Semaphore manages a counter representing the number of `.release` calls + minus the number of `.acquire` calls, plus an initial value. The `.acquire` + method blocks if necessary until it can return without making the counter + negative. + + `.acquire` supports the context manager protocol: + + >>> from tornado import gen, locks + >>> semaphore = locks.Semaphore() + >>> @gen.coroutine + ... def f(): + ... with (yield semaphore.acquire()): + ... assert semaphore.locked() + ... + ... assert not semaphore.locked() + + .. note:: Unlike the standard `threading.Semaphore`, a Tornado `.Semaphore` + can tell you the current value of its `.counter`, because code in a + single-threaded Tornado application can check this value and act upon + it without fear of interruption from another thread. + """ + def __init__(self, value=1): + if value < 0: + raise ValueError('semaphore initial value must be >= 0') + + self.io_loop = ioloop.IOLoop.current() + self._value = value + self._waiters = collections.deque() + + def __repr__(self): + res = super(Semaphore, self).__repr__() + extra = 'locked' if self.locked() else 'unlocked,value:{0}'.format( + self._value) + if self._waiters: + extra = '{0},waiters:{1}'.format(extra, len(self._waiters)) + return '<{0} [{1}]>'.format(res[1:-1], extra) + + @property + def counter(self): + """An integer, the current semaphore value.""" + return self._value + + def locked(self): + """True if the semaphore cannot be acquired immediately.""" + return self._value == 0 + + def release(self): + """Increment `.counter` and wake one waiter.""" + self._value += 1 + for waiter in self._waiters: + if not waiter.done(): + self._value -= 1 + waiter.set_result(None) + break + + def acquire(self, timeout=None): + """Decrement `.counter`. Returns a Future. + + Block if the counter is zero and wait for a `.release`. The Future + raises `.TimeoutError` after the deadline. + """ + if self._value > 0: + self._value -= 1 + future = gen._null_future + else: + waiter = Future() + self._waiters.append(waiter) + if timeout: + future = gen.with_timeout(timeout, waiter, self.io_loop, + quiet_exceptions=gen.TimeoutError) + + # Set waiter's exception after the deadline. + gen.chain_future(future, waiter) + else: + future = waiter + return _ContextManagerFuture(future, self.release) + + def __enter__(self): + raise RuntimeError( + "Use Semaphore like 'with (yield semaphore.acquire())', not like" + " 'with semaphore'") + + __exit__ = __enter__ diff --git a/tornado/test/locks_test.py b/tornado/test/locks_test.py index d1c9a61d8..cb0365a29 100644 --- a/tornado/test/locks_test.py +++ b/tornado/test/locks_test.py @@ -216,5 +216,122 @@ class TestEvent(AsyncTestCase): self.assertTrue(f1.done()) +class SemaphoreTest(AsyncTestCase): + def test_negative_value(self): + self.assertRaises(ValueError, locks.Semaphore, value=-1) + + def test_str(self): + sem = locks.Semaphore() + self.assertIn('Semaphore', str(sem)) + self.assertIn('unlocked,value:1', str(sem)) + sem.acquire() + self.assertIn('locked', str(sem)) + self.assertNotIn('waiters', str(sem)) + sem.acquire() + self.assertIn('waiters', str(sem)) + + def test_acquire(self): + sem = locks.Semaphore() + self.assertFalse(sem.locked()) + f0 = sem.acquire() + self.assertTrue(f0.done()) + self.assertTrue(sem.locked()) + + # Wait for release(). + f1 = sem.acquire() + f2 = sem.acquire() + sem.release() + self.assertTrue(f1.done()) + self.assertFalse(f2.done()) + sem.release() + self.assertTrue(f2.done()) + + sem.release() + # Now acquire() is instant. + self.assertTrue(sem.acquire().done()) + + @gen_test + def test_acquire_timeout(self): + sem = locks.Semaphore(2) + yield sem.acquire() + yield sem.acquire() + with self.assertRaises(gen.TimeoutError): + yield sem.acquire(timedelta(seconds=0.01)) + + f = sem.acquire() + sem.release() + self.assertTrue(f.done()) + + def test_release_unacquired(self): + # Unbounded releases are allowed, and increment the semaphore's value. + sem = locks.Semaphore() + sem.release() + sem.release() + self.assertEqual(3, sem.counter) + + +class SemaphoreContextManagerTest(AsyncTestCase): + @gen_test + def test_context_manager(self): + sem = locks.Semaphore() + with (yield sem.acquire()) as yielded: + self.assertTrue(sem.locked()) + self.assertTrue(yielded is None) + + self.assertFalse(sem.locked()) + + @gen_test + def test_context_manager_exception(self): + sem = locks.Semaphore() + with self.assertRaises(ZeroDivisionError): + with (yield sem.acquire()): + 1 / 0 + + # Context manager released semaphore. + self.assertFalse(sem.locked()) + + @gen_test + def test_context_manager_timeout(self): + sem = locks.Semaphore(value=0) + with self.assertRaises(gen.TimeoutError): + with (yield sem.acquire(timedelta(seconds=0.01))): + pass + + @gen_test + def test_context_manager_contended(self): + sem = locks.Semaphore() + history = [] + + @gen.coroutine + def f(index): + with (yield sem.acquire()): + history.append('acquired %d' % index) + yield gen.sleep(0.01) + history.append('release %d' % index) + + yield [f(i) for i in range(2)] + + expected_history = [] + for i in range(2): + expected_history.extend(['acquired %d' % i, 'release %d' % i]) + + self.assertEqual(expected_history, history) + + @gen_test + def test_yield_sem(self): + # Ensure we catch a "with (yield sem)", which should be + # "with (yield sem.acquire())". + with self.assertRaises(gen.BadYieldError): + with (yield locks.Semaphore()): + pass + + def test_context_manager_misuse(self): + # Ensure we catch a "with sem", which should be + # "with (yield sem.acquire())". + with self.assertRaises(RuntimeError): + with locks.Semaphore(): + pass + + if __name__ == '__main__': unittest.main()