]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
[feature-waitany] Created WaitIterator + tests
authorJordan Bettis <jordanb@hafd.org>
Mon, 27 Oct 2014 11:05:01 +0000 (06:05 -0500)
committerJordan Bettis <jordanb@hafd.org>
Mon, 27 Oct 2014 11:05:01 +0000 (06:05 -0500)
tornado/gen.py
tornado/test/gen_test.py

index 2fc9b0c70538bcb7694872049e5fe4a8e07ab07c..4feba8a9c9cbb496e9a2eb36bdc65e97642164f7 100644 (file)
@@ -240,6 +240,124 @@ class Return(Exception):
         super(Return, self).__init__()
         self.value = value
 
+class WaitIterator(object):
+    """Provides an iterator to yield the result of futures as they finish
+
+    Yielding a set of futures like this:
+
+    ``results = yield [future1, future2]``
+
+    pauses the coroutine until both ``future1`` and ``future2``
+    return, and then restarts the coroutine with the results of both
+    futures. If either future is an exception, the expression will
+    raise that exception and the result of the other future will be
+    lost.
+
+    If you need to get the result of each future as soon as possible,
+    or if you need the result of some futures even if others produce
+    errors, you can use ``WaitIterator``:
+
+    ::
+
+      wait_iterator = gen.WaitIterator(future1, future2)
+      for future in wait_iterator:
+          try:
+              result = yield future
+          except Exception as e:
+              print "Error {} from {}".format(e, wait_iterator.current_future())
+          else:
+              print "Result {} recieved from {} at {}".format(
+                  result, wait_iterator.current_future(), wait_iterator.current_index())
+
+    Because results are returned as soon as they are available the
+    output from the iterator *will not be in the same order as the
+    input arguments*. If you need to know which future produced the
+    current result, you can use ``WaitIterator.current_future()``, or
+    ``WaitIterator.current_index()`` to yield the index of the future
+    from the input list.
+    """
+    def __init__(self, *args, **kwargs):
+        if args and kwargs:
+            raise ValueError(
+                "You must provide a list of futures or key/values, not both")
+
+        if kwargs:
+            self._keys = kwargs.keys()
+            self._futures = kwargs.values()
+        else:
+            self._keys = None
+            self._futures = list(args)
+            
+        self._queue = collections.deque()
+        self._current_future = None
+        
+        for future in self._futures:
+            if future.done():
+                self._queue.append(future)
+            else:
+                future.add_done_callback(self._done_callback)
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        return self.next()
+
+    def next(self):
+        """Return a `.Future` that will yield the next avaliable
+        result
+        """
+        if all(x is None for x in self._futures):
+            self._current_future = None
+            raise StopIteration
+
+        self._running_future = TracebackFuture()
+
+        try:
+            done = self._queue.popleft()
+            self._return_result(done)
+        except IndexError:
+            pass
+
+        return self._running_future
+    
+    def current_index(self):
+        """Returns the index of the most recently completed `.Future`
+        from the argument list. If keyword arguments were used, the
+        keyword will be returned.
+        """
+        if self._current_future:
+            return self._current_future[0]
+            
+    def current_future(self):
+        """Returns the most recently completed `.Future` object"""
+        if self._current_future:
+            return self._current_future[1]
+        
+    def _done_callback(self, done):
+        if self._running_future and not self._running_future.done():
+            self._return_result(done)
+        else:
+            self._queue.append(done)
+
+    def _return_result(self, done):
+        """Called set the returned future's state that of the future
+        we yielded, and set the current future for the iterator.
+        """
+        exception = done.exception()
+        if exception:
+            self._running_future.set_exception(exception)
+        else:
+            self._running_future.set_result(done.result())
+
+        index = self._futures.index(done)
+        ## Eliminate the reference for GC
+        self._futures[index] = None
+        
+        if self._keys:
+            index = self._keys[index]
+        
+        self._current_future = (index, done)
 
 class YieldPoint(object):
     """Base class for objects that may be yielded from the generator.
index a15cdf73a152f970986937d233ecc06f42ecd0d1..3fde2e7af7b982c423e736288db5c611e9861a5a 100644 (file)
@@ -1066,6 +1066,113 @@ class WithTimeoutTest(AsyncTestCase):
             yield gen.with_timeout(datetime.timedelta(seconds=3600),
                                    executor.submit(lambda: None))
 
+class WaitIteratorTest(AsyncTestCase):
+    @gen_test
+    def test_empty_iterator(self):
+        g = gen.WaitIterator()
+        for i in g:
+            self.assertTrue(True, 'empty generator iterated')
+
+        try:
+            g = gen.WaitIterator(False, bar=False)
+        except ValueError:
+            pass
+        else:
+            self.assertTrue(True, 'missed incompatible args')
+
+        self.assertEqual(g.current_index(), None, "bad nil current index")
+        self.assertEqual(g.current_future(), None, "bad nil current future")
+
+    @gen_test
+    def test_already_done(self):
+        f1 = Future()
+        f2 = Future()
+        f3 = Future()
+        f1.set_result(24)
+        f2.set_result(42)
+        f3.set_result(84)
+        
+        g = gen.WaitIterator(f1, f2, f3)
+        i = 0
+        for f in g:
+            r = yield f
+            if i == 0:
+                self.assertTrue(
+                    all([g.current_index()==0, g.current_future()==f1, r==24]),
+                    "WaitIterator status incorrect")
+            elif i == 1:
+                self.assertTrue(
+                    all([g.current_index()==1, g.current_future()==f2, r==42]),
+                    "WaitIterator status incorrect")
+            elif i == 2:
+                self.assertTrue(
+                    all([g.current_index()==2, g.current_future()==f3, r==84]),
+                    "WaitIterator status incorrect")
+            i += 1
+
+        self.assertEqual(g.current_index(), None, "bad nil current index")
+        self.assertEqual(g.current_future(), None, "bad nil current future")
+
+        dg = gen.WaitIterator(f1=f1, f2=f2)
+                        
+        for df in dg:
+            dr = yield df
+            if dg.current_index() == "f1":
+                self.assertTrue(dg.current_future()==f1 and dr==24,
+                                "WaitIterator dict status incorrect")
+            elif dg.current_index() == "f2":
+                self.assertTrue(dg.current_future()==f2 and dr==42,
+                                "WaitIterator dict status incorrect")
+            else:
+                self.assertTrue(False, "got bad WaitIterator index {}".format(
+                    dg.current_index()))
+
+            i += 1
+
+        self.assertEqual(dg.current_index(), None, "bad nil current index")
+        self.assertEqual(dg.current_future(), None, "bad nil current future")
+
+    def finish_coroutines(self, iteration, futures):
+        if iteration == 3:
+            futures[2].set_result(24)
+        elif iteration == 5:
+            futures[0].set_exception(ZeroDivisionError)
+        elif iteration == 8:
+            futures[1].set_result(42)
+            futures[3].set_result(84)
+
+        if iteration < 8:
+            self.io_loop.add_callback(self.finish_coroutines, iteration+1, futures)
+
+    @gen_test
+    def test_iterator(self):
+        futures = [Future(), Future(), Future(), Future()]
+
+        class TestException(Exception):
+            pass
+        
+        self.finish_coroutines(0, futures)
+
+        g = gen.WaitIterator(*futures)
+
+        i = 0
+        for f in g:
+            try:
+                r = yield f
+            except ZeroDivisionError:
+                self.assertEqual(g.current_future(), futures[0],
+                                 'exception future invalid')
+            else:
+                if i == 0:
+                    self.assertEqual(r, 24, 'iterator value incorrect')
+                    self.assertEqual(g.current_index(), 2, 'wrong index')
+                elif i == 2:
+                    self.assertEqual(r, 42, 'iterator value incorrect')
+                    self.assertEqual(g.current_index(), 1, 'wrong index')
+                elif i == 3:
+                    self.assertEqual(r, 84, 'iterator value incorrect')
+                    self.assertEqual(g.current_index(), 3, 'wrong index')
+            i += 1
 
 if __name__ == '__main__':
     unittest.main()