]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add tornado.queues.
authorA. Jesse Jiryu Davis <jesse@mongodb.com>
Sat, 14 Mar 2015 21:52:39 +0000 (17:52 -0400)
committerA. Jesse Jiryu Davis <jesse@mongodb.com>
Sun, 15 Mar 2015 01:13:25 +0000 (21:13 -0400)
docs/coroutine.rst
docs/queues.rst [new file with mode: 0644]
tornado/queues.py [new file with mode: 0644]
tornado/test/queues_test.py [new file with mode: 0644]
tornado/test/runtests.py

index 41309e021990c0dddb133092e10c0148cdc3554a..4db7100951fb354eb8f9152c57636384c5339081 100644 (file)
@@ -6,4 +6,5 @@ Coroutines and concurrency
    gen
    concurrent
    locks
+   queues
    process
diff --git a/docs/queues.rst b/docs/queues.rst
new file mode 100644 (file)
index 0000000..01df494
--- /dev/null
@@ -0,0 +1,7 @@
+``tornado.queues`` -- Queues for coroutines
+===========================================
+
+.. versionadded:: 4.2
+
+.. automodule:: tornado.queues
+    :members:
diff --git a/tornado/queues.py b/tornado/queues.py
new file mode 100644 (file)
index 0000000..e2aec50
--- /dev/null
@@ -0,0 +1,227 @@
+# Copyright 2015 The Tornado Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+from __future__ import absolute_import, division, print_function, with_statement
+
+__all__ = ['Queue', 'QueueFull', 'QueueEmpty']
+
+import collections
+from functools import partial
+
+from tornado import gen, ioloop
+from tornado.concurrent import Future
+from tornado.locks import Event
+
+
+class QueueEmpty(Exception):
+    """Raised by `.Queue.get_nowait` when the queue has no items."""
+    pass
+
+
+class QueueFull(Exception):
+    """Raised by `.Queue.put_nowait` when a queue is at its maximum size."""
+    pass
+
+
+class Queue(object):
+    """Coordinate producer and consumer coroutines.
+
+    If maxsize is 0 (the default) the queue size is unbounded.
+    """
+    def __init__(self, maxsize=0):
+        if maxsize is None:
+            raise TypeError("maxsize can't be None")
+
+        if maxsize < 0:
+            raise ValueError("maxsize can't be negative")
+
+        self._maxsize = maxsize
+        self._init()
+        self._getters = collections.deque([])  # Futures.
+        self._putters = collections.deque([])  # Pairs of (item, Future).
+        self._unfinished_tasks = 0
+        self._finished = Event()
+        self._finished.set()
+
+    @property
+    def maxsize(self):
+        """Number of items allowed in the queue."""
+        return self._maxsize
+
+    def qsize(self):
+        """Number of items in the queue."""
+        return len(self._queue)
+
+    def empty(self):
+        return not self._queue
+
+    def full(self):
+        if self.maxsize == 0:
+            return False
+        else:
+            return self.qsize() >= self.maxsize
+
+    def put(self, item, timeout=None):
+        """Put an item into the queue, perhaps waiting until there is room.
+
+        Returns a Future, which raises `tornado.gen.TimeoutError` after a
+        timeout.
+        """
+        self._consume_expired()
+        if self._getters:
+            assert not self._queue, "queue non-empty, why are getters waiting?"
+            getter = self._getters.popleft()
+            self._put(item)
+            getter.set_result(self._get())
+            return gen._null_future
+        elif self.full():
+            future = Future()
+            self._putters.append((item, future))
+            if timeout:
+                def on_timeout():
+                    future.set_exception(gen.TimeoutError())
+                ioloop.IOLoop.current().add_timeout(timeout, on_timeout)
+            return future
+        else:
+            self._put(item)
+            return gen._null_future
+
+    def put_nowait(self, item):
+        """Put an item into the queue without blocking.
+
+        If no free slot is immediately available, raise `QueueFull`.
+        """
+        self._consume_expired()
+        if self._getters:
+            assert self.empty(), "queue non-empty, why are getters waiting?"
+            getter = self._getters.popleft()
+
+            self._put(item)
+            getter.set_result(self._get())
+        elif self.full():
+            raise QueueFull
+        else:
+            self._put(item)
+
+    def get(self, timeout=None):
+        """Remove and return an item from the queue.
+
+        Returns a Future which resolves once an item is available, or raises
+        `tornado.gen.TimeoutError` after a timeout.
+        """
+        self._consume_expired()
+        if self._putters:
+            assert self.full(), "queue not full, why are putters waiting?"
+            item, putter = self._putters.popleft()
+            self._put(item)
+            putter.set_result(None)
+
+        if self.qsize():
+            future = Future()
+            # Defer unblocking the getter, which might do task_done() without
+            # yielding first. We want putters to have a chance to run first and
+            # keep join() blocked. See test_producer_consumer().
+            ioloop.IOLoop.current().add_callback(
+                partial(future.set_result, self._get()))
+            return future
+        else:
+            future = Future()
+            self._getters.append(future)
+            if timeout:
+                def on_timeout():
+                    future.set_exception(gen.TimeoutError())
+                ioloop.IOLoop.current().add_timeout(timeout, on_timeout)
+            return future
+
+    def get_nowait(self):
+        """Remove and return an item from the queue without blocking.
+
+        Return an item if one is immediately available, else raise
+        `QueueEmpty`.
+        """
+        self._consume_expired()
+        if self._putters:
+            assert self.full(), "queue not full, why are putters waiting?"
+            item, putter = self._putters.popleft()
+            self._put(item)
+            putter.set_result(None)
+            return self._get()
+        elif self.qsize():
+            return self._get()
+        else:
+            raise QueueEmpty
+
+    def task_done(self):
+        """Indicate that a formerly enqueued task is complete.
+
+        Used by queue consumers. For each `.get` used to fetch a task, a
+        subsequent call to `.task_done` tells the queue that the processing
+        on the task is complete.
+
+        If a `.join` is blocking, it resumes when all items have been
+        processed; that is, when every `.put` is matched by a `.task_done`.
+
+        Raises `ValueError` if called more times than `.put`.
+        """
+        if self._unfinished_tasks <= 0:
+            raise ValueError('task_done() called too many times')
+        self._unfinished_tasks -= 1
+        if self._unfinished_tasks == 0:
+            self._finished.set()
+
+    def join(self, timeout=None):
+        """Block until all items in the queue are processed. Returns a Future.
+
+        Returns a Future, which raises `tornado.gen.TimeoutError` after a
+        timeout.
+        """
+        return self._finished.wait(timeout)
+
+    def _init(self):
+        self._queue = collections.deque()
+
+    def _get(self):
+        return self._queue.popleft()
+
+    def _put(self, item):
+        self._unfinished_tasks += 1
+        self._finished.clear()
+        self._queue.append(item)
+
+    def _consume_expired(self):
+        # Remove timed-out waiters.
+        while self._putters and self._putters[0][1].done():
+            self._putters.popleft()
+
+        while self._getters and self._getters[0].done():
+            self._getters.popleft()
+
+    def __repr__(self):
+        return '<%s at %s %s>' % (
+            type(self).__name__, hex(id(self)), self._format())
+
+    def __str__(self):
+        return '<%s %s>' % (type(self).__name__, self._format())
+
+    def _format(self):
+        result = 'maxsize=%r' % (self.maxsize, )
+        if getattr(self, '_queue', None):
+            result += ' queue=%r' % self._queue
+        if self._getters:
+            result += ' getters[%s]' % len(self._getters)
+        if self._putters:
+            result += ' putters[%s]' % len(self._putters)
+        if self._unfinished_tasks:
+            result += ' tasks=%s' % self._unfinished_tasks
+        return result
diff --git a/tornado/test/queues_test.py b/tornado/test/queues_test.py
new file mode 100644 (file)
index 0000000..880db82
--- /dev/null
@@ -0,0 +1,333 @@
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+from datetime import timedelta
+from random import random
+
+from tornado import gen, queues
+from tornado.gen import TimeoutError
+from tornado.testing import gen_test, AsyncTestCase
+from tornado.test.util import unittest
+
+
+class QueueBasicTest(AsyncTestCase):
+    def test_repr_and_str(self):
+        q = queues.Queue(maxsize=1)
+        self.assertIn(hex(id(q)), repr(q))
+        self.assertNotIn(hex(id(q)), str(q))
+        q.get()
+
+        for q_str in repr(q), str(q):
+            self.assertTrue(q_str.startswith('<Queue'))
+            self.assertIn('maxsize=1', q_str)
+            self.assertIn('getters[1]', q_str)
+            self.assertNotIn('putters', q_str)
+            self.assertNotIn('tasks', q_str)
+
+        q.put(None)
+        q.put(None)
+        # Now the queue is full, this putter blocks.
+        q.put(None)
+
+        for q_str in repr(q), str(q):
+            self.assertNotIn('getters', q_str)
+            self.assertIn('putters[1]', q_str)
+            self.assertIn('tasks=2', q_str)
+
+    def test_order(self):
+        q = queues.Queue()
+        for i in [1, 3, 2]:
+            q.put_nowait(i)
+
+        items = [q.get_nowait() for _ in range(3)]
+        self.assertEqual([1, 3, 2], items)
+
+    @gen_test
+    def test_maxsize(self):
+        self.assertRaises(TypeError, queues.Queue, maxsize=None)
+        self.assertRaises(ValueError, queues.Queue, maxsize=-1)
+
+        q = queues.Queue(maxsize=2)
+        self.assertTrue(q.empty())
+        self.assertFalse(q.full())
+        self.assertEqual(2, q.maxsize)
+        self.assertTrue(q.put(0).done())
+        self.assertTrue(q.put(1).done())
+        self.assertFalse(q.empty())
+        self.assertTrue(q.full())
+        put2 = q.put(2)
+        self.assertFalse(put2.done())
+        self.assertEqual(0, (yield q.get()))  # Make room.
+        self.assertTrue(put2.done())
+        self.assertFalse(q.empty())
+        self.assertTrue(q.full())
+
+
+class QueueGetTest(AsyncTestCase):
+    @gen_test
+    def test_blocking_get(self):
+        q = queues.Queue()
+        q.put_nowait(0)
+        self.assertEqual(0, (yield q.get()))
+
+    def test_nonblocking_get(self):
+        q = queues.Queue()
+        q.put_nowait(0)
+        self.assertEqual(0, q.get_nowait())
+
+    def test_nonblocking_get_exception(self):
+        q = queues.Queue()
+        self.assertRaises(queues.QueueEmpty, q.get_nowait)
+
+    @gen_test
+    def test_get_with_putters(self):
+        q = queues.Queue(1)
+        q.put_nowait(0)
+        put = q.put(1)
+        self.assertEqual(0, (yield q.get()))
+        self.assertIsNone((yield put))
+
+    @gen_test
+    def test_blocking_get_wait(self):
+        q = queues.Queue()
+        q.put(0)
+        self.io_loop.call_later(0.01, q.put, 1)
+        self.io_loop.call_later(0.02, q.put, 2)
+        self.assertEqual(0, (yield q.get(timeout=timedelta(seconds=1))))
+        self.assertEqual(1, (yield q.get(timeout=timedelta(seconds=1))))
+
+    @gen_test
+    def test_get_timeout(self):
+        q = queues.Queue()
+        get_timeout = q.get(timeout=timedelta(seconds=0.01))
+        get = q.get()
+        with self.assertRaises(TimeoutError):
+            yield get_timeout
+        
+        q.put_nowait(0)
+        self.assertEqual(0, (yield get))
+
+    @gen_test
+    def test_get_clears_timed_out_putters(self):
+        q = queues.Queue(1)
+        # First putter succeeds, remainder block.
+        putters = [q.put(i, timedelta(seconds=0.01)) for i in range(10)]
+        put = q.put(10)
+        self.assertEqual(10, len(q._putters))
+        yield gen.sleep(0.02)
+        self.assertEqual(10, len(q._putters))
+        self.assertFalse(put.done())  # Final waiter is still active.
+        q.put(11)
+        self.assertEqual(0, (yield q.get()))  # get() clears the waiters.
+        self.assertEqual(1, len(q._putters))
+        for putter in putters[1:]:
+            self.assertRaises(TimeoutError, putter.result)
+
+    @gen_test
+    def test_get_clears_timed_out_getters(self):
+        q = queues.Queue()
+        getters = [q.get(timedelta(seconds=0.01)) for _ in range(10)]
+        get = q.get()
+        self.assertEqual(11, len(q._getters))
+        yield gen.sleep(0.02)
+        self.assertEqual(11, len(q._getters))
+        self.assertFalse(get.done())  # Final waiter is still active.
+        q.get()  # get() clears the waiters.
+        self.assertEqual(2, len(q._getters))
+        for getter in getters:
+            self.assertRaises(TimeoutError, getter.result)
+
+
+class QueuePutTest(AsyncTestCase):
+    @gen_test
+    def test_blocking_put(self):
+        q = queues.Queue()
+        q.put(0)
+        self.assertEqual(0, q.get_nowait())
+
+    def test_nonblocking_put_exception(self):
+        q = queues.Queue(1)
+        q.put(0)
+        self.assertRaises(queues.QueueFull, q.put_nowait, 1)
+
+    @gen_test
+    def test_put_with_getters(self):
+        q = queues.Queue()
+        get0 = q.get()
+        get1 = q.get()
+        yield q.put(0)
+        self.assertEqual(0, (yield get0))
+        yield q.put(1)
+        self.assertEqual(1, (yield get1))
+        
+    @gen_test
+    def test_nonblocking_put_with_getters(self):
+        q = queues.Queue()
+        get0 = q.get()
+        get1 = q.get()
+        q.put_nowait(0)
+        # put_nowait does *not* immediately unblock getters.
+        yield gen.moment
+        self.assertEqual(0, (yield get0))
+        q.put_nowait(1)
+        yield gen.moment
+        self.assertEqual(1, (yield get1))
+
+    @gen_test
+    def test_blocking_put_wait(self):
+        q = queues.Queue(1)
+        q.put_nowait(0)
+        self.io_loop.call_later(0.01, q.get)
+        self.io_loop.call_later(0.02, q.get)
+        futures = [q.put(0), q.put(1)]
+        self.assertFalse(any(f.done() for f in futures))
+        yield futures
+
+    @gen_test
+    def test_put_timeout(self):
+        q = queues.Queue(1)
+        q.put_nowait(0)  # Now it's full.
+        put_timeout = q.put(1, timeout=timedelta(seconds=0.01))
+        put = q.put(2)
+        with self.assertRaises(TimeoutError):
+            yield put_timeout
+        
+        self.assertEqual(0, q.get_nowait())
+        # 1 was never put in the queue.
+        self.assertEqual(2, (yield q.get()))
+
+        # Final get() unblocked this putter.
+        yield put
+
+    @gen_test
+    def test_put_clears_timed_out_putters(self):
+        q = queues.Queue(1)
+        # First putter succeeds, remainder block.
+        putters = [q.put(i, timedelta(seconds=0.01)) for i in range(10)]
+        put = q.put(10)
+        self.assertEqual(10, len(q._putters))
+        yield gen.sleep(0.02)
+        self.assertEqual(10, len(q._putters))
+        self.assertFalse(put.done())  # Final waiter is still active.
+        q.put(11)  # put() clears the waiters.
+        self.assertEqual(2, len(q._putters))
+        for putter in putters[1:]:
+            self.assertRaises(TimeoutError, putter.result)
+
+    @gen_test
+    def test_put_clears_timed_out_getters(self):
+        q = queues.Queue()
+        getters = [q.get(timedelta(seconds=0.01)) for _ in range(10)]
+        get = q.get()
+        q.get()
+        self.assertEqual(12, len(q._getters))
+        yield gen.sleep(0.02)
+        self.assertEqual(12, len(q._getters))
+        self.assertFalse(get.done())  # Final waiters still active.
+        q.put(0)  # put() clears the waiters.
+        self.assertEqual(1, len(q._getters))
+        self.assertEqual(0, (yield get))
+        for getter in getters:
+            self.assertRaises(TimeoutError, getter.result)
+
+    @gen_test
+    def test_float_maxsize(self):
+        # Non-int maxsize must round down: http://bugs.python.org/issue21723
+        q = queues.Queue(maxsize=1.3)
+        self.assertTrue(q.empty())
+        self.assertFalse(q.full())
+        q.put_nowait(0)
+        q.put_nowait(1)
+        self.assertFalse(q.empty())
+        self.assertTrue(q.full())
+        self.assertRaises(queues.QueueFull, q.put_nowait, 2)
+        self.assertEqual(0, q.get_nowait())
+        self.assertFalse(q.empty())
+        self.assertFalse(q.full())
+
+        yield q.put(2)
+        put = q.put(3)
+        self.assertFalse(put.done())
+        self.assertEqual(1, (yield q.get()))
+        yield put
+        self.assertTrue(q.full())
+
+
+class QueueJoinTest(AsyncTestCase):
+    def test_task_done_underflow(self):
+        q = queues.Queue()
+        self.assertRaises(ValueError, q.task_done)
+
+    @gen_test
+    def test_task_done(self):
+        q = queues.Queue()
+        for i in range(100):
+            q.put_nowait(i)
+
+        self.accumulator = 0
+
+        @gen.coroutine
+        def worker():
+            while True:
+                item = yield q.get()
+                self.accumulator += item
+                q.task_done()
+                yield gen.sleep(random() * 0.01)
+
+        # Two coroutines share work.
+        worker()
+        worker()
+        yield q.join()
+        self.assertEqual(sum(range(100)), self.accumulator)
+
+    @gen_test
+    def test_join_empty_queue(self):
+        q = queues.Queue()
+        yield q.join()
+        yield q.join()
+
+    @gen_test
+    def test_join_timeout(self):
+        q = queues.Queue()
+        q.put(0)
+        with self.assertRaises(TimeoutError):
+            yield q.join(timeout=timedelta(seconds=0.01))
+
+    @gen_test
+    def test_producer_consumer(self):
+        q = queues.Queue(maxsize=3)
+        history = []
+
+        # We don't yield between get() and task_done(), so get() must wait for
+        # the next tick. Otherwise we'd immediately call task_done and unblock
+        # join() before q.put() resumes, and we'd only process the first four
+        # items. Consumers would normally yield in the course of processing an
+        # item, but it's worthwhile testing the degenerate case.
+        @gen.coroutine
+        def consumer():
+            while True:
+                history.append((yield q.get()))
+                q.task_done()
+
+        @gen.coroutine
+        def producer():
+            for item in range(10):
+                yield q.put(item)
+
+        producer()
+        consumer()
+        yield q.join()
+        self.assertEqual(list(range(10)), history)
+
+
+if __name__ == '__main__':
+    unittest.main()
index cb9969d3c34162eab3e95462478108ec0feab049..ad9b0b8357be9271158502ba42d1dcc464e2763b 100644 (file)
@@ -41,6 +41,7 @@ TEST_MODULES = [
     'tornado.test.log_test',
     'tornado.test.options_test',
     'tornado.test.process_test',
+    'tornado.test.queues_test',
     'tornado.test.simple_httpclient_test',
     'tornado.test.stack_context_test',
     'tornado.test.tcpclient_test',