]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(tests): generate test_sched from async counterpart
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 11 Sep 2023 17:50:44 +0000 (18:50 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 11 Oct 2023 21:45:38 +0000 (23:45 +0200)
tests/pool/test_sched.py
tests/pool/test_sched_async.py
tests/utils.py
tools/async_to_sync.py
tools/convert_async_to_sync.sh

index b58be2f81873e697b7e585c9518f1ffa84757ad9..2639dc1f8c7825f86d3be329bbd7fc25be50aa2e 100644 (file)
@@ -1,10 +1,15 @@
+# WARNING: this file is auto-generated by 'async_to_sync.py'
+# from the original file 'test_sched_async.py'
+# DO NOT CHANGE! Change the original file instead.
 import logging
-from time import time, sleep
+from time import time
 from functools import partial
-from threading import Thread
+from contextlib import contextmanager
 
 import pytest
 
+from ..utils import spawn, gather, sleep
+
 try:
     from psycopg_pool.sched import Scheduler
 except ImportError:
@@ -15,7 +20,6 @@ pytestmark = [pytest.mark.timing]
 
 
 @pytest.mark.slow
-@pytest.mark.timing
 def test_sched():
     s = Scheduler()
     results = []
@@ -37,11 +41,9 @@ def test_sched():
 
 
 @pytest.mark.slow
-@pytest.mark.timing
-def test_sched_thread():
+def test_sched_task():
     s = Scheduler()
-    t = Thread(target=s.run, daemon=True)
-    t.start()
+    t = spawn(s.run)
 
     results = []
 
@@ -54,7 +56,7 @@ def test_sched_thread():
     s.enter(0.3, None)
     s.enter(0.2, partial(worker, 2))
 
-    t.join()
+    gather(t)
     t1 = time()
     assert t1 - t0 == pytest.approx(0.3, 0.2)
 
@@ -66,12 +68,10 @@ def test_sched_thread():
 
 
 @pytest.mark.slow
-@pytest.mark.timing
 def test_sched_error(caplog):
     caplog.set_level(logging.WARNING, logger="psycopg")
     s = Scheduler()
-    t = Thread(target=s.run, daemon=True)
-    t.start()
+    t = spawn(s.run)
 
     results = []
 
@@ -87,7 +87,7 @@ def test_sched_error(caplog):
     s.enter(0.3, partial(worker, 2))
     s.enter(0.2, error)
 
-    t.join()
+    gather(t)
     t1 = time()
     assert t1 - t0 == pytest.approx(0.4, 0.1)
 
@@ -105,25 +105,14 @@ def test_sched_error(caplog):
 def test_empty_queue_timeout():
     s = Scheduler()
 
-    t0 = time()
-    times = []
+    with timed_wait(s) as times:
+        s.EMPTY_QUEUE_TIMEOUT = 0.2
 
-    wait_orig = s._event.wait
-
-    def wait_logging(timeout=None):
-        rv = wait_orig(timeout)
-        times.append(time() - t0)
-        return rv
+        t = spawn(s.run)
+        sleep(0.5)
+        s.enter(0.5, None)
+        gather(t)
 
-    setattr(s._event, "wait", wait_logging)
-    s.EMPTY_QUEUE_TIMEOUT = 0.2
-
-    t = Thread(target=s.run)
-    t.start()
-    sleep(0.5)
-    s.enter(0.5, None)
-    t.join()
-    times.append(time() - t0)
     for got, want in zip(times, [0.2, 0.4, 0.5, 1.0]):
         assert got == pytest.approx(want, 0.2), times
 
@@ -132,26 +121,47 @@ def test_empty_queue_timeout():
 def test_first_task_rescheduling():
     s = Scheduler()
 
+    with timed_wait(s) as times:
+        s.EMPTY_QUEUE_TIMEOUT = 0.1
+
+        s.enter(0.4, noop)
+        t = spawn(s.run)
+        s.enter(0.6, None)  # this task doesn't trigger a reschedule
+        sleep(0.1)
+        s.enter(0.1, noop)  # this triggers a reschedule
+        gather(t)
+
+    for got, want in zip(times, [0.1, 0.2, 0.4, 0.6, 0.6]):
+        assert got == pytest.approx(want, 0.2), times
+
+
+@contextmanager
+def timed_wait(s):
+    """
+    Hack the scheduler's Event.wait() function in order to log waited time.
+
+    The context is a list where the times are accumulated.
+    """
     t0 = time()
     times = []
 
     wait_orig = s._event.wait
 
     def wait_logging(timeout=None):
-        rv = wait_orig(timeout)
-        times.append(time() - t0)
+        args = (timeout,)
+
+        try:
+            rv = wait_orig(*args)
+        finally:
+            times.append(time() - t0)
         return rv
 
     setattr(s._event, "wait", wait_logging)
-    s.EMPTY_QUEUE_TIMEOUT = 0.1
-
-    s.enter(0.4, lambda: None)
-    t = Thread(target=s.run)
-    t.start()
-    s.enter(0.6, None)  # this task doesn't trigger a reschedule
-    sleep(0.1)
-    s.enter(0.1, lambda: None)  # this triggers a reschedule
-    t.join()
+
+    yield times
+
     times.append(time() - t0)
-    for got, want in zip(times, [0.1, 0.2, 0.4, 0.6, 0.6]):
-        assert got == pytest.approx(want, 0.2), times
+
+
+def noop():
+    pass
index 259c66383bfa36911a0e18d8df9bcee4ec367ba1..23c7cefbd914a18b3ea305676455238b9feb8cd7 100644 (file)
@@ -1,22 +1,24 @@
-import asyncio
 import logging
 from time import time
-from asyncio import create_task
 from functools import partial
+from contextlib import asynccontextmanager
 
 import pytest
 
+from ..utils import spawn, gather, asleep
+
 try:
     from psycopg_pool.sched_async import AsyncScheduler
 except ImportError:
     # Tests should have been skipped if the package is not available
     pass
 
-pytestmark = [pytest.mark.anyio, pytest.mark.timing]
+pytestmark = [pytest.mark.timing]
+if True:  # ASYNC:
+    pytestmark.append(pytest.mark.anyio)
 
 
 @pytest.mark.slow
-@pytest.mark.timing
 async def test_sched():
     s = AsyncScheduler()
     results = []
@@ -38,10 +40,9 @@ async def test_sched():
 
 
 @pytest.mark.slow
-@pytest.mark.timing
 async def test_sched_task():
     s = AsyncScheduler()
-    t = create_task(s.run())
+    t = spawn(s.run)
 
     results = []
 
@@ -54,7 +55,7 @@ async def test_sched_task():
     await s.enter(0.3, None)
     await s.enter(0.2, partial(worker, 2))
 
-    await asyncio.gather(t)
+    await gather(t)
     t1 = time()
     assert t1 - t0 == pytest.approx(0.3, 0.2)
 
@@ -66,11 +67,10 @@ async def test_sched_task():
 
 
 @pytest.mark.slow
-@pytest.mark.timing
 async def test_sched_error(caplog):
     caplog.set_level(logging.WARNING, logger="psycopg")
     s = AsyncScheduler()
-    t = create_task(s.run())
+    t = spawn(s.run)
 
     results = []
 
@@ -86,7 +86,7 @@ async def test_sched_error(caplog):
     await s.enter(0.3, partial(worker, 2))
     await s.enter(0.2, error)
 
-    await asyncio.gather(t)
+    await gather(t)
     t1 = time()
     assert t1 - t0 == pytest.approx(0.4, 0.1)
 
@@ -104,26 +104,14 @@ async def test_sched_error(caplog):
 async def test_empty_queue_timeout():
     s = AsyncScheduler()
 
-    t0 = time()
-    times = []
+    async with timed_wait(s) as times:
+        s.EMPTY_QUEUE_TIMEOUT = 0.2
 
-    wait_orig = s._event.wait
+        t = spawn(s.run)
+        await asleep(0.5)
+        await s.enter(0.5, None)
+        await gather(t)
 
-    async def wait_logging():
-        try:
-            rv = await wait_orig()
-        finally:
-            times.append(time() - t0)
-        return rv
-
-    setattr(s._event, "wait", wait_logging)
-    s.EMPTY_QUEUE_TIMEOUT = 0.2
-
-    t = create_task(s.run())
-    await asyncio.sleep(0.5)
-    await s.enter(0.5, None)
-    await asyncio.gather(t)
-    times.append(time() - t0)
     for got, want in zip(times, [0.2, 0.4, 0.5, 1.0]):
         assert got == pytest.approx(want, 0.2), times
 
@@ -132,30 +120,50 @@ async def test_empty_queue_timeout():
 async def test_first_task_rescheduling():
     s = AsyncScheduler()
 
+    async with timed_wait(s) as times:
+        s.EMPTY_QUEUE_TIMEOUT = 0.1
+
+        await s.enter(0.4, noop)
+        t = spawn(s.run)
+        await s.enter(0.6, None)  # this task doesn't trigger a reschedule
+        await asleep(0.1)
+        await s.enter(0.1, noop)  # this triggers a reschedule
+        await gather(t)
+
+    for got, want in zip(times, [0.1, 0.2, 0.4, 0.6, 0.6]):
+        assert got == pytest.approx(want, 0.2), times
+
+
+@asynccontextmanager
+async def timed_wait(s):
+    """
+    Hack the scheduler's Event.wait() function in order to log waited time.
+
+    The context is a list where the times are accumulated.
+    """
     t0 = time()
     times = []
 
     wait_orig = s._event.wait
 
-    async def wait_logging():
+    async def wait_logging(timeout=None):
+        if True:  # ASYNC
+            args = ()
+        else:
+            args = (timeout,)
+
         try:
-            rv = await wait_orig()
+            rv = await wait_orig(*args)
         finally:
             times.append(time() - t0)
         return rv
 
     setattr(s._event, "wait", wait_logging)
-    s.EMPTY_QUEUE_TIMEOUT = 0.1
 
-    async def noop():
-        pass
+    yield times
 
-    await s.enter(0.4, noop)
-    t = create_task(s.run())
-    await s.enter(0.6, None)  # this task doesn't trigger a reschedule
-    await asyncio.sleep(0.1)
-    await s.enter(0.1, noop)  # this triggers a reschedule
-    await asyncio.gather(t)
     times.append(time() - t0)
-    for got, want in zip(times, [0.1, 0.2, 0.4, 0.6, 0.6]):
-        assert got == pytest.approx(want, 0.2), times
+
+
+async def noop():
+    pass
index 64d72fe7faa21a1cd621fe529742eb9fd740b9a6..5b9b73c1a40d08de8b43899f2180ca0682f0f42e 100644 (file)
@@ -1,10 +1,14 @@
 import gc
 import re
 import sys
+import asyncio
+import inspect
 import operator
+from time import sleep as sleep  # noqa: F401 -- re-export
 from typing import Callable, Optional, Tuple
+from threading import Thread
 from contextlib import contextmanager, asynccontextmanager
-from contextlib import closing as closing  # noqa: F401  - re-export
+from contextlib import closing as closing  # noqa: F401 -- re-export
 
 import pytest
 
@@ -226,3 +230,33 @@ def raiseif(cond, *args, **kwargs):
         with pytest.raises(*args, **kwargs) as ex:
             yield ex
         return
+
+
+def spawn(f):
+    """
+    Equivalent to asyncio.create_task or creating and running a Thread.
+    """
+    if inspect.iscoroutinefunction(f):
+        return asyncio.create_task(f())
+    else:
+        t = Thread(target=f, daemon=True)
+        t.start()
+        return t
+
+
+def gather(*ts):
+    """
+    Equivalent to asyncio.gather or Thread.join()
+    """
+    if ts and inspect.isawaitable(ts[0]):
+        return asyncio.gather(*ts)
+    else:
+        for t in ts:
+            t.join()
+
+
+def asleep(s):
+    """
+    Equivalent to asyncio.sleep(), converted to time.sleep() by async_to_sync.
+    """
+    return asyncio.sleep(s)
index ab2c4fb8e9cf52046a4ea93a92e96c38c1fc9073..40ed2ccc6b5b308f3ac0f1136d6ec7f9e79d1e14 100755 (executable)
@@ -179,12 +179,14 @@ class RenameAsyncToSync(ast.NodeTransformer):
         "aconn_cls": "conn_cls",
         "alist": "list",
         "anext": "next",
+        "asleep": "sleep",
         "apipeline": "pipeline",
         "asynccontextmanager": "contextmanager",
         "connection_async": "connection",
         "cursor_async": "cursor",
         "ensure_table_async": "ensure_table",
         "find_insert_problem_async": "find_insert_problem",
+        "psycopg_pool.sched_async": "psycopg_pool.sched",
         "wait_async": "wait",
         "wait_conn_async": "wait_conn",
     }
index bea2963a86978f46f30d8f558e0af1a0d57cf92c..983d9c8e40ba9687b13564a35f41315feb9e55f1 100755 (executable)
@@ -21,6 +21,7 @@ for async in \
     psycopg/psycopg/connection_async.py \
     psycopg/psycopg/cursor_async.py \
     psycopg_pool/psycopg_pool/sched_async.py \
+    tests/pool/test_sched_async.py \
     tests/test_client_cursor_async.py \
     tests/test_connection_async.py \
     tests/test_copy_async.py \