]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Convert WaitIterator to use while-like iteration instead of for-like.
authorBen Darnell <ben@bendarnell.com>
Sun, 18 Jan 2015 22:06:53 +0000 (17:06 -0500)
committerBen Darnell <ben@bendarnell.com>
Sun, 18 Jan 2015 22:06:53 +0000 (17:06 -0500)
Make current_future and current_index attributes instead of methods.
Restructure the internals to avoid quadratic performance.

tornado/gen.py
tornado/test/gen_test.py

index 2ec48dfb461bea19ae94a2239374087a6504bfc8..1d1d267bf6998ab282c174b01daf5a6651c96adb 100644 (file)
@@ -243,7 +243,7 @@ class Return(Exception):
         self.value = value
 
 class WaitIterator(object):
-    """Provides an iterator to yield the results of futures as they finish
+    """Provides an iterator to yield the results of futures as they finish.
 
     Yielding a set of futures like this:
 
@@ -261,21 +261,26 @@ class WaitIterator(object):
     ::
 
       wait_iterator = gen.WaitIterator(future1, future2)
-      for future in wait_iterator:
+      while not wait_iterator.done():
           try:
-              result = yield future
+              result = yield wait_iterator.next()
           except Exception as e:
-              print "Error {} from {}".format(e, wait_iterator.current_future())
+              print "Error {} from {}".format(e, wait_iterator.current_future)
           else:
               print "Result {} recieved from {} at {}".format(
-                  result, wait_iterator.current_future(), wait_iterator.current_index())
+                  result, wait_iterator.current_future,
+                  wait_iterator.current_index)
 
     Because results are returned as soon as they are available the
     output from the iterator *will not be in the same order as the
     input arguments*. If you need to know which future produced the
-    current result, you can use ``WaitIterator.current_future()``, or
-    ``WaitIterator.current_index()`` to yield the index of the future
-    from the input list.
+    current result, you can use the attributes
+    ``WaitIterator.current_future``, or ``WaitIterator.current_index``
+    to get the index of the future from the input list. (if keyword
+    arguments were used in the construction of the `WaitIterator`,
+    ``current_index`` will use the corresponding keyword).
+
+    .. versionadded:: 4.1
     """
     def __init__(self, *args, **kwargs):
         if args and kwargs:
@@ -283,62 +288,41 @@ class WaitIterator(object):
                 "You must provide args or kwargs, not both")
 
         if kwargs:
-            self._keys, self._futures = list(), list()
-            for k, v in kwargs.items():
-                self._keys.append(k)
-                self._futures.append(v)
+            self._unfinished = dict((f, k) for (k, f) in kwargs.items())
+            futures = list(kwargs.values())
         else:
-            self._keys = None
-            self._futures = list(args)
-
-        self._queue = collections.deque()
-        self._current_future = None
+            self._unfinished = dict((f, i) for (i, f) in enumerate(args))
+            futures = args
 
-        for future in self._futures:
-            if future.done():
-                self._queue.append(future)
-            else:
-                self_ref = weakref.ref(self)
-                future.add_done_callback(functools.partial(
-                        self._done_callback, self_ref))
+        self._finished = collections.deque()
+        self.current_index = self.current_future = None
+        self._running_future = None
 
-    def __iter__(self):
-        return self
+        self_ref = weakref.ref(self)
+        for future in futures:
+            future.add_done_callback(functools.partial(
+                self._done_callback, self_ref))
 
-    def __next__(self):
-        return self.next()
+    def done(self):
+        if self._finished or self._unfinished:
+            return False
+        # Clear the 'current' values when iteration is done.
+        self.current_index = self.current_future = None
+        return True
 
     def next(self):
-        """Returns a `.Future` that will yield the next available
-        result.
-        """
-        if all(x is None for x in self._futures):
-            self._current_future = None
-            raise StopIteration
+        """Returns a `.Future` that will yield the next available result.
 
+        Note that this `.Future` will not be the same object as any of
+        the inputs.
+        """
         self._running_future = TracebackFuture()
 
-        try:
-            done = self._queue.popleft()
-            self._return_result(done)
-        except IndexError:
-            pass
+        if self._finished:
+            self._return_result(self._finished.popleft())
 
         return self._running_future
 
-    def current_index(self):
-        """Returns the index of the current `.Future` from the
-        argument list. If keyword arguments were used, the keyword
-        will be returned.
-        """
-        if self._current_future:
-            return self._current_future[0]
-
-    def current_future(self):
-        """Returns the current `.Future` object."""
-        if self._current_future:
-            return self._current_future[1]
-
     @staticmethod
     def _done_callback(self_ref, done):
         self = self_ref()
@@ -346,26 +330,17 @@ class WaitIterator(object):
             if self._running_future and not self._running_future.done():
                 self._return_result(done)
             else:
-                self._queue.append(done)
+                self._finished.append(done)
 
     def _return_result(self, done):
         """Called set the returned future's state that of the future
         we yielded, and set the current future for the iterator.
         """
-        exception = done.exception()
-        if exception:
-            self._running_future.set_exception(exception)
-        else:
-            self._running_future.set_result(done.result())
-
-        index = self._futures.index(done)
-        ## Eliminate the reference for GC
-        self._futures[index] = None
+        chain_future(done, self._running_future)
 
-        if self._keys:
-            index = self._keys[index]
+        self.current_future = done
+        self.current_index = self._unfinished.pop(done)
 
-        self._current_future = (index, done)
 
 class YieldPoint(object):
     """Base class for objects that may be yielded from the generator.
index 692552f976c72c44906cb5cb4eea7404b9dd266e..13ee1a2c5936c1457057f3f01e037025adec0a1d 100644 (file)
@@ -1070,18 +1070,13 @@ class WaitIteratorTest(AsyncTestCase):
     @gen_test
     def test_empty_iterator(self):
         g = gen.WaitIterator()
-        for i in g:
-            self.assertTrue(True, 'empty generator iterated')
+        self.assertTrue(g.done(), 'empty generator iterated')
 
-        try:
+        with self.assertRaises(ValueError):
             g = gen.WaitIterator(False, bar=False)
-        except ValueError:
-            pass
-        else:
-            self.assertTrue(True, 'missed incompatible args')
 
-        self.assertEqual(g.current_index(), None, "bad nil current index")
-        self.assertEqual(g.current_future(), None, "bad nil current future")
+        self.assertEqual(g.current_index, None, "bad nil current index")
+        self.assertEqual(g.current_future, None, "bad nil current future")
 
     @gen_test
     def test_already_done(self):
@@ -1094,43 +1089,45 @@ class WaitIteratorTest(AsyncTestCase):
 
         g = gen.WaitIterator(f1, f2, f3)
         i = 0
-        for f in g:
-            r = yield f
+        while not g.done():
+            r = yield g.next()
+            # Order is not guaranteed, but the current implementation
+            # preserves ordering of already-done Futures.
             if i == 0:
-                self.assertTrue(
-                    all([g.current_index()==0, g.current_future()==f1, r==24]),
-                    "WaitIterator status incorrect")
+                self.assertEqual(g.current_index, 0)
+                self.assertIs(g.current_future, f1)
+                self.assertEqual(r, 24)
             elif i == 1:
-                self.assertTrue(
-                    all([g.current_index()==1, g.current_future()==f2, r==42]),
-                    "WaitIterator status incorrect")
+                self.assertEqual(g.current_index, 1)
+                self.assertIs(g.current_future, f2)
+                self.assertEqual(r, 42)
             elif i == 2:
-                self.assertTrue(
-                    all([g.current_index()==2, g.current_future()==f3, r==84]),
-                    "WaitIterator status incorrect")
+                self.assertEqual(g.current_index, 2)
+                self.assertIs(g.current_future, f3)
+                self.assertEqual(r, 84)
             i += 1
 
-        self.assertEqual(g.current_index(), None, "bad nil current index")
-        self.assertEqual(g.current_future(), None, "bad nil current future")
+        self.assertEqual(g.current_index, None, "bad nil current index")
+        self.assertEqual(g.current_future, None, "bad nil current future")
 
         dg = gen.WaitIterator(f1=f1, f2=f2)
 
-        for df in dg:
-            dr = yield df
-            if dg.current_index() == "f1":
-                self.assertTrue(dg.current_future()==f1 and dr==24,
+        while not dg.done():
+            dr = yield dg.next()
+            if dg.current_index == "f1":
+                self.assertTrue(dg.current_future==f1 and dr==24,
                                 "WaitIterator dict status incorrect")
-            elif dg.current_index() == "f2":
-                self.assertTrue(dg.current_future()==f2 and dr==42,
+            elif dg.current_index == "f2":
+                self.assertTrue(dg.current_future==f2 and dr==42,
                                 "WaitIterator dict status incorrect")
             else:
-                self.assertTrue(False, "got bad WaitIterator index {}".format(
-                    dg.current_index()))
+                self.fail("got bad WaitIterator index {}".format(
+                    dg.current_index))
 
             i += 1
 
-        self.assertEqual(dg.current_index(), None, "bad nil current index")
-        self.assertEqual(dg.current_future(), None, "bad nil current future")
+        self.assertEqual(dg.current_index, None, "bad nil current index")
+        self.assertEqual(dg.current_future, None, "bad nil current future")
 
     def finish_coroutines(self, iteration, futures):
         if iteration == 3:
@@ -1153,22 +1150,22 @@ class WaitIteratorTest(AsyncTestCase):
         g = gen.WaitIterator(*futures)
 
         i = 0
-        for f in g:
+        while not g.done():
             try:
-                r = yield f
+                r = yield g.next()
             except ZeroDivisionError:
-                self.assertEqual(g.current_future(), futures[0],
-                                 'exception future invalid')
+                self.assertIs(g.current_future, futures[0],
+                              'exception future invalid')
             else:
                 if i == 0:
                     self.assertEqual(r, 24, 'iterator value incorrect')
-                    self.assertEqual(g.current_index(), 2, 'wrong index')
+                    self.assertEqual(g.current_index, 2, 'wrong index')
                 elif i == 2:
                     self.assertEqual(r, 42, 'iterator value incorrect')
-                    self.assertEqual(g.current_index(), 1, 'wrong index')
+                    self.assertEqual(g.current_index, 1, 'wrong index')
                 elif i == 3:
                     self.assertEqual(r, 84, 'iterator value incorrect')
-                    self.assertEqual(g.current_index(), 3, 'wrong index')
+                    self.assertEqual(g.current_index, 3, 'wrong index')
             i += 1
 
 if __name__ == '__main__':