From a7454fc616f80b035186ba189ca268a4575eb4fd Mon Sep 17 00:00:00 2001 From: "A. Jesse Jiryu Davis" Date: Thu, 26 Feb 2015 22:11:00 -0500 Subject: [PATCH] Add tornado.locks.BoundedSemaphore. --- tornado/locks.py | 20 +++++++++++++++++++- tornado/test/locks_test.py | 15 +++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/tornado/locks.py b/tornado/locks.py index 49667cead..1367e1de8 100644 --- a/tornado/locks.py +++ b/tornado/locks.py @@ -14,7 +14,7 @@ from __future__ import absolute_import, division, print_function, with_statement -__all__ = ['Condition', 'Event', 'Semaphore'] +__all__ = ['Condition', 'Event', 'Semaphore', 'BoundedSemaphore'] import collections @@ -223,3 +223,21 @@ class Semaphore(object): " 'with semaphore'") __exit__ = __enter__ + + +class BoundedSemaphore(Semaphore): + """A semaphore that prevents release() being called too many times. + + If `.release` would increment the semaphore's value past the initial + value, it raises `ValueError`. Semaphores are mostly used to guard + resources with limited capacity, so a semaphore released too many times + is a sign of a bug. + """ + def __init__(self, value=1): + super(BoundedSemaphore, self).__init__(value=value) + self._initial_value = value + + def release(self): + if self._value >= self._initial_value: + raise ValueError("Semaphore released too many times") + super(BoundedSemaphore, self).release() diff --git a/tornado/test/locks_test.py b/tornado/test/locks_test.py index 38d2bafb8..29d80b94f 100644 --- a/tornado/test/locks_test.py +++ b/tornado/test/locks_test.py @@ -349,5 +349,20 @@ class SemaphoreContextManagerTest(AsyncTestCase): pass +class BoundedSemaphoreTest(AsyncTestCase): + def test_release_unacquired(self): + sem = locks.BoundedSemaphore() + self.assertRaises(ValueError, sem.release) + # Value is 0. + sem.acquire() + # Block on acquire(). + future = sem.acquire() + self.assertFalse(future.done()) + sem.release() + self.assertTrue(future.done()) + # Value is 1. + sem.release() + self.assertRaises(ValueError, sem.release) + if __name__ == '__main__': unittest.main() -- 2.47.2