]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add IOLoop.run_sync convenience method.
authorBen Darnell <ben@bendarnell.com>
Sat, 2 Mar 2013 23:58:02 +0000 (18:58 -0500)
committerBen Darnell <ben@bendarnell.com>
Sat, 2 Mar 2013 23:58:02 +0000 (18:58 -0500)
maint/scripts/test_resolvers.py
tornado/httpclient.py
tornado/ioloop.py
tornado/test/ioloop_test.py
tornado/testing.py

index a57c382e2c994b8685db833b642aab1064a7c4db..7a1ad358f2765ea22d57228a67174777b961ddbb 100644 (file)
@@ -22,7 +22,7 @@ except ImportError:
 define('family', default='unspec',
        help='Address family to query: unspec, inet, or inet6')
 
-@gen.engine
+@gen.coroutine
 def main():
     args = parse_command_line()
 
@@ -53,8 +53,6 @@ def main():
             print('%s: %s' % (resolver.__class__.__name__,
                               pprint.pformat(addrinfo)))
         print()
-    IOLoop.instance().stop()
 
 if __name__ == '__main__':
-    IOLoop.instance().add_callback(main)
-    IOLoop.instance().start()
+    IOLoop.instance().run_sync(main)
index 9ceeb022522c42cab0bb4c60fb350288d410be3e..c325a63bd0b5ee1f9a62284ca85db10cda3d9f68 100644 (file)
@@ -31,6 +31,7 @@ supported version is 7.18.2, and the recommended version is 7.21.1 or newer.
 
 from __future__ import absolute_import, division, print_function, with_statement
 
+import functools
 import time
 import weakref
 
@@ -60,7 +61,6 @@ class HTTPClient(object):
         if async_client_class is None:
             async_client_class = AsyncHTTPClient
         self._async_client = async_client_class(self._io_loop, **kwargs)
-        self._response = None
         self._closed = False
 
     def __del__(self):
@@ -82,14 +82,8 @@ class HTTPClient(object):
 
         If an error occurs during the fetch, we raise an `HTTPError`.
         """
-        def callback(response):
-            self._response = response
-            self._io_loop.stop()
-        self._io_loop.add_callback(self._async_client.fetch, request,
-                                   callback, **kwargs)
-        self._io_loop.start()
-        response = self._response
-        self._response = None
+        response = self._io_loop.run_sync(functools.partial(
+                self._async_client.fetch, request, **kwargs))
         response.rethrow()
         return response
 
index 08e6d60f7694a1a2d1e72fddb84173970f4c7b8d..4fc43baddf723fa80fabeb5089c94cca3c8e3718 100644 (file)
@@ -58,6 +58,10 @@ except ImportError:
 from tornado.platform.auto import set_close_exec, Waker
 
 
+class TimeoutError(Exception):
+    pass
+
+
 class IOLoop(Configurable):
     """A level-triggered I/O loop.
 
@@ -166,8 +170,8 @@ class IOLoop(Configurable):
     def make_current(self):
         IOLoop._current.instance = self
 
-    def clear_current(self):
-        assert IOLoop._current.instance is self
+    @staticmethod
+    def clear_current():
         IOLoop._current.instance = None
 
     @classmethod
@@ -281,6 +285,52 @@ class IOLoop(Configurable):
         """
         raise NotImplementedError()
 
+    def run_sync(self, func, timeout=None):
+        """Starts the `IOLoop`, runs the given function, and stops the loop.
+
+        If the function returns a `Future`, the `IOLoop` will run until
+        the future is resolved.  If it raises an exception, the `IOLoop`
+        will stop and the exception will be re-raised to the caller.
+
+        The keyword-only argument ``timeout`` may be used to set
+        a maximum duration for the function.  If the timeout expires,
+        a `TimeoutError` is raised.
+
+        This method is useful in conjunction with `tornado.gen.coroutine`
+        to allow asynchronous calls in a `main()` function::
+
+            @gen.coroutine
+            def main():
+                # do stuff...
+
+            if __name__ == '__main__':
+                IOLoop.instance().run_sync(main)
+        """
+        future_cell = [None]
+        def run():
+            try:
+                result = func()
+            except Exception as e:
+                future_cell[0] = Future()
+                future_cell[0].set_exception(e)
+            else:
+                if isinstance(result, Future):
+                    future_cell[0] = result
+                else:
+                    future_cell[0] = Future()
+                    future_cell[0].set_result(result)
+            self.add_future(future_cell[0], lambda future: self.stop())
+        self.add_callback(run)
+        if timeout is not None:
+            timeout_handle = self.add_timeout(self.time() + timeout, self.stop)
+        self.start()
+        if timeout is not None:
+            self.remove_timeout(timeout_handle)
+        if not future_cell[0].done():
+            raise TimeoutError('Operation timed out after %s seconds' % timeout)
+        return future_cell[0].result()
+
+
     def time(self):
         """Returns the current time according to the IOLoop's clock.
 
index 0ac4fcae75976e4ac035667ce6db853aee225f9f..76b7a24da1715879379b0bc46fdf25c395ffd913 100644 (file)
@@ -10,7 +10,8 @@ import sys
 import threading
 import time
 
-from tornado.ioloop import IOLoop
+from tornado import gen
+from tornado.ioloop import IOLoop, TimeoutError
 from tornado.stack_context import ExceptionStackContext, StackContext, wrap, NullContext
 from tornado.testing import AsyncTestCase, bind_unused_port
 from tornado.test.util import unittest, skipIfNonUnix
@@ -272,5 +273,45 @@ class TestIOLoopFutures(AsyncTestCase):
         self.assertEqual(self.future.exception().args[0], "worker")
 
 
+class TestIOLoopRunSync(unittest.TestCase):
+    def setUp(self):
+        self.io_loop = IOLoop()
+
+    def tearDown(self):
+        self.io_loop.close()
+
+    def test_sync_result(self):
+        self.assertEqual(self.io_loop.run_sync(lambda: 42), 42)
+
+    def test_sync_exception(self):
+        with self.assertRaises(ZeroDivisionError):
+            self.io_loop.run_sync(lambda: 1 / 0)
+
+    def test_async_result(self):
+        @gen.coroutine
+        def f():
+            yield gen.Task(self.io_loop.add_callback)
+            raise gen.Return(42)
+        self.assertEqual(self.io_loop.run_sync(f), 42)
+
+    def test_async_exception(self):
+        @gen.coroutine
+        def f():
+            yield gen.Task(self.io_loop.add_callback)
+            1 / 0
+        with self.assertRaises(ZeroDivisionError):
+            self.io_loop.run_sync(f)
+
+    def test_current(self):
+        def f():
+            self.assertIs(IOLoop.current(), self.io_loop)
+        self.io_loop.run_sync(f)
+
+    def test_timeout(self):
+        @gen.coroutine
+        def f():
+            yield gen.Task(self.io_loop.add_timeout, self.io_loop.time() + 1)
+        self.assertRaises(TimeoutError, self.io_loop.run_sync, f, timeout=0.01)
+
 if __name__ == "__main__":
     unittest.main()
index 51716657f03c129676d8f9142b4f8987f116c7f8..f0fd6504a4ac76cd6a87f90b03d9cdbbf593c1b7 100644 (file)
@@ -365,16 +365,12 @@ class AsyncHTTPSTestCase(AsyncHTTPTestCase):
 
 
 def gen_test(f):
-    """Testing equivalent of ``@gen.engine``, to be applied to test methods.
+    """Testing equivalent of ``@gen.coroutine``, to be applied to test methods.
 
-    ``@gen.engine`` cannot be used on tests because the `IOLoop` is not
+    ``@gen.coroutine`` cannot be used on tests because the `IOLoop` is not
     already running.  ``@gen_test`` should be applied to test methods
     on subclasses of `AsyncTestCase`.
 
-    Note that unlike most uses of ``@gen.engine``, ``@gen_test`` can
-    detect automatically when the function finishes cleanly so there
-    is no need to run a callback to signal completion.
-
     Example::
         class MyTest(AsyncHTTPTestCase):
             @gen_test
@@ -382,15 +378,10 @@ def gen_test(f):
                 response = yield gen.Task(self.fetch('/'))
 
     """
+    f = gen.coroutine(f)
     @functools.wraps(f)
-    def wrapper(self, *args, **kwargs):
-        result = f(self, *args, **kwargs)
-        if result is None:
-            return
-        assert isinstance(result, types.GeneratorType)
-        runner = gen.Runner(result, lambda value: self.stop())
-        runner.run()
-        self.wait()
+    def wrapper(self):
+        return self.io_loop.run_sync(functools.partial(f, self), timeout=5)
     return wrapper