]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add tornado.locks.Lock.
authorA. Jesse Jiryu Davis <jesse@mongodb.com>
Sat, 28 Feb 2015 17:08:46 +0000 (12:08 -0500)
committerA. Jesse Jiryu Davis <jesse@mongodb.com>
Sat, 28 Feb 2015 17:08:46 +0000 (12:08 -0500)
tornado/locks.py
tornado/test/locks_test.py

index 1367e1de89e8a25762a8b4c30067f199592be1d4..c5ab59266a0edbd8c5f89b645259f16a1ad1c3b1 100644 (file)
@@ -14,7 +14,7 @@
 
 from __future__ import absolute_import, division, print_function, with_statement
 
-__all__ = ['Condition', 'Event', 'Semaphore', 'BoundedSemaphore']
+__all__ = ['Condition', 'Event', 'Semaphore', 'BoundedSemaphore', 'Lock']
 
 import collections
 
@@ -241,3 +241,63 @@ class BoundedSemaphore(Semaphore):
         if self._value >= self._initial_value:
             raise ValueError("Semaphore released too many times")
         super(BoundedSemaphore, self).release()
+
+
+class Lock(object):
+    """A lock for coroutines.
+
+    A Lock begins unlocked, and `acquire` locks it immediately. While it is
+    locked, a coroutine that yields `acquire` waits until another coroutine
+    calls `release`.
+
+    Releasing an unlocked lock raises `RuntimeError`.
+
+    `acquire` supports the context manager protocol:
+
+    >>> from tornado import gen, locks
+    >>> lock = locks.Lock()
+    >>>
+    >>> @gen.coroutine
+    ... def f():
+    ...    with (yield lock.acquire()):
+    ...        # Do something holding the lock.
+    ...        pass
+    ...
+    ...    # Now the lock is released.
+
+    Coroutines waiting for `acquire` are granted the lock in first-in, first-out
+    order.
+    """
+    def __init__(self):
+        self._block = BoundedSemaphore(value=1)
+
+    def __repr__(self):
+        return "<%s _block=%s>" % (
+            self.__class__.__name__,
+            self._block)
+
+    def acquire(self, deadline=None):
+        """Attempt to lock. Returns a Future.
+
+        Returns a Future, which raises `tornado.gen.TimeoutError` after a
+        timeout.
+        """
+        return self._block.acquire(deadline)
+
+    def release(self):
+        """Unlock.
+
+        The first coroutine in line waiting for `acquire` gets the lock.
+
+        If not locked, raise a `RuntimeError`.
+        """
+        try:
+            self._block.release()
+        except ValueError:
+            raise RuntimeError('release unlocked lock')
+
+    def __enter__(self):
+        raise RuntimeError(
+            "Use Lock like 'with (yield lock)', not like 'with lock'")
+
+    __exit__ = __enter__
index 29d80b94fd29cbcc7cc2e33bebfff82da34516b0..6da82bc44dbaad5ae48a84e6f5578b5b3e463a66 100644 (file)
@@ -364,5 +364,73 @@ class BoundedSemaphoreTest(AsyncTestCase):
         sem.release()
         self.assertRaises(ValueError, sem.release)
 
+
+class LockTests(AsyncTestCase):
+    def test_repr(self):
+        lock = locks.Lock()
+        # No errors.
+        repr(lock)
+        lock.acquire()
+        repr(lock)
+
+    def test_acquire_release(self):
+        lock = locks.Lock()
+        self.assertTrue(lock.acquire().done())
+        future = lock.acquire()
+        self.assertFalse(future.done())
+        lock.release()
+        self.assertTrue(future.done())
+
+    @gen_test
+    def test_acquire_fifo(self):
+        lock = locks.Lock()
+        self.assertTrue(lock.acquire().done())
+        N = 5
+        history = []
+
+        @gen.coroutine
+        def f(idx):
+            with (yield lock.acquire()):
+                history.append(idx)
+
+        futures = [f(i) for i in range(N)]
+        self.assertFalse(any(future.done() for future in futures))
+        lock.release()
+        yield futures
+        self.assertEqual(range(N), history)
+
+    @gen_test
+    def test_acquire_timeout(self):
+        lock = locks.Lock()
+        lock.acquire()
+        with self.assertRaises(gen.TimeoutError):
+            yield lock.acquire(deadline=timedelta(seconds=0.01))
+
+        # Still locked.
+        self.assertFalse(lock.acquire().done())
+
+    def test_multi_release(self):
+        lock = locks.Lock()
+        self.assertRaises(RuntimeError, lock.release)
+        lock.acquire()
+        lock.release()
+        self.assertRaises(RuntimeError, lock.release)
+
+    @gen_test
+    def test_yield_lock(self):
+        # Ensure we catch a "with (yield lock)", which should be
+        # "with (yield lock.acquire())".
+        with self.assertRaises(gen.BadYieldError):
+            with (yield locks.Lock()):
+                pass
+
+    def test_context_manager_misuse(self):
+        # Ensure we catch a "with lock", which should be
+        # "with (yield lock.acquire())".
+        with self.assertRaises(RuntimeError):
+            with locks.Lock():
+                pass
+
+
 if __name__ == '__main__':
     unittest.main()