]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
PeriodicCallback: support async/coroutine callback (#2924)
authorPierce Lopez <pierce.lopez@gmail.com>
Sun, 30 May 2021 15:33:14 +0000 (11:33 -0400)
committerGitHub <noreply@github.com>
Sun, 30 May 2021 15:33:14 +0000 (11:33 -0400)
ISSUE: https://github.com/tornadoweb/tornado/issues/2828

* ioloop: call_later() and call_at() take any Callable coroutine or plain, returning any type

Co-authored-by: agnewee <agnewee@gmail.com>
tornado/ioloop.py
tornado/platform/asyncio.py
tornado/test/ioloop_test.py

index 2b0ac019fa0c930bb9296206a23286197bcb3087..28d86f5049d625a5a5178537c34d8c5b236f25b1 100644 (file)
@@ -41,6 +41,7 @@ import sys
 import time
 import math
 import random
+from inspect import isawaitable
 
 from tornado.concurrent import (
     Future,
@@ -546,7 +547,7 @@ class IOLoop(Configurable):
     def add_timeout(
         self,
         deadline: Union[float, datetime.timedelta],
-        callback: Callable[..., None],
+        callback: Callable[..., Optional[Awaitable]],
         *args: Any,
         **kwargs: Any
     ) -> object:
@@ -585,7 +586,7 @@ class IOLoop(Configurable):
             raise TypeError("Unsupported deadline %r" % deadline)
 
     def call_later(
-        self, delay: float, callback: Callable[..., None], *args: Any, **kwargs: Any
+        self, delay: float, callback: Callable, *args: Any, **kwargs: Any
     ) -> object:
         """Runs the ``callback`` after ``delay`` seconds have passed.
 
@@ -600,7 +601,7 @@ class IOLoop(Configurable):
         return self.call_at(self.time() + delay, callback, *args, **kwargs)
 
     def call_at(
-        self, when: float, callback: Callable[..., None], *args: Any, **kwargs: Any
+        self, when: float, callback: Callable, *args: Any, **kwargs: Any
     ) -> object:
         """Runs the ``callback`` at the absolute time designated by ``when``.
 
@@ -863,11 +864,17 @@ class PeriodicCallback(object):
 
     .. versionchanged:: 5.1
        The ``jitter`` argument is added.
+
+    .. versionchanged:: 6.2
+       If the ``callback`` argument is a coroutine, and a callback runs for
+       longer than ``callback_time``, subsequent invocations will be skipped.
+       Previously this was only true for regular functions, not coroutines,
+       which were "fire-and-forget" for `PeriodicCallback`.
     """
 
     def __init__(
         self,
-        callback: Callable[[], None],
+        callback: Callable[[], Optional[Awaitable]],
         callback_time: Union[datetime.timedelta, float],
         jitter: float = 0,
     ) -> None:
@@ -906,11 +913,13 @@ class PeriodicCallback(object):
         """
         return self._running
 
-    def _run(self) -> None:
+    async def _run(self) -> None:
         if not self._running:
             return
         try:
-            return self.callback()
+            val = self.callback()
+            if val is not None and isawaitable(val):
+                await val
         except Exception:
             app_log.error("Exception in callback %r", self.callback, exc_info=True)
         finally:
index 292d9b66a46b1ae9b99c94779983ce1bf61e162a..5e9c776d02ef00b8e9de8d901438e38a00a89716 100644 (file)
@@ -204,7 +204,11 @@ class BaseAsyncIOLoop(IOLoop):
         self.asyncio_loop.stop()
 
     def call_at(
-        self, when: float, callback: Callable[..., None], *args: Any, **kwargs: Any
+        self,
+        when: float,
+        callback: Callable[..., Optional[Awaitable]],
+        *args: Any,
+        **kwargs: Any
     ) -> object:
         # asyncio.call_at supports *args but not **kwargs, so bind them here.
         # We do not synchronize self.time and asyncio_loop.time, so
index daf74f9be5cd43547e67d1a73eb4e9721a802d0c..6fd41540f3cfa899eb4ab10d0dcc1a151515c926 100644 (file)
@@ -161,7 +161,7 @@ class TestIOLoop(AsyncTestCase):
 
             self.io_loop.add_handler(client.fileno(), handler, IOLoop.READ)
             self.io_loop.add_timeout(
-                self.io_loop.time() + 0.01, functools.partial(server.send, b"asdf")  # type: ignore
+                self.io_loop.time() + 0.01, functools.partial(server.send, b"asdf")
             )
             self.wait()
             self.io_loop.remove_handler(client.fileno())
@@ -694,6 +694,60 @@ class TestPeriodicCallbackMath(unittest.TestCase):
         self.assertEqual(pc.callback_time, expected_callback_time)
 
 
+class TestPeriodicCallbackAsync(AsyncTestCase):
+    def test_periodic_plain(self):
+        count = 0
+
+        def callback() -> None:
+            nonlocal count
+            count += 1
+            if count == 3:
+                self.stop()
+
+        pc = PeriodicCallback(callback, 10)
+        pc.start()
+        self.wait()
+        pc.stop()
+        self.assertEqual(count, 3)
+
+    def test_periodic_coro(self):
+        counts = [0, 0]
+        pc = None
+
+        @gen.coroutine
+        def callback() -> None:
+            counts[0] += 1
+            yield gen.sleep(0.025)
+            counts[1] += 1
+            if counts[1] == 3:
+                pc.stop()
+                self.io_loop.add_callback(self.stop)
+
+        pc = PeriodicCallback(callback, 10)
+        pc.start()
+        self.wait()
+        self.assertEqual(counts[0], 3)
+        self.assertEqual(counts[1], 3)
+
+    def test_periodic_async(self):
+        counts = [0, 0]
+        pc = None
+
+        async def callback() -> None:
+            counts[0] += 1
+            await gen.sleep(0.025)
+            counts[1] += 1
+            if counts[1] == 3:
+                pc.stop()
+                self.io_loop.add_callback(self.stop)
+
+        pc = PeriodicCallback(callback, 10)
+        pc.start()
+        self.wait()
+        self.assertEqual(counts[0], 3)
+        self.assertEqual(counts[1], 3)
+
+
 class TestIOLoopConfiguration(unittest.TestCase):
     def run_python(self, *statements):
         stmt_list = [