]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Support async iterator protocol in tornado.queues.
authorBen Darnell <ben@bendarnell.com>
Mon, 3 Aug 2015 01:21:48 +0000 (21:21 -0400)
committerBen Darnell <ben@bendarnell.com>
Mon, 3 Aug 2015 01:44:46 +0000 (21:44 -0400)
tornado/queues.py
tornado/test/queues_test.py

index 6d694cc4ce8a7480248c79f8090f87baeb0beb92..bc57e29812d72939f54d2cd945f884a8fcf95be9 100644 (file)
@@ -44,6 +44,14 @@ def _set_timeout(future, timeout):
             lambda _: io_loop.remove_timeout(timeout_handle))
 
 
+class _QueueIterator(object):
+    def __init__(self, q):
+        self.q = q
+
+    def __anext__(self):
+        return self.q.get()
+
+
 class Queue(object):
     """Coordinate producer and consumer coroutines.
 
@@ -96,6 +104,18 @@ class Queue(object):
         Doing work on 3
         Doing work on 4
         Done
+
+    In Python 3.5, `Queue` implements the async iterator protocol, so
+    ``consumer()`` could be rewritten as::
+
+        async def consumer():
+            async for item in q:
+                try:
+                    print('Doing work on %s' % item)
+                    yield gen.sleep(0.01)
+                finally:
+                    q.task_done()
+
     """
     def __init__(self, maxsize=0):
         if maxsize is None:
@@ -220,6 +240,10 @@ class Queue(object):
         """
         return self._finished.wait(timeout)
 
+    @gen.coroutine
+    def __aiter__(self):
+        return _QueueIterator(self)
+
     # These three are overridable in subclasses.
     def _init(self):
         self._queue = collections.deque()
index f2ffb646f0c94a192f1dac28fd6084a0aaca0b6f..519dd6ae9120c9afb0e2160982bda447f5b25730 100644 (file)
 
 from datetime import timedelta
 from random import random
+import sys
+import textwrap
 
 from tornado import gen, queues
 from tornado.gen import TimeoutError
 from tornado.testing import gen_test, AsyncTestCase
-from tornado.test.util import unittest
+from tornado.test.util import unittest, skipBefore35, exec_test
 
 
 class QueueBasicTest(AsyncTestCase):
@@ -112,7 +114,7 @@ class QueueGetTest(AsyncTestCase):
         get = q.get()
         with self.assertRaises(TimeoutError):
             yield get_timeout
-        
+
         q.put_nowait(0)
         self.assertEqual(0, (yield get))
 
@@ -154,6 +156,24 @@ class QueueGetTest(AsyncTestCase):
         for getter in getters:
             self.assertRaises(TimeoutError, getter.result)
 
+    @skipBefore35
+    @gen_test
+    def test_async_for(self):
+        q = queues.Queue()
+        for i in range(5):
+            q.put(i)
+
+        namespace = exec_test(globals(), locals(), """
+        async def f():
+            results = []
+            async for i in q:
+                results.append(i)
+                if i == 4:
+                    return results
+        """)
+        results = yield namespace['f']()
+        self.assertEqual(results, list(range(5)))
+
 
 class QueuePutTest(AsyncTestCase):
     @gen_test
@@ -176,7 +196,7 @@ class QueuePutTest(AsyncTestCase):
         self.assertEqual(0, (yield get0))
         yield q.put(1)
         self.assertEqual(1, (yield get1))
-        
+
     @gen_test
     def test_nonblocking_put_with_getters(self):
         q = queues.Queue()
@@ -208,7 +228,7 @@ class QueuePutTest(AsyncTestCase):
         put = q.put(2)
         with self.assertRaises(TimeoutError):
             yield put_timeout
-        
+
         self.assertEqual(0, q.get_nowait())
         # 1 was never put in the queue.
         self.assertEqual(2, (yield q.get()))
@@ -281,7 +301,7 @@ class QueuePutTest(AsyncTestCase):
 
 class QueueJoinTest(AsyncTestCase):
     queue_class = queues.Queue
-    
+
     def test_task_done_underflow(self):
         q = self.queue_class()
         self.assertRaises(ValueError, q.task_done)
@@ -338,7 +358,7 @@ class QueueJoinTest(AsyncTestCase):
 
 class PriorityQueueJoinTest(QueueJoinTest):
     queue_class = queues.PriorityQueue
-    
+
     @gen_test
     def test_order(self):
         q = self.queue_class(maxsize=2)