]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add magic for yielding futures.
authorBen Darnell <ben@bendarnell.com>
Mon, 3 Sep 2012 19:10:57 +0000 (12:10 -0700)
committerBen Darnell <ben@bendarnell.com>
Mon, 3 Sep 2012 19:10:57 +0000 (12:10 -0700)
Introduce IOLoop.current() as a thread-local counterpart to IOLoop.instance().

gen.engine now recognizes Futures directly.

tornado/gen.py
tornado/ioloop.py
tornado/test/concurrent_test.py
tornado/testing.py

index bc0f935dcf592500caf73d860ebec523e31701b8..51d64a3aa3d810efb692bb5f38d9ecfb85efe796 100644 (file)
@@ -69,6 +69,7 @@ import operator
 import sys
 import types
 
+from tornado.concurrent import Future
 from tornado.ioloop import IOLoop
 from tornado.stack_context import ExceptionStackContext
 
@@ -251,7 +252,7 @@ class Task(YieldPoint):
 class YieldFuture(YieldPoint):
     def __init__(self, future, io_loop=None):
         self.future = future
-        self.io_loop = io_loop or IOLoop.instance()
+        self.io_loop = io_loop or IOLoop.current()
 
     def start(self, runner):
         self.runner = runner
@@ -379,6 +380,9 @@ class Runner(object):
                     raise
                 if isinstance(yielded, list):
                     yielded = Multi(yielded)
+                if isinstance(yielded, Future):
+                    # TODO: lists of futures
+                    yielded = YieldFuture(yielded)
                 if isinstance(yielded, YieldPoint):
                     self.yield_point = yielded
                     try:
index 745bdf8987d7f17267bab1dc8412584b5aa9c0fc..7244720924590696036d94618ac3d34e1c591a9c 100644 (file)
@@ -114,6 +114,8 @@ class IOLoop(object):
     # Global lock for creating global IOLoop instance
     _instance_lock = threading.Lock()
 
+    _current = threading.local()
+
     def __init__(self, impl=None):
         self._impl = impl or _poll()
         if hasattr(self._impl, 'fileno'):
@@ -173,6 +175,20 @@ class IOLoop(object):
         assert not IOLoop.initialized()
         IOLoop._instance = self
 
+    @staticmethod
+    def current():
+        current = getattr(IOLoop._current, "instance", None)
+        if current is None:
+            raise ValueError("no current IOLoop")
+        return current
+
+    def make_current(self):
+        IOLoop._current.instance = self
+
+    def clear_current(self):
+        assert IOLoop._current.instance is self
+        IOLoop._current.instance = None
+
     def close(self, all_fds=False):
         """Closes the IOLoop, freeing any resources used.
 
@@ -264,6 +280,8 @@ class IOLoop(object):
         if self._stopped:
             self._stopped = False
             return
+        old_current = getattr(IOLoop._current, "instance", None)
+        IOLoop._current.instance = self
         self._thread_ident = thread.get_ident()
         self._running = True
         while True:
@@ -346,6 +364,7 @@ class IOLoop(object):
         self._stopped = False
         if self._blocking_signal_threshold is not None:
             signal.setitimer(signal.ITIMER_REAL, 0, 0)
+        IOLoop._current.instance = old_current
 
     def stop(self):
         """Stop the loop after the current event loop iteration is complete.
index 2cf3718949d20651adc517e7bb53057c64715583..5267358ad1ffd45a0bd19222c7670a3e08c3bc0a 100644 (file)
@@ -161,8 +161,7 @@ class ClientTestMixin(object):
     def test_generator(self):
         @gen.engine
         def f():
-            result = yield gen.YieldFuture(self.client.capitalize("hello"),
-                                           io_loop=self.io_loop)
+            result = yield self.client.capitalize("hello")
             self.assertEqual(result, "HELLO")
             self.stop()
         f()
@@ -172,8 +171,7 @@ class ClientTestMixin(object):
         @gen.engine
         def f():
             with self.assertRaisesRegexp(CapError, "already capitalized"):
-                 yield gen.YieldFuture(self.client.capitalize("HELLO"),
-                                       io_loop=self.io_loop)
+                 yield self.client.capitalize("HELLO")
             self.stop()
         f()
         self.wait()
index c597e3a77f1a9cb6acc832c42c11169fef2603e7..4e2bc2efbf95b29e58dde7df22a1be86db27b632 100644 (file)
@@ -124,8 +124,10 @@ class AsyncTestCase(unittest.TestCase):
     def setUp(self):
         super(AsyncTestCase, self).setUp()
         self.io_loop = self.get_new_ioloop()
+        self.io_loop.make_current()
 
     def tearDown(self):
+        self.io_loop.clear_current()
         if (not IOLoop.initialized() or
             self.io_loop is not IOLoop.instance()):
             # Try to clean up any file descriptors left open in the ioloop.