From: A. Jesse Jiryu Davis Date: Tue, 24 Mar 2015 03:35:20 +0000 (-0400) Subject: Add PriorityQueue and LifoQueue. X-Git-Tag: v4.2.0b1~55^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=refs%2Fpull%2F1394%2Fhead;p=thirdparty%2Ftornado.git Add PriorityQueue and LifoQueue. --- diff --git a/tornado/queues.py b/tornado/queues.py index 0e7af6ed5..128c415ce 100644 --- a/tornado/queues.py +++ b/tornado/queues.py @@ -14,9 +14,10 @@ from __future__ import absolute_import, division, print_function, with_statement -__all__ = ['Queue', 'QueueFull', 'QueueEmpty'] +__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty'] import collections +import heapq from tornado import gen, ioloop from tornado.concurrent import Future @@ -106,12 +107,12 @@ class Queue(object): if self._getters: assert self.empty(), "queue non-empty, why are getters waiting?" getter = self._getters.popleft() - self._put(item) + self.__put_internal(item) getter.set_result(self._get()) elif self.full(): raise QueueFull else: - self._put(item) + self.__put_internal(item) def get(self, timeout=None): """Remove and return an item from the queue. @@ -137,7 +138,7 @@ class Queue(object): if self._putters: assert self.full(), "queue not full, why are putters waiting?" item, putter = self._putters.popleft() - self._put(item) + self.__put_internal(item) putter.set_result(None) return self._get() elif self.qsize(): @@ -171,6 +172,7 @@ class Queue(object): """ return self._finished.wait(timeout) + # These three are overridable in subclasses. def _init(self): self._queue = collections.deque() @@ -178,9 +180,13 @@ class Queue(object): return self._queue.popleft() def _put(self, item): + self._queue.append(item) + # End of the overridable methods. + + def __put_internal(self, item): self._unfinished_tasks += 1 self._finished.clear() - self._queue.append(item) + self._put(item) def _consume_expired(self): # Remove timed-out waiters. @@ -208,3 +214,30 @@ class Queue(object): if self._unfinished_tasks: result += ' tasks=%s' % self._unfinished_tasks return result + + +class PriorityQueue(Queue): + """A `.Queue` that retrieves entries in priority order, lowest first. + + Entries are typically tuples like ``(priority number, data)``. + """ + def _init(self): + self._queue = [] + + def _put(self, item): + heapq.heappush(self._queue, item) + + def _get(self): + return heapq.heappop(self._queue) + + +class LifoQueue(Queue): + """A `.Queue` that retrieves the most recently put items first.""" + def _init(self): + self._queue = [] + + def _put(self, item): + self._queue.append(item) + + def _get(self): + return self._queue.pop() diff --git a/tornado/test/queues_test.py b/tornado/test/queues_test.py index ac2118332..f2ffb646f 100644 --- a/tornado/test/queues_test.py +++ b/tornado/test/queues_test.py @@ -280,13 +280,15 @@ class QueuePutTest(AsyncTestCase): class QueueJoinTest(AsyncTestCase): + queue_class = queues.Queue + def test_task_done_underflow(self): - q = queues.Queue() + q = self.queue_class() self.assertRaises(ValueError, q.task_done) @gen_test def test_task_done(self): - q = queues.Queue() + q = self.queue_class() for i in range(100): q.put_nowait(i) @@ -309,7 +311,7 @@ class QueueJoinTest(AsyncTestCase): @gen_test def test_task_done_delay(self): # Verify it is task_done(), not get(), that unblocks join(). - q = queues.Queue() + q = self.queue_class() q.put_nowait(0) join = q.join() self.assertFalse(join.done()) @@ -322,17 +324,55 @@ class QueueJoinTest(AsyncTestCase): @gen_test def test_join_empty_queue(self): - q = queues.Queue() + q = self.queue_class() yield q.join() yield q.join() @gen_test def test_join_timeout(self): - q = queues.Queue() + q = self.queue_class() q.put(0) with self.assertRaises(TimeoutError): yield q.join(timeout=timedelta(seconds=0.01)) + +class PriorityQueueJoinTest(QueueJoinTest): + queue_class = queues.PriorityQueue + + @gen_test + def test_order(self): + q = self.queue_class(maxsize=2) + q.put_nowait((1, 'a')) + q.put_nowait((0, 'b')) + self.assertTrue(q.full()) + q.put((3, 'c')) + q.put((2, 'd')) + self.assertEqual((0, 'b'), q.get_nowait()) + self.assertEqual((1, 'a'), (yield q.get())) + self.assertEqual((2, 'd'), q.get_nowait()) + self.assertEqual((3, 'c'), (yield q.get())) + self.assertTrue(q.empty()) + + +class LifoQueueJoinTest(QueueJoinTest): + queue_class = queues.LifoQueue + + @gen_test + def test_order(self): + q = self.queue_class(maxsize=2) + q.put_nowait(1) + q.put_nowait(0) + self.assertTrue(q.full()) + q.put(3) + q.put(2) + self.assertEqual(3, q.get_nowait()) + self.assertEqual(2, (yield q.get())) + self.assertEqual(0, q.get_nowait()) + self.assertEqual(1, (yield q.get())) + self.assertTrue(q.empty()) + + +class ProducerConsumerTest(AsyncTestCase): @gen_test def test_producer_consumer(self): q = queues.Queue(maxsize=3)