]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Support lists of futures in @gen.engine.
authorBen Darnell <ben@bendarnell.com>
Sun, 27 Jan 2013 23:07:21 +0000 (18:07 -0500)
committerBen Darnell <ben@bendarnell.com>
Sun, 27 Jan 2013 23:07:21 +0000 (18:07 -0500)
tornado/gen.py
tornado/test/gen_test.py

index 320f7da65df6b7e0fb75185eabc708db9748f771..e309bf3c0a08897712f079fd61e3df933da38c8e 100644 (file)
@@ -276,8 +276,12 @@ class Multi(YieldPoint):
     a list of ``YieldPoints``.
     """
     def __init__(self, children):
-        assert all(isinstance(i, YieldPoint) for i in children)
-        self.children = children
+        self.children = []
+        for i in children:
+            if isinstance(i, Future):
+                i = YieldFuture(i)
+            self.children.append(i)
+        assert all(isinstance(i, YieldPoint) for i in self.children)
 
     def start(self, runner):
         for i in self.children:
@@ -383,8 +387,7 @@ class Runner(object):
                     raise
                 if isinstance(yielded, list):
                     yielded = Multi(yielded)
-                if isinstance(yielded, Future):
-                    # TODO: lists of futures
+                elif isinstance(yielded, Future):
                     yielded = YieldFuture(yielded)
                 if isinstance(yielded, YieldPoint):
                     self.yield_point = yielded
index 7249a3ad4a2ca35098e8da6f640aea232d7e5919..ad5310776a185124af1143dca4dfa9ed1d8dc392 100644 (file)
@@ -1,14 +1,17 @@
 from __future__ import absolute_import, division, print_function, with_statement
+
 import functools
+from tornado.concurrent import return_future
 from tornado.escape import url_escape
 from tornado.httpclient import AsyncHTTPClient
 from tornado.log import app_log
-from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog
+from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog, gen_test
 from tornado.web import Application, RequestHandler, asynchronous
 
 from tornado import gen
 
 
+
 class GenTest(AsyncTestCase):
     def run_gen(self, f):
         f()
@@ -22,6 +25,10 @@ class GenTest(AsyncTestCase):
             self.io_loop.add_callback(functools.partial(
                 self.delay_callback, iterations - 1, callback, arg))
 
+    @return_future
+    def async_future(self, result, callback):
+        self.io_loop.add_callback(callback, result)
+
     def test_no_yield(self):
         @gen.engine
         def f():
@@ -220,6 +227,16 @@ class GenTest(AsyncTestCase):
             self.stop()
         self.run_gen(f)
 
+    @gen_test
+    def test_future(self):
+        result = yield self.async_future(1)
+        self.assertEqual(result, 1)
+
+    @gen_test
+    def test_multi_future(self):
+        results = yield [self.async_future(1), self.async_future(2)]
+        self.assertEqual(results, [1, 2])
+
     def test_arguments(self):
         @gen.engine
         def f():