]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(tests): generate test_pool from async counterpart
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 11 Sep 2023 21:57:44 +0000 (22:57 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 11 Oct 2023 21:45:38 +0000 (23:45 +0200)
tests/pool/test_module.py [new file with mode: 0644]
tests/pool/test_pool.py
tests/pool/test_pool_async.py
tests/utils.py
tools/async_to_sync.py
tools/convert_async_to_sync.sh

diff --git a/tests/pool/test_module.py b/tests/pool/test_module.py
new file mode 100644 (file)
index 0000000..31f77aa
--- /dev/null
@@ -0,0 +1,8 @@
+def test_version(mypy):
+    cp = mypy.run_on_source(
+        """\
+from psycopg_pool import __version__
+assert __version__
+"""
+    )
+    assert not cp.stdout
index 7234a3c84c7da753e53927e1f2a26a9e72a817f0..2032fab89c9d448324c396d2e5bf08d4fe514429 100644 (file)
@@ -1,7 +1,9 @@
+# WARNING: this file is auto-generated by 'async_to_sync.py'
+# from the original file 'test_pool_async.py'
+# DO NOT CHANGE! Change the original file instead.
 import logging
 import weakref
-from time import sleep, time
-from threading import Thread, Event
+from time import time
 from typing import Any, Dict, List, Tuple
 
 import pytest
@@ -9,7 +11,9 @@ import pytest
 import psycopg
 from psycopg.pq import TransactionStatus
 from psycopg.rows import class_row, Row, TupleRow
-from psycopg._compat import Counter, assert_type
+from psycopg._compat import assert_type, Counter
+
+from ..utils import Event, spawn, gather, sleep, is_alive, is_async
 
 try:
     import psycopg_pool as pool
@@ -18,16 +22,6 @@ except ImportError:
     pass
 
 
-def test_package_version(mypy):
-    cp = mypy.run_on_source(
-        """\
-from psycopg_pool import __version__
-assert __version__
-"""
-    )
-    assert not cp.stdout
-
-
 def test_defaults(dsn):
     with pool.ConnectionPool(dsn) as p:
         assert p.min_size == p.max_size == 4
@@ -71,7 +65,7 @@ class MyRow(Dict[str, Any]):
 
 def test_generic_connection_type(dsn):
     def set_autocommit(conn: psycopg.Connection[Any]) -> None:
-        conn.autocommit = True
+        conn.set_autocommit(True)
 
     class MyConnection(psycopg.Connection[Row]):
         pass
@@ -81,8 +75,10 @@ def test_generic_connection_type(dsn):
         connection_class=MyConnection[MyRow],
         kwargs=dict(row_factory=class_row(MyRow)),
         configure=set_autocommit,
-    ) as p1, p1.connection() as conn1:
-        (row1,) = conn1.execute("select 1 as x").fetchall()
+    ) as p1:
+        with p1.connection() as conn1:
+            cur1 = conn1.execute("select 1 as x")
+            (row1,) = cur1.fetchall()
     assert_type(p1, pool.ConnectionPool[MyConnection[MyRow]])
     assert_type(conn1, MyConnection[MyRow])
     assert_type(row1, MyRow)
@@ -91,7 +87,8 @@ def test_generic_connection_type(dsn):
 
     with pool.ConnectionPool(dsn, connection_class=MyConnection[TupleRow]) as p2:
         with p2.connection() as conn2:
-            (row2,) = conn2.execute("select 2 as y").fetchall()
+            cur2 = conn2.execute("select 2 as y")
+            (row2,) = cur2.fetchall()
     assert_type(p2, pool.ConnectionPool[MyConnection[TupleRow]])
     assert_type(conn2, MyConnection[TupleRow])
     assert_type(row2, TupleRow)
@@ -100,7 +97,7 @@ def test_generic_connection_type(dsn):
 
 def test_non_generic_connection_type(dsn):
     def set_autocommit(conn: psycopg.Connection[Any]) -> None:
-        conn.autocommit = True
+        conn.set_autocommit(True)
 
     class MyConnection(psycopg.Connection[MyRow]):
         def __init__(self, *args: Any, **kwargs: Any):
@@ -111,7 +108,8 @@ def test_non_generic_connection_type(dsn):
         dsn, connection_class=MyConnection, configure=set_autocommit
     ) as p1:
         with p1.connection() as conn1:
-            (row1,) = conn1.execute("select 1 as x").fetchall()
+            cur1 = conn1.execute("select 1 as x")
+            (row1,) = cur1.fetchall()
     assert_type(p1, pool.ConnectionPool[MyConnection])
     assert_type(conn1, MyConnection)
     assert_type(row1, MyRow)
@@ -222,7 +220,7 @@ def test_configure(dsn):
             conn.execute("set default_transaction_read_only to on")
 
     with pool.ConnectionPool(dsn, min_size=1, configure=configure) as p:
-        p.wait()
+        p.wait(timeout=1.0)
         with p.connection() as conn:
             assert inits == 1
             res = conn.execute("show default_transaction_read_only")
@@ -293,8 +291,8 @@ def test_reset(dsn):
         assert resets == 1
 
         with p.connection() as conn:
-            with conn.execute("show timezone") as cur:
-                assert cur.fetchone() == ("UTC",)
+            cur = conn.execute("show timezone")
+            assert cur.fetchone() == ("UTC",)
 
         p.wait()
         assert resets == 2
@@ -358,18 +356,15 @@ def test_queue(dsn):
     results: List[Tuple[int, float, int]] = []
     with pool.ConnectionPool(dsn, min_size=2) as p:
         p.wait()
-        ts = [Thread(target=worker, args=(i,)) for i in range(6)]
-        for t in ts:
-            t.start()
-        for t in ts:
-            t.join()
+        ts = [spawn(worker, args=(i,)) for i in range(6)]
+        gather(*ts)
 
     times = [item[1] for item in results]
     want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6]
     for got, want in zip(times, want_times):
         assert got == pytest.approx(want, 0.1), times
 
-    assert len(set(r[2] for r in results)) == 2, results
+    assert len(set((r[2] for r in results))) == 2, results
 
 
 @pytest.mark.slow
@@ -391,15 +386,11 @@ def test_queue_size(dsn):
     with pool.ConnectionPool(dsn, min_size=1, max_waiting=3) as p:
         p.wait()
         ev = Event()
-        t = Thread(target=worker, args=(0.3, ev))
-        t.start()
+        spawn(worker, args=(0.3, ev))
         ev.wait()
 
-        ts = [Thread(target=worker, args=(0.1,)) for i in range(4)]
-        for t in ts:
-            t.start()
-        for t in ts:
-            t.join()
+        ts = [spawn(worker, args=(0.1,)) for i in range(4)]
+        gather(*ts)
 
     assert len(success) == 4
     assert len(errors) == 1
@@ -430,11 +421,8 @@ def test_queue_timeout(dsn):
     errors: List[Tuple[int, float, Exception]] = []
 
     with pool.ConnectionPool(dsn, min_size=2, timeout=0.1) as p:
-        ts = [Thread(target=worker, args=(i,)) for i in range(4)]
-        for t in ts:
-            t.start()
-        for t in ts:
-            t.join()
+        ts = [spawn(worker, args=(i,)) for i in range(4)]
+        gather(*ts)
 
     assert len(results) == 2
     assert len(errors) == 2
@@ -454,17 +442,14 @@ def test_dead_client(dsn):
             if timeout > 0.2:
                 raise
 
-    results: List[int] = []
-
     with pool.ConnectionPool(dsn, min_size=2) as p:
+        results: List[int] = []
         ts = [
-            Thread(target=worker, args=(i, timeout))
-            for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4])
+            spawn(worker, args=(i, timeout))
+            for (i, timeout) in enumerate([0.4, 0.4, 0.1, 0.4, 0.4])
         ]
-        for t in ts:
-            t.start()
-        for t in ts:
-            t.join()
+        gather(*ts)
+
         sleep(0.2)
         assert set(results) == set([0, 1, 3, 4])
         assert len(p._pool) == 2  # no connection was lost
@@ -492,11 +477,8 @@ def test_queue_timeout_override(dsn):
     errors: List[Tuple[int, float, Exception]] = []
 
     with pool.ConnectionPool(dsn, min_size=2, timeout=0.1) as p:
-        ts = [Thread(target=worker, args=(i,)) for i in range(4)]
-        for t in ts:
-            t.start()
-        for t in ts:
-            t.join()
+        ts = [spawn(worker, args=(i,)) for i in range(4)]
+        gather(*ts)
 
     assert len(results) == 3
     assert len(errors) == 1
@@ -531,9 +513,10 @@ def test_intrans_rollback(dsn, caplog):
         with p.connection() as conn2:
             assert conn2.info.backend_pid == pid
             assert conn2.info.transaction_status == TransactionStatus.IDLE
-            assert not conn2.execute(
+            cur = conn2.execute(
                 "select 1 from pg_class where relname = 'test_intrans_rollback'"
-            ).fetchone()
+            )
+            assert not cur.fetchone()
 
     assert len(caplog.records) == 1
     assert "INTRANS" in caplog.records[0].message
@@ -611,19 +594,19 @@ def test_fail_rollback_close(dsn, caplog, monkeypatch):
     assert "BAD" in caplog.records[2].message
 
 
-def test_close_no_threads(dsn):
+def test_close_no_tasks(dsn):
     p = pool.ConnectionPool(dsn)
-    assert p._sched_runner and p._sched_runner.is_alive()
+    assert p._sched_runner and is_alive(p._sched_runner)
     workers = p._workers[:]
     assert workers
     for t in workers:
-        assert t.is_alive()
+        assert is_alive(t)
 
     p.close()
     assert p._sched_runner is None
     assert not p._workers
     for t in workers:
-        assert not t.is_alive()
+        assert not is_alive(t)
 
 
 def test_putconn_no_pool(conn_cls, dsn):
@@ -656,14 +639,15 @@ def test_del_no_warning(dsn, recwarn):
 
 
 @pytest.mark.slow
-def test_del_stop_threads(dsn):
+@pytest.mark.skipif(is_async(__name__), reason="sync test only")
+def test_del_stops_threads(dsn):
     p = pool.ConnectionPool(dsn)
     assert p._sched_runner is not None
     ts = [p._sched_runner] + p._workers
     del p
     sleep(0.1)
     for t in ts:
-        assert not t.is_alive()
+        assert not is_alive(t), t
 
 
 def test_closed_getconn(dsn):
@@ -715,22 +699,18 @@ def test_closed_queue(dsn):
     p.wait()
     success: List[str] = []
 
-    t1 = Thread(target=w1)
-    t1.start()
+    t1 = spawn(w1)
     # Wait until w1 has received a connection
     e1.wait()
 
-    t2 = Thread(target=w2)
-    t2.start()
+    t2 = spawn(w2)
     # Wait until w2 is in the queue
     ensure_waiting(p)
-
-    p.close(0)
+    p.close()
 
     # Wait for the workers to finish
     e2.set()
-    t1.join()
-    t2.join()
+    gather(t1, t2)
     assert len(success) == 2
 
 
@@ -740,7 +720,7 @@ def test_open_explicit(dsn):
     with pytest.raises(pool.PoolClosed, match="is not open yet"):
         p.getconn()
 
-    with pytest.raises(pool.PoolClosed):
+    with pytest.raises(pool.PoolClosed, match="is not open yet"):
         with p.connection():
             pass
 
@@ -751,7 +731,6 @@ def test_open_explicit(dsn):
         with p.connection() as conn:
             cur = conn.execute("select 1")
             assert cur.fetchone() == (1,)
-
     finally:
         p.close()
 
@@ -783,7 +762,6 @@ def test_open_no_op(dsn):
         with p.connection() as conn:
             cur = conn.execute("select 1")
             assert cur.fetchone() == (1,)
-
     finally:
         p.close()
 
@@ -835,8 +813,8 @@ def test_reopen(dsn):
 @pytest.mark.parametrize(
     "min_size, want_times",
     [
-        (2, [0.25, 0.25, 0.35, 0.45, 0.50, 0.50, 0.60, 0.70]),
-        (0, [0.35, 0.45, 0.55, 0.60, 0.65, 0.70, 0.80, 0.85]),
+        (2, [0.25, 0.25, 0.35, 0.45, 0.5, 0.5, 0.6, 0.7]),
+        (0, [0.35, 0.45, 0.55, 0.6, 0.65, 0.7, 0.8, 0.85]),
     ],
 )
 def test_grow(dsn, monkeypatch, min_size, want_times):
@@ -852,12 +830,8 @@ def test_grow(dsn, monkeypatch, min_size, want_times):
     with pool.ConnectionPool(dsn, min_size=min_size, max_size=4, num_workers=3) as p:
         p.wait(1.0)
         results: List[Tuple[int, float]] = []
-
-        ts = [Thread(target=worker, args=(i,)) for i in range(len(want_times))]
-        for t in ts:
-            t.start()
-        for t in ts:
-            t.join()
+        ts = [spawn(worker, args=(i,)) for i in range(len(want_times))]
+        gather(*ts)
 
     times = [item[1] for item in results]
     for got, want in zip(times, want_times):
@@ -888,11 +862,9 @@ def test_shrink(dsn, monkeypatch):
         p.wait(5.0)
         assert p.max_idle == 0.2
 
-        ts = [Thread(target=worker, args=(i,)) for i in range(4)]
-        for t in ts:
-            t.start()
-        for t in ts:
-            t.join()
+        ts = [spawn(worker, args=(i,)) for i in range(4)]
+        gather(*ts)
+
         sleep(1)
 
     assert results == [(4, 4), (4, 3), (3, 2), (2, 2), (2, 2)]
@@ -938,15 +910,28 @@ def test_reconnect(proxy, caplog, monkeypatch):
 
 @pytest.mark.slow
 @pytest.mark.timing
-def test_reconnect_failure(proxy):
+@pytest.mark.parametrize("async_cb", [True, False])
+def test_reconnect_failure(proxy, async_cb):
+    if async_cb and (not is_async(__name__)):
+        pytest.skip("async test only")
+
     proxy.start()
 
     t1 = None
 
-    def failed(pool):
-        assert pool.name == "this-one"
-        nonlocal t1
-        t1 = time()
+    if async_cb:
+
+        def failed(pool):
+            assert pool.name == "this-one"
+            nonlocal t1
+            t1 = time()
+
+    else:
+
+        def failed(pool):
+            assert pool.name == "this-one"
+            nonlocal t1
+            t1 = time()
 
     with pool.ConnectionPool(
         proxy.client_dsn,
@@ -990,14 +975,14 @@ def test_reconnect_after_grow_failed(proxy):
     with pool.ConnectionPool(
         proxy.client_dsn, min_size=4, reconnect_timeout=1.0, reconnect_failed=failed
     ) as p:
-        assert ev.wait(timeout=2)
+        ev.wait(2.0)
 
         with pytest.raises(pool.PoolTimeout):
             with p.connection(timeout=0.5) as conn:
                 pass
 
         ev.clear()
-        assert ev.wait(timeout=2)
+        ev.wait(2.0)
 
         proxy.start()
 
@@ -1027,7 +1012,7 @@ def test_refill_on_check(proxy):
 
         # Checking the pool will empty it
         p.check()
-        assert ev.wait(timeout=2)
+        ev.wait(2.0)
         assert len(p._pool) == 0
 
         # Allow to connect again
@@ -1070,12 +1055,11 @@ def test_resize(dsn):
     size: List[int] = []
 
     with pool.ConnectionPool(dsn, min_size=2, max_idle=0.2) as p:
-        s = Thread(target=sampler)
-        s.start()
+        s = spawn(sampler)
 
         sleep(0.3)
-        c = Thread(target=client, args=(0.4,))
-        c.start()
+
+        c = spawn(client, args=(0.4,))
 
         sleep(0.2)
         p.resize(4)
@@ -1089,7 +1073,7 @@ def test_resize(dsn):
 
         sleep(0.6)
 
-    s.join()
+    gather(s, c)
     assert size == [2, 1, 3, 4, 3, 2, 2]
 
 
@@ -1130,7 +1114,7 @@ def test_check(dsn, caplog):
             pid = conn.info.backend_pid
 
         p.wait(1.0)
-        pids = set(conn.info.backend_pid for conn in p._pool)
+        pids = set((conn.info.backend_pid for conn in p._pool))
         assert pid in pids
         conn.close()
 
@@ -1138,7 +1122,7 @@ def test_check(dsn, caplog):
         p.check()
         assert len(caplog.records) == 1
         p.wait(1.0)
-        pids2 = set(conn.info.backend_pid for conn in p._pool)
+        pids2 = set((conn.info.backend_pid for conn in p._pool))
         assert len(pids & pids2) == 3
         assert pid not in pids2
 
@@ -1181,13 +1165,10 @@ def test_stats_measures(dsn):
         assert stats["pool_available"] == 2
         assert stats["requests_waiting"] == 0
 
-        ts = [Thread(target=worker, args=(i,)) for i in range(3)]
-        for t in ts:
-            t.start()
+        ts = [spawn(worker, args=(i,)) for i in range(3)]
         sleep(0.1)
         stats = p.get_stats()
-        for t in ts:
-            t.join()
+        gather(*ts)
         assert stats["pool_min"] == 2
         assert stats["pool_max"] == 4
         assert stats["pool_size"] == 3
@@ -1195,13 +1176,10 @@ def test_stats_measures(dsn):
         assert stats["requests_waiting"] == 0
 
         p.wait(2.0)
-        ts = [Thread(target=worker, args=(i,)) for i in range(7)]
-        for t in ts:
-            t.start()
+        ts = [spawn(worker, args=(i,)) for i in range(7)]
         sleep(0.1)
         stats = p.get_stats()
-        for t in ts:
-            t.join()
+        gather(*ts)
         assert stats["pool_min"] == 2
         assert stats["pool_max"] == 4
         assert stats["pool_size"] == 4
@@ -1222,11 +1200,8 @@ def test_stats_usage(dsn):
     with pool.ConnectionPool(dsn, min_size=3) as p:
         p.wait(2.0)
 
-        ts = [Thread(target=worker, args=(i,)) for i in range(7)]
-        for t in ts:
-            t.start()
-        for t in ts:
-            t.join()
+        ts = [spawn(worker, args=(i,)) for i in range(7)]
+        gather(*ts)
         stats = p.get_stats()
         assert stats["requests_num"] == 7
         assert stats["requests_queued"] == 4
@@ -1256,7 +1231,7 @@ def test_stats_connect(dsn, proxy, monkeypatch):
         assert stats["connections_num"] == 3
         assert stats.get("connections_errors", 0) == 0
         assert stats.get("connections_lost", 0) == 0
-        assert 600 <= stats["connections_ms"] < 1200
+        assert 580 <= stats["connections_ms"] < 1200
 
         proxy.stop()
         p.check()
@@ -1280,11 +1255,8 @@ def test_spike(dsn, monkeypatch):
     with pool.ConnectionPool(dsn, min_size=5, max_size=10) as p:
         p.wait()
 
-        ts = [Thread(target=worker) for i in range(50)]
-        for t in ts:
-            t.start()
-        for t in ts:
-            t.join()
+        ts = [spawn(worker) for i in range(50)]
+        gather(*ts)
         p.wait()
 
         assert len(p._pool) < 7
@@ -1300,15 +1272,68 @@ def test_debug_deadlock(dsn):
     logger.addHandler(handler)
     try:
         with pool.ConnectionPool(dsn, min_size=4, open=True) as p:
-            try:
-                p.wait(timeout=2)
-            finally:
-                print(p.get_stats())
+            p.wait(timeout=2)
     finally:
         logger.removeHandler(handler)
         logger.setLevel(old_level)
 
 
+@pytest.mark.skipif(not is_async(__name__), reason="async test only")
+def test_cancellation_in_queue(dsn):
+    # https://github.com/psycopg/psycopg/issues/509
+
+    nconns = 3
+
+    with pool.ConnectionPool(dsn, min_size=nconns, timeout=1) as p:
+        p.wait()
+
+        got_conns = []
+        ev = Event()
+
+        def worker(i):
+            try:
+                logging.info("worker %s started", i)
+                nonlocal got_conns
+
+                with p.connection() as conn:
+                    logging.info("worker %s got conn", i)
+                    cur = conn.execute("select 1")
+                    assert cur.fetchone() == (1,)
+
+                    got_conns.append(conn)
+                    if len(got_conns) >= nconns:
+                        ev.set()
+
+                    sleep(5)
+            except BaseException as ex:
+                logging.info("worker %s stopped: %r", i, ex)
+                raise
+
+        # Start tasks taking up all the connections and getting in the queue
+        tasks = [spawn(worker, (i,)) for i in range(nconns * 3)]
+
+        # wait until the pool has served all the connections and clients are queued.
+        ev.wait(3.0)
+        for i in range(10):
+            if p.get_stats().get("requests_queued", 0):
+                break
+            else:
+                sleep(0.1)
+        else:
+            pytest.fail("no client got in the queue")
+
+        [task.cancel() for task in reversed(tasks)]
+        gather(*tasks, return_exceptions=True, timeout=1.0)
+
+        stats = p.get_stats()
+        assert stats["pool_available"] == 3
+        assert stats.get("requests_waiting", 0) == 0
+
+        with p.connection() as conn:
+            cur = conn.execute("select 1")
+            assert cur.fetchone() == (1,)
+
+
 def delay_connection(monkeypatch, sec):
     """
     Return a _connect_gen function delayed by the amount of seconds
index 0726f35c727dcd559a9c279d3afeae35e0f8df93..92a8f7bf59a4b61d54c3c3cd4e494516cb1ad531 100644 (file)
@@ -1,8 +1,7 @@
-import asyncio
 import logging
+import weakref
 from time import time
 from typing import Any, Dict, List, Tuple
-from asyncio import create_task
 
 import pytest
 
@@ -11,13 +10,16 @@ from psycopg.pq import TransactionStatus
 from psycopg.rows import class_row, Row, TupleRow
 from psycopg._compat import assert_type, Counter
 
+from ..utils import AEvent, spawn, gather, asleep, is_alive, is_async
+
 try:
     import psycopg_pool as pool
 except ImportError:
     # Tests should have been skipped if the package is not available
     pass
 
-pytestmark = [pytest.mark.anyio]
+if True:  # ASYNC
+    pytestmark = [pytest.mark.anyio]
 
 
 async def test_defaults(dsn):
@@ -210,7 +212,7 @@ async def test_setup_no_timeout(dsn, proxy):
     async with pool.AsyncConnectionPool(
         proxy.client_dsn, min_size=1, num_workers=1
     ) as p:
-        await asyncio.sleep(0.5)
+        await asleep(0.5)
         assert not p._pool
         proxy.start()
 
@@ -364,8 +366,8 @@ async def test_queue(dsn):
     results: List[Tuple[int, float, int]] = []
     async with pool.AsyncConnectionPool(dsn, min_size=2) as p:
         await p.wait()
-        ts = [create_task(worker(i)) for i in range(6)]
-        await asyncio.gather(*ts)
+        ts = [spawn(worker, args=(i,)) for i in range(6)]
+        await gather(*ts)
 
     times = [item[1] for item in results]
     want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6]
@@ -382,7 +384,7 @@ async def test_queue_size(dsn):
             async with p.connection():
                 if ev:
                     ev.set()
-                await asyncio.sleep(t)
+                await asleep(t)
         except pool.TooManyRequests as e:
             errors.append(e)
         else:
@@ -393,12 +395,12 @@ async def test_queue_size(dsn):
 
     async with pool.AsyncConnectionPool(dsn, min_size=1, max_waiting=3) as p:
         await p.wait()
-        ev = asyncio.Event()
-        create_task(worker(0.3, ev))
+        ev = AEvent()
+        spawn(worker, args=(0.3, ev))
         await ev.wait()
 
-        ts = [create_task(worker(0.1)) for i in range(4)]
-        await asyncio.gather(*ts)
+        ts = [spawn(worker, args=(0.1,)) for i in range(4)]
+        await gather(*ts)
 
     assert len(success) == 4
     assert len(errors) == 1
@@ -429,8 +431,8 @@ async def test_queue_timeout(dsn):
     errors: List[Tuple[int, float, Exception]] = []
 
     async with pool.AsyncConnectionPool(dsn, min_size=2, timeout=0.1) as p:
-        ts = [create_task(worker(i)) for i in range(4)]
-        await asyncio.gather(*ts)
+        ts = [spawn(worker, args=(i,)) for i in range(4)]
+        await gather(*ts)
 
     assert len(results) == 2
     assert len(errors) == 2
@@ -453,12 +455,12 @@ async def test_dead_client(dsn):
     async with pool.AsyncConnectionPool(dsn, min_size=2) as p:
         results: List[int] = []
         ts = [
-            create_task(worker(i, timeout))
+            spawn(worker, args=(i, timeout))
             for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4])
         ]
-        await asyncio.gather(*ts)
+        await gather(*ts)
 
-        await asyncio.sleep(0.2)
+        await asleep(0.2)
         assert set(results) == set([0, 1, 3, 4])
         assert len(p._pool) == 2  # no connection was lost
 
@@ -485,8 +487,8 @@ async def test_queue_timeout_override(dsn):
     errors: List[Tuple[int, float, Exception]] = []
 
     async with pool.AsyncConnectionPool(dsn, min_size=2, timeout=0.1) as p:
-        ts = [create_task(worker(i)) for i in range(4)]
-        await asyncio.gather(*ts)
+        ts = [spawn(worker, args=(i,)) for i in range(4)]
+        await gather(*ts)
 
     assert len(results) == 3
     assert len(errors) == 1
@@ -604,17 +606,17 @@ async def test_fail_rollback_close(dsn, caplog, monkeypatch):
 
 async def test_close_no_tasks(dsn):
     p = pool.AsyncConnectionPool(dsn)
-    assert p._sched_runner and not p._sched_runner.done()
-    assert p._workers
+    assert p._sched_runner and is_alive(p._sched_runner)
     workers = p._workers[:]
+    assert workers
     for t in workers:
-        assert not t.done()
+        assert is_alive(t)
 
     await p.close()
     assert p._sched_runner is None
     assert not p._workers
     for t in workers:
-        assert t.done()
+        assert not is_alive(t)
 
 
 async def test_putconn_no_pool(aconn_cls, dsn):
@@ -634,6 +636,30 @@ async def test_putconn_wrong_pool(dsn):
                 await p2.putconn(conn)
 
 
+async def test_del_no_warning(dsn, recwarn):
+    p = pool.AsyncConnectionPool(dsn, min_size=2)
+    async with p.connection() as conn:
+        await conn.execute("select 1")
+
+    await p.wait()
+    ref = weakref.ref(p)
+    del p
+    assert not ref()
+    assert not recwarn, [str(w.message) for w in recwarn.list]
+
+
+@pytest.mark.slow
+@pytest.mark.skipif(is_async(__name__), reason="sync test only")
+async def test_del_stops_threads(dsn):
+    p = pool.AsyncConnectionPool(dsn)
+    assert p._sched_runner is not None
+    ts = [p._sched_runner] + p._workers
+    del p
+    await asleep(0.1)
+    for t in ts:
+        assert not is_alive(t), t
+
+
 async def test_closed_getconn(dsn):
     p = pool.AsyncConnectionPool(dsn, min_size=1)
     assert not p.closed
@@ -676,32 +702,32 @@ async def test_closed_queue(dsn):
         except pool.PoolClosed:
             success.append("w2")
 
-    e1 = asyncio.Event()
-    e2 = asyncio.Event()
+    e1 = AEvent()
+    e2 = AEvent()
 
     p = pool.AsyncConnectionPool(dsn, min_size=1)
     await p.wait()
     success: List[str] = []
 
-    t1 = create_task(w1())
+    t1 = spawn(w1)
     # Wait until w1 has received a connection
     await e1.wait()
 
-    t2 = create_task(w2())
+    t2 = spawn(w2)
     # Wait until w2 is in the queue
     await ensure_waiting(p)
     await p.close()
 
     # Wait for the workers to finish
     e2.set()
-    await asyncio.gather(t1, t2)
+    await gather(t1, t2)
     assert len(success) == 2
 
 
 async def test_open_explicit(dsn):
     p = pool.AsyncConnectionPool(dsn, open=False)
     assert p.closed
-    with pytest.raises(pool.PoolClosed):
+    with pytest.raises(pool.PoolClosed, match="is not open yet"):
         await p.getconn()
 
     with pytest.raises(pool.PoolClosed, match="is not open yet"):
@@ -788,6 +814,7 @@ async def test_reopen(dsn):
         await conn.execute("select 1")
     await p.close()
     assert p._sched_runner is None
+    assert not p._workers
 
     with pytest.raises(psycopg.OperationalError, match="cannot be reused"):
         await p.open()
@@ -816,11 +843,9 @@ async def test_grow(dsn, monkeypatch, min_size, want_times):
         dsn, min_size=min_size, max_size=4, num_workers=3
     ) as p:
         await p.wait(1.0)
-        ts = []
         results: List[Tuple[int, float]] = []
-
-        ts = [create_task(worker(i)) for i in range(len(want_times))]
-        await asyncio.gather(*ts)
+        ts = [spawn(worker, args=(i,)) for i in range(len(want_times))]
+        await gather(*ts)
 
     times = [item[1] for item in results]
     for got, want in zip(times, want_times):
@@ -851,10 +876,10 @@ async def test_shrink(dsn, monkeypatch):
         await p.wait(5.0)
         assert p.max_idle == 0.2
 
-        ts = [create_task(worker(i)) for i in range(4)]
-        await asyncio.gather(*ts)
+        ts = [spawn(worker, args=(i,)) for i in range(4)]
+        await gather(*ts)
 
-        await asyncio.sleep(1)
+        await asleep(1)
 
     assert results == [(4, 4), (4, 3), (3, 2), (2, 2), (2, 2)]
 
@@ -879,7 +904,7 @@ async def test_reconnect(proxy, caplog, monkeypatch):
             async with p.connection() as conn:
                 await conn.execute("select 1")
 
-        await asyncio.sleep(1.0)
+        await asleep(1.0)
         proxy.start()
         await p.wait()
 
@@ -901,6 +926,9 @@ async def test_reconnect(proxy, caplog, monkeypatch):
 @pytest.mark.timing
 @pytest.mark.parametrize("async_cb", [True, False])
 async def test_reconnect_failure(proxy, async_cb):
+    if async_cb and not is_async(__name__):
+        pytest.skip("async test only")
+
     proxy.start()
 
     t1 = None
@@ -934,7 +962,7 @@ async def test_reconnect_failure(proxy, async_cb):
                 await conn.execute("select 1")
 
         t0 = time()
-        await asyncio.sleep(1.5)
+        await asleep(1.5)
         assert t1
         assert t1 - t0 == pytest.approx(1.0, 0.1)
         assert p._nconns == 0
@@ -953,7 +981,7 @@ async def test_reconnect_after_grow_failed(proxy):
     # in grow mode. See issue #370.
     proxy.stop()
 
-    ev = asyncio.Event()
+    ev = AEvent()
 
     def failed(pool):
         ev.set()
@@ -961,14 +989,14 @@ async def test_reconnect_after_grow_failed(proxy):
     async with pool.AsyncConnectionPool(
         proxy.client_dsn, min_size=4, reconnect_timeout=1.0, reconnect_failed=failed
     ) as p:
-        await asyncio.wait_for(ev.wait(), 2.0)
+        await ev.wait_timeout(2.0)
 
         with pytest.raises(pool.PoolTimeout):
             async with p.connection(timeout=0.5) as conn:
                 pass
 
         ev.clear()
-        await asyncio.wait_for(ev.wait(), 2.0)
+        await ev.wait_timeout(2.0)
 
         proxy.start()
 
@@ -982,7 +1010,7 @@ async def test_reconnect_after_grow_failed(proxy):
 @pytest.mark.slow
 async def test_refill_on_check(proxy):
     proxy.start()
-    ev = asyncio.Event()
+    ev = AEvent()
 
     def failed(pool):
         ev.set()
@@ -998,7 +1026,7 @@ async def test_refill_on_check(proxy):
 
         # Checking the pool will empty it
         await p.check()
-        await asyncio.wait_for(ev.wait(), 2.0)
+        await ev.wait_timeout(2.0)
         assert len(p._pool) == 0
 
         # Allow to connect again
@@ -1016,7 +1044,7 @@ async def test_uniform_use(dsn):
         counts = Counter[int]()
         for i in range(8):
             async with p.connection() as conn:
-                await asyncio.sleep(0.1)
+                await asleep(0.1)
                 counts[id(conn)] += 1
 
     assert len(counts) == 4
@@ -1027,9 +1055,9 @@ async def test_uniform_use(dsn):
 @pytest.mark.timing
 async def test_resize(dsn):
     async def sampler():
-        await asyncio.sleep(0.05)  # ensure sampling happens after shrink check
+        await asleep(0.05)  # ensure sampling happens after shrink check
         while True:
-            await asyncio.sleep(0.2)
+            await asleep(0.2)
             if p.closed:
                 break
             size.append(len(p._pool))
@@ -1041,25 +1069,25 @@ async def test_resize(dsn):
     size: List[int] = []
 
     async with pool.AsyncConnectionPool(dsn, min_size=2, max_idle=0.2) as p:
-        s = create_task(sampler())
+        s = spawn(sampler)
 
-        await asyncio.sleep(0.3)
+        await asleep(0.3)
 
-        c = create_task(client(0.4))
+        c = spawn(client, args=(0.4,))
 
-        await asyncio.sleep(0.2)
+        await asleep(0.2)
         await p.resize(4)
         assert p.min_size == 4
         assert p.max_size == 4
 
-        await asyncio.sleep(0.4)
+        await asleep(0.4)
         await p.resize(2)
         assert p.min_size == 2
         assert p.max_size == 2
 
-        await asyncio.sleep(0.6)
+        await asleep(0.6)
 
-    await asyncio.gather(s, c)
+    await gather(s, c)
     assert size == [2, 1, 3, 4, 3, 2, 2]
 
 
@@ -1081,12 +1109,12 @@ async def test_jitter():
 @pytest.mark.crdb_skip("backend pid")
 async def test_max_lifetime(dsn):
     async with pool.AsyncConnectionPool(dsn, min_size=1, max_lifetime=0.2) as p:
-        await asyncio.sleep(0.1)
+        await asleep(0.1)
         pids = []
         for i in range(5):
             async with p.connection() as conn:
                 pids.append(conn.info.backend_pid)
-            await asyncio.sleep(0.2)
+            await asleep(0.2)
 
     assert pids[0] == pids[1] != pids[4], pids
 
@@ -1128,7 +1156,7 @@ async def test_check_max_lifetime(dsn):
             pid = conn.info.backend_pid
         async with p.connection() as conn:
             assert conn.info.backend_pid == pid
-        await asyncio.sleep(0.3)
+        await asleep(0.3)
         await p.check()
         async with p.connection() as conn:
             assert conn.info.backend_pid != pid
@@ -1151,10 +1179,10 @@ async def test_stats_measures(dsn):
         assert stats["pool_available"] == 2
         assert stats["requests_waiting"] == 0
 
-        ts = [create_task(worker(i)) for i in range(3)]
-        await asyncio.sleep(0.1)
+        ts = [spawn(worker, args=(i,)) for i in range(3)]
+        await asleep(0.1)
         stats = p.get_stats()
-        await asyncio.gather(*ts)
+        await gather(*ts)
         assert stats["pool_min"] == 2
         assert stats["pool_max"] == 4
         assert stats["pool_size"] == 3
@@ -1162,10 +1190,10 @@ async def test_stats_measures(dsn):
         assert stats["requests_waiting"] == 0
 
         await p.wait(2.0)
-        ts = [create_task(worker(i)) for i in range(7)]
-        await asyncio.sleep(0.1)
+        ts = [spawn(worker, args=(i,)) for i in range(7)]
+        await asleep(0.1)
         stats = p.get_stats()
-        await asyncio.gather(*ts)
+        await gather(*ts)
         assert stats["pool_min"] == 2
         assert stats["pool_max"] == 4
         assert stats["pool_size"] == 4
@@ -1186,8 +1214,8 @@ async def test_stats_usage(dsn):
     async with pool.AsyncConnectionPool(dsn, min_size=3) as p:
         await p.wait(2.0)
 
-        ts = [create_task(worker(i)) for i in range(7)]
-        await asyncio.gather(*ts)
+        ts = [spawn(worker, args=(i,)) for i in range(7)]
+        await gather(*ts)
         stats = p.get_stats()
         assert stats["requests_num"] == 7
         assert stats["requests_queued"] == 4
@@ -1221,7 +1249,7 @@ async def test_stats_connect(dsn, proxy, monkeypatch):
 
         proxy.stop()
         await p.check()
-        await asyncio.sleep(0.1)
+        await asleep(0.1)
         stats = p.get_stats()
         assert stats["connections_num"] > 3
         assert stats["connections_errors"] > 0
@@ -1236,13 +1264,13 @@ async def test_spike(dsn, monkeypatch):
 
     async def worker():
         async with p.connection():
-            await asyncio.sleep(0.002)
+            await asleep(0.002)
 
     async with pool.AsyncConnectionPool(dsn, min_size=5, max_size=10) as p:
         await p.wait()
 
-        ts = [create_task(worker()) for i in range(50)]
-        await asyncio.gather(*ts)
+        ts = [spawn(worker) for i in range(50)]
+        await gather(*ts)
         await p.wait()
 
         assert len(p._pool) < 7
@@ -1264,6 +1292,7 @@ async def test_debug_deadlock(dsn):
         logger.setLevel(old_level)
 
 
+@pytest.mark.skipif(not is_async(__name__), reason="async test only")
 async def test_cancellation_in_queue(dsn):
     # https://github.com/psycopg/psycopg/issues/509
 
@@ -1273,7 +1302,7 @@ async def test_cancellation_in_queue(dsn):
         await p.wait()
 
         got_conns = []
-        ev = asyncio.Event()
+        ev = AEvent()
 
         async def worker(i):
             try:
@@ -1289,27 +1318,27 @@ async def test_cancellation_in_queue(dsn):
                     if len(got_conns) >= nconns:
                         ev.set()
 
-                    await asyncio.sleep(5)
+                    await asleep(5)
 
             except BaseException as ex:
                 logging.info("worker %s stopped: %r", i, ex)
                 raise
 
         # Start tasks taking up all the connections and getting in the queue
-        tasks = [asyncio.ensure_future(worker(i)) for i in range(nconns * 3)]
+        tasks = [spawn(worker, (i,)) for i in range(nconns * 3)]
 
         # wait until the pool has served all the connections and clients are queued.
-        await asyncio.wait_for(ev.wait(), 3.0)
+        await ev.wait_timeout(3.0)
         for i in range(10):
             if p.get_stats().get("requests_queued", 0):
                 break
             else:
-                await asyncio.sleep(0.1)
+                await asleep(0.1)
         else:
             pytest.fail("no client got in the queue")
 
         [task.cancel() for task in reversed(tasks)]
-        await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), 1.0)
+        await gather(*tasks, return_exceptions=True, timeout=1.0)
 
         stats = p.get_stats()
         assert stats["pool_available"] == 3
@@ -1329,7 +1358,7 @@ def delay_connection(monkeypatch, sec):
         t0 = time()
         rv = await connect_orig(*args, **kwargs)
         t1 = time()
-        await asyncio.sleep(max(0, sec - (t1 - t0)))
+        await asleep(max(0, sec - (t1 - t0)))
         return rv
 
     connect_orig = psycopg.AsyncConnection.connect
@@ -1337,5 +1366,8 @@ def delay_connection(monkeypatch, sec):
 
 
 async def ensure_waiting(p, num=1):
+    """
+    Wait until there are at least *num* clients waiting in the queue.
+    """
     while len(p._waiting) < num:
-        await asyncio.sleep(0)
+        await asleep(0)
index 5b9b73c1a40d08de8b43899f2180ca0682f0f42e..0603886724e8863f516acc313990407faac45842 100644 (file)
@@ -4,11 +4,14 @@ import sys
 import asyncio
 import inspect
 import operator
-from time import sleep as sleep  # noqa: F401 -- re-export
-from typing import Callable, Optional, Tuple
+from typing import Any, Callable, Optional, Tuple
 from threading import Thread
 from contextlib import contextmanager, asynccontextmanager
-from contextlib import closing as closing  # noqa: F401 -- re-export
+
+# Re-exports
+from time import sleep as sleep  # noqa: F401
+from threading import Event as Event  # noqa: F401
+from contextlib import closing as closing  # noqa: F401
 
 import pytest
 
@@ -232,27 +235,34 @@ def raiseif(cond, *args, **kwargs):
         return
 
 
-def spawn(f):
+def spawn(f, args=None):
     """
     Equivalent to asyncio.create_task or creating and running a Thread.
     """
+    if not args:
+        args = ()
+
     if inspect.iscoroutinefunction(f):
-        return asyncio.create_task(f())
+        return asyncio.create_task(f(*args))
     else:
-        t = Thread(target=f, daemon=True)
+        t = Thread(target=f, args=args, daemon=True)
         t.start()
         return t
 
 
-def gather(*ts):
+def gather(*ts, return_exceptions=False, timeout=None):
     """
     Equivalent to asyncio.gather or Thread.join()
     """
     if ts and inspect.isawaitable(ts[0]):
-        return asyncio.gather(*ts)
+        rv: Any = asyncio.gather(*ts, return_exceptions=return_exceptions)
+        if timeout is None:
+            rv = asyncio.wait_for(rv, timeout)
+        return rv
     else:
         for t in ts:
-            t.join()
+            t.join(timeout)
+            assert not t.is_alive()
 
 
 def asleep(s):
@@ -260,3 +270,21 @@ def asleep(s):
     Equivalent to asyncio.sleep(), converted to time.sleep() by async_to_sync.
     """
     return asyncio.sleep(s)
+
+
+def is_alive(t):
+    """
+    Return true if an asyncio.Task or threading.Thread is alive.
+    """
+    return t.is_alive() if isinstance(t, Thread) else not t.done()
+
+
+class AEvent(asyncio.Event):
+    """
+    Subclass of asyncio.Event adding a wait with timeout like threading.Event.
+
+    wait_timeout() is converted to wait() by async_to_sync.
+    """
+
+    async def wait_timeout(self, timeout):
+        await asyncio.wait_for(self.wait(), timeout)
index 40ed2ccc6b5b308f3ac0f1136d6ec7f9e79d1e14..4183d722219c7c9fad0adbed86e06d0a06287291 100755 (executable)
@@ -152,8 +152,10 @@ class AsyncToSync(ast.NodeTransformer):
 
 class RenameAsyncToSync(ast.NodeTransformer):
     names_map = {
+        "AEvent": "Event",
         "AsyncClientCursor": "ClientCursor",
         "AsyncConnection": "Connection",
+        "AsyncConnectionPool": "ConnectionPool",
         "AsyncCopy": "Copy",
         "AsyncCopyWriter": "CopyWriter",
         "AsyncCursor": "Cursor",
@@ -179,16 +181,18 @@ class RenameAsyncToSync(ast.NodeTransformer):
         "aconn_cls": "conn_cls",
         "alist": "list",
         "anext": "next",
-        "asleep": "sleep",
         "apipeline": "pipeline",
+        "asleep": "sleep",
         "asynccontextmanager": "contextmanager",
         "connection_async": "connection",
         "cursor_async": "cursor",
         "ensure_table_async": "ensure_table",
         "find_insert_problem_async": "find_insert_problem",
+        "psycopg_pool.pool_async": "psycopg_pool.pool",
         "psycopg_pool.sched_async": "psycopg_pool.sched",
         "wait_async": "wait",
         "wait_conn_async": "wait_conn",
+        "wait_timeout": "wait",
     }
     _skip_imports = {
         "utils": {"alist", "anext"},
index 983d9c8e40ba9687b13564a35f41315feb9e55f1..8784817fa84da65e0c07b3ceb3338569b3de73f4 100755 (executable)
@@ -21,6 +21,7 @@ for async in \
     psycopg/psycopg/connection_async.py \
     psycopg/psycopg/cursor_async.py \
     psycopg_pool/psycopg_pool/sched_async.py \
+    tests/pool/test_pool_async.py \
     tests/pool/test_sched_async.py \
     tests/test_client_cursor_async.py \
     tests/test_connection_async.py \