From 5abd9d457fc3b478d8c6ae4d7a52f2b632028948 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Sat, 27 Feb 2021 15:28:18 +0100 Subject: [PATCH] Add an async scheduler --- psycopg3/psycopg3/pool/sched.py | 75 ++++++++++++++++ tests/pool/test_sched.py | 6 +- tests/pool/test_sched_async.py | 154 ++++++++++++++++++++++++++++++++ 3 files changed, 231 insertions(+), 4 deletions(-) create mode 100644 tests/pool/test_sched_async.py diff --git a/psycopg3/psycopg3/pool/sched.py b/psycopg3/psycopg3/pool/sched.py index 7c54b12f7..95af07532 100644 --- a/psycopg3/psycopg3/pool/sched.py +++ b/psycopg3/psycopg3/pool/sched.py @@ -12,6 +12,7 @@ Tasks are called "Task", not "Event", here, because we actually make use of # Copyright (C) 2021 The Psycopg Team +import asyncio import logging import threading from time import monotonic @@ -107,3 +108,77 @@ class Scheduler: else: # Block for the expected timeout or until a new task scheduled self._event.wait(timeout=delay) + + +class AsyncScheduler: + def __init__(self) -> None: + """Initialize a new instance, passing the time and delay functions.""" + self._queue: List[Task] = [] + self._lock = asyncio.Lock() + self._event = asyncio.Event() + + EMPTY_QUEUE_TIMEOUT = 600.0 + + async def enter( + self, delay: float, action: Optional[Callable[[], Any]] + ) -> Task: + """Enter a new task in the queue delayed in the future. + + Schedule a `!None` to stop the execution. + """ + time = monotonic() + delay + return await self.enterabs(time, action) + + async def enterabs( + self, time: float, action: Optional[Callable[[], Any]] + ) -> Task: + """Enter a new task in the queue at an absolute time. + + Schedule a `!None` to stop the execution. + """ + task = Task(time, action) + async with self._lock: + heappush(self._queue, task) + first = self._queue[0] is task + + if first: + self._event.set() + + return task + + async def run(self) -> None: + """Execute the events scheduled.""" + q = self._queue + while True: + async with self._lock: + now = monotonic() + task = q[0] if q else None + if task: + if task.time <= now: + heappop(q) + else: + delay = task.time - now + task = None + else: + delay = self.EMPTY_QUEUE_TIMEOUT + self._event.clear() + + if task: + # logger.info("task %s action %s", task, task.action) + if not task.action: + break + try: + await task.action() + except Exception as e: + logger.warning( + "scheduled task run %s failed: %s: %s", + task.action, + e.__class__.__name__, + e, + ) + else: + # Block for the expected timeout or until a new task scheduled + try: + await asyncio.wait_for(self._event.wait(), delay) + except asyncio.TimeoutError: + pass diff --git a/tests/pool/test_sched.py b/tests/pool/test_sched.py index 3b227abba..3ca6cb1f0 100644 --- a/tests/pool/test_sched.py +++ b/tests/pool/test_sched.py @@ -32,8 +32,7 @@ def test_sched(): @pytest.mark.slow def test_sched_thread(): s = Scheduler() - t = Thread(target=s.run) - t.daemon = True + t = Thread(target=s.run, daemon=True) t.start() results = [] @@ -62,8 +61,7 @@ def test_sched_thread(): def test_sched_error(caplog): caplog.set_level(logging.WARNING, logger="psycopg3") s = Scheduler() - t = Thread(target=s.run) - t.daemon = True + t = Thread(target=s.run, daemon=True) t.start() results = [] diff --git a/tests/pool/test_sched_async.py b/tests/pool/test_sched_async.py new file mode 100644 index 000000000..de3f6faf6 --- /dev/null +++ b/tests/pool/test_sched_async.py @@ -0,0 +1,154 @@ +import asyncio +import logging +from time import time +from functools import partial + +import pytest + +from psycopg3.pool.sched import AsyncScheduler +from psycopg3.utils.compat import create_task + +pytestmark = pytest.mark.asyncio + + +@pytest.mark.slow +async def test_sched(): + s = AsyncScheduler() + results = [] + + async def worker(i): + results.append((i, time())) + + t0 = time() + await s.enter(0.1, partial(worker, 1)) + await s.enter(0.4, partial(worker, 3)) + await s.enter(0.3, None) + await s.enter(0.2, partial(worker, 2)) + await s.run() + assert len(results) == 2 + assert results[0][0] == 1 + assert results[0][1] - t0 == pytest.approx(0.1, 0.1) + assert results[1][0] == 2 + assert results[1][1] - t0 == pytest.approx(0.2, 0.1) + + +@pytest.mark.slow +async def test_sched_task(): + s = AsyncScheduler() + t = create_task(s.run()) + + results = [] + + async def worker(i): + results.append((i, time())) + + t0 = time() + await s.enter(0.1, partial(worker, 1)) + await s.enter(0.4, partial(worker, 3)) + await s.enter(0.3, None) + await s.enter(0.2, partial(worker, 2)) + + await asyncio.gather(t) + t1 = time() + assert t1 - t0 == pytest.approx(0.3, 0.1) + + assert len(results) == 2 + assert results[0][0] == 1 + assert results[0][1] - t0 == pytest.approx(0.1, 0.1) + assert results[1][0] == 2 + assert results[1][1] - t0 == pytest.approx(0.2, 0.1) + + +@pytest.mark.slow +async def test_sched_error(caplog): + caplog.set_level(logging.WARNING, logger="psycopg3") + s = AsyncScheduler() + t = create_task(s.run()) + + results = [] + + async def worker(i): + results.append((i, time())) + + async def error(): + 1 / 0 + + t0 = time() + await s.enter(0.1, partial(worker, 1)) + await s.enter(0.4, None) + await s.enter(0.3, partial(worker, 2)) + await s.enter(0.2, error) + + await asyncio.gather(t) + t1 = time() + assert t1 - t0 == pytest.approx(0.4, 0.1) + + assert len(results) == 2 + assert results[0][0] == 1 + assert results[0][1] - t0 == pytest.approx(0.1, 0.1) + assert results[1][0] == 2 + assert results[1][1] - t0 == pytest.approx(0.3, 0.1) + + assert len(caplog.records) == 1 + assert "ZeroDivisionError" in caplog.records[0].message + + +@pytest.mark.slow +async def test_empty_queue_timeout(): + s = AsyncScheduler() + + t0 = time() + times = [] + + wait_orig = s._event.wait + + async def wait_logging(): + try: + rv = await wait_orig() + finally: + times.append(time() - t0) + return rv + + s._event.wait = wait_logging + s.EMPTY_QUEUE_TIMEOUT = 0.2 + + t = create_task(s.run()) + await asyncio.sleep(0.5) + await s.enter(0.5, None) + await asyncio.gather(t) + times.append(time() - t0) + for got, want in zip(times, [0.2, 0.4, 0.5, 1.0]): + assert got == pytest.approx(want, 0.1), times + + +@pytest.mark.slow +async def test_first_task_rescheduling(): + s = AsyncScheduler() + + t0 = time() + times = [] + + wait_orig = s._event.wait + + async def wait_logging(): + try: + rv = await wait_orig() + finally: + times.append(time() - t0) + return rv + + s._event.wait = wait_logging + s.EMPTY_QUEUE_TIMEOUT = 0.1 + + async def noop(): + pass + + await s.enter(0.4, noop) + t = create_task(s.run()) + await s.enter(0.6, None) # this task doesn't trigger a reschedule + await asyncio.sleep(0.1) + await s.enter(0.1, noop) # this triggers a reschedule + await asyncio.gather(t) + times.append(time() - t0) + for got, want in zip(times, [0.1, 0.2, 0.4, 0.6, 0.6]): + assert got == pytest.approx(want, 0.1), times -- 2.47.2