]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add PriorityQueue and LifoQueue. 1394/head
authorA. Jesse Jiryu Davis <jesse@mongodb.com>
Tue, 24 Mar 2015 03:35:20 +0000 (23:35 -0400)
committerA. Jesse Jiryu Davis <jesse@mongodb.com>
Tue, 24 Mar 2015 03:35:20 +0000 (23:35 -0400)
tornado/queues.py
tornado/test/queues_test.py

index 0e7af6ed517f1b3bb804e605a6d2d2946d63814e..128c415ce1c6eb3f6da16719c6dc690bc5b7e1dd 100644 (file)
 
 from __future__ import absolute_import, division, print_function, with_statement
 
-__all__ = ['Queue', 'QueueFull', 'QueueEmpty']
+__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty']
 
 import collections
+import heapq
 
 from tornado import gen, ioloop
 from tornado.concurrent import Future
@@ -106,12 +107,12 @@ class Queue(object):
         if self._getters:
             assert self.empty(), "queue non-empty, why are getters waiting?"
             getter = self._getters.popleft()
-            self._put(item)
+            self.__put_internal(item)
             getter.set_result(self._get())
         elif self.full():
             raise QueueFull
         else:
-            self._put(item)
+            self.__put_internal(item)
 
     def get(self, timeout=None):
         """Remove and return an item from the queue.
@@ -137,7 +138,7 @@ class Queue(object):
         if self._putters:
             assert self.full(), "queue not full, why are putters waiting?"
             item, putter = self._putters.popleft()
-            self._put(item)
+            self.__put_internal(item)
             putter.set_result(None)
             return self._get()
         elif self.qsize():
@@ -171,6 +172,7 @@ class Queue(object):
         """
         return self._finished.wait(timeout)
 
+    # These three are overridable in subclasses.
     def _init(self):
         self._queue = collections.deque()
 
@@ -178,9 +180,13 @@ class Queue(object):
         return self._queue.popleft()
 
     def _put(self, item):
+        self._queue.append(item)
+    # End of the overridable methods.
+
+    def __put_internal(self, item):
         self._unfinished_tasks += 1
         self._finished.clear()
-        self._queue.append(item)
+        self._put(item)
 
     def _consume_expired(self):
         # Remove timed-out waiters.
@@ -208,3 +214,30 @@ class Queue(object):
         if self._unfinished_tasks:
             result += ' tasks=%s' % self._unfinished_tasks
         return result
+
+
+class PriorityQueue(Queue):
+    """A `.Queue` that retrieves entries in priority order, lowest first.
+
+    Entries are typically tuples like ``(priority number, data)``.
+    """
+    def _init(self):
+        self._queue = []
+
+    def _put(self, item):
+        heapq.heappush(self._queue, item)
+
+    def _get(self):
+        return heapq.heappop(self._queue)
+
+
+class LifoQueue(Queue):
+    """A `.Queue` that retrieves the most recently put items first."""
+    def _init(self):
+        self._queue = []
+
+    def _put(self, item):
+        self._queue.append(item)
+
+    def _get(self):
+        return self._queue.pop()
index ac2118332d8b93add8fe33bd8049f6e377f5f892..f2ffb646f0c94a192f1dac28fd6084a0aaca0b6f 100644 (file)
@@ -280,13 +280,15 @@ class QueuePutTest(AsyncTestCase):
 
 
 class QueueJoinTest(AsyncTestCase):
+    queue_class = queues.Queue
+    
     def test_task_done_underflow(self):
-        q = queues.Queue()
+        q = self.queue_class()
         self.assertRaises(ValueError, q.task_done)
 
     @gen_test
     def test_task_done(self):
-        q = queues.Queue()
+        q = self.queue_class()
         for i in range(100):
             q.put_nowait(i)
 
@@ -309,7 +311,7 @@ class QueueJoinTest(AsyncTestCase):
     @gen_test
     def test_task_done_delay(self):
         # Verify it is task_done(), not get(), that unblocks join().
-        q = queues.Queue()
+        q = self.queue_class()
         q.put_nowait(0)
         join = q.join()
         self.assertFalse(join.done())
@@ -322,17 +324,55 @@ class QueueJoinTest(AsyncTestCase):
 
     @gen_test
     def test_join_empty_queue(self):
-        q = queues.Queue()
+        q = self.queue_class()
         yield q.join()
         yield q.join()
 
     @gen_test
     def test_join_timeout(self):
-        q = queues.Queue()
+        q = self.queue_class()
         q.put(0)
         with self.assertRaises(TimeoutError):
             yield q.join(timeout=timedelta(seconds=0.01))
 
+
+class PriorityQueueJoinTest(QueueJoinTest):
+    queue_class = queues.PriorityQueue
+    
+    @gen_test
+    def test_order(self):
+        q = self.queue_class(maxsize=2)
+        q.put_nowait((1, 'a'))
+        q.put_nowait((0, 'b'))
+        self.assertTrue(q.full())
+        q.put((3, 'c'))
+        q.put((2, 'd'))
+        self.assertEqual((0, 'b'), q.get_nowait())
+        self.assertEqual((1, 'a'), (yield q.get()))
+        self.assertEqual((2, 'd'), q.get_nowait())
+        self.assertEqual((3, 'c'), (yield q.get()))
+        self.assertTrue(q.empty())
+
+
+class LifoQueueJoinTest(QueueJoinTest):
+    queue_class = queues.LifoQueue
+
+    @gen_test
+    def test_order(self):
+        q = self.queue_class(maxsize=2)
+        q.put_nowait(1)
+        q.put_nowait(0)
+        self.assertTrue(q.full())
+        q.put(3)
+        q.put(2)
+        self.assertEqual(3, q.get_nowait())
+        self.assertEqual(2, (yield q.get()))
+        self.assertEqual(0, q.get_nowait())
+        self.assertEqual(1, (yield q.get()))
+        self.assertTrue(q.empty())
+
+
+class ProducerConsumerTest(AsyncTestCase):
     @gen_test
     def test_producer_consumer(self):
         q = queues.Queue(maxsize=3)