From: Pierce Lopez Date: Sun, 30 May 2021 15:33:14 +0000 (-0400) Subject: PeriodicCallback: support async/coroutine callback (#2924) X-Git-Tag: v6.2.0b1~43 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a9fbbeeec9abfe9ff54365160ce476ba853ba203;p=thirdparty%2Ftornado.git PeriodicCallback: support async/coroutine callback (#2924) ISSUE: https://github.com/tornadoweb/tornado/issues/2828 * ioloop: call_later() and call_at() take any Callable coroutine or plain, returning any type Co-authored-by: agnewee --- diff --git a/tornado/ioloop.py b/tornado/ioloop.py index 2b0ac019f..28d86f504 100644 --- a/tornado/ioloop.py +++ b/tornado/ioloop.py @@ -41,6 +41,7 @@ import sys import time import math import random +from inspect import isawaitable from tornado.concurrent import ( Future, @@ -546,7 +547,7 @@ class IOLoop(Configurable): def add_timeout( self, deadline: Union[float, datetime.timedelta], - callback: Callable[..., None], + callback: Callable[..., Optional[Awaitable]], *args: Any, **kwargs: Any ) -> object: @@ -585,7 +586,7 @@ class IOLoop(Configurable): raise TypeError("Unsupported deadline %r" % deadline) def call_later( - self, delay: float, callback: Callable[..., None], *args: Any, **kwargs: Any + self, delay: float, callback: Callable, *args: Any, **kwargs: Any ) -> object: """Runs the ``callback`` after ``delay`` seconds have passed. @@ -600,7 +601,7 @@ class IOLoop(Configurable): return self.call_at(self.time() + delay, callback, *args, **kwargs) def call_at( - self, when: float, callback: Callable[..., None], *args: Any, **kwargs: Any + self, when: float, callback: Callable, *args: Any, **kwargs: Any ) -> object: """Runs the ``callback`` at the absolute time designated by ``when``. @@ -863,11 +864,17 @@ class PeriodicCallback(object): .. versionchanged:: 5.1 The ``jitter`` argument is added. + + .. versionchanged:: 6.2 + If the ``callback`` argument is a coroutine, and a callback runs for + longer than ``callback_time``, subsequent invocations will be skipped. + Previously this was only true for regular functions, not coroutines, + which were "fire-and-forget" for `PeriodicCallback`. """ def __init__( self, - callback: Callable[[], None], + callback: Callable[[], Optional[Awaitable]], callback_time: Union[datetime.timedelta, float], jitter: float = 0, ) -> None: @@ -906,11 +913,13 @@ class PeriodicCallback(object): """ return self._running - def _run(self) -> None: + async def _run(self) -> None: if not self._running: return try: - return self.callback() + val = self.callback() + if val is not None and isawaitable(val): + await val except Exception: app_log.error("Exception in callback %r", self.callback, exc_info=True) finally: diff --git a/tornado/platform/asyncio.py b/tornado/platform/asyncio.py index 292d9b66a..5e9c776d0 100644 --- a/tornado/platform/asyncio.py +++ b/tornado/platform/asyncio.py @@ -204,7 +204,11 @@ class BaseAsyncIOLoop(IOLoop): self.asyncio_loop.stop() def call_at( - self, when: float, callback: Callable[..., None], *args: Any, **kwargs: Any + self, + when: float, + callback: Callable[..., Optional[Awaitable]], + *args: Any, + **kwargs: Any ) -> object: # asyncio.call_at supports *args but not **kwargs, so bind them here. # We do not synchronize self.time and asyncio_loop.time, so diff --git a/tornado/test/ioloop_test.py b/tornado/test/ioloop_test.py index daf74f9be..6fd41540f 100644 --- a/tornado/test/ioloop_test.py +++ b/tornado/test/ioloop_test.py @@ -161,7 +161,7 @@ class TestIOLoop(AsyncTestCase): self.io_loop.add_handler(client.fileno(), handler, IOLoop.READ) self.io_loop.add_timeout( - self.io_loop.time() + 0.01, functools.partial(server.send, b"asdf") # type: ignore + self.io_loop.time() + 0.01, functools.partial(server.send, b"asdf") ) self.wait() self.io_loop.remove_handler(client.fileno()) @@ -694,6 +694,60 @@ class TestPeriodicCallbackMath(unittest.TestCase): self.assertEqual(pc.callback_time, expected_callback_time) +class TestPeriodicCallbackAsync(AsyncTestCase): + def test_periodic_plain(self): + count = 0 + + def callback() -> None: + nonlocal count + count += 1 + if count == 3: + self.stop() + + pc = PeriodicCallback(callback, 10) + pc.start() + self.wait() + pc.stop() + self.assertEqual(count, 3) + + def test_periodic_coro(self): + counts = [0, 0] + pc = None + + @gen.coroutine + def callback() -> None: + counts[0] += 1 + yield gen.sleep(0.025) + counts[1] += 1 + if counts[1] == 3: + pc.stop() + self.io_loop.add_callback(self.stop) + + pc = PeriodicCallback(callback, 10) + pc.start() + self.wait() + self.assertEqual(counts[0], 3) + self.assertEqual(counts[1], 3) + + def test_periodic_async(self): + counts = [0, 0] + pc = None + + async def callback() -> None: + counts[0] += 1 + await gen.sleep(0.025) + counts[1] += 1 + if counts[1] == 3: + pc.stop() + self.io_loop.add_callback(self.stop) + + pc = PeriodicCallback(callback, 10) + pc.start() + self.wait() + self.assertEqual(counts[0], 3) + self.assertEqual(counts[1], 3) + + class TestIOLoopConfiguration(unittest.TestCase): def run_python(self, *statements): stmt_list = [