From: Daniele Varrazzo Date: Mon, 11 Sep 2023 21:57:44 +0000 (+0100) Subject: refactor(tests): generate test_pool from async counterpart X-Git-Tag: pool-3.2.0~12^2~35 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=1fed19d9ab90ad8da63144c17f69fe70bc753f94;p=thirdparty%2Fpsycopg.git refactor(tests): generate test_pool from async counterpart --- diff --git a/tests/pool/test_module.py b/tests/pool/test_module.py new file mode 100644 index 000000000..31f77aab3 --- /dev/null +++ b/tests/pool/test_module.py @@ -0,0 +1,8 @@ +def test_version(mypy): + cp = mypy.run_on_source( + """\ +from psycopg_pool import __version__ +assert __version__ +""" + ) + assert not cp.stdout diff --git a/tests/pool/test_pool.py b/tests/pool/test_pool.py index 7234a3c84..2032fab89 100644 --- a/tests/pool/test_pool.py +++ b/tests/pool/test_pool.py @@ -1,7 +1,9 @@ +# WARNING: this file is auto-generated by 'async_to_sync.py' +# from the original file 'test_pool_async.py' +# DO NOT CHANGE! Change the original file instead. import logging import weakref -from time import sleep, time -from threading import Thread, Event +from time import time from typing import Any, Dict, List, Tuple import pytest @@ -9,7 +11,9 @@ import pytest import psycopg from psycopg.pq import TransactionStatus from psycopg.rows import class_row, Row, TupleRow -from psycopg._compat import Counter, assert_type +from psycopg._compat import assert_type, Counter + +from ..utils import Event, spawn, gather, sleep, is_alive, is_async try: import psycopg_pool as pool @@ -18,16 +22,6 @@ except ImportError: pass -def test_package_version(mypy): - cp = mypy.run_on_source( - """\ -from psycopg_pool import __version__ -assert __version__ -""" - ) - assert not cp.stdout - - def test_defaults(dsn): with pool.ConnectionPool(dsn) as p: assert p.min_size == p.max_size == 4 @@ -71,7 +65,7 @@ class MyRow(Dict[str, Any]): def test_generic_connection_type(dsn): def set_autocommit(conn: psycopg.Connection[Any]) -> None: - conn.autocommit = True + conn.set_autocommit(True) class MyConnection(psycopg.Connection[Row]): pass @@ -81,8 +75,10 @@ def test_generic_connection_type(dsn): connection_class=MyConnection[MyRow], kwargs=dict(row_factory=class_row(MyRow)), configure=set_autocommit, - ) as p1, p1.connection() as conn1: - (row1,) = conn1.execute("select 1 as x").fetchall() + ) as p1: + with p1.connection() as conn1: + cur1 = conn1.execute("select 1 as x") + (row1,) = cur1.fetchall() assert_type(p1, pool.ConnectionPool[MyConnection[MyRow]]) assert_type(conn1, MyConnection[MyRow]) assert_type(row1, MyRow) @@ -91,7 +87,8 @@ def test_generic_connection_type(dsn): with pool.ConnectionPool(dsn, connection_class=MyConnection[TupleRow]) as p2: with p2.connection() as conn2: - (row2,) = conn2.execute("select 2 as y").fetchall() + cur2 = conn2.execute("select 2 as y") + (row2,) = cur2.fetchall() assert_type(p2, pool.ConnectionPool[MyConnection[TupleRow]]) assert_type(conn2, MyConnection[TupleRow]) assert_type(row2, TupleRow) @@ -100,7 +97,7 @@ def test_generic_connection_type(dsn): def test_non_generic_connection_type(dsn): def set_autocommit(conn: psycopg.Connection[Any]) -> None: - conn.autocommit = True + conn.set_autocommit(True) class MyConnection(psycopg.Connection[MyRow]): def __init__(self, *args: Any, **kwargs: Any): @@ -111,7 +108,8 @@ def test_non_generic_connection_type(dsn): dsn, connection_class=MyConnection, configure=set_autocommit ) as p1: with p1.connection() as conn1: - (row1,) = conn1.execute("select 1 as x").fetchall() + cur1 = conn1.execute("select 1 as x") + (row1,) = cur1.fetchall() assert_type(p1, pool.ConnectionPool[MyConnection]) assert_type(conn1, MyConnection) assert_type(row1, MyRow) @@ -222,7 +220,7 @@ def test_configure(dsn): conn.execute("set default_transaction_read_only to on") with pool.ConnectionPool(dsn, min_size=1, configure=configure) as p: - p.wait() + p.wait(timeout=1.0) with p.connection() as conn: assert inits == 1 res = conn.execute("show default_transaction_read_only") @@ -293,8 +291,8 @@ def test_reset(dsn): assert resets == 1 with p.connection() as conn: - with conn.execute("show timezone") as cur: - assert cur.fetchone() == ("UTC",) + cur = conn.execute("show timezone") + assert cur.fetchone() == ("UTC",) p.wait() assert resets == 2 @@ -358,18 +356,15 @@ def test_queue(dsn): results: List[Tuple[int, float, int]] = [] with pool.ConnectionPool(dsn, min_size=2) as p: p.wait() - ts = [Thread(target=worker, args=(i,)) for i in range(6)] - for t in ts: - t.start() - for t in ts: - t.join() + ts = [spawn(worker, args=(i,)) for i in range(6)] + gather(*ts) times = [item[1] for item in results] want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6] for got, want in zip(times, want_times): assert got == pytest.approx(want, 0.1), times - assert len(set(r[2] for r in results)) == 2, results + assert len(set((r[2] for r in results))) == 2, results @pytest.mark.slow @@ -391,15 +386,11 @@ def test_queue_size(dsn): with pool.ConnectionPool(dsn, min_size=1, max_waiting=3) as p: p.wait() ev = Event() - t = Thread(target=worker, args=(0.3, ev)) - t.start() + spawn(worker, args=(0.3, ev)) ev.wait() - ts = [Thread(target=worker, args=(0.1,)) for i in range(4)] - for t in ts: - t.start() - for t in ts: - t.join() + ts = [spawn(worker, args=(0.1,)) for i in range(4)] + gather(*ts) assert len(success) == 4 assert len(errors) == 1 @@ -430,11 +421,8 @@ def test_queue_timeout(dsn): errors: List[Tuple[int, float, Exception]] = [] with pool.ConnectionPool(dsn, min_size=2, timeout=0.1) as p: - ts = [Thread(target=worker, args=(i,)) for i in range(4)] - for t in ts: - t.start() - for t in ts: - t.join() + ts = [spawn(worker, args=(i,)) for i in range(4)] + gather(*ts) assert len(results) == 2 assert len(errors) == 2 @@ -454,17 +442,14 @@ def test_dead_client(dsn): if timeout > 0.2: raise - results: List[int] = [] - with pool.ConnectionPool(dsn, min_size=2) as p: + results: List[int] = [] ts = [ - Thread(target=worker, args=(i, timeout)) - for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4]) + spawn(worker, args=(i, timeout)) + for (i, timeout) in enumerate([0.4, 0.4, 0.1, 0.4, 0.4]) ] - for t in ts: - t.start() - for t in ts: - t.join() + gather(*ts) + sleep(0.2) assert set(results) == set([0, 1, 3, 4]) assert len(p._pool) == 2 # no connection was lost @@ -492,11 +477,8 @@ def test_queue_timeout_override(dsn): errors: List[Tuple[int, float, Exception]] = [] with pool.ConnectionPool(dsn, min_size=2, timeout=0.1) as p: - ts = [Thread(target=worker, args=(i,)) for i in range(4)] - for t in ts: - t.start() - for t in ts: - t.join() + ts = [spawn(worker, args=(i,)) for i in range(4)] + gather(*ts) assert len(results) == 3 assert len(errors) == 1 @@ -531,9 +513,10 @@ def test_intrans_rollback(dsn, caplog): with p.connection() as conn2: assert conn2.info.backend_pid == pid assert conn2.info.transaction_status == TransactionStatus.IDLE - assert not conn2.execute( + cur = conn2.execute( "select 1 from pg_class where relname = 'test_intrans_rollback'" - ).fetchone() + ) + assert not cur.fetchone() assert len(caplog.records) == 1 assert "INTRANS" in caplog.records[0].message @@ -611,19 +594,19 @@ def test_fail_rollback_close(dsn, caplog, monkeypatch): assert "BAD" in caplog.records[2].message -def test_close_no_threads(dsn): +def test_close_no_tasks(dsn): p = pool.ConnectionPool(dsn) - assert p._sched_runner and p._sched_runner.is_alive() + assert p._sched_runner and is_alive(p._sched_runner) workers = p._workers[:] assert workers for t in workers: - assert t.is_alive() + assert is_alive(t) p.close() assert p._sched_runner is None assert not p._workers for t in workers: - assert not t.is_alive() + assert not is_alive(t) def test_putconn_no_pool(conn_cls, dsn): @@ -656,14 +639,15 @@ def test_del_no_warning(dsn, recwarn): @pytest.mark.slow -def test_del_stop_threads(dsn): +@pytest.mark.skipif(is_async(__name__), reason="sync test only") +def test_del_stops_threads(dsn): p = pool.ConnectionPool(dsn) assert p._sched_runner is not None ts = [p._sched_runner] + p._workers del p sleep(0.1) for t in ts: - assert not t.is_alive() + assert not is_alive(t), t def test_closed_getconn(dsn): @@ -715,22 +699,18 @@ def test_closed_queue(dsn): p.wait() success: List[str] = [] - t1 = Thread(target=w1) - t1.start() + t1 = spawn(w1) # Wait until w1 has received a connection e1.wait() - t2 = Thread(target=w2) - t2.start() + t2 = spawn(w2) # Wait until w2 is in the queue ensure_waiting(p) - - p.close(0) + p.close() # Wait for the workers to finish e2.set() - t1.join() - t2.join() + gather(t1, t2) assert len(success) == 2 @@ -740,7 +720,7 @@ def test_open_explicit(dsn): with pytest.raises(pool.PoolClosed, match="is not open yet"): p.getconn() - with pytest.raises(pool.PoolClosed): + with pytest.raises(pool.PoolClosed, match="is not open yet"): with p.connection(): pass @@ -751,7 +731,6 @@ def test_open_explicit(dsn): with p.connection() as conn: cur = conn.execute("select 1") assert cur.fetchone() == (1,) - finally: p.close() @@ -783,7 +762,6 @@ def test_open_no_op(dsn): with p.connection() as conn: cur = conn.execute("select 1") assert cur.fetchone() == (1,) - finally: p.close() @@ -835,8 +813,8 @@ def test_reopen(dsn): @pytest.mark.parametrize( "min_size, want_times", [ - (2, [0.25, 0.25, 0.35, 0.45, 0.50, 0.50, 0.60, 0.70]), - (0, [0.35, 0.45, 0.55, 0.60, 0.65, 0.70, 0.80, 0.85]), + (2, [0.25, 0.25, 0.35, 0.45, 0.5, 0.5, 0.6, 0.7]), + (0, [0.35, 0.45, 0.55, 0.6, 0.65, 0.7, 0.8, 0.85]), ], ) def test_grow(dsn, monkeypatch, min_size, want_times): @@ -852,12 +830,8 @@ def test_grow(dsn, monkeypatch, min_size, want_times): with pool.ConnectionPool(dsn, min_size=min_size, max_size=4, num_workers=3) as p: p.wait(1.0) results: List[Tuple[int, float]] = [] - - ts = [Thread(target=worker, args=(i,)) for i in range(len(want_times))] - for t in ts: - t.start() - for t in ts: - t.join() + ts = [spawn(worker, args=(i,)) for i in range(len(want_times))] + gather(*ts) times = [item[1] for item in results] for got, want in zip(times, want_times): @@ -888,11 +862,9 @@ def test_shrink(dsn, monkeypatch): p.wait(5.0) assert p.max_idle == 0.2 - ts = [Thread(target=worker, args=(i,)) for i in range(4)] - for t in ts: - t.start() - for t in ts: - t.join() + ts = [spawn(worker, args=(i,)) for i in range(4)] + gather(*ts) + sleep(1) assert results == [(4, 4), (4, 3), (3, 2), (2, 2), (2, 2)] @@ -938,15 +910,28 @@ def test_reconnect(proxy, caplog, monkeypatch): @pytest.mark.slow @pytest.mark.timing -def test_reconnect_failure(proxy): +@pytest.mark.parametrize("async_cb", [True, False]) +def test_reconnect_failure(proxy, async_cb): + if async_cb and (not is_async(__name__)): + pytest.skip("async test only") + proxy.start() t1 = None - def failed(pool): - assert pool.name == "this-one" - nonlocal t1 - t1 = time() + if async_cb: + + def failed(pool): + assert pool.name == "this-one" + nonlocal t1 + t1 = time() + + else: + + def failed(pool): + assert pool.name == "this-one" + nonlocal t1 + t1 = time() with pool.ConnectionPool( proxy.client_dsn, @@ -990,14 +975,14 @@ def test_reconnect_after_grow_failed(proxy): with pool.ConnectionPool( proxy.client_dsn, min_size=4, reconnect_timeout=1.0, reconnect_failed=failed ) as p: - assert ev.wait(timeout=2) + ev.wait(2.0) with pytest.raises(pool.PoolTimeout): with p.connection(timeout=0.5) as conn: pass ev.clear() - assert ev.wait(timeout=2) + ev.wait(2.0) proxy.start() @@ -1027,7 +1012,7 @@ def test_refill_on_check(proxy): # Checking the pool will empty it p.check() - assert ev.wait(timeout=2) + ev.wait(2.0) assert len(p._pool) == 0 # Allow to connect again @@ -1070,12 +1055,11 @@ def test_resize(dsn): size: List[int] = [] with pool.ConnectionPool(dsn, min_size=2, max_idle=0.2) as p: - s = Thread(target=sampler) - s.start() + s = spawn(sampler) sleep(0.3) - c = Thread(target=client, args=(0.4,)) - c.start() + + c = spawn(client, args=(0.4,)) sleep(0.2) p.resize(4) @@ -1089,7 +1073,7 @@ def test_resize(dsn): sleep(0.6) - s.join() + gather(s, c) assert size == [2, 1, 3, 4, 3, 2, 2] @@ -1130,7 +1114,7 @@ def test_check(dsn, caplog): pid = conn.info.backend_pid p.wait(1.0) - pids = set(conn.info.backend_pid for conn in p._pool) + pids = set((conn.info.backend_pid for conn in p._pool)) assert pid in pids conn.close() @@ -1138,7 +1122,7 @@ def test_check(dsn, caplog): p.check() assert len(caplog.records) == 1 p.wait(1.0) - pids2 = set(conn.info.backend_pid for conn in p._pool) + pids2 = set((conn.info.backend_pid for conn in p._pool)) assert len(pids & pids2) == 3 assert pid not in pids2 @@ -1181,13 +1165,10 @@ def test_stats_measures(dsn): assert stats["pool_available"] == 2 assert stats["requests_waiting"] == 0 - ts = [Thread(target=worker, args=(i,)) for i in range(3)] - for t in ts: - t.start() + ts = [spawn(worker, args=(i,)) for i in range(3)] sleep(0.1) stats = p.get_stats() - for t in ts: - t.join() + gather(*ts) assert stats["pool_min"] == 2 assert stats["pool_max"] == 4 assert stats["pool_size"] == 3 @@ -1195,13 +1176,10 @@ def test_stats_measures(dsn): assert stats["requests_waiting"] == 0 p.wait(2.0) - ts = [Thread(target=worker, args=(i,)) for i in range(7)] - for t in ts: - t.start() + ts = [spawn(worker, args=(i,)) for i in range(7)] sleep(0.1) stats = p.get_stats() - for t in ts: - t.join() + gather(*ts) assert stats["pool_min"] == 2 assert stats["pool_max"] == 4 assert stats["pool_size"] == 4 @@ -1222,11 +1200,8 @@ def test_stats_usage(dsn): with pool.ConnectionPool(dsn, min_size=3) as p: p.wait(2.0) - ts = [Thread(target=worker, args=(i,)) for i in range(7)] - for t in ts: - t.start() - for t in ts: - t.join() + ts = [spawn(worker, args=(i,)) for i in range(7)] + gather(*ts) stats = p.get_stats() assert stats["requests_num"] == 7 assert stats["requests_queued"] == 4 @@ -1256,7 +1231,7 @@ def test_stats_connect(dsn, proxy, monkeypatch): assert stats["connections_num"] == 3 assert stats.get("connections_errors", 0) == 0 assert stats.get("connections_lost", 0) == 0 - assert 600 <= stats["connections_ms"] < 1200 + assert 580 <= stats["connections_ms"] < 1200 proxy.stop() p.check() @@ -1280,11 +1255,8 @@ def test_spike(dsn, monkeypatch): with pool.ConnectionPool(dsn, min_size=5, max_size=10) as p: p.wait() - ts = [Thread(target=worker) for i in range(50)] - for t in ts: - t.start() - for t in ts: - t.join() + ts = [spawn(worker) for i in range(50)] + gather(*ts) p.wait() assert len(p._pool) < 7 @@ -1300,15 +1272,68 @@ def test_debug_deadlock(dsn): logger.addHandler(handler) try: with pool.ConnectionPool(dsn, min_size=4, open=True) as p: - try: - p.wait(timeout=2) - finally: - print(p.get_stats()) + p.wait(timeout=2) finally: logger.removeHandler(handler) logger.setLevel(old_level) +@pytest.mark.skipif(not is_async(__name__), reason="async test only") +def test_cancellation_in_queue(dsn): + # https://github.com/psycopg/psycopg/issues/509 + + nconns = 3 + + with pool.ConnectionPool(dsn, min_size=nconns, timeout=1) as p: + p.wait() + + got_conns = [] + ev = Event() + + def worker(i): + try: + logging.info("worker %s started", i) + nonlocal got_conns + + with p.connection() as conn: + logging.info("worker %s got conn", i) + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + + got_conns.append(conn) + if len(got_conns) >= nconns: + ev.set() + + sleep(5) + except BaseException as ex: + logging.info("worker %s stopped: %r", i, ex) + raise + + # Start tasks taking up all the connections and getting in the queue + tasks = [spawn(worker, (i,)) for i in range(nconns * 3)] + + # wait until the pool has served all the connections and clients are queued. + ev.wait(3.0) + for i in range(10): + if p.get_stats().get("requests_queued", 0): + break + else: + sleep(0.1) + else: + pytest.fail("no client got in the queue") + + [task.cancel() for task in reversed(tasks)] + gather(*tasks, return_exceptions=True, timeout=1.0) + + stats = p.get_stats() + assert stats["pool_available"] == 3 + assert stats.get("requests_waiting", 0) == 0 + + with p.connection() as conn: + cur = conn.execute("select 1") + assert cur.fetchone() == (1,) + + def delay_connection(monkeypatch, sec): """ Return a _connect_gen function delayed by the amount of seconds diff --git a/tests/pool/test_pool_async.py b/tests/pool/test_pool_async.py index 0726f35c7..92a8f7bf5 100644 --- a/tests/pool/test_pool_async.py +++ b/tests/pool/test_pool_async.py @@ -1,8 +1,7 @@ -import asyncio import logging +import weakref from time import time from typing import Any, Dict, List, Tuple -from asyncio import create_task import pytest @@ -11,13 +10,16 @@ from psycopg.pq import TransactionStatus from psycopg.rows import class_row, Row, TupleRow from psycopg._compat import assert_type, Counter +from ..utils import AEvent, spawn, gather, asleep, is_alive, is_async + try: import psycopg_pool as pool except ImportError: # Tests should have been skipped if the package is not available pass -pytestmark = [pytest.mark.anyio] +if True: # ASYNC + pytestmark = [pytest.mark.anyio] async def test_defaults(dsn): @@ -210,7 +212,7 @@ async def test_setup_no_timeout(dsn, proxy): async with pool.AsyncConnectionPool( proxy.client_dsn, min_size=1, num_workers=1 ) as p: - await asyncio.sleep(0.5) + await asleep(0.5) assert not p._pool proxy.start() @@ -364,8 +366,8 @@ async def test_queue(dsn): results: List[Tuple[int, float, int]] = [] async with pool.AsyncConnectionPool(dsn, min_size=2) as p: await p.wait() - ts = [create_task(worker(i)) for i in range(6)] - await asyncio.gather(*ts) + ts = [spawn(worker, args=(i,)) for i in range(6)] + await gather(*ts) times = [item[1] for item in results] want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6] @@ -382,7 +384,7 @@ async def test_queue_size(dsn): async with p.connection(): if ev: ev.set() - await asyncio.sleep(t) + await asleep(t) except pool.TooManyRequests as e: errors.append(e) else: @@ -393,12 +395,12 @@ async def test_queue_size(dsn): async with pool.AsyncConnectionPool(dsn, min_size=1, max_waiting=3) as p: await p.wait() - ev = asyncio.Event() - create_task(worker(0.3, ev)) + ev = AEvent() + spawn(worker, args=(0.3, ev)) await ev.wait() - ts = [create_task(worker(0.1)) for i in range(4)] - await asyncio.gather(*ts) + ts = [spawn(worker, args=(0.1,)) for i in range(4)] + await gather(*ts) assert len(success) == 4 assert len(errors) == 1 @@ -429,8 +431,8 @@ async def test_queue_timeout(dsn): errors: List[Tuple[int, float, Exception]] = [] async with pool.AsyncConnectionPool(dsn, min_size=2, timeout=0.1) as p: - ts = [create_task(worker(i)) for i in range(4)] - await asyncio.gather(*ts) + ts = [spawn(worker, args=(i,)) for i in range(4)] + await gather(*ts) assert len(results) == 2 assert len(errors) == 2 @@ -453,12 +455,12 @@ async def test_dead_client(dsn): async with pool.AsyncConnectionPool(dsn, min_size=2) as p: results: List[int] = [] ts = [ - create_task(worker(i, timeout)) + spawn(worker, args=(i, timeout)) for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4]) ] - await asyncio.gather(*ts) + await gather(*ts) - await asyncio.sleep(0.2) + await asleep(0.2) assert set(results) == set([0, 1, 3, 4]) assert len(p._pool) == 2 # no connection was lost @@ -485,8 +487,8 @@ async def test_queue_timeout_override(dsn): errors: List[Tuple[int, float, Exception]] = [] async with pool.AsyncConnectionPool(dsn, min_size=2, timeout=0.1) as p: - ts = [create_task(worker(i)) for i in range(4)] - await asyncio.gather(*ts) + ts = [spawn(worker, args=(i,)) for i in range(4)] + await gather(*ts) assert len(results) == 3 assert len(errors) == 1 @@ -604,17 +606,17 @@ async def test_fail_rollback_close(dsn, caplog, monkeypatch): async def test_close_no_tasks(dsn): p = pool.AsyncConnectionPool(dsn) - assert p._sched_runner and not p._sched_runner.done() - assert p._workers + assert p._sched_runner and is_alive(p._sched_runner) workers = p._workers[:] + assert workers for t in workers: - assert not t.done() + assert is_alive(t) await p.close() assert p._sched_runner is None assert not p._workers for t in workers: - assert t.done() + assert not is_alive(t) async def test_putconn_no_pool(aconn_cls, dsn): @@ -634,6 +636,30 @@ async def test_putconn_wrong_pool(dsn): await p2.putconn(conn) +async def test_del_no_warning(dsn, recwarn): + p = pool.AsyncConnectionPool(dsn, min_size=2) + async with p.connection() as conn: + await conn.execute("select 1") + + await p.wait() + ref = weakref.ref(p) + del p + assert not ref() + assert not recwarn, [str(w.message) for w in recwarn.list] + + +@pytest.mark.slow +@pytest.mark.skipif(is_async(__name__), reason="sync test only") +async def test_del_stops_threads(dsn): + p = pool.AsyncConnectionPool(dsn) + assert p._sched_runner is not None + ts = [p._sched_runner] + p._workers + del p + await asleep(0.1) + for t in ts: + assert not is_alive(t), t + + async def test_closed_getconn(dsn): p = pool.AsyncConnectionPool(dsn, min_size=1) assert not p.closed @@ -676,32 +702,32 @@ async def test_closed_queue(dsn): except pool.PoolClosed: success.append("w2") - e1 = asyncio.Event() - e2 = asyncio.Event() + e1 = AEvent() + e2 = AEvent() p = pool.AsyncConnectionPool(dsn, min_size=1) await p.wait() success: List[str] = [] - t1 = create_task(w1()) + t1 = spawn(w1) # Wait until w1 has received a connection await e1.wait() - t2 = create_task(w2()) + t2 = spawn(w2) # Wait until w2 is in the queue await ensure_waiting(p) await p.close() # Wait for the workers to finish e2.set() - await asyncio.gather(t1, t2) + await gather(t1, t2) assert len(success) == 2 async def test_open_explicit(dsn): p = pool.AsyncConnectionPool(dsn, open=False) assert p.closed - with pytest.raises(pool.PoolClosed): + with pytest.raises(pool.PoolClosed, match="is not open yet"): await p.getconn() with pytest.raises(pool.PoolClosed, match="is not open yet"): @@ -788,6 +814,7 @@ async def test_reopen(dsn): await conn.execute("select 1") await p.close() assert p._sched_runner is None + assert not p._workers with pytest.raises(psycopg.OperationalError, match="cannot be reused"): await p.open() @@ -816,11 +843,9 @@ async def test_grow(dsn, monkeypatch, min_size, want_times): dsn, min_size=min_size, max_size=4, num_workers=3 ) as p: await p.wait(1.0) - ts = [] results: List[Tuple[int, float]] = [] - - ts = [create_task(worker(i)) for i in range(len(want_times))] - await asyncio.gather(*ts) + ts = [spawn(worker, args=(i,)) for i in range(len(want_times))] + await gather(*ts) times = [item[1] for item in results] for got, want in zip(times, want_times): @@ -851,10 +876,10 @@ async def test_shrink(dsn, monkeypatch): await p.wait(5.0) assert p.max_idle == 0.2 - ts = [create_task(worker(i)) for i in range(4)] - await asyncio.gather(*ts) + ts = [spawn(worker, args=(i,)) for i in range(4)] + await gather(*ts) - await asyncio.sleep(1) + await asleep(1) assert results == [(4, 4), (4, 3), (3, 2), (2, 2), (2, 2)] @@ -879,7 +904,7 @@ async def test_reconnect(proxy, caplog, monkeypatch): async with p.connection() as conn: await conn.execute("select 1") - await asyncio.sleep(1.0) + await asleep(1.0) proxy.start() await p.wait() @@ -901,6 +926,9 @@ async def test_reconnect(proxy, caplog, monkeypatch): @pytest.mark.timing @pytest.mark.parametrize("async_cb", [True, False]) async def test_reconnect_failure(proxy, async_cb): + if async_cb and not is_async(__name__): + pytest.skip("async test only") + proxy.start() t1 = None @@ -934,7 +962,7 @@ async def test_reconnect_failure(proxy, async_cb): await conn.execute("select 1") t0 = time() - await asyncio.sleep(1.5) + await asleep(1.5) assert t1 assert t1 - t0 == pytest.approx(1.0, 0.1) assert p._nconns == 0 @@ -953,7 +981,7 @@ async def test_reconnect_after_grow_failed(proxy): # in grow mode. See issue #370. proxy.stop() - ev = asyncio.Event() + ev = AEvent() def failed(pool): ev.set() @@ -961,14 +989,14 @@ async def test_reconnect_after_grow_failed(proxy): async with pool.AsyncConnectionPool( proxy.client_dsn, min_size=4, reconnect_timeout=1.0, reconnect_failed=failed ) as p: - await asyncio.wait_for(ev.wait(), 2.0) + await ev.wait_timeout(2.0) with pytest.raises(pool.PoolTimeout): async with p.connection(timeout=0.5) as conn: pass ev.clear() - await asyncio.wait_for(ev.wait(), 2.0) + await ev.wait_timeout(2.0) proxy.start() @@ -982,7 +1010,7 @@ async def test_reconnect_after_grow_failed(proxy): @pytest.mark.slow async def test_refill_on_check(proxy): proxy.start() - ev = asyncio.Event() + ev = AEvent() def failed(pool): ev.set() @@ -998,7 +1026,7 @@ async def test_refill_on_check(proxy): # Checking the pool will empty it await p.check() - await asyncio.wait_for(ev.wait(), 2.0) + await ev.wait_timeout(2.0) assert len(p._pool) == 0 # Allow to connect again @@ -1016,7 +1044,7 @@ async def test_uniform_use(dsn): counts = Counter[int]() for i in range(8): async with p.connection() as conn: - await asyncio.sleep(0.1) + await asleep(0.1) counts[id(conn)] += 1 assert len(counts) == 4 @@ -1027,9 +1055,9 @@ async def test_uniform_use(dsn): @pytest.mark.timing async def test_resize(dsn): async def sampler(): - await asyncio.sleep(0.05) # ensure sampling happens after shrink check + await asleep(0.05) # ensure sampling happens after shrink check while True: - await asyncio.sleep(0.2) + await asleep(0.2) if p.closed: break size.append(len(p._pool)) @@ -1041,25 +1069,25 @@ async def test_resize(dsn): size: List[int] = [] async with pool.AsyncConnectionPool(dsn, min_size=2, max_idle=0.2) as p: - s = create_task(sampler()) + s = spawn(sampler) - await asyncio.sleep(0.3) + await asleep(0.3) - c = create_task(client(0.4)) + c = spawn(client, args=(0.4,)) - await asyncio.sleep(0.2) + await asleep(0.2) await p.resize(4) assert p.min_size == 4 assert p.max_size == 4 - await asyncio.sleep(0.4) + await asleep(0.4) await p.resize(2) assert p.min_size == 2 assert p.max_size == 2 - await asyncio.sleep(0.6) + await asleep(0.6) - await asyncio.gather(s, c) + await gather(s, c) assert size == [2, 1, 3, 4, 3, 2, 2] @@ -1081,12 +1109,12 @@ async def test_jitter(): @pytest.mark.crdb_skip("backend pid") async def test_max_lifetime(dsn): async with pool.AsyncConnectionPool(dsn, min_size=1, max_lifetime=0.2) as p: - await asyncio.sleep(0.1) + await asleep(0.1) pids = [] for i in range(5): async with p.connection() as conn: pids.append(conn.info.backend_pid) - await asyncio.sleep(0.2) + await asleep(0.2) assert pids[0] == pids[1] != pids[4], pids @@ -1128,7 +1156,7 @@ async def test_check_max_lifetime(dsn): pid = conn.info.backend_pid async with p.connection() as conn: assert conn.info.backend_pid == pid - await asyncio.sleep(0.3) + await asleep(0.3) await p.check() async with p.connection() as conn: assert conn.info.backend_pid != pid @@ -1151,10 +1179,10 @@ async def test_stats_measures(dsn): assert stats["pool_available"] == 2 assert stats["requests_waiting"] == 0 - ts = [create_task(worker(i)) for i in range(3)] - await asyncio.sleep(0.1) + ts = [spawn(worker, args=(i,)) for i in range(3)] + await asleep(0.1) stats = p.get_stats() - await asyncio.gather(*ts) + await gather(*ts) assert stats["pool_min"] == 2 assert stats["pool_max"] == 4 assert stats["pool_size"] == 3 @@ -1162,10 +1190,10 @@ async def test_stats_measures(dsn): assert stats["requests_waiting"] == 0 await p.wait(2.0) - ts = [create_task(worker(i)) for i in range(7)] - await asyncio.sleep(0.1) + ts = [spawn(worker, args=(i,)) for i in range(7)] + await asleep(0.1) stats = p.get_stats() - await asyncio.gather(*ts) + await gather(*ts) assert stats["pool_min"] == 2 assert stats["pool_max"] == 4 assert stats["pool_size"] == 4 @@ -1186,8 +1214,8 @@ async def test_stats_usage(dsn): async with pool.AsyncConnectionPool(dsn, min_size=3) as p: await p.wait(2.0) - ts = [create_task(worker(i)) for i in range(7)] - await asyncio.gather(*ts) + ts = [spawn(worker, args=(i,)) for i in range(7)] + await gather(*ts) stats = p.get_stats() assert stats["requests_num"] == 7 assert stats["requests_queued"] == 4 @@ -1221,7 +1249,7 @@ async def test_stats_connect(dsn, proxy, monkeypatch): proxy.stop() await p.check() - await asyncio.sleep(0.1) + await asleep(0.1) stats = p.get_stats() assert stats["connections_num"] > 3 assert stats["connections_errors"] > 0 @@ -1236,13 +1264,13 @@ async def test_spike(dsn, monkeypatch): async def worker(): async with p.connection(): - await asyncio.sleep(0.002) + await asleep(0.002) async with pool.AsyncConnectionPool(dsn, min_size=5, max_size=10) as p: await p.wait() - ts = [create_task(worker()) for i in range(50)] - await asyncio.gather(*ts) + ts = [spawn(worker) for i in range(50)] + await gather(*ts) await p.wait() assert len(p._pool) < 7 @@ -1264,6 +1292,7 @@ async def test_debug_deadlock(dsn): logger.setLevel(old_level) +@pytest.mark.skipif(not is_async(__name__), reason="async test only") async def test_cancellation_in_queue(dsn): # https://github.com/psycopg/psycopg/issues/509 @@ -1273,7 +1302,7 @@ async def test_cancellation_in_queue(dsn): await p.wait() got_conns = [] - ev = asyncio.Event() + ev = AEvent() async def worker(i): try: @@ -1289,27 +1318,27 @@ async def test_cancellation_in_queue(dsn): if len(got_conns) >= nconns: ev.set() - await asyncio.sleep(5) + await asleep(5) except BaseException as ex: logging.info("worker %s stopped: %r", i, ex) raise # Start tasks taking up all the connections and getting in the queue - tasks = [asyncio.ensure_future(worker(i)) for i in range(nconns * 3)] + tasks = [spawn(worker, (i,)) for i in range(nconns * 3)] # wait until the pool has served all the connections and clients are queued. - await asyncio.wait_for(ev.wait(), 3.0) + await ev.wait_timeout(3.0) for i in range(10): if p.get_stats().get("requests_queued", 0): break else: - await asyncio.sleep(0.1) + await asleep(0.1) else: pytest.fail("no client got in the queue") [task.cancel() for task in reversed(tasks)] - await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), 1.0) + await gather(*tasks, return_exceptions=True, timeout=1.0) stats = p.get_stats() assert stats["pool_available"] == 3 @@ -1329,7 +1358,7 @@ def delay_connection(monkeypatch, sec): t0 = time() rv = await connect_orig(*args, **kwargs) t1 = time() - await asyncio.sleep(max(0, sec - (t1 - t0))) + await asleep(max(0, sec - (t1 - t0))) return rv connect_orig = psycopg.AsyncConnection.connect @@ -1337,5 +1366,8 @@ def delay_connection(monkeypatch, sec): async def ensure_waiting(p, num=1): + """ + Wait until there are at least *num* clients waiting in the queue. + """ while len(p._waiting) < num: - await asyncio.sleep(0) + await asleep(0) diff --git a/tests/utils.py b/tests/utils.py index 5b9b73c1a..060388672 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,11 +4,14 @@ import sys import asyncio import inspect import operator -from time import sleep as sleep # noqa: F401 -- re-export -from typing import Callable, Optional, Tuple +from typing import Any, Callable, Optional, Tuple from threading import Thread from contextlib import contextmanager, asynccontextmanager -from contextlib import closing as closing # noqa: F401 -- re-export + +# Re-exports +from time import sleep as sleep # noqa: F401 +from threading import Event as Event # noqa: F401 +from contextlib import closing as closing # noqa: F401 import pytest @@ -232,27 +235,34 @@ def raiseif(cond, *args, **kwargs): return -def spawn(f): +def spawn(f, args=None): """ Equivalent to asyncio.create_task or creating and running a Thread. """ + if not args: + args = () + if inspect.iscoroutinefunction(f): - return asyncio.create_task(f()) + return asyncio.create_task(f(*args)) else: - t = Thread(target=f, daemon=True) + t = Thread(target=f, args=args, daemon=True) t.start() return t -def gather(*ts): +def gather(*ts, return_exceptions=False, timeout=None): """ Equivalent to asyncio.gather or Thread.join() """ if ts and inspect.isawaitable(ts[0]): - return asyncio.gather(*ts) + rv: Any = asyncio.gather(*ts, return_exceptions=return_exceptions) + if timeout is None: + rv = asyncio.wait_for(rv, timeout) + return rv else: for t in ts: - t.join() + t.join(timeout) + assert not t.is_alive() def asleep(s): @@ -260,3 +270,21 @@ def asleep(s): Equivalent to asyncio.sleep(), converted to time.sleep() by async_to_sync. """ return asyncio.sleep(s) + + +def is_alive(t): + """ + Return true if an asyncio.Task or threading.Thread is alive. + """ + return t.is_alive() if isinstance(t, Thread) else not t.done() + + +class AEvent(asyncio.Event): + """ + Subclass of asyncio.Event adding a wait with timeout like threading.Event. + + wait_timeout() is converted to wait() by async_to_sync. + """ + + async def wait_timeout(self, timeout): + await asyncio.wait_for(self.wait(), timeout) diff --git a/tools/async_to_sync.py b/tools/async_to_sync.py index 40ed2ccc6..4183d7222 100755 --- a/tools/async_to_sync.py +++ b/tools/async_to_sync.py @@ -152,8 +152,10 @@ class AsyncToSync(ast.NodeTransformer): class RenameAsyncToSync(ast.NodeTransformer): names_map = { + "AEvent": "Event", "AsyncClientCursor": "ClientCursor", "AsyncConnection": "Connection", + "AsyncConnectionPool": "ConnectionPool", "AsyncCopy": "Copy", "AsyncCopyWriter": "CopyWriter", "AsyncCursor": "Cursor", @@ -179,16 +181,18 @@ class RenameAsyncToSync(ast.NodeTransformer): "aconn_cls": "conn_cls", "alist": "list", "anext": "next", - "asleep": "sleep", "apipeline": "pipeline", + "asleep": "sleep", "asynccontextmanager": "contextmanager", "connection_async": "connection", "cursor_async": "cursor", "ensure_table_async": "ensure_table", "find_insert_problem_async": "find_insert_problem", + "psycopg_pool.pool_async": "psycopg_pool.pool", "psycopg_pool.sched_async": "psycopg_pool.sched", "wait_async": "wait", "wait_conn_async": "wait_conn", + "wait_timeout": "wait", } _skip_imports = { "utils": {"alist", "anext"}, diff --git a/tools/convert_async_to_sync.sh b/tools/convert_async_to_sync.sh index 983d9c8e4..8784817fa 100755 --- a/tools/convert_async_to_sync.sh +++ b/tools/convert_async_to_sync.sh @@ -21,6 +21,7 @@ for async in \ psycopg/psycopg/connection_async.py \ psycopg/psycopg/cursor_async.py \ psycopg_pool/psycopg_pool/sched_async.py \ + tests/pool/test_pool_async.py \ tests/pool/test_sched_async.py \ tests/test_client_cursor_async.py \ tests/test_connection_async.py \