# Copyright (C) 2021 The Psycopg Team
+import asyncio
import logging
import threading
from time import monotonic
else:
# Block for the expected timeout or until a new task scheduled
self._event.wait(timeout=delay)
+
+
+class AsyncScheduler:
+ def __init__(self) -> None:
+ """Initialize a new instance, passing the time and delay functions."""
+ self._queue: List[Task] = []
+ self._lock = asyncio.Lock()
+ self._event = asyncio.Event()
+
+ EMPTY_QUEUE_TIMEOUT = 600.0
+
+ async def enter(
+ self, delay: float, action: Optional[Callable[[], Any]]
+ ) -> Task:
+ """Enter a new task in the queue delayed in the future.
+
+ Schedule a `!None` to stop the execution.
+ """
+ time = monotonic() + delay
+ return await self.enterabs(time, action)
+
+ async def enterabs(
+ self, time: float, action: Optional[Callable[[], Any]]
+ ) -> Task:
+ """Enter a new task in the queue at an absolute time.
+
+ Schedule a `!None` to stop the execution.
+ """
+ task = Task(time, action)
+ async with self._lock:
+ heappush(self._queue, task)
+ first = self._queue[0] is task
+
+ if first:
+ self._event.set()
+
+ return task
+
+ async def run(self) -> None:
+ """Execute the events scheduled."""
+ q = self._queue
+ while True:
+ async with self._lock:
+ now = monotonic()
+ task = q[0] if q else None
+ if task:
+ if task.time <= now:
+ heappop(q)
+ else:
+ delay = task.time - now
+ task = None
+ else:
+ delay = self.EMPTY_QUEUE_TIMEOUT
+ self._event.clear()
+
+ if task:
+ # logger.info("task %s action %s", task, task.action)
+ if not task.action:
+ break
+ try:
+ await task.action()
+ except Exception as e:
+ logger.warning(
+ "scheduled task run %s failed: %s: %s",
+ task.action,
+ e.__class__.__name__,
+ e,
+ )
+ else:
+ # Block for the expected timeout or until a new task scheduled
+ try:
+ await asyncio.wait_for(self._event.wait(), delay)
+ except asyncio.TimeoutError:
+ pass
--- /dev/null
+import asyncio
+import logging
+from time import time
+from functools import partial
+
+import pytest
+
+from psycopg3.pool.sched import AsyncScheduler
+from psycopg3.utils.compat import create_task
+
+pytestmark = pytest.mark.asyncio
+
+
+@pytest.mark.slow
+async def test_sched():
+ s = AsyncScheduler()
+ results = []
+
+ async def worker(i):
+ results.append((i, time()))
+
+ t0 = time()
+ await s.enter(0.1, partial(worker, 1))
+ await s.enter(0.4, partial(worker, 3))
+ await s.enter(0.3, None)
+ await s.enter(0.2, partial(worker, 2))
+ await s.run()
+ assert len(results) == 2
+ assert results[0][0] == 1
+ assert results[0][1] - t0 == pytest.approx(0.1, 0.1)
+ assert results[1][0] == 2
+ assert results[1][1] - t0 == pytest.approx(0.2, 0.1)
+
+
+@pytest.mark.slow
+async def test_sched_task():
+ s = AsyncScheduler()
+ t = create_task(s.run())
+
+ results = []
+
+ async def worker(i):
+ results.append((i, time()))
+
+ t0 = time()
+ await s.enter(0.1, partial(worker, 1))
+ await s.enter(0.4, partial(worker, 3))
+ await s.enter(0.3, None)
+ await s.enter(0.2, partial(worker, 2))
+
+ await asyncio.gather(t)
+ t1 = time()
+ assert t1 - t0 == pytest.approx(0.3, 0.1)
+
+ assert len(results) == 2
+ assert results[0][0] == 1
+ assert results[0][1] - t0 == pytest.approx(0.1, 0.1)
+ assert results[1][0] == 2
+ assert results[1][1] - t0 == pytest.approx(0.2, 0.1)
+
+
+@pytest.mark.slow
+async def test_sched_error(caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg3")
+ s = AsyncScheduler()
+ t = create_task(s.run())
+
+ results = []
+
+ async def worker(i):
+ results.append((i, time()))
+
+ async def error():
+ 1 / 0
+
+ t0 = time()
+ await s.enter(0.1, partial(worker, 1))
+ await s.enter(0.4, None)
+ await s.enter(0.3, partial(worker, 2))
+ await s.enter(0.2, error)
+
+ await asyncio.gather(t)
+ t1 = time()
+ assert t1 - t0 == pytest.approx(0.4, 0.1)
+
+ assert len(results) == 2
+ assert results[0][0] == 1
+ assert results[0][1] - t0 == pytest.approx(0.1, 0.1)
+ assert results[1][0] == 2
+ assert results[1][1] - t0 == pytest.approx(0.3, 0.1)
+
+ assert len(caplog.records) == 1
+ assert "ZeroDivisionError" in caplog.records[0].message
+
+
+@pytest.mark.slow
+async def test_empty_queue_timeout():
+ s = AsyncScheduler()
+
+ t0 = time()
+ times = []
+
+ wait_orig = s._event.wait
+
+ async def wait_logging():
+ try:
+ rv = await wait_orig()
+ finally:
+ times.append(time() - t0)
+ return rv
+
+ s._event.wait = wait_logging
+ s.EMPTY_QUEUE_TIMEOUT = 0.2
+
+ t = create_task(s.run())
+ await asyncio.sleep(0.5)
+ await s.enter(0.5, None)
+ await asyncio.gather(t)
+ times.append(time() - t0)
+ for got, want in zip(times, [0.2, 0.4, 0.5, 1.0]):
+ assert got == pytest.approx(want, 0.1), times
+
+
+@pytest.mark.slow
+async def test_first_task_rescheduling():
+ s = AsyncScheduler()
+
+ t0 = time()
+ times = []
+
+ wait_orig = s._event.wait
+
+ async def wait_logging():
+ try:
+ rv = await wait_orig()
+ finally:
+ times.append(time() - t0)
+ return rv
+
+ s._event.wait = wait_logging
+ s.EMPTY_QUEUE_TIMEOUT = 0.1
+
+ async def noop():
+ pass
+
+ await s.enter(0.4, noop)
+ t = create_task(s.run())
+ await s.enter(0.6, None) # this task doesn't trigger a reschedule
+ await asyncio.sleep(0.1)
+ await s.enter(0.1, noop) # this triggers a reschedule
+ await asyncio.gather(t)
+ times.append(time() - t0)
+ for got, want in zip(times, [0.1, 0.2, 0.4, 0.6, 0.6]):
+ assert got == pytest.approx(want, 0.1), times