]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add tornado.locks.Semaphore.
authorA. Jesse Jiryu Davis <jesse@mongodb.com>
Fri, 20 Feb 2015 03:32:36 +0000 (22:32 -0500)
committerA. Jesse Jiryu Davis <jesse@mongodb.com>
Fri, 20 Feb 2015 03:34:19 +0000 (22:34 -0500)
tornado/locks.py
tornado/test/locks_test.py

index 73cc9e2b8b8cfb6c379ff0e3f5f249c8e3d23592..73a06c85bc6fb25df7b37c48a9e11d2a5a05786d 100644 (file)
 
 from __future__ import absolute_import, division, print_function, with_statement
 
-__all__ = ['Condition', 'Event']
+__all__ = ['Condition', 'Event', 'Semaphore']
 
 import collections
+import contextlib
 
 from tornado import gen, ioloop
 from tornado.concurrent import Future
@@ -123,3 +124,122 @@ class Event(object):
             return self._future
         else:
             return gen.with_timeout(timeout, self._future)
+
+
+class _ContextManagerFuture(Future):
+    """A Future that can be used with the "with" statement.
+
+    When a coroutine yields this Future, the return value is a context manager
+    that can be used like:
+
+        with (yield future):
+            pass
+
+    At the end of the block, the Future's exit callback is run. Used for
+    Lock.acquire and Semaphore.acquire.
+    """
+    def __init__(self, wrapped, exit_callback):
+        super(_ContextManagerFuture, self).__init__()
+        gen.chain_future(wrapped, self)
+        self.exit_callback = exit_callback
+
+    def result(self, timeout=None):
+        if self.exception():
+            raise self.exception()
+
+        # Otherwise return a context manager that cleans up after the block.
+        @contextlib.contextmanager
+        def f():
+            try:
+                yield
+            finally:
+                self.exit_callback()
+        return f()
+
+
+class Semaphore(object):
+    """A lock that can be acquired a fixed number of times before blocking.
+
+    A Semaphore manages a counter representing the number of `.release` calls
+    minus the number of `.acquire` calls, plus an initial value. The `.acquire`
+    method blocks if necessary until it can return without making the counter
+    negative.
+
+    `.acquire` supports the context manager protocol:
+
+    >>> from tornado import gen, locks
+    >>> semaphore = locks.Semaphore()
+    >>> @gen.coroutine
+    ... def f():
+    ...    with (yield semaphore.acquire()):
+    ...        assert semaphore.locked()
+    ...
+    ...    assert not semaphore.locked()
+
+    .. note:: Unlike the standard `threading.Semaphore`, a Tornado `.Semaphore`
+      can tell you the current value of its `.counter`, because code in a
+      single-threaded Tornado application can check this value and act upon
+      it without fear of interruption from another thread.
+    """
+    def __init__(self, value=1):
+        if value < 0:
+            raise ValueError('semaphore initial value must be >= 0')
+
+        self.io_loop = ioloop.IOLoop.current()
+        self._value = value
+        self._waiters = collections.deque()
+
+    def __repr__(self):
+        res = super(Semaphore, self).__repr__()
+        extra = 'locked' if self.locked() else 'unlocked,value:{0}'.format(
+            self._value)
+        if self._waiters:
+            extra = '{0},waiters:{1}'.format(extra, len(self._waiters))
+        return '<{0} [{1}]>'.format(res[1:-1], extra)
+
+    @property
+    def counter(self):
+        """An integer, the current semaphore value."""
+        return self._value
+
+    def locked(self):
+        """True if the semaphore cannot be acquired immediately."""
+        return self._value == 0
+
+    def release(self):
+        """Increment `.counter` and wake one waiter."""
+        self._value += 1
+        for waiter in self._waiters:
+            if not waiter.done():
+                self._value -= 1
+                waiter.set_result(None)
+                break
+
+    def acquire(self, timeout=None):
+        """Decrement `.counter`. Returns a Future.
+
+        Block if the counter is zero and wait for a `.release`. The Future
+        raises `.TimeoutError` after the deadline.
+        """
+        if self._value > 0:
+            self._value -= 1
+            future = gen._null_future
+        else:
+            waiter = Future()
+            self._waiters.append(waiter)
+            if timeout:
+                future = gen.with_timeout(timeout, waiter, self.io_loop,
+                                          quiet_exceptions=gen.TimeoutError)
+
+                # Set waiter's exception after the deadline.
+                gen.chain_future(future, waiter)
+            else:
+                future = waiter
+        return _ContextManagerFuture(future, self.release)
+
+    def __enter__(self):
+        raise RuntimeError(
+            "Use Semaphore like 'with (yield semaphore.acquire())', not like"
+            " 'with semaphore'")
+
+    __exit__ = __enter__
index d1c9a61d8c3b0363cedd8ef7931246ab89be26aa..cb0365a29a479eac94738d8b18b2b5b9ec6f438b 100644 (file)
@@ -216,5 +216,122 @@ class TestEvent(AsyncTestCase):
         self.assertTrue(f1.done())
 
 
+class SemaphoreTest(AsyncTestCase):
+    def test_negative_value(self):
+        self.assertRaises(ValueError, locks.Semaphore, value=-1)
+
+    def test_str(self):
+        sem = locks.Semaphore()
+        self.assertIn('Semaphore', str(sem))
+        self.assertIn('unlocked,value:1', str(sem))
+        sem.acquire()
+        self.assertIn('locked', str(sem))
+        self.assertNotIn('waiters', str(sem))
+        sem.acquire()
+        self.assertIn('waiters', str(sem))
+
+    def test_acquire(self):
+        sem = locks.Semaphore()
+        self.assertFalse(sem.locked())
+        f0 = sem.acquire()
+        self.assertTrue(f0.done())
+        self.assertTrue(sem.locked())
+
+        # Wait for release().
+        f1 = sem.acquire()
+        f2 = sem.acquire()
+        sem.release()
+        self.assertTrue(f1.done())
+        self.assertFalse(f2.done())
+        sem.release()
+        self.assertTrue(f2.done())
+
+        sem.release()
+        # Now acquire() is instant.
+        self.assertTrue(sem.acquire().done())
+
+    @gen_test
+    def test_acquire_timeout(self):
+        sem = locks.Semaphore(2)
+        yield sem.acquire()
+        yield sem.acquire()
+        with self.assertRaises(gen.TimeoutError):
+            yield sem.acquire(timedelta(seconds=0.01))
+
+        f = sem.acquire()
+        sem.release()
+        self.assertTrue(f.done())
+
+    def test_release_unacquired(self):
+        # Unbounded releases are allowed, and increment the semaphore's value.
+        sem = locks.Semaphore()
+        sem.release()
+        sem.release()
+        self.assertEqual(3, sem.counter)
+
+
+class SemaphoreContextManagerTest(AsyncTestCase):
+    @gen_test
+    def test_context_manager(self):
+        sem = locks.Semaphore()
+        with (yield sem.acquire()) as yielded:
+            self.assertTrue(sem.locked())
+            self.assertTrue(yielded is None)
+
+        self.assertFalse(sem.locked())
+
+    @gen_test
+    def test_context_manager_exception(self):
+        sem = locks.Semaphore()
+        with self.assertRaises(ZeroDivisionError):
+            with (yield sem.acquire()):
+                1 / 0
+
+        # Context manager released semaphore.
+        self.assertFalse(sem.locked())
+
+    @gen_test
+    def test_context_manager_timeout(self):
+        sem = locks.Semaphore(value=0)
+        with self.assertRaises(gen.TimeoutError):
+            with (yield sem.acquire(timedelta(seconds=0.01))):
+                pass
+
+    @gen_test
+    def test_context_manager_contended(self):
+        sem = locks.Semaphore()
+        history = []
+
+        @gen.coroutine
+        def f(index):
+            with (yield sem.acquire()):
+                history.append('acquired %d' % index)
+                yield gen.sleep(0.01)
+                history.append('release %d' % index)
+
+        yield [f(i) for i in range(2)]
+
+        expected_history = []
+        for i in range(2):
+            expected_history.extend(['acquired %d' % i, 'release %d' % i])
+
+        self.assertEqual(expected_history, history)
+
+    @gen_test
+    def test_yield_sem(self):
+        # Ensure we catch a "with (yield sem)", which should be
+        # "with (yield sem.acquire())".
+        with self.assertRaises(gen.BadYieldError):
+            with (yield locks.Semaphore()):
+                pass
+
+    def test_context_manager_misuse(self):
+        # Ensure we catch a "with sem", which should be
+        # "with (yield sem.acquire())".
+        with self.assertRaises(RuntimeError):
+            with locks.Semaphore():
+                pass
+
+
 if __name__ == '__main__':
     unittest.main()