--- /dev/null
+"""
+A minimal scheduler to schedule tasks run in the future.
+
+Inspired to the standard library `sched.scheduler`, but designed for
+multi-thread usage ground up, not as an afterthought. Tasks can be scheduled in
+front of the one currently running and `Scheduler.run()` can be left running
+without any task scheduled.
+
+Tasks are called "Task", not "Event", here, because we actually make use of
+`threading.Event` and the two would be confusing.
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import logging
+import threading
+from time import monotonic
+from heapq import heappush, heappop
+from typing import Any, Callable, List, Optional, NamedTuple
+
+logger = logging.getLogger("psycopg3.sched")
+
+
+class Task(NamedTuple):
+ time: float
+ action: Optional[Callable[[], Any]]
+
+ def __eq__(self, other: "Task") -> Any: # type: ignore[override]
+ return self.time == other.time
+
+ def __lt__(self, other: "Task") -> Any: # type: ignore[override]
+ return self.time < other.time
+
+ def __le__(self, other: "Task") -> Any: # type: ignore[override]
+ return self.time <= other.time
+
+ def __gt__(self, other: "Task") -> Any: # type: ignore[override]
+ return self.time > other.time
+
+ def __ge__(self, other: "Task") -> Any: # type: ignore[override]
+ return self.time >= other.time
+
+
+class Scheduler:
+ def __init__(self) -> None:
+ """Initialize a new instance, passing the time and delay functions."""
+ self._queue: List[Task] = []
+ self._lock = threading.RLock()
+ self._event = threading.Event()
+
+ EMPTY_QUEUE_TIMEOUT = 600.0
+
+ def enter(self, delay: float, action: Optional[Callable[[], Any]]) -> Task:
+ """Enter a new task in the queue delayed in the future.
+
+ Schedule a `!None` to stop the execution.
+ """
+ time = monotonic() + delay
+ return self.enterabs(time, action)
+
+ def enterabs(
+ self, time: float, action: Optional[Callable[[], Any]]
+ ) -> Task:
+ """Enter a new task in the queue at an absolute time.
+
+ Schedule a `!None` to stop the execution.
+ """
+ task = Task(time, action)
+ with self._lock:
+ heappush(self._queue, task)
+ first = self._queue[0] is task
+
+ if first:
+ self._event.set()
+
+ return task
+
+ def run(self) -> None:
+ """Execute the events scheduled."""
+ q = self._queue
+ while True:
+ with self._lock:
+ now = monotonic()
+ task = q[0] if q else None
+ if task:
+ if task.time <= now:
+ heappop(q)
+ else:
+ delay = task.time - now
+ task = None
+ else:
+ delay = self.EMPTY_QUEUE_TIMEOUT
+ self._event.clear()
+
+ if task:
+ if not task.action:
+ break
+ try:
+ task.action()
+ except Exception as e:
+ logger.warning(
+ "scheduled task run %s failed: %s: %s",
+ task.action,
+ e.__class__.__name__,
+ e,
+ )
+ else:
+ # Block for the expected timeout or until a new task scheduled
+ self._event.wait(timeout=delay)
--- /dev/null
+import logging
+from time import time, sleep
+from functools import partial
+from threading import Thread
+
+import pytest
+
+from psycopg3._sched import Scheduler
+
+
+@pytest.mark.slow
+def test_sched():
+ s = Scheduler()
+ results = []
+
+ def worker(i):
+ results.append((i, time()))
+
+ t0 = time()
+ s.enter(0.1, partial(worker, 1))
+ s.enter(0.4, partial(worker, 3))
+ s.enter(0.3, None)
+ s.enter(0.2, partial(worker, 2))
+ s.run()
+ assert len(results) == 2
+ assert results[0][0] == 1
+ assert results[0][1] - t0 == pytest.approx(0.1, 0.01)
+ assert results[1][0] == 2
+ assert results[1][1] - t0 == pytest.approx(0.2, 0.01)
+
+
+@pytest.mark.slow
+def test_sched_thread():
+ s = Scheduler()
+ t = Thread(target=s.run)
+ t.daemon = True
+ t.start()
+
+ results = []
+
+ def worker(i):
+ results.append((i, time()))
+
+ t0 = time()
+ s.enter(0.1, partial(worker, 1))
+ s.enter(0.4, partial(worker, 3))
+ s.enter(0.3, None)
+ s.enter(0.2, partial(worker, 2))
+
+ t.join()
+ t1 = time()
+ assert t1 - t0 == pytest.approx(0.3, 0.01)
+
+ assert len(results) == 2
+ assert results[0][0] == 1
+ assert results[0][1] - t0 == pytest.approx(0.1, 0.01)
+ assert results[1][0] == 2
+ assert results[1][1] - t0 == pytest.approx(0.2, 0.01)
+
+
+@pytest.mark.slow
+def test_sched_error(caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg3")
+ s = Scheduler()
+ t = Thread(target=s.run)
+ t.daemon = True
+ t.start()
+
+ results = []
+
+ def worker(i):
+ results.append((i, time()))
+
+ def error():
+ 1 / 0
+
+ t0 = time()
+ s.enter(0.1, partial(worker, 1))
+ s.enter(0.4, None)
+ s.enter(0.3, partial(worker, 2))
+ s.enter(0.2, error)
+
+ t.join()
+ t1 = time()
+ assert t1 - t0 == pytest.approx(0.4, 0.01)
+
+ assert len(results) == 2
+ assert results[0][0] == 1
+ assert results[0][1] - t0 == pytest.approx(0.1, 0.01)
+ assert results[1][0] == 2
+ assert results[1][1] - t0 == pytest.approx(0.3, 0.01)
+
+ assert len(caplog.records) == 1
+ assert "ZeroDivisionError" in caplog.records[0].message
+
+
+@pytest.mark.slow
+def test_empty_queue_timeout():
+ s = Scheduler()
+
+ t0 = time()
+ times = []
+
+ wait_orig = s._event.wait
+
+ def wait_logging(timeout=None):
+ rv = wait_orig(timeout)
+ times.append(time() - t0)
+ return rv
+
+ s._event.wait = wait_logging
+ s.EMPTY_QUEUE_TIMEOUT = 0.2
+
+ t = Thread(target=s.run)
+ t.start()
+ sleep(0.5)
+ s.enter(0.5, None)
+ t.join()
+ times.append(time() - t0)
+ for got, want in zip(times, [0.2, 0.4, 0.5, 1.0]):
+ assert got == pytest.approx(want, 0.01)
+
+
+@pytest.mark.slow
+def test_first_task_rescheduling():
+ s = Scheduler()
+
+ t0 = time()
+ times = []
+
+ wait_orig = s._event.wait
+
+ def wait_logging(timeout=None):
+ rv = wait_orig(timeout)
+ times.append(time() - t0)
+ return rv
+
+ s._event.wait = wait_logging
+ s.EMPTY_QUEUE_TIMEOUT = 0.1
+
+ s.enter(0.4, lambda: None)
+ t = Thread(target=s.run)
+ t.start()
+ s.enter(0.6, None) # this task doesn't trigger a reschedule
+ sleep(0.1)
+ s.enter(0.1, lambda: None) # this triggers a reschedule
+ t.join()
+ times.append(time() - t0)
+ for got, want in zip(times, [0.1, 0.2, 0.4, 0.6, 0.6]):
+ assert got == pytest.approx(want, 0.01)