From 32acb5c06050ed702dfd6d8c0c854f501ce0bb41 Mon Sep 17 00:00:00 2001 From: "A. Jesse Jiryu Davis" Date: Sat, 14 Feb 2015 14:30:48 -0500 Subject: [PATCH] Add tornado.locks.Condition. Copied from Toro with small improvements. --- docs/coroutine.rst | 1 + docs/locks.rst | 7 ++ tornado/locks.py | 81 +++++++++++++++++ tornado/test/condition_test.py | 161 +++++++++++++++++++++++++++++++++ tornado/test/runtests.py | 1 + 5 files changed, 251 insertions(+) create mode 100644 docs/locks.rst create mode 100644 tornado/locks.py create mode 100644 tornado/test/condition_test.py diff --git a/docs/coroutine.rst b/docs/coroutine.rst index 8412f734d..41309e021 100644 --- a/docs/coroutine.rst +++ b/docs/coroutine.rst @@ -5,4 +5,5 @@ Coroutines and concurrency gen concurrent + locks process diff --git a/docs/locks.rst b/docs/locks.rst new file mode 100644 index 000000000..11191c226 --- /dev/null +++ b/docs/locks.rst @@ -0,0 +1,7 @@ +``tornado.locks`` -- Synchronization primitives +=============================================== + +.. versionadded:: 4.2 + +.. automodule:: tornado.locks + :members: diff --git a/tornado/locks.py b/tornado/locks.py new file mode 100644 index 000000000..1833b04f4 --- /dev/null +++ b/tornado/locks.py @@ -0,0 +1,81 @@ +# Copyright 2015 The Tornado Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, with_statement + +__all__ = ['Condition'] + +import collections + +from tornado import concurrent, gen, ioloop +from tornado.concurrent import Future + + +class Condition(object): + """A condition allows one or more coroutines to wait until notified. + + Like a standard `threading.Condition`, but does not need an underlying lock + that is acquired and released. + """ + + def __init__(self): + self.io_loop = ioloop.IOLoop.current() + self._waiters = collections.deque() # Futures. + self._timeouts = 0 + + def __str__(self): + result = '<%s' % (self.__class__.__name__, ) + if self._waiters: + result += ' waiters[%s]' % len(self._waiters) + return result + '>' + + def wait(self, timeout=None): + """Wait for `.notify`. + + Returns a `.Future` that resolves ``True`` if the condition is notified, + or ``False`` after a timeout. + """ + waiter = Future() + self._waiters.append(waiter) + if timeout: + def on_timeout(): + waiter.set_result(False) + self._garbage_collect() + self.io_loop.add_timeout(timeout, on_timeout) + return waiter + + def notify(self, n=1): + """Wake ``n`` waiters.""" + waiters = [] # Waiters we plan to run right now. + while n and self._waiters: + waiter = self._waiters.popleft() + if not waiter.done(): # Might have timed out. + n -= 1 + waiters.append(waiter) + + for waiter in waiters: + waiter.set_result(True) + + def notify_all(self): + """Wake all waiters.""" + self.notify(len(self._waiters)) + + def _garbage_collect(self): + # Occasionally clear timed-out waiters, if many coroutines wait with a + # timeout but notify is called rarely. + self._timeouts += 1 + if self._timeouts > 100: + self._timeouts = 0 + self._waiters = collections.deque( + w for w in self._waiters if not w.done()) diff --git a/tornado/test/condition_test.py b/tornado/test/condition_test.py new file mode 100644 index 000000000..1ad828907 --- /dev/null +++ b/tornado/test/condition_test.py @@ -0,0 +1,161 @@ +from __future__ import absolute_import, division, print_function, with_statement + +from datetime import timedelta + +from tornado import gen, locks +from tornado.testing import AsyncTestCase, gen_test, unittest + + +class ConditionTest(AsyncTestCase): + def setUp(self): + super(ConditionTest, self).setUp() + self.history = [] + + def record_done(self, future, key): + """Record the resolution of a Future returned by Condition.wait.""" + def callback(_): + if not future.result(): + # wait() resolved to False, meaning it timed out. + self.history.append('timeout') + else: + self.history.append(key) + future.add_done_callback(callback) + + def test_str(self): + c = locks.Condition() + self.assertIn('Condition', str(c)) + self.assertNotIn('waiters', str(c)) + c.wait() + self.assertIn('waiters', str(c)) + + @gen_test + def test_notify(self): + c = locks.Condition() + self.io_loop.call_later(0.01, c.notify) + yield c.wait() + + def test_notify_1(self): + c = locks.Condition() + self.record_done(c.wait(), 'wait1') + self.record_done(c.wait(), 'wait2') + c.notify(1) + self.history.append('notify1') + c.notify(1) + self.history.append('notify2') + self.assertEqual(['wait1', 'notify1', 'wait2', 'notify2'], + self.history) + + def test_notify_n(self): + c = locks.Condition() + for i in range(6): + self.record_done(c.wait(), i) + + c.notify(3) + + # Callbacks execute in the order they were registered. + self.assertEqual(list(range(3)), self.history) + c.notify(1) + self.assertEqual(list(range(4)), self.history) + c.notify(2) + self.assertEqual(list(range(6)), self.history) + + def test_notify_all(self): + c = locks.Condition() + for i in range(4): + self.record_done(c.wait(), i) + + c.notify_all() + self.history.append('notify_all') + + # Callbacks execute in the order they were registered. + self.assertEqual( + list(range(4)) + ['notify_all'], + self.history) + + @gen_test + def test_wait_timeout(self): + c = locks.Condition() + self.assertFalse((yield c.wait(timedelta(seconds=0.01)))) + + @gen_test + def test_wait_timeout_preempted(self): + c = locks.Condition() + + # This fires before the wait times out. + self.io_loop.call_later(0.01, c.notify) + yield c.wait(timedelta(seconds=1)) + + @gen_test + def test_notify_n_with_timeout(self): + # Register callbacks 0, 1, 2, and 3. Callback 1 has a timeout. + # Wait for that timeout to expire, then do notify(2) and make + # sure everyone runs. Verifies that a timed-out callback does + # not count against the 'n' argument to notify(). + c = locks.Condition() + self.record_done(c.wait(), 0) + self.record_done(c.wait(timedelta(seconds=0.01)), 1) + self.record_done(c.wait(), 2) + self.record_done(c.wait(), 3) + + # Wait for callback 1 to time out. + yield gen.sleep(0.02) + self.assertEqual(['timeout'], self.history) + + c.notify(2) + yield gen.sleep(0.01) + self.assertEqual(['timeout', 0, 2], self.history) + self.assertEqual(['timeout', 0, 2], self.history) + c.notify() + self.assertEqual(['timeout', 0, 2, 3], self.history) + + @gen_test + def test_notify_all_with_timeout(self): + c = locks.Condition() + self.record_done(c.wait(), 0) + self.record_done(c.wait(timedelta(seconds=0.01)), 1) + self.record_done(c.wait(), 2) + + # Wait for callback 1 to time out. + yield gen.sleep(0.02) + self.assertEqual(['timeout'], self.history) + + c.notify_all() + self.assertEqual(['timeout', 0, 2], self.history) + + @gen_test + def test_nested_notify(self): + # Ensure no notifications lost, even if notify() is reentered by a + # waiter calling notify(). + c = locks.Condition() + + # Three waiters. + futures = [c.wait() for _ in range(3)] + + # First and second futures resolved. Second future reenters notify(), + # resolving third future. + futures[1].add_done_callback(lambda _: c.notify()) + c.notify(2) + self.assertTrue(all(f.done() for f in futures)) + + @gen_test + def test_garbage_collection(self): + # Test that timed-out waiters are occasionally cleaned from the queue. + c = locks.Condition() + for _ in range(101): + c.wait(timedelta(seconds=0.01)) + + future = c.wait() + self.assertEqual(102, len(c._waiters)) + + # Let first 101 waiters time out, triggering a collection. + yield gen.sleep(0.02) + self.assertEqual(1, len(c._waiters)) + + # Final waiter is still active. + self.assertFalse(future.done()) + c.notify() + self.assertTrue(future.done()) + + +if __name__ == '__main__': + unittest.main() diff --git a/tornado/test/runtests.py b/tornado/test/runtests.py index acbb5695e..3f4aaa1e6 100644 --- a/tornado/test/runtests.py +++ b/tornado/test/runtests.py @@ -25,6 +25,7 @@ TEST_MODULES = [ 'tornado.test.asyncio_test', 'tornado.test.auth_test', 'tornado.test.concurrent_test', + 'tornado.test.condition_test', 'tornado.test.curl_httpclient_test', 'tornado.test.escape_test', 'tornado.test.gen_test', -- 2.47.2