]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add scheduler class
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 20 Feb 2021 23:07:10 +0000 (00:07 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Mar 2021 04:07:25 +0000 (05:07 +0100)
To be used in the connection pool to implement reconnection with
backoff, maybe periodic tasks.

psycopg3/psycopg3/_sched.py [new file with mode: 0644]
tests/test_sched.py [new file with mode: 0644]

diff --git a/psycopg3/psycopg3/_sched.py b/psycopg3/psycopg3/_sched.py
new file mode 100644 (file)
index 0000000..040a70c
--- /dev/null
@@ -0,0 +1,109 @@
+"""
+A minimal scheduler to schedule tasks run in the future.
+
+Inspired to the standard library `sched.scheduler`, but designed for
+multi-thread usage ground up, not as an afterthought. Tasks can be scheduled in
+front of the one currently running and `Scheduler.run()` can be left running
+without any task scheduled.
+
+Tasks are called "Task", not "Event", here, because we actually make use of
+`threading.Event` and the two would be confusing.
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import logging
+import threading
+from time import monotonic
+from heapq import heappush, heappop
+from typing import Any, Callable, List, Optional, NamedTuple
+
+logger = logging.getLogger("psycopg3.sched")
+
+
+class Task(NamedTuple):
+    time: float
+    action: Optional[Callable[[], Any]]
+
+    def __eq__(self, other: "Task") -> Any:  # type: ignore[override]
+        return self.time == other.time
+
+    def __lt__(self, other: "Task") -> Any:  # type: ignore[override]
+        return self.time < other.time
+
+    def __le__(self, other: "Task") -> Any:  # type: ignore[override]
+        return self.time <= other.time
+
+    def __gt__(self, other: "Task") -> Any:  # type: ignore[override]
+        return self.time > other.time
+
+    def __ge__(self, other: "Task") -> Any:  # type: ignore[override]
+        return self.time >= other.time
+
+
+class Scheduler:
+    def __init__(self) -> None:
+        """Initialize a new instance, passing the time and delay functions."""
+        self._queue: List[Task] = []
+        self._lock = threading.RLock()
+        self._event = threading.Event()
+
+    EMPTY_QUEUE_TIMEOUT = 600.0
+
+    def enter(self, delay: float, action: Optional[Callable[[], Any]]) -> Task:
+        """Enter a new task in the queue delayed in the future.
+
+        Schedule a `!None` to stop the execution.
+        """
+        time = monotonic() + delay
+        return self.enterabs(time, action)
+
+    def enterabs(
+        self, time: float, action: Optional[Callable[[], Any]]
+    ) -> Task:
+        """Enter a new task in the queue at an absolute time.
+
+        Schedule a `!None` to stop the execution.
+        """
+        task = Task(time, action)
+        with self._lock:
+            heappush(self._queue, task)
+            first = self._queue[0] is task
+
+        if first:
+            self._event.set()
+
+        return task
+
+    def run(self) -> None:
+        """Execute the events scheduled."""
+        q = self._queue
+        while True:
+            with self._lock:
+                now = monotonic()
+                task = q[0] if q else None
+                if task:
+                    if task.time <= now:
+                        heappop(q)
+                    else:
+                        delay = task.time - now
+                        task = None
+                else:
+                    delay = self.EMPTY_QUEUE_TIMEOUT
+                self._event.clear()
+
+            if task:
+                if not task.action:
+                    break
+                try:
+                    task.action()
+                except Exception as e:
+                    logger.warning(
+                        "scheduled task run %s failed: %s: %s",
+                        task.action,
+                        e.__class__.__name__,
+                        e,
+                    )
+            else:
+                # Block for the expected timeout or until a new task scheduled
+                self._event.wait(timeout=delay)
diff --git a/tests/test_sched.py b/tests/test_sched.py
new file mode 100644 (file)
index 0000000..4ee7caf
--- /dev/null
@@ -0,0 +1,150 @@
+import logging
+from time import time, sleep
+from functools import partial
+from threading import Thread
+
+import pytest
+
+from psycopg3._sched import Scheduler
+
+
+@pytest.mark.slow
+def test_sched():
+    s = Scheduler()
+    results = []
+
+    def worker(i):
+        results.append((i, time()))
+
+    t0 = time()
+    s.enter(0.1, partial(worker, 1))
+    s.enter(0.4, partial(worker, 3))
+    s.enter(0.3, None)
+    s.enter(0.2, partial(worker, 2))
+    s.run()
+    assert len(results) == 2
+    assert results[0][0] == 1
+    assert results[0][1] - t0 == pytest.approx(0.1, 0.01)
+    assert results[1][0] == 2
+    assert results[1][1] - t0 == pytest.approx(0.2, 0.01)
+
+
+@pytest.mark.slow
+def test_sched_thread():
+    s = Scheduler()
+    t = Thread(target=s.run)
+    t.daemon = True
+    t.start()
+
+    results = []
+
+    def worker(i):
+        results.append((i, time()))
+
+    t0 = time()
+    s.enter(0.1, partial(worker, 1))
+    s.enter(0.4, partial(worker, 3))
+    s.enter(0.3, None)
+    s.enter(0.2, partial(worker, 2))
+
+    t.join()
+    t1 = time()
+    assert t1 - t0 == pytest.approx(0.3, 0.01)
+
+    assert len(results) == 2
+    assert results[0][0] == 1
+    assert results[0][1] - t0 == pytest.approx(0.1, 0.01)
+    assert results[1][0] == 2
+    assert results[1][1] - t0 == pytest.approx(0.2, 0.01)
+
+
+@pytest.mark.slow
+def test_sched_error(caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg3")
+    s = Scheduler()
+    t = Thread(target=s.run)
+    t.daemon = True
+    t.start()
+
+    results = []
+
+    def worker(i):
+        results.append((i, time()))
+
+    def error():
+        1 / 0
+
+    t0 = time()
+    s.enter(0.1, partial(worker, 1))
+    s.enter(0.4, None)
+    s.enter(0.3, partial(worker, 2))
+    s.enter(0.2, error)
+
+    t.join()
+    t1 = time()
+    assert t1 - t0 == pytest.approx(0.4, 0.01)
+
+    assert len(results) == 2
+    assert results[0][0] == 1
+    assert results[0][1] - t0 == pytest.approx(0.1, 0.01)
+    assert results[1][0] == 2
+    assert results[1][1] - t0 == pytest.approx(0.3, 0.01)
+
+    assert len(caplog.records) == 1
+    assert "ZeroDivisionError" in caplog.records[0].message
+
+
+@pytest.mark.slow
+def test_empty_queue_timeout():
+    s = Scheduler()
+
+    t0 = time()
+    times = []
+
+    wait_orig = s._event.wait
+
+    def wait_logging(timeout=None):
+        rv = wait_orig(timeout)
+        times.append(time() - t0)
+        return rv
+
+    s._event.wait = wait_logging
+    s.EMPTY_QUEUE_TIMEOUT = 0.2
+
+    t = Thread(target=s.run)
+    t.start()
+    sleep(0.5)
+    s.enter(0.5, None)
+    t.join()
+    times.append(time() - t0)
+    for got, want in zip(times, [0.2, 0.4, 0.5, 1.0]):
+        assert got == pytest.approx(want, 0.01)
+
+
+@pytest.mark.slow
+def test_first_task_rescheduling():
+    s = Scheduler()
+
+    t0 = time()
+    times = []
+
+    wait_orig = s._event.wait
+
+    def wait_logging(timeout=None):
+        rv = wait_orig(timeout)
+        times.append(time() - t0)
+        return rv
+
+    s._event.wait = wait_logging
+    s.EMPTY_QUEUE_TIMEOUT = 0.1
+
+    s.enter(0.4, lambda: None)
+    t = Thread(target=s.run)
+    t.start()
+    s.enter(0.6, None)  # this task doesn't trigger a reschedule
+    sleep(0.1)
+    s.enter(0.1, lambda: None)  # this triggers a reschedule
+    t.join()
+    times.append(time() - t0)
+    for got, want in zip(times, [0.1, 0.2, 0.4, 0.6, 0.6]):
+        assert got == pytest.approx(want, 0.01)