]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add an async scheduler
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 27 Feb 2021 14:28:18 +0000 (15:28 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Mar 2021 04:07:25 +0000 (05:07 +0100)
psycopg3/psycopg3/pool/sched.py
tests/pool/test_sched.py
tests/pool/test_sched_async.py [new file with mode: 0644]

index 7c54b12f70a18e7ebf55cfe14a692aa657719f1c..95af07532c4ac41cf16ac45d01c5a5aa16fe3250 100644 (file)
@@ -12,6 +12,7 @@ Tasks are called "Task", not "Event", here, because we actually make use of
 
 # Copyright (C) 2021 The Psycopg Team
 
+import asyncio
 import logging
 import threading
 from time import monotonic
@@ -107,3 +108,77 @@ class Scheduler:
             else:
                 # Block for the expected timeout or until a new task scheduled
                 self._event.wait(timeout=delay)
+
+
+class AsyncScheduler:
+    def __init__(self) -> None:
+        """Initialize a new instance, passing the time and delay functions."""
+        self._queue: List[Task] = []
+        self._lock = asyncio.Lock()
+        self._event = asyncio.Event()
+
+    EMPTY_QUEUE_TIMEOUT = 600.0
+
+    async def enter(
+        self, delay: float, action: Optional[Callable[[], Any]]
+    ) -> Task:
+        """Enter a new task in the queue delayed in the future.
+
+        Schedule a `!None` to stop the execution.
+        """
+        time = monotonic() + delay
+        return await self.enterabs(time, action)
+
+    async def enterabs(
+        self, time: float, action: Optional[Callable[[], Any]]
+    ) -> Task:
+        """Enter a new task in the queue at an absolute time.
+
+        Schedule a `!None` to stop the execution.
+        """
+        task = Task(time, action)
+        async with self._lock:
+            heappush(self._queue, task)
+            first = self._queue[0] is task
+
+        if first:
+            self._event.set()
+
+        return task
+
+    async def run(self) -> None:
+        """Execute the events scheduled."""
+        q = self._queue
+        while True:
+            async with self._lock:
+                now = monotonic()
+                task = q[0] if q else None
+                if task:
+                    if task.time <= now:
+                        heappop(q)
+                    else:
+                        delay = task.time - now
+                        task = None
+                else:
+                    delay = self.EMPTY_QUEUE_TIMEOUT
+                self._event.clear()
+
+            if task:
+                # logger.info("task %s action %s", task, task.action)
+                if not task.action:
+                    break
+                try:
+                    await task.action()
+                except Exception as e:
+                    logger.warning(
+                        "scheduled task run %s failed: %s: %s",
+                        task.action,
+                        e.__class__.__name__,
+                        e,
+                    )
+            else:
+                # Block for the expected timeout or until a new task scheduled
+                try:
+                    await asyncio.wait_for(self._event.wait(), delay)
+                except asyncio.TimeoutError:
+                    pass
index 3b227abba2acb6cb284f0420f9ad01640866192d..3ca6cb1f08a779df62ac9d3a5d2663b0dd6fced5 100644 (file)
@@ -32,8 +32,7 @@ def test_sched():
 @pytest.mark.slow
 def test_sched_thread():
     s = Scheduler()
-    t = Thread(target=s.run)
-    t.daemon = True
+    t = Thread(target=s.run, daemon=True)
     t.start()
 
     results = []
@@ -62,8 +61,7 @@ def test_sched_thread():
 def test_sched_error(caplog):
     caplog.set_level(logging.WARNING, logger="psycopg3")
     s = Scheduler()
-    t = Thread(target=s.run)
-    t.daemon = True
+    t = Thread(target=s.run, daemon=True)
     t.start()
 
     results = []
diff --git a/tests/pool/test_sched_async.py b/tests/pool/test_sched_async.py
new file mode 100644 (file)
index 0000000..de3f6fa
--- /dev/null
@@ -0,0 +1,154 @@
+import asyncio
+import logging
+from time import time
+from functools import partial
+
+import pytest
+
+from psycopg3.pool.sched import AsyncScheduler
+from psycopg3.utils.compat import create_task
+
+pytestmark = pytest.mark.asyncio
+
+
+@pytest.mark.slow
+async def test_sched():
+    s = AsyncScheduler()
+    results = []
+
+    async def worker(i):
+        results.append((i, time()))
+
+    t0 = time()
+    await s.enter(0.1, partial(worker, 1))
+    await s.enter(0.4, partial(worker, 3))
+    await s.enter(0.3, None)
+    await s.enter(0.2, partial(worker, 2))
+    await s.run()
+    assert len(results) == 2
+    assert results[0][0] == 1
+    assert results[0][1] - t0 == pytest.approx(0.1, 0.1)
+    assert results[1][0] == 2
+    assert results[1][1] - t0 == pytest.approx(0.2, 0.1)
+
+
+@pytest.mark.slow
+async def test_sched_task():
+    s = AsyncScheduler()
+    t = create_task(s.run())
+
+    results = []
+
+    async def worker(i):
+        results.append((i, time()))
+
+    t0 = time()
+    await s.enter(0.1, partial(worker, 1))
+    await s.enter(0.4, partial(worker, 3))
+    await s.enter(0.3, None)
+    await s.enter(0.2, partial(worker, 2))
+
+    await asyncio.gather(t)
+    t1 = time()
+    assert t1 - t0 == pytest.approx(0.3, 0.1)
+
+    assert len(results) == 2
+    assert results[0][0] == 1
+    assert results[0][1] - t0 == pytest.approx(0.1, 0.1)
+    assert results[1][0] == 2
+    assert results[1][1] - t0 == pytest.approx(0.2, 0.1)
+
+
+@pytest.mark.slow
+async def test_sched_error(caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg3")
+    s = AsyncScheduler()
+    t = create_task(s.run())
+
+    results = []
+
+    async def worker(i):
+        results.append((i, time()))
+
+    async def error():
+        1 / 0
+
+    t0 = time()
+    await s.enter(0.1, partial(worker, 1))
+    await s.enter(0.4, None)
+    await s.enter(0.3, partial(worker, 2))
+    await s.enter(0.2, error)
+
+    await asyncio.gather(t)
+    t1 = time()
+    assert t1 - t0 == pytest.approx(0.4, 0.1)
+
+    assert len(results) == 2
+    assert results[0][0] == 1
+    assert results[0][1] - t0 == pytest.approx(0.1, 0.1)
+    assert results[1][0] == 2
+    assert results[1][1] - t0 == pytest.approx(0.3, 0.1)
+
+    assert len(caplog.records) == 1
+    assert "ZeroDivisionError" in caplog.records[0].message
+
+
+@pytest.mark.slow
+async def test_empty_queue_timeout():
+    s = AsyncScheduler()
+
+    t0 = time()
+    times = []
+
+    wait_orig = s._event.wait
+
+    async def wait_logging():
+        try:
+            rv = await wait_orig()
+        finally:
+            times.append(time() - t0)
+        return rv
+
+    s._event.wait = wait_logging
+    s.EMPTY_QUEUE_TIMEOUT = 0.2
+
+    t = create_task(s.run())
+    await asyncio.sleep(0.5)
+    await s.enter(0.5, None)
+    await asyncio.gather(t)
+    times.append(time() - t0)
+    for got, want in zip(times, [0.2, 0.4, 0.5, 1.0]):
+        assert got == pytest.approx(want, 0.1), times
+
+
+@pytest.mark.slow
+async def test_first_task_rescheduling():
+    s = AsyncScheduler()
+
+    t0 = time()
+    times = []
+
+    wait_orig = s._event.wait
+
+    async def wait_logging():
+        try:
+            rv = await wait_orig()
+        finally:
+            times.append(time() - t0)
+        return rv
+
+    s._event.wait = wait_logging
+    s.EMPTY_QUEUE_TIMEOUT = 0.1
+
+    async def noop():
+        pass
+
+    await s.enter(0.4, noop)
+    t = create_task(s.run())
+    await s.enter(0.6, None)  # this task doesn't trigger a reschedule
+    await asyncio.sleep(0.1)
+    await s.enter(0.1, noop)  # this triggers a reschedule
+    await asyncio.gather(t)
+    times.append(time() - t0)
+    for got, want in zip(times, [0.1, 0.2, 0.4, 0.6, 0.6]):
+        assert got == pytest.approx(want, 0.1), times