]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Yielding dict in coroutine
authorAnton Ryzhov <anton@ryzhov.me>
Thu, 17 Oct 2013 15:37:00 +0000 (19:37 +0400)
committerAnton Ryzhov <anton@ryzhov.me>
Thu, 17 Oct 2013 15:37:00 +0000 (19:37 +0400)
tornado/gen.py
tornado/test/gen_test.py

index 92b7458ed959695c61a7d5d215bc03e33da730e6..7eb2c0ca32e1d7273567486a71c9d0407ebf4b86 100644 (file)
@@ -404,6 +404,10 @@ class Multi(YieldPoint):
     a list of ``YieldPoints``.
     """
     def __init__(self, children):
+        self.keys = None
+        if isinstance(children, dict):
+            self.keys = list(children.keys())
+            children = children.values()
         self.children = []
         for i in children:
             if isinstance(i, Future):
@@ -423,7 +427,11 @@ class Multi(YieldPoint):
         return not self.unfinished_children
 
     def get_result(self):
-        return [i.get_result() for i in self.children]
+        result = (i.get_result() for i in self.children)
+        if self.keys:
+            return dict(zip(self.keys, result))
+        else:
+            return list(result)
 
 
 class _NullYieldPoint(YieldPoint):
@@ -523,7 +531,7 @@ class Runner(object):
                     self.finished = True
                     self.yield_point = _null_yield_point
                     raise
-                if isinstance(yielded, list):
+                if isinstance(yielded, (list, dict)):
                     yielded = Multi(yielded)
                 elif isinstance(yielded, Future):
                     yielded = YieldFuture(yielded)
index 52de8da532016c8950c505498b415b7e18f0e771..3c51c23b49b75544ef4b8c9742ef43969de541f5 100644 (file)
@@ -281,6 +281,16 @@ class GenEngineTest(AsyncTestCase):
             self.stop()
         self.run_gen(f)
 
+    def test_multi_dict(self):
+        @gen.engine
+        def f():
+            (yield gen.Callback("k1"))("v1")
+            (yield gen.Callback("k2"))("v2")
+            results = yield dict(foo=gen.Wait("k1"), bar=gen.Wait("k2"))
+            self.assertEqual(results, dict(foo="v1", bar="v2"))
+            self.stop()
+        self.run_gen(f)
+
     def test_multi_delayed(self):
         @gen.engine
         def f():
@@ -293,6 +303,18 @@ class GenEngineTest(AsyncTestCase):
             self.stop()
         self.run_gen(f)
 
+    def test_multi_dict_delayed(self):
+        @gen.engine
+        def f():
+            # callbacks run at different times
+            responses = yield dict(
+                foo=gen.Task(self.delay_callback, 3, arg="v1"),
+                bar=gen.Task(self.delay_callback, 1, arg="v2"),
+            )
+            self.assertEqual(responses, dict(foo="v1", bar="v2"))
+            self.stop()
+        self.run_gen(f)
+
     @skipOnTravis
     @gen_test
     def test_multi_performance(self):
@@ -314,6 +336,11 @@ class GenEngineTest(AsyncTestCase):
         results = yield [self.async_future(1), self.async_future(2)]
         self.assertEqual(results, [1, 2])
 
+    @gen_test
+    def test_multi_dict_future(self):
+        results = yield dict(foo=self.async_future(1), bar=self.async_future(2))
+        self.assertEqual(results, dict(foo=1, bar=2))
+
     def test_arguments(self):
         @gen.engine
         def f():