]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Implement IOLoop.run_in_executor (#2067)
authorMike DePalatis <mike@depalatis.net>
Sat, 21 Oct 2017 18:04:57 +0000 (14:04 -0400)
committerBen Darnell <ben@bendarnell.com>
Sat, 21 Oct 2017 18:04:57 +0000 (14:04 -0400)
tornado/ioloop.py
tornado/test/ioloop_test.py

index 73d0cbdb787075d806585e4c95df64251ae276ee..5686576cf7dbe310c3e4d00c67f04ef3d5a0ec5d 100644 (file)
@@ -55,6 +55,10 @@ try:
 except ImportError:
     signal = None
 
+try:
+    from concurrent.futures import ThreadPoolExecutor
+except ImportError:
+    ThreadPoolExecutor = None
 
 if PY3:
     import _thread as thread
@@ -635,6 +639,29 @@ class IOLoop(Configurable):
         future.add_done_callback(
             lambda future: self.add_callback(callback, future))
 
+    def run_in_executor(self, executor, func, *args):
+        """Runs a function in a ``concurrent.futures.Executor``. If
+        ``executor`` is ``None``, the IO loop's default executor will be used.
+
+        Use `functools.partial` to pass keyword arguments to `func`.
+
+        """
+        if ThreadPoolExecutor is None:
+            raise RuntimeError(
+                "concurrent.futures is required to use IOLoop.run_in_executor")
+
+        if executor is None:
+            if not hasattr(self, '_executor'):
+                from tornado.process import cpu_count
+                self._executor = ThreadPoolExecutor(max_workers=(cpu_count() * 5))
+            executor = self._executor
+
+        return executor.submit(func, *args)
+
+    def set_default_executor(self, executor):
+        """Sets the default executor to use with :meth:`run_in_executor`."""
+        self._executor = executor
+
     def _run_callback(self, callback):
         """Runs a callback with error handling.
 
@@ -777,6 +804,8 @@ class PollIOLoop(IOLoop):
         self._impl.close()
         self._callbacks = None
         self._timeouts = None
+        if hasattr(self, '_executor'):
+            self._executor.shutdown()
 
     def add_handler(self, fd, handler, events):
         fd, obj = self.split_fd(fd)
index 5b9bd9cc744d5b335e1ee14b3ea136cb201e2d78..f3cd32ae42a5b4d1bc52c5cb760addb844bea0a8 100644 (file)
@@ -18,8 +18,9 @@ from tornado.ioloop import IOLoop, TimeoutError, PollIOLoop, PeriodicCallback
 from tornado.log import app_log
 from tornado.platform.select import _Select
 from tornado.stack_context import ExceptionStackContext, StackContext, wrap, NullContext
-from tornado.testing import AsyncTestCase, bind_unused_port, ExpectLog
+from tornado.testing import AsyncTestCase, bind_unused_port, ExpectLog, gen_test
 from tornado.test.util import unittest, skipIfNonUnix, skipOnTravis, skipBefore35, exec_test
+from tornado.concurrent import Future
 
 try:
     from concurrent import futures
@@ -598,6 +599,62 @@ class TestIOLoopFutures(AsyncTestCase):
         self.assertEqual(self.exception.args[0], "callback")
         self.assertEqual(self.future.exception().args[0], "worker")
 
+    @gen_test
+    def test_run_in_executor_gen(self):
+        event1 = threading.Event()
+        event2 = threading.Event()
+
+        def callback(self_event, other_event):
+            self_event.set()
+            time.sleep(0.01)
+            self.assertTrue(other_event.is_set())
+            return self_event
+
+        res = yield [
+            IOLoop.current().run_in_executor(None, callback, event1, event2),
+            IOLoop.current().run_in_executor(None, callback, event2, event1)
+        ]
+
+        self.assertEqual([event1, event2], res)
+
+    @skipBefore35
+    def test_run_in_executor_native(self):
+        event1 = threading.Event()
+        event2 = threading.Event()
+
+        def callback(self_event, other_event):
+            self_event.set()
+            time.sleep(0.01)
+            self.assertTrue(other_event.is_set())
+            other_event.wait()
+            return self_event
+
+        namespace = exec_test(globals(), locals(), """
+            async def main():
+                res = await gen.multi([
+                    IOLoop.current().run_in_executor(None, callback, event1, event2),
+                    IOLoop.current().run_in_executor(None, callback, event2, event1)
+                ])
+                self.assertEqual([event1, event2], res)
+        """)
+        IOLoop.current().run_sync(namespace['main'])
+
+    def test_set_default_executor(self):
+        class MyExecutor(futures.Executor):
+            def submit(self, func, *args):
+                return Future()
+
+        event = threading.Event()
+
+        def future_func():
+            event.set()
+
+        executor = MyExecutor()
+        loop = IOLoop.current()
+        loop.set_default_executor(executor)
+        loop.run_in_executor(None, future_func)
+        loop.add_timeout(0.01, lambda: self.assertFalse(event.is_set()))
+
 
 class TestIOLoopRunSync(unittest.TestCase):
     def setUp(self):