From: Jordan Bettis Date: Mon, 27 Oct 2014 11:05:01 +0000 (-0500) Subject: [feature-waitany] Created WaitIterator + tests X-Git-Tag: v4.1.0b1~18^2~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5dd933815175d6b472a6787189babbd8e4250cad;p=thirdparty%2Ftornado.git [feature-waitany] Created WaitIterator + tests --- diff --git a/tornado/gen.py b/tornado/gen.py index 2fc9b0c70..4feba8a9c 100644 --- a/tornado/gen.py +++ b/tornado/gen.py @@ -240,6 +240,124 @@ class Return(Exception): super(Return, self).__init__() self.value = value +class WaitIterator(object): + """Provides an iterator to yield the result of futures as they finish + + Yielding a set of futures like this: + + ``results = yield [future1, future2]`` + + pauses the coroutine until both ``future1`` and ``future2`` + return, and then restarts the coroutine with the results of both + futures. If either future is an exception, the expression will + raise that exception and the result of the other future will be + lost. + + If you need to get the result of each future as soon as possible, + or if you need the result of some futures even if others produce + errors, you can use ``WaitIterator``: + + :: + + wait_iterator = gen.WaitIterator(future1, future2) + for future in wait_iterator: + try: + result = yield future + except Exception as e: + print "Error {} from {}".format(e, wait_iterator.current_future()) + else: + print "Result {} recieved from {} at {}".format( + result, wait_iterator.current_future(), wait_iterator.current_index()) + + Because results are returned as soon as they are available the + output from the iterator *will not be in the same order as the + input arguments*. If you need to know which future produced the + current result, you can use ``WaitIterator.current_future()``, or + ``WaitIterator.current_index()`` to yield the index of the future + from the input list. + """ + def __init__(self, *args, **kwargs): + if args and kwargs: + raise ValueError( + "You must provide a list of futures or key/values, not both") + + if kwargs: + self._keys = kwargs.keys() + self._futures = kwargs.values() + else: + self._keys = None + self._futures = list(args) + + self._queue = collections.deque() + self._current_future = None + + for future in self._futures: + if future.done(): + self._queue.append(future) + else: + future.add_done_callback(self._done_callback) + + def __iter__(self): + return self + + def __next__(self): + return self.next() + + def next(self): + """Return a `.Future` that will yield the next avaliable + result + """ + if all(x is None for x in self._futures): + self._current_future = None + raise StopIteration + + self._running_future = TracebackFuture() + + try: + done = self._queue.popleft() + self._return_result(done) + except IndexError: + pass + + return self._running_future + + def current_index(self): + """Returns the index of the most recently completed `.Future` + from the argument list. If keyword arguments were used, the + keyword will be returned. + """ + if self._current_future: + return self._current_future[0] + + def current_future(self): + """Returns the most recently completed `.Future` object""" + if self._current_future: + return self._current_future[1] + + def _done_callback(self, done): + if self._running_future and not self._running_future.done(): + self._return_result(done) + else: + self._queue.append(done) + + def _return_result(self, done): + """Called set the returned future's state that of the future + we yielded, and set the current future for the iterator. + """ + exception = done.exception() + if exception: + self._running_future.set_exception(exception) + else: + self._running_future.set_result(done.result()) + + index = self._futures.index(done) + ## Eliminate the reference for GC + self._futures[index] = None + + if self._keys: + index = self._keys[index] + + self._current_future = (index, done) class YieldPoint(object): """Base class for objects that may be yielded from the generator. diff --git a/tornado/test/gen_test.py b/tornado/test/gen_test.py index a15cdf73a..3fde2e7af 100644 --- a/tornado/test/gen_test.py +++ b/tornado/test/gen_test.py @@ -1066,6 +1066,113 @@ class WithTimeoutTest(AsyncTestCase): yield gen.with_timeout(datetime.timedelta(seconds=3600), executor.submit(lambda: None)) +class WaitIteratorTest(AsyncTestCase): + @gen_test + def test_empty_iterator(self): + g = gen.WaitIterator() + for i in g: + self.assertTrue(True, 'empty generator iterated') + + try: + g = gen.WaitIterator(False, bar=False) + except ValueError: + pass + else: + self.assertTrue(True, 'missed incompatible args') + + self.assertEqual(g.current_index(), None, "bad nil current index") + self.assertEqual(g.current_future(), None, "bad nil current future") + + @gen_test + def test_already_done(self): + f1 = Future() + f2 = Future() + f3 = Future() + f1.set_result(24) + f2.set_result(42) + f3.set_result(84) + + g = gen.WaitIterator(f1, f2, f3) + i = 0 + for f in g: + r = yield f + if i == 0: + self.assertTrue( + all([g.current_index()==0, g.current_future()==f1, r==24]), + "WaitIterator status incorrect") + elif i == 1: + self.assertTrue( + all([g.current_index()==1, g.current_future()==f2, r==42]), + "WaitIterator status incorrect") + elif i == 2: + self.assertTrue( + all([g.current_index()==2, g.current_future()==f3, r==84]), + "WaitIterator status incorrect") + i += 1 + + self.assertEqual(g.current_index(), None, "bad nil current index") + self.assertEqual(g.current_future(), None, "bad nil current future") + + dg = gen.WaitIterator(f1=f1, f2=f2) + + for df in dg: + dr = yield df + if dg.current_index() == "f1": + self.assertTrue(dg.current_future()==f1 and dr==24, + "WaitIterator dict status incorrect") + elif dg.current_index() == "f2": + self.assertTrue(dg.current_future()==f2 and dr==42, + "WaitIterator dict status incorrect") + else: + self.assertTrue(False, "got bad WaitIterator index {}".format( + dg.current_index())) + + i += 1 + + self.assertEqual(dg.current_index(), None, "bad nil current index") + self.assertEqual(dg.current_future(), None, "bad nil current future") + + def finish_coroutines(self, iteration, futures): + if iteration == 3: + futures[2].set_result(24) + elif iteration == 5: + futures[0].set_exception(ZeroDivisionError) + elif iteration == 8: + futures[1].set_result(42) + futures[3].set_result(84) + + if iteration < 8: + self.io_loop.add_callback(self.finish_coroutines, iteration+1, futures) + + @gen_test + def test_iterator(self): + futures = [Future(), Future(), Future(), Future()] + + class TestException(Exception): + pass + + self.finish_coroutines(0, futures) + + g = gen.WaitIterator(*futures) + + i = 0 + for f in g: + try: + r = yield f + except ZeroDivisionError: + self.assertEqual(g.current_future(), futures[0], + 'exception future invalid') + else: + if i == 0: + self.assertEqual(r, 24, 'iterator value incorrect') + self.assertEqual(g.current_index(), 2, 'wrong index') + elif i == 2: + self.assertEqual(r, 42, 'iterator value incorrect') + self.assertEqual(g.current_index(), 1, 'wrong index') + elif i == 3: + self.assertEqual(r, 84, 'iterator value incorrect') + self.assertEqual(g.current_index(), 3, 'wrong index') + i += 1 if __name__ == '__main__': unittest.main()