From: Daniele Varrazzo Date: Sat, 20 Feb 2021 23:07:10 +0000 (+0100) Subject: Add scheduler class X-Git-Tag: 3.0.dev0~87^2~62 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=5dcc8bfd84e04e6ccd55a0ee8ed3a75a55ce7982;p=thirdparty%2Fpsycopg.git Add scheduler class To be used in the connection pool to implement reconnection with backoff, maybe periodic tasks. --- diff --git a/psycopg3/psycopg3/_sched.py b/psycopg3/psycopg3/_sched.py new file mode 100644 index 000000000..040a70c1f --- /dev/null +++ b/psycopg3/psycopg3/_sched.py @@ -0,0 +1,109 @@ +""" +A minimal scheduler to schedule tasks run in the future. + +Inspired to the standard library `sched.scheduler`, but designed for +multi-thread usage ground up, not as an afterthought. Tasks can be scheduled in +front of the one currently running and `Scheduler.run()` can be left running +without any task scheduled. + +Tasks are called "Task", not "Event", here, because we actually make use of +`threading.Event` and the two would be confusing. +""" + +# Copyright (C) 2021 The Psycopg Team + +import logging +import threading +from time import monotonic +from heapq import heappush, heappop +from typing import Any, Callable, List, Optional, NamedTuple + +logger = logging.getLogger("psycopg3.sched") + + +class Task(NamedTuple): + time: float + action: Optional[Callable[[], Any]] + + def __eq__(self, other: "Task") -> Any: # type: ignore[override] + return self.time == other.time + + def __lt__(self, other: "Task") -> Any: # type: ignore[override] + return self.time < other.time + + def __le__(self, other: "Task") -> Any: # type: ignore[override] + return self.time <= other.time + + def __gt__(self, other: "Task") -> Any: # type: ignore[override] + return self.time > other.time + + def __ge__(self, other: "Task") -> Any: # type: ignore[override] + return self.time >= other.time + + +class Scheduler: + def __init__(self) -> None: + """Initialize a new instance, passing the time and delay functions.""" + self._queue: List[Task] = [] + self._lock = threading.RLock() + self._event = threading.Event() + + EMPTY_QUEUE_TIMEOUT = 600.0 + + def enter(self, delay: float, action: Optional[Callable[[], Any]]) -> Task: + """Enter a new task in the queue delayed in the future. + + Schedule a `!None` to stop the execution. + """ + time = monotonic() + delay + return self.enterabs(time, action) + + def enterabs( + self, time: float, action: Optional[Callable[[], Any]] + ) -> Task: + """Enter a new task in the queue at an absolute time. + + Schedule a `!None` to stop the execution. + """ + task = Task(time, action) + with self._lock: + heappush(self._queue, task) + first = self._queue[0] is task + + if first: + self._event.set() + + return task + + def run(self) -> None: + """Execute the events scheduled.""" + q = self._queue + while True: + with self._lock: + now = monotonic() + task = q[0] if q else None + if task: + if task.time <= now: + heappop(q) + else: + delay = task.time - now + task = None + else: + delay = self.EMPTY_QUEUE_TIMEOUT + self._event.clear() + + if task: + if not task.action: + break + try: + task.action() + except Exception as e: + logger.warning( + "scheduled task run %s failed: %s: %s", + task.action, + e.__class__.__name__, + e, + ) + else: + # Block for the expected timeout or until a new task scheduled + self._event.wait(timeout=delay) diff --git a/tests/test_sched.py b/tests/test_sched.py new file mode 100644 index 000000000..4ee7cafa6 --- /dev/null +++ b/tests/test_sched.py @@ -0,0 +1,150 @@ +import logging +from time import time, sleep +from functools import partial +from threading import Thread + +import pytest + +from psycopg3._sched import Scheduler + + +@pytest.mark.slow +def test_sched(): + s = Scheduler() + results = [] + + def worker(i): + results.append((i, time())) + + t0 = time() + s.enter(0.1, partial(worker, 1)) + s.enter(0.4, partial(worker, 3)) + s.enter(0.3, None) + s.enter(0.2, partial(worker, 2)) + s.run() + assert len(results) == 2 + assert results[0][0] == 1 + assert results[0][1] - t0 == pytest.approx(0.1, 0.01) + assert results[1][0] == 2 + assert results[1][1] - t0 == pytest.approx(0.2, 0.01) + + +@pytest.mark.slow +def test_sched_thread(): + s = Scheduler() + t = Thread(target=s.run) + t.daemon = True + t.start() + + results = [] + + def worker(i): + results.append((i, time())) + + t0 = time() + s.enter(0.1, partial(worker, 1)) + s.enter(0.4, partial(worker, 3)) + s.enter(0.3, None) + s.enter(0.2, partial(worker, 2)) + + t.join() + t1 = time() + assert t1 - t0 == pytest.approx(0.3, 0.01) + + assert len(results) == 2 + assert results[0][0] == 1 + assert results[0][1] - t0 == pytest.approx(0.1, 0.01) + assert results[1][0] == 2 + assert results[1][1] - t0 == pytest.approx(0.2, 0.01) + + +@pytest.mark.slow +def test_sched_error(caplog): + caplog.set_level(logging.WARNING, logger="psycopg3") + s = Scheduler() + t = Thread(target=s.run) + t.daemon = True + t.start() + + results = [] + + def worker(i): + results.append((i, time())) + + def error(): + 1 / 0 + + t0 = time() + s.enter(0.1, partial(worker, 1)) + s.enter(0.4, None) + s.enter(0.3, partial(worker, 2)) + s.enter(0.2, error) + + t.join() + t1 = time() + assert t1 - t0 == pytest.approx(0.4, 0.01) + + assert len(results) == 2 + assert results[0][0] == 1 + assert results[0][1] - t0 == pytest.approx(0.1, 0.01) + assert results[1][0] == 2 + assert results[1][1] - t0 == pytest.approx(0.3, 0.01) + + assert len(caplog.records) == 1 + assert "ZeroDivisionError" in caplog.records[0].message + + +@pytest.mark.slow +def test_empty_queue_timeout(): + s = Scheduler() + + t0 = time() + times = [] + + wait_orig = s._event.wait + + def wait_logging(timeout=None): + rv = wait_orig(timeout) + times.append(time() - t0) + return rv + + s._event.wait = wait_logging + s.EMPTY_QUEUE_TIMEOUT = 0.2 + + t = Thread(target=s.run) + t.start() + sleep(0.5) + s.enter(0.5, None) + t.join() + times.append(time() - t0) + for got, want in zip(times, [0.2, 0.4, 0.5, 1.0]): + assert got == pytest.approx(want, 0.01) + + +@pytest.mark.slow +def test_first_task_rescheduling(): + s = Scheduler() + + t0 = time() + times = [] + + wait_orig = s._event.wait + + def wait_logging(timeout=None): + rv = wait_orig(timeout) + times.append(time() - t0) + return rv + + s._event.wait = wait_logging + s.EMPTY_QUEUE_TIMEOUT = 0.1 + + s.enter(0.4, lambda: None) + t = Thread(target=s.run) + t.start() + s.enter(0.6, None) # this task doesn't trigger a reschedule + sleep(0.1) + s.enter(0.1, lambda: None) # this triggers a reschedule + t.join() + times.append(time() - t0) + for got, want in zip(times, [0.1, 0.2, 0.4, 0.6, 0.6]): + assert got == pytest.approx(want, 0.01)