from __future__ import absolute_import, division, print_function, with_statement
-__all__ = ['Queue', 'QueueFull', 'QueueEmpty']
+__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty']
import collections
+import heapq
from tornado import gen, ioloop
from tornado.concurrent import Future
if self._getters:
assert self.empty(), "queue non-empty, why are getters waiting?"
getter = self._getters.popleft()
- self._put(item)
+ self.__put_internal(item)
getter.set_result(self._get())
elif self.full():
raise QueueFull
else:
- self._put(item)
+ self.__put_internal(item)
def get(self, timeout=None):
"""Remove and return an item from the queue.
if self._putters:
assert self.full(), "queue not full, why are putters waiting?"
item, putter = self._putters.popleft()
- self._put(item)
+ self.__put_internal(item)
putter.set_result(None)
return self._get()
elif self.qsize():
"""
return self._finished.wait(timeout)
+ # These three are overridable in subclasses.
def _init(self):
self._queue = collections.deque()
return self._queue.popleft()
def _put(self, item):
+ self._queue.append(item)
+ # End of the overridable methods.
+
+ def __put_internal(self, item):
self._unfinished_tasks += 1
self._finished.clear()
- self._queue.append(item)
+ self._put(item)
def _consume_expired(self):
# Remove timed-out waiters.
if self._unfinished_tasks:
result += ' tasks=%s' % self._unfinished_tasks
return result
+
+
+class PriorityQueue(Queue):
+ """A `.Queue` that retrieves entries in priority order, lowest first.
+
+ Entries are typically tuples like ``(priority number, data)``.
+ """
+ def _init(self):
+ self._queue = []
+
+ def _put(self, item):
+ heapq.heappush(self._queue, item)
+
+ def _get(self):
+ return heapq.heappop(self._queue)
+
+
+class LifoQueue(Queue):
+ """A `.Queue` that retrieves the most recently put items first."""
+ def _init(self):
+ self._queue = []
+
+ def _put(self, item):
+ self._queue.append(item)
+
+ def _get(self):
+ return self._queue.pop()
class QueueJoinTest(AsyncTestCase):
+ queue_class = queues.Queue
+
def test_task_done_underflow(self):
- q = queues.Queue()
+ q = self.queue_class()
self.assertRaises(ValueError, q.task_done)
@gen_test
def test_task_done(self):
- q = queues.Queue()
+ q = self.queue_class()
for i in range(100):
q.put_nowait(i)
@gen_test
def test_task_done_delay(self):
# Verify it is task_done(), not get(), that unblocks join().
- q = queues.Queue()
+ q = self.queue_class()
q.put_nowait(0)
join = q.join()
self.assertFalse(join.done())
@gen_test
def test_join_empty_queue(self):
- q = queues.Queue()
+ q = self.queue_class()
yield q.join()
yield q.join()
@gen_test
def test_join_timeout(self):
- q = queues.Queue()
+ q = self.queue_class()
q.put(0)
with self.assertRaises(TimeoutError):
yield q.join(timeout=timedelta(seconds=0.01))
+
+class PriorityQueueJoinTest(QueueJoinTest):
+ queue_class = queues.PriorityQueue
+
+ @gen_test
+ def test_order(self):
+ q = self.queue_class(maxsize=2)
+ q.put_nowait((1, 'a'))
+ q.put_nowait((0, 'b'))
+ self.assertTrue(q.full())
+ q.put((3, 'c'))
+ q.put((2, 'd'))
+ self.assertEqual((0, 'b'), q.get_nowait())
+ self.assertEqual((1, 'a'), (yield q.get()))
+ self.assertEqual((2, 'd'), q.get_nowait())
+ self.assertEqual((3, 'c'), (yield q.get()))
+ self.assertTrue(q.empty())
+
+
+class LifoQueueJoinTest(QueueJoinTest):
+ queue_class = queues.LifoQueue
+
+ @gen_test
+ def test_order(self):
+ q = self.queue_class(maxsize=2)
+ q.put_nowait(1)
+ q.put_nowait(0)
+ self.assertTrue(q.full())
+ q.put(3)
+ q.put(2)
+ self.assertEqual(3, q.get_nowait())
+ self.assertEqual(2, (yield q.get()))
+ self.assertEqual(0, q.get_nowait())
+ self.assertEqual(1, (yield q.get()))
+ self.assertTrue(q.empty())
+
+
+class ProducerConsumerTest(AsyncTestCase):
@gen_test
def test_producer_consumer(self):
q = queues.Queue(maxsize=3)