]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(tests): generate null pool tests from async counterpart
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 12 Sep 2023 05:08:21 +0000 (07:08 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 11 Oct 2023 21:45:38 +0000 (23:45 +0200)
tests/pool/test_null_pool.py [deleted file]
tests/pool/test_pool.py
tests/pool/test_pool_async.py
tests/pool/test_pool_null.py [new file with mode: 0644]
tests/pool/test_pool_null_async.py [moved from tests/pool/test_null_pool_async.py with 91% similarity]
tools/async_to_sync.py
tools/convert_async_to_sync.sh

diff --git a/tests/pool/test_null_pool.py b/tests/pool/test_null_pool.py
deleted file mode 100644 (file)
index a1b2715..0000000
+++ /dev/null
@@ -1,952 +0,0 @@
-import logging
-from time import sleep, time
-from threading import Thread, Event
-from typing import Any, Dict, List, Tuple
-
-import pytest
-from packaging.version import parse as ver  # noqa: F401  # used in skipif
-
-import psycopg
-from psycopg.pq import TransactionStatus
-from psycopg.rows import class_row, Row, TupleRow
-from psycopg._compat import assert_type
-
-from .test_pool import delay_connection, ensure_waiting
-
-try:
-    from psycopg_pool import NullConnectionPool
-    from psycopg_pool import PoolClosed, PoolTimeout, TooManyRequests
-except ImportError:
-    pass
-
-
-def test_defaults(dsn):
-    with NullConnectionPool(dsn) as p:
-        assert p.min_size == p.max_size == 0
-        assert p.timeout == 30
-        assert p.max_idle == 10 * 60
-        assert p.max_lifetime == 60 * 60
-        assert p.num_workers == 3
-
-
-def test_min_size_max_size(dsn):
-    with NullConnectionPool(dsn, min_size=0, max_size=2) as p:
-        assert p.min_size == 0
-        assert p.max_size == 2
-
-
-@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)])
-def test_bad_size(dsn, min_size, max_size):
-    with pytest.raises(ValueError):
-        NullConnectionPool(min_size=min_size, max_size=max_size)
-
-
-def test_connection_class(dsn):
-    class MyConn(psycopg.Connection[Any]):
-        pass
-
-    with NullConnectionPool(dsn, connection_class=MyConn) as p:
-        with p.connection() as conn:
-            assert isinstance(conn, MyConn)
-
-
-def test_kwargs(dsn):
-    with NullConnectionPool(dsn, kwargs={"autocommit": True}) as p:
-        with p.connection() as conn:
-            assert conn.autocommit
-
-
-class MyRow(Dict[str, Any]):
-    ...
-
-
-def test_generic_connection_type(dsn):
-    def set_autocommit(conn: psycopg.Connection[Any]) -> None:
-        conn.autocommit = True
-
-    class MyConnection(psycopg.Connection[Row]):
-        pass
-
-    with NullConnectionPool(
-        dsn,
-        connection_class=MyConnection[MyRow],
-        kwargs={"row_factory": class_row(MyRow)},
-        configure=set_autocommit,
-    ) as p1:
-        with p1.connection() as conn1:
-            cur1 = conn1.execute("select 1 as x")
-            (row1,) = cur1.fetchall()
-
-    assert_type(p1, NullConnectionPool[MyConnection[MyRow]])
-    assert_type(conn1, MyConnection[MyRow])
-    assert_type(row1, MyRow)
-    assert conn1.autocommit
-    assert row1 == {"x": 1}
-
-    with NullConnectionPool(dsn, connection_class=MyConnection[TupleRow]) as p2:
-        with p2.connection() as conn2:
-            (row2,) = conn2.execute("select 2 as y").fetchall()
-    assert_type(p2, NullConnectionPool[MyConnection[TupleRow]])
-    assert_type(conn2, MyConnection[TupleRow])
-    assert_type(row2, TupleRow)
-    assert row2 == (2,)
-
-
-def test_non_generic_connection_type(dsn):
-    def set_autocommit(conn: psycopg.Connection[Any]) -> None:
-        conn.autocommit = True
-
-    class MyConnection(psycopg.Connection[MyRow]):
-        def __init__(self, *args: Any, **kwargs: Any):
-            kwargs["row_factory"] = class_row(MyRow)
-            super().__init__(*args, **kwargs)
-
-    with NullConnectionPool(
-        dsn, connection_class=MyConnection, configure=set_autocommit
-    ) as p1:
-        with p1.connection() as conn1:
-            (row1,) = conn1.execute("select 1 as x").fetchall()
-    assert_type(p1, NullConnectionPool[MyConnection])
-    assert_type(conn1, MyConnection)
-    assert_type(row1, MyRow)
-    assert conn1.autocommit
-    assert row1 == {"x": 1}
-
-
-@pytest.mark.crdb_skip("backend pid")
-def test_its_no_pool_at_all(dsn):
-    with NullConnectionPool(dsn, max_size=2) as p:
-        with p.connection() as conn:
-            pid1 = conn.info.backend_pid
-
-            with p.connection() as conn2:
-                pid2 = conn2.info.backend_pid
-
-        with p.connection() as conn:
-            assert conn.info.backend_pid not in (pid1, pid2)
-
-
-def test_context(dsn):
-    with NullConnectionPool(dsn) as p:
-        assert not p.closed
-    assert p.closed
-
-
-@pytest.mark.slow
-@pytest.mark.timing
-def test_wait_ready(dsn, monkeypatch):
-    delay_connection(monkeypatch, 0.2)
-    with pytest.raises(PoolTimeout):
-        with NullConnectionPool(dsn, num_workers=1) as p:
-            p.wait(0.1)
-
-    with NullConnectionPool(dsn, num_workers=1) as p:
-        p.wait(0.4)
-
-
-def test_wait_closed(dsn):
-    with NullConnectionPool(dsn) as p:
-        pass
-
-    with pytest.raises(PoolClosed):
-        p.wait()
-
-
-@pytest.mark.slow
-def test_setup_no_timeout(dsn, proxy):
-    with pytest.raises(PoolTimeout):
-        with NullConnectionPool(proxy.client_dsn, num_workers=1) as p:
-            p.wait(0.2)
-
-    with NullConnectionPool(proxy.client_dsn, num_workers=1) as p:
-        sleep(0.5)
-        assert not p._pool
-        proxy.start()
-
-        with p.connection() as conn:
-            conn.execute("select 1")
-
-
-def test_configure(dsn):
-    inits = 0
-
-    def configure(conn):
-        nonlocal inits
-        inits += 1
-        with conn.transaction():
-            conn.execute("set default_transaction_read_only to on")
-
-    with NullConnectionPool(dsn, configure=configure) as p:
-        with p.connection() as conn:
-            assert inits == 1
-            res = conn.execute("show default_transaction_read_only")
-            assert res.fetchone()[0] == "on"  # type: ignore[index]
-
-        with p.connection() as conn:
-            assert inits == 2
-            res = conn.execute("show default_transaction_read_only")
-            assert res.fetchone()[0] == "on"  # type: ignore[index]
-            conn.close()
-
-        with p.connection() as conn:
-            assert inits == 3
-            res = conn.execute("show default_transaction_read_only")
-            assert res.fetchone()[0] == "on"  # type: ignore[index]
-
-
-@pytest.mark.slow
-def test_configure_badstate(dsn, caplog):
-    caplog.set_level(logging.WARNING, logger="psycopg.pool")
-
-    def configure(conn):
-        conn.execute("select 1")
-
-    with NullConnectionPool(dsn, configure=configure) as p:
-        with pytest.raises(PoolTimeout):
-            p.wait(timeout=0.5)
-
-    assert caplog.records
-    assert "INTRANS" in caplog.records[0].message
-
-
-@pytest.mark.slow
-def test_configure_broken(dsn, caplog):
-    caplog.set_level(logging.WARNING, logger="psycopg.pool")
-
-    def configure(conn):
-        with conn.transaction():
-            conn.execute("WAT")
-
-    with NullConnectionPool(dsn, configure=configure) as p:
-        with pytest.raises(PoolTimeout):
-            p.wait(timeout=0.5)
-
-    assert caplog.records
-    assert "WAT" in caplog.records[0].message
-
-
-@pytest.mark.crdb_skip("backend pid")
-def test_reset(dsn):
-    resets = 0
-
-    def setup(conn):
-        with conn.transaction():
-            conn.execute("set timezone to '+1:00'")
-
-    def reset(conn):
-        nonlocal resets
-        resets += 1
-        with conn.transaction():
-            conn.execute("set timezone to utc")
-
-    pids = []
-
-    def worker():
-        with p.connection() as conn:
-            assert resets == 1
-            with conn.execute("show timezone") as cur:
-                assert cur.fetchone() == ("UTC",)
-            pids.append(conn.info.backend_pid)
-
-    with NullConnectionPool(dsn, max_size=1, reset=reset) as p:
-        with p.connection() as conn:
-            # Queue the worker so it will take the same connection a second time
-            # instead of making a new one.
-            t = Thread(target=worker)
-            t.start()
-            ensure_waiting(p)
-
-            assert resets == 0
-            conn.execute("set timezone to '+2:00'")
-            pids.append(conn.info.backend_pid)
-
-        t.join()
-        p.wait()
-
-    assert resets == 1
-    assert pids[0] == pids[1]
-
-
-@pytest.mark.crdb_skip("backend pid")
-def test_reset_badstate(dsn, caplog):
-    caplog.set_level(logging.WARNING, logger="psycopg.pool")
-
-    def reset(conn):
-        conn.execute("reset all")
-
-    pids = []
-
-    def worker():
-        with p.connection() as conn:
-            conn.execute("select 1")
-            pids.append(conn.info.backend_pid)
-
-    with NullConnectionPool(dsn, max_size=1, reset=reset) as p:
-        with p.connection() as conn:
-            t = Thread(target=worker)
-            t.start()
-            ensure_waiting(p)
-
-            conn.execute("select 1")
-            pids.append(conn.info.backend_pid)
-
-        t.join()
-
-    assert pids[0] != pids[1]
-    assert caplog.records
-    assert "INTRANS" in caplog.records[0].message
-
-
-@pytest.mark.crdb_skip("backend pid")
-def test_reset_broken(dsn, caplog):
-    caplog.set_level(logging.WARNING, logger="psycopg.pool")
-
-    def reset(conn):
-        with conn.transaction():
-            conn.execute("WAT")
-
-    pids = []
-
-    def worker():
-        with p.connection() as conn:
-            conn.execute("select 1")
-            pids.append(conn.info.backend_pid)
-
-    with NullConnectionPool(dsn, max_size=1, reset=reset) as p:
-        with p.connection() as conn:
-            t = Thread(target=worker)
-            t.start()
-            ensure_waiting(p)
-
-            conn.execute("select 1")
-            pids.append(conn.info.backend_pid)
-
-        t.join()
-
-    assert pids[0] != pids[1]
-    assert caplog.records
-    assert "WAT" in caplog.records[0].message
-
-
-@pytest.mark.slow
-@pytest.mark.skipif("ver(psycopg.__version__) < ver('3.0.8')")
-def test_no_queue_timeout(deaf_port):
-    with NullConnectionPool(kwargs={"host": "localhost", "port": deaf_port}) as p:
-        with pytest.raises(PoolTimeout):
-            with p.connection(timeout=1):
-                pass
-
-
-@pytest.mark.slow
-@pytest.mark.timing
-@pytest.mark.crdb_skip("backend pid")
-def test_queue(dsn):
-    def worker(n):
-        t0 = time()
-        with p.connection() as conn:
-            conn.execute("select pg_sleep(0.2)")
-            pid = conn.info.backend_pid
-
-        t1 = time()
-        results.append((n, t1 - t0, pid))
-
-    results: List[Tuple[int, float, int]] = []
-    with NullConnectionPool(dsn, max_size=2) as p:
-        p.wait()
-        ts = [Thread(target=worker, args=(i,)) for i in range(6)]
-        for t in ts:
-            t.start()
-        for t in ts:
-            t.join()
-
-    times = [item[1] for item in results]
-    want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6]
-    for got, want in zip(times, want_times):
-        assert got == pytest.approx(want, 0.2), times
-
-    assert len(set(r[2] for r in results)) == 2, results
-
-
-@pytest.mark.slow
-def test_queue_size(dsn):
-    def worker(t, ev=None):
-        try:
-            with p.connection():
-                if ev:
-                    ev.set()
-                sleep(t)
-        except TooManyRequests as e:
-            errors.append(e)
-        else:
-            success.append(True)
-
-    errors: List[Exception] = []
-    success: List[bool] = []
-
-    with NullConnectionPool(dsn, max_size=1, max_waiting=3) as p:
-        p.wait()
-        ev = Event()
-        t = Thread(target=worker, args=(0.3, ev))
-        t.start()
-        ev.wait()
-
-        ts = [Thread(target=worker, args=(0.1,)) for i in range(4)]
-        for t in ts:
-            t.start()
-        for t in ts:
-            t.join()
-
-    assert len(success) == 4
-    assert len(errors) == 1
-    assert isinstance(errors[0], TooManyRequests)
-    assert p.name in str(errors[0])
-    assert str(p.max_waiting) in str(errors[0])
-    assert p.get_stats()["requests_errors"] == 1
-
-
-@pytest.mark.slow
-@pytest.mark.timing
-@pytest.mark.crdb_skip("backend pid")
-def test_queue_timeout(dsn):
-    def worker(n):
-        t0 = time()
-        try:
-            with p.connection() as conn:
-                conn.execute("select pg_sleep(0.2)")
-                pid = conn.info.backend_pid
-        except PoolTimeout as e:
-            t1 = time()
-            errors.append((n, t1 - t0, e))
-        else:
-            t1 = time()
-            results.append((n, t1 - t0, pid))
-
-    results: List[Tuple[int, float, int]] = []
-    errors: List[Tuple[int, float, Exception]] = []
-
-    with NullConnectionPool(dsn, max_size=2, timeout=0.1) as p:
-        ts = [Thread(target=worker, args=(i,)) for i in range(4)]
-        for t in ts:
-            t.start()
-        for t in ts:
-            t.join()
-
-    assert len(results) == 2
-    assert len(errors) == 2
-    for e in errors:
-        assert 0.1 < e[1] < 0.15
-
-
-@pytest.mark.slow
-@pytest.mark.timing
-def test_dead_client(dsn):
-    def worker(i, timeout):
-        try:
-            with p.connection(timeout=timeout) as conn:
-                conn.execute("select pg_sleep(0.3)")
-                results.append(i)
-        except PoolTimeout:
-            if timeout > 0.2:
-                raise
-
-    results: List[int] = []
-
-    with NullConnectionPool(dsn, max_size=2) as p:
-        ts = [
-            Thread(target=worker, args=(i, timeout))
-            for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4])
-        ]
-        for t in ts:
-            t.start()
-        for t in ts:
-            t.join()
-        sleep(0.2)
-        assert set(results) == set([0, 1, 3, 4])
-
-
-@pytest.mark.slow
-@pytest.mark.timing
-@pytest.mark.crdb_skip("backend pid")
-def test_queue_timeout_override(dsn):
-    def worker(n):
-        t0 = time()
-        timeout = 0.25 if n == 3 else None
-        try:
-            with p.connection(timeout=timeout) as conn:
-                conn.execute("select pg_sleep(0.2)")
-                pid = conn.info.backend_pid
-        except PoolTimeout as e:
-            t1 = time()
-            errors.append((n, t1 - t0, e))
-        else:
-            t1 = time()
-            results.append((n, t1 - t0, pid))
-
-    results: List[Tuple[int, float, int]] = []
-    errors: List[Tuple[int, float, Exception]] = []
-
-    with NullConnectionPool(dsn, max_size=2, timeout=0.1) as p:
-        ts = [Thread(target=worker, args=(i,)) for i in range(4)]
-        for t in ts:
-            t.start()
-        for t in ts:
-            t.join()
-
-    assert len(results) == 3
-    assert len(errors) == 1
-    for e in errors:
-        assert 0.1 < e[1] < 0.15
-
-
-@pytest.mark.crdb_skip("backend pid")
-def test_broken_reconnect(dsn):
-    with NullConnectionPool(dsn, max_size=1) as p:
-        with p.connection() as conn:
-            pid1 = conn.info.backend_pid
-            conn.close()
-
-        with p.connection() as conn2:
-            pid2 = conn2.info.backend_pid
-
-    assert pid1 != pid2
-
-
-@pytest.mark.crdb_skip("backend pid")
-def test_intrans_rollback(dsn, caplog):
-    caplog.set_level(logging.WARNING, logger="psycopg.pool")
-    pids = []
-
-    def worker():
-        with p.connection() as conn:
-            pids.append(conn.info.backend_pid)
-            assert conn.info.transaction_status == TransactionStatus.IDLE
-            assert not conn.execute(
-                "select 1 from pg_class where relname = 'test_intrans_rollback'"
-            ).fetchone()
-
-    with NullConnectionPool(dsn, max_size=1) as p:
-        conn = p.getconn()
-
-        # Queue the worker so it will take the connection a second time instead
-        # of making a new one.
-        t = Thread(target=worker)
-        t.start()
-        ensure_waiting(p)
-
-        pids.append(conn.info.backend_pid)
-        conn.execute("create table test_intrans_rollback ()")
-        assert conn.info.transaction_status == TransactionStatus.INTRANS
-        p.putconn(conn)
-        t.join()
-
-    assert pids[0] == pids[1]
-    assert len(caplog.records) == 1
-    assert "INTRANS" in caplog.records[0].message
-
-
-@pytest.mark.crdb_skip("backend pid")
-def test_inerror_rollback(dsn, caplog):
-    caplog.set_level(logging.WARNING, logger="psycopg.pool")
-    pids = []
-
-    def worker():
-        with p.connection() as conn:
-            pids.append(conn.info.backend_pid)
-            assert conn.info.transaction_status == TransactionStatus.IDLE
-
-    with NullConnectionPool(dsn, max_size=1) as p:
-        conn = p.getconn()
-
-        # Queue the worker so it will take the connection a second time instead
-        # of making a new one.
-        t = Thread(target=worker)
-        t.start()
-        ensure_waiting(p)
-
-        pids.append(conn.info.backend_pid)
-        with pytest.raises(psycopg.ProgrammingError):
-            conn.execute("wat")
-        assert conn.info.transaction_status == TransactionStatus.INERROR
-        p.putconn(conn)
-        t.join()
-
-    assert pids[0] == pids[1]
-    assert len(caplog.records) == 1
-    assert "INERROR" in caplog.records[0].message
-
-
-@pytest.mark.crdb_skip("backend pid")
-@pytest.mark.crdb_skip("copy")
-def test_active_close(dsn, caplog):
-    caplog.set_level(logging.WARNING, logger="psycopg.pool")
-    pids = []
-
-    def worker():
-        with p.connection() as conn:
-            pids.append(conn.info.backend_pid)
-            assert conn.info.transaction_status == TransactionStatus.IDLE
-
-    with NullConnectionPool(dsn, max_size=1) as p:
-        conn = p.getconn()
-
-        t = Thread(target=worker)
-        t.start()
-        ensure_waiting(p)
-
-        pids.append(conn.info.backend_pid)
-        conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout")
-        assert conn.info.transaction_status == TransactionStatus.ACTIVE
-        p.putconn(conn)
-        t.join()
-
-    assert pids[0] != pids[1]
-    assert len(caplog.records) == 2
-    assert "ACTIVE" in caplog.records[0].message
-    assert "BAD" in caplog.records[1].message
-
-
-@pytest.mark.crdb_skip("backend pid")
-def test_fail_rollback_close(dsn, caplog, monkeypatch):
-    caplog.set_level(logging.WARNING, logger="psycopg.pool")
-    pids = []
-
-    def worker(p):
-        with p.connection() as conn:
-            pids.append(conn.info.backend_pid)
-            assert conn.info.transaction_status == TransactionStatus.IDLE
-
-    with NullConnectionPool(dsn, max_size=1) as p:
-        conn = p.getconn()
-
-        def bad_rollback():
-            conn.pgconn.finish()
-            orig_rollback()
-
-        # Make the rollback fail
-        orig_rollback = conn.rollback
-        monkeypatch.setattr(conn, "rollback", bad_rollback)
-
-        t = Thread(target=worker, args=(p,))
-        t.start()
-        ensure_waiting(p)
-
-        pids.append(conn.info.backend_pid)
-        with pytest.raises(psycopg.ProgrammingError):
-            conn.execute("wat")
-        assert conn.info.transaction_status == TransactionStatus.INERROR
-        p.putconn(conn)
-        t.join()
-
-    assert pids[0] != pids[1]
-    assert len(caplog.records) == 3
-    assert "INERROR" in caplog.records[0].message
-    assert "OperationalError" in caplog.records[1].message
-    assert "BAD" in caplog.records[2].message
-
-
-def test_close_no_threads(dsn):
-    p = NullConnectionPool(dsn)
-    assert p._sched_runner and p._sched_runner.is_alive()
-    workers = p._workers[:]
-    assert workers
-    for t in workers:
-        assert t.is_alive()
-
-    p.close()
-    assert p._sched_runner is None
-    assert not p._workers
-    for t in workers:
-        assert not t.is_alive()
-
-
-def test_putconn_no_pool(conn_cls, dsn):
-    with NullConnectionPool(dsn) as p:
-        conn = conn_cls.connect(dsn)
-        with pytest.raises(ValueError):
-            p.putconn(conn)
-
-    conn.close()
-
-
-def test_putconn_wrong_pool(dsn):
-    with NullConnectionPool(dsn) as p1:
-        with NullConnectionPool(dsn) as p2:
-            conn = p1.getconn()
-            with pytest.raises(ValueError):
-                p2.putconn(conn)
-
-
-@pytest.mark.slow
-def test_del_stop_threads(dsn):
-    p = NullConnectionPool(dsn)
-    assert p._sched_runner is not None
-    ts = [p._sched_runner] + p._workers
-    del p
-    sleep(0.1)
-    for t in ts:
-        assert not t.is_alive()
-
-
-def test_closed_getconn(dsn):
-    p = NullConnectionPool(dsn)
-    assert not p.closed
-    with p.connection():
-        pass
-
-    p.close()
-    assert p.closed
-
-    with pytest.raises(PoolClosed):
-        with p.connection():
-            pass
-
-
-def test_closed_putconn(dsn):
-    p = NullConnectionPool(dsn)
-
-    with p.connection() as conn:
-        pass
-    assert conn.closed
-
-    with p.connection() as conn:
-        p.close()
-    assert conn.closed
-
-
-def test_closed_queue(dsn):
-    def w1():
-        with p.connection() as conn:
-            e1.set()  # Tell w0 that w1 got a connection
-            cur = conn.execute("select 1")
-            assert cur.fetchone() == (1,)
-            e2.wait()  # Wait until w0 has tested w2
-        success.append("w1")
-
-    def w2():
-        try:
-            with p.connection():
-                pass  # unexpected
-        except PoolClosed:
-            success.append("w2")
-
-    e1 = Event()
-    e2 = Event()
-
-    p = NullConnectionPool(dsn, max_size=1)
-    p.wait()
-    success: List[str] = []
-
-    t1 = Thread(target=w1)
-    t1.start()
-    # Wait until w1 has received a connection
-    e1.wait()
-
-    t2 = Thread(target=w2)
-    t2.start()
-    # Wait until w2 is in the queue
-    ensure_waiting(p)
-
-    p.close(0)
-
-    # Wait for the workers to finish
-    e2.set()
-    t1.join()
-    t2.join()
-    assert len(success) == 2
-
-
-def test_open_explicit(dsn):
-    p = NullConnectionPool(dsn, open=False)
-    assert p.closed
-    with pytest.raises(PoolClosed, match="is not open yet"):
-        p.getconn()
-
-    with pytest.raises(PoolClosed):
-        with p.connection():
-            pass
-
-    p.open()
-    try:
-        assert not p.closed
-
-        with p.connection() as conn:
-            cur = conn.execute("select 1")
-            assert cur.fetchone() == (1,)
-
-    finally:
-        p.close()
-
-    with pytest.raises(PoolClosed, match="is already closed"):
-        p.getconn()
-
-
-def test_open_context(dsn):
-    p = NullConnectionPool(dsn, open=False)
-    assert p.closed
-
-    with p:
-        assert not p.closed
-
-        with p.connection() as conn:
-            cur = conn.execute("select 1")
-            assert cur.fetchone() == (1,)
-
-    assert p.closed
-
-
-def test_open_no_op(dsn):
-    p = NullConnectionPool(dsn)
-    try:
-        assert not p.closed
-        p.open()
-        assert not p.closed
-
-        with p.connection() as conn:
-            cur = conn.execute("select 1")
-            assert cur.fetchone() == (1,)
-
-    finally:
-        p.close()
-
-
-def test_reopen(dsn):
-    p = NullConnectionPool(dsn)
-    with p.connection() as conn:
-        conn.execute("select 1")
-    p.close()
-    assert p._sched_runner is None
-    assert not p._workers
-
-    with pytest.raises(psycopg.OperationalError, match="cannot be reused"):
-        p.open()
-
-
-@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)])
-def test_bad_resize(dsn, min_size, max_size):
-    with NullConnectionPool() as p:
-        with pytest.raises(ValueError):
-            p.resize(min_size=min_size, max_size=max_size)
-
-
-@pytest.mark.slow
-@pytest.mark.timing
-@pytest.mark.crdb_skip("backend pid")
-def test_max_lifetime(dsn):
-    pids = []
-
-    def worker(p):
-        with p.connection() as conn:
-            pids.append(conn.info.backend_pid)
-            sleep(0.1)
-
-    ts = []
-    with NullConnectionPool(dsn, max_size=1, max_lifetime=0.2) as p:
-        for i in range(5):
-            ts.append(Thread(target=worker, args=(p,)))
-            ts[-1].start()
-
-        for t in ts:
-            t.join()
-
-    assert pids[0] == pids[1] != pids[4], pids
-
-
-def test_check(dsn):
-    with NullConnectionPool(dsn) as p:
-        # No-op
-        p.check()
-
-
-@pytest.mark.slow
-@pytest.mark.timing
-def test_stats_measures(dsn):
-    def worker(n):
-        with p.connection() as conn:
-            conn.execute("select pg_sleep(0.2)")
-
-    with NullConnectionPool(dsn, max_size=4) as p:
-        p.wait(2.0)
-
-        stats = p.get_stats()
-        assert stats["pool_min"] == 0
-        assert stats["pool_max"] == 4
-        assert stats["pool_size"] == 0
-        assert stats["pool_available"] == 0
-        assert stats["requests_waiting"] == 0
-
-        ts = [Thread(target=worker, args=(i,)) for i in range(3)]
-        for t in ts:
-            t.start()
-        sleep(0.1)
-        stats = p.get_stats()
-        for t in ts:
-            t.join()
-        assert stats["pool_min"] == 0
-        assert stats["pool_max"] == 4
-        assert stats["pool_size"] == 3
-        assert stats["pool_available"] == 0
-        assert stats["requests_waiting"] == 0
-
-        p.wait(2.0)
-        ts = [Thread(target=worker, args=(i,)) for i in range(7)]
-        for t in ts:
-            t.start()
-        sleep(0.1)
-        stats = p.get_stats()
-        for t in ts:
-            t.join()
-        assert stats["pool_min"] == 0
-        assert stats["pool_max"] == 4
-        assert stats["pool_size"] == 4
-        assert stats["pool_available"] == 0
-        assert stats["requests_waiting"] == 3
-
-
-@pytest.mark.slow
-@pytest.mark.timing
-def test_stats_usage(dsn):
-    def worker(n):
-        try:
-            with p.connection(timeout=0.3) as conn:
-                conn.execute("select pg_sleep(0.2)")
-        except PoolTimeout:
-            pass
-
-    with NullConnectionPool(dsn, max_size=3) as p:
-        p.wait(2.0)
-
-        ts = [Thread(target=worker, args=(i,)) for i in range(7)]
-        for t in ts:
-            t.start()
-        for t in ts:
-            t.join()
-        stats = p.get_stats()
-        assert stats["requests_num"] == 7
-        assert stats["requests_queued"] == 4
-        assert 850 <= stats["requests_wait_ms"] <= 950
-        assert stats["requests_errors"] == 1
-        assert 1150 <= stats["usage_ms"] <= 1250
-        assert stats.get("returns_bad", 0) == 0
-
-        with p.connection() as conn:
-            conn.close()
-        p.wait()
-        stats = p.pop_stats()
-        assert stats["requests_num"] == 8
-        assert stats["returns_bad"] == 1
-        with p.connection():
-            pass
-        assert p.get_stats()["requests_num"] == 1
-
-
-@pytest.mark.slow
-def test_stats_connect(dsn, proxy, monkeypatch):
-    proxy.start()
-    delay_connection(monkeypatch, 0.2)
-    with NullConnectionPool(proxy.client_dsn, max_size=3) as p:
-        p.wait()
-        stats = p.get_stats()
-        assert stats["connections_num"] == 1
-        assert stats.get("connections_errors", 0) == 0
-        assert stats.get("connections_lost", 0) == 0
-        assert 200 <= stats["connections_ms"] < 300
index fc353069ec394ef44b4edf2567fbb750c58ac9e7..10b9616b701944e8a93fd7f7a4dd5fc9ae5ef0ac 100644 (file)
@@ -14,6 +14,7 @@ from psycopg.rows import class_row, Row, TupleRow
 from psycopg._compat import assert_type, Counter
 
 from ..utils import Event, spawn, gather, sleep, is_async
+from .test_pool_common import delay_connection
 
 try:
     import psycopg_pool as pool
@@ -851,27 +852,3 @@ def test_cancellation_in_queue(dsn):
         with p.connection() as conn:
             cur = conn.execute("select 1")
             assert cur.fetchone() == (1,)
-
-
-def delay_connection(monkeypatch, sec):
-    """
-    Return a _connect_gen function delayed by the amount of seconds
-    """
-
-    def connect_delay(*args, **kwargs):
-        t0 = time()
-        rv = connect_orig(*args, **kwargs)
-        t1 = time()
-        sleep(max(0, sec - (t1 - t0)))
-        return rv
-
-    connect_orig = psycopg.Connection.connect
-    monkeypatch.setattr(psycopg.Connection, "connect", connect_delay)
-
-
-def ensure_waiting(p, num=1):
-    """
-    Wait until there are at least *num* clients waiting in the queue.
-    """
-    while len(p._waiting) < num:
-        sleep(0)
index ef6db8d2286143fc27f196b8e3a7321baf73df9c..734a0ad1f8f642236e37f8f1579bc68587ede3ca 100644 (file)
@@ -11,6 +11,7 @@ from psycopg.rows import class_row, Row, TupleRow
 from psycopg._compat import assert_type, Counter
 
 from ..utils import AEvent, spawn, gather, asleep, is_async
+from .test_pool_common_async import delay_connection
 
 try:
     import psycopg_pool as pool
@@ -858,27 +859,3 @@ async def test_cancellation_in_queue(dsn):
         async with p.connection() as conn:
             cur = await conn.execute("select 1")
             assert await cur.fetchone() == (1,)
-
-
-def delay_connection(monkeypatch, sec):
-    """
-    Return a _connect_gen function delayed by the amount of seconds
-    """
-
-    async def connect_delay(*args, **kwargs):
-        t0 = time()
-        rv = await connect_orig(*args, **kwargs)
-        t1 = time()
-        await asleep(max(0, sec - (t1 - t0)))
-        return rv
-
-    connect_orig = psycopg.AsyncConnection.connect
-    monkeypatch.setattr(psycopg.AsyncConnection, "connect", connect_delay)
-
-
-async def ensure_waiting(p, num=1):
-    """
-    Wait until there are at least *num* clients waiting in the queue.
-    """
-    while len(p._waiting) < num:
-        await asleep(0)
diff --git a/tests/pool/test_pool_null.py b/tests/pool/test_pool_null.py
new file mode 100644 (file)
index 0000000..7fdd055
--- /dev/null
@@ -0,0 +1,492 @@
+# WARNING: this file is auto-generated by 'async_to_sync.py'
+# from the original file 'test_pool_null_async.py'
+# DO NOT CHANGE! Change the original file instead.
+import logging
+from typing import Any, Dict, List
+
+import pytest
+from packaging.version import parse as ver  # noqa: F401  # used in skipif
+
+import psycopg
+from psycopg.pq import TransactionStatus
+from psycopg.rows import class_row, Row, TupleRow
+from psycopg._compat import assert_type
+
+from ..utils import Event, sleep, spawn, gather, is_async
+from .test_pool_common import delay_connection, ensure_waiting
+
+try:
+    import psycopg_pool as pool
+except ImportError:
+    # Tests should have been skipped if the package is not available
+    pass
+
+
+def test_default_sizes(dsn):
+    with pool.NullConnectionPool(dsn) as p:
+        assert p.min_size == p.max_size == 0
+
+
+def test_min_size_max_size(dsn):
+    with pool.NullConnectionPool(dsn, min_size=0, max_size=2) as p:
+        assert p.min_size == 0
+        assert p.max_size == 2
+
+
+@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)])
+def test_bad_size(dsn, min_size, max_size):
+    with pytest.raises(ValueError):
+        pool.NullConnectionPool(min_size=min_size, max_size=max_size)
+
+
+class MyRow(Dict[str, Any]):
+    ...
+
+
+def test_generic_connection_type(dsn):
+    def set_autocommit(conn: psycopg.Connection[Any]) -> None:
+        conn.set_autocommit(True)
+
+    class MyConnection(psycopg.Connection[Row]):
+        pass
+
+    with pool.NullConnectionPool(
+        dsn,
+        connection_class=MyConnection[MyRow],
+        kwargs={"row_factory": class_row(MyRow)},
+        configure=set_autocommit,
+    ) as p1:
+        with p1.connection() as conn1:
+            cur1 = conn1.execute("select 1 as x")
+            (row1,) = cur1.fetchall()
+    assert_type(p1, pool.NullConnectionPool[MyConnection[MyRow]])
+    assert_type(conn1, MyConnection[MyRow])
+    assert_type(row1, MyRow)
+    assert conn1.autocommit
+    assert row1 == {"x": 1}
+
+    with pool.NullConnectionPool(dsn, connection_class=MyConnection[TupleRow]) as p2:
+        with p2.connection() as conn2:
+            cur2 = conn2.execute("select 2 as y")
+            (row2,) = cur2.fetchall()
+    assert_type(p2, pool.NullConnectionPool[MyConnection[TupleRow]])
+    assert_type(conn2, MyConnection[TupleRow])
+    assert_type(row2, TupleRow)
+    assert row2 == (2,)
+
+
+def test_non_generic_connection_type(dsn):
+    def set_autocommit(conn: psycopg.Connection[Any]) -> None:
+        conn.set_autocommit(True)
+
+    class MyConnection(psycopg.Connection[MyRow]):
+        def __init__(self, *args: Any, **kwargs: Any):
+            kwargs["row_factory"] = class_row(MyRow)
+            super().__init__(*args, **kwargs)
+
+    with pool.NullConnectionPool(
+        dsn, connection_class=MyConnection, configure=set_autocommit
+    ) as p1:
+        with p1.connection() as conn1:
+            (row1,) = conn1.execute("select 1 as x").fetchall()
+    assert_type(p1, pool.NullConnectionPool[MyConnection])
+    assert_type(conn1, MyConnection)
+    assert_type(row1, MyRow)
+    assert conn1.autocommit
+    assert row1 == {"x": 1}
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_its_no_pool_at_all(dsn):
+    with pool.NullConnectionPool(dsn, max_size=2) as p:
+        with p.connection() as conn:
+            pid1 = conn.info.backend_pid
+
+            with p.connection() as conn2:
+                pid2 = conn2.info.backend_pid
+
+        with p.connection() as conn:
+            assert conn.info.backend_pid not in (pid1, pid2)
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_wait_ready(dsn, monkeypatch):
+    delay_connection(monkeypatch, 0.2)
+    with pytest.raises(pool.PoolTimeout):
+        with pool.NullConnectionPool(dsn, num_workers=1) as p:
+            p.wait(0.1)
+
+    with pool.NullConnectionPool(dsn, num_workers=1) as p:
+        p.wait(0.4)
+
+
+def test_configure(dsn):
+    inits = 0
+
+    def configure(conn):
+        nonlocal inits
+        inits += 1
+        with conn.transaction():
+            conn.execute("set default_transaction_read_only to on")
+
+    with pool.NullConnectionPool(dsn, configure=configure) as p:
+        with p.connection() as conn:
+            assert inits == 1
+            res = conn.execute("show default_transaction_read_only")
+            assert res.fetchone() == ("on",)
+
+        with p.connection() as conn:
+            assert inits == 2
+            res = conn.execute("show default_transaction_read_only")
+            assert res.fetchone() == ("on",)
+            conn.close()
+
+        with p.connection() as conn:
+            assert inits == 3
+            res = conn.execute("show default_transaction_read_only")
+            assert res.fetchone() == ("on",)
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_reset(dsn):
+    resets = 0
+
+    def setup(conn):
+        with conn.transaction():
+            conn.execute("set timezone to '+1:00'")
+
+    def reset(conn):
+        nonlocal resets
+        resets += 1
+        with conn.transaction():
+            conn.execute("set timezone to utc")
+
+    pids = []
+
+    def worker():
+        with p.connection() as conn:
+            assert resets == 1
+            cur = conn.execute("show timezone")
+            assert cur.fetchone() == ("UTC",)
+            pids.append(conn.info.backend_pid)
+
+    with pool.NullConnectionPool(dsn, max_size=1, reset=reset) as p:
+        with p.connection() as conn:
+            # Queue the worker so it will take the same connection a second time
+            # instead of making a new one.
+            t = spawn(worker)
+            ensure_waiting(p)
+
+            assert resets == 0
+            conn.execute("set timezone to '+2:00'")
+            pids.append(conn.info.backend_pid)
+
+        gather(t)
+        p.wait()
+
+    assert resets == 1
+    assert pids[0] == pids[1]
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_reset_badstate(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+    def reset(conn):
+        conn.execute("reset all")
+
+    pids = []
+
+    def worker():
+        with p.connection() as conn:
+            conn.execute("select 1")
+            pids.append(conn.info.backend_pid)
+
+    with pool.NullConnectionPool(dsn, max_size=1, reset=reset) as p:
+        with p.connection() as conn:
+            t = spawn(worker)
+            ensure_waiting(p)
+
+            conn.execute("select 1")
+            pids.append(conn.info.backend_pid)
+
+        gather(t)
+
+    assert pids[0] != pids[1]
+    assert caplog.records
+    assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_reset_broken(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg.pool")
+
+    def reset(conn):
+        with conn.transaction():
+            conn.execute("WAT")
+
+    pids = []
+
+    def worker():
+        with p.connection() as conn:
+            conn.execute("select 1")
+            pids.append(conn.info.backend_pid)
+
+    with pool.NullConnectionPool(dsn, max_size=1, reset=reset) as p:
+        with p.connection() as conn:
+            t = spawn(worker)
+            ensure_waiting(p)
+
+            conn.execute("select 1")
+            pids.append(conn.info.backend_pid)
+
+        gather(t)
+
+    assert pids[0] != pids[1]
+    assert caplog.records
+    assert "WAT" in caplog.records[0].message
+
+
+@pytest.mark.slow
+@pytest.mark.skipif("ver(psycopg.__version__) < ver('3.0.8')")
+def test_no_queue_timeout(deaf_port):
+    with pool.NullConnectionPool(kwargs={"host": "localhost", "port": deaf_port}) as p:
+        with pytest.raises(pool.PoolTimeout):
+            with p.connection(timeout=1):
+                pass
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_intrans_rollback(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg.pool")
+    pids = []
+
+    def worker():
+        with p.connection() as conn:
+            pids.append(conn.info.backend_pid)
+            assert conn.info.transaction_status == TransactionStatus.IDLE
+            cur = conn.execute(
+                "select 1 from pg_class where relname = 'test_intrans_rollback'"
+            )
+            assert not cur.fetchone()
+
+    with pool.NullConnectionPool(dsn, max_size=1) as p:
+        conn = p.getconn()
+
+        # Queue the worker so it will take the connection a second time instead
+        # of making a new one.
+        t = spawn(worker)
+        ensure_waiting(p)
+
+        pids.append(conn.info.backend_pid)
+        conn.execute("create table test_intrans_rollback ()")
+        assert conn.info.transaction_status == TransactionStatus.INTRANS
+        p.putconn(conn)
+        gather(t)
+
+    assert pids[0] == pids[1]
+    assert len(caplog.records) == 1
+    assert "INTRANS" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_inerror_rollback(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg.pool")
+    pids = []
+
+    def worker():
+        with p.connection() as conn:
+            pids.append(conn.info.backend_pid)
+            assert conn.info.transaction_status == TransactionStatus.IDLE
+
+    with pool.NullConnectionPool(dsn, max_size=1) as p:
+        conn = p.getconn()
+
+        # Queue the worker so it will take the connection a second time instead
+        # of making a new one.
+        t = spawn(worker)
+        ensure_waiting(p)
+
+        pids.append(conn.info.backend_pid)
+        with pytest.raises(psycopg.ProgrammingError):
+            conn.execute("wat")
+        assert conn.info.transaction_status == TransactionStatus.INERROR
+        p.putconn(conn)
+        gather(t)
+
+    assert pids[0] == pids[1]
+    assert len(caplog.records) == 1
+    assert "INERROR" in caplog.records[0].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+@pytest.mark.crdb_skip("copy")
+def test_active_close(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg.pool")
+    pids = []
+
+    def worker():
+        with p.connection() as conn:
+            pids.append(conn.info.backend_pid)
+            assert conn.info.transaction_status == TransactionStatus.IDLE
+
+    with pool.NullConnectionPool(dsn, max_size=1) as p:
+        conn = p.getconn()
+
+        t = spawn(worker)
+        ensure_waiting(p)
+
+        pids.append(conn.info.backend_pid)
+        conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout")
+        assert conn.info.transaction_status == TransactionStatus.ACTIVE
+        p.putconn(conn)
+        gather(t)
+
+    assert pids[0] != pids[1]
+    assert len(caplog.records) == 2
+    assert "ACTIVE" in caplog.records[0].message
+    assert "BAD" in caplog.records[1].message
+
+
+@pytest.mark.crdb_skip("backend pid")
+def test_fail_rollback_close(dsn, caplog, monkeypatch):
+    caplog.set_level(logging.WARNING, logger="psycopg.pool")
+    pids = []
+
+    def worker():
+        with p.connection() as conn:
+            pids.append(conn.info.backend_pid)
+            assert conn.info.transaction_status == TransactionStatus.IDLE
+
+    with pool.NullConnectionPool(dsn, max_size=1) as p:
+        conn = p.getconn()
+        t = spawn(worker)
+        ensure_waiting(p)
+
+        def bad_rollback():
+            conn.pgconn.finish()
+            orig_rollback()
+
+        # Make the rollback fail
+        orig_rollback = conn.rollback
+        monkeypatch.setattr(conn, "rollback", bad_rollback)
+
+        pids.append(conn.info.backend_pid)
+        with pytest.raises(psycopg.ProgrammingError):
+            conn.execute("wat")
+        assert conn.info.transaction_status == TransactionStatus.INERROR
+        p.putconn(conn)
+        gather(t)
+
+    assert pids[0] != pids[1]
+    assert len(caplog.records) == 3
+    assert "INERROR" in caplog.records[0].message
+    assert "OperationalError" in caplog.records[1].message
+    assert "BAD" in caplog.records[2].message
+
+
+def test_closed_putconn(dsn):
+    with pool.NullConnectionPool(dsn) as p:
+        with p.connection() as conn:
+            pass
+        assert conn.closed
+
+
+@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)])
+def test_bad_resize(dsn, min_size, max_size):
+    with pool.NullConnectionPool() as p:
+        with pytest.raises(ValueError):
+            p.resize(min_size=min_size, max_size=max_size)
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.crdb_skip("backend pid")
+def test_max_lifetime(dsn):
+    pids: List[int] = []
+
+    def worker():
+        with p.connection() as conn:
+            pids.append(conn.info.backend_pid)
+            sleep(0.1)
+
+    with pool.NullConnectionPool(dsn, max_size=1, max_lifetime=0.2) as p:
+        ts = [spawn(worker) for i in range(5)]
+        gather(*ts)
+
+    assert pids[0] == pids[1] != pids[4], pids
+
+
+def test_check(dsn):
+    # no.op
+    with pool.NullConnectionPool(dsn) as p:
+        p.check()
+
+
+@pytest.mark.slow
+def test_stats_connect(dsn, proxy, monkeypatch):
+    proxy.start()
+    delay_connection(monkeypatch, 0.2)
+    with pool.NullConnectionPool(proxy.client_dsn, max_size=3) as p:
+        p.wait()
+        stats = p.get_stats()
+        assert stats["connections_num"] == 1
+        assert stats.get("connections_errors", 0) == 0
+        assert stats.get("connections_lost", 0) == 0
+        assert 200 <= stats["connections_ms"] < 300
+
+
+@pytest.mark.skipif(not is_async(__name__), reason="async test only")
+def test_cancellation_in_queue(dsn):
+    # https://github.com/psycopg/psycopg/issues/509
+
+    nconns = 3
+
+    with pool.NullConnectionPool(dsn, min_size=0, max_size=nconns, timeout=1) as p:
+        p.wait()
+
+        got_conns = []
+        ev = Event()
+
+        def worker(i):
+            try:
+                logging.info("worker %s started", i)
+                nonlocal got_conns
+
+                with p.connection() as conn:
+                    logging.info("worker %s got conn", i)
+                    cur = conn.execute("select 1")
+                    assert cur.fetchone() == (1,)
+
+                    got_conns.append(conn)
+                    if len(got_conns) >= nconns:
+                        ev.set()
+
+                    sleep(5)
+            except BaseException as ex:
+                logging.info("worker %s stopped: %r", i, ex)
+                raise
+
+        # Start tasks taking up all the connections and getting in the queue
+        tasks = [spawn(worker, (i,)) for i in range(nconns * 3)]
+
+        # wait until the pool has served all the connections and clients are queued.
+        ev.wait(3.0)
+        for i in range(10):
+            if p.get_stats().get("requests_queued", 0):
+                break
+            else:
+                sleep(0.1)
+        else:
+            pytest.fail("no client got in the queue")
+
+        [task.cancel() for task in reversed(tasks)]
+        gather(*tasks, return_exceptions=True, timeout=1.0)
+
+        stats = p.get_stats()
+        assert stats.get("requests_waiting", 0) == 0
+
+        with p.connection() as conn:
+            cur = conn.execute("select 1")
+            assert cur.fetchone() == (1,)
similarity index 91%
rename from tests/pool/test_null_pool_async.py
rename to tests/pool/test_pool_null_async.py
index 56a016397d4b456760418d1bbcc9f30f9450b538..8525c356bbd303514a61387f183a5613a119aaa9 100644 (file)
@@ -1,8 +1,5 @@
-import asyncio
 import logging
-from time import time
-from typing import Any, Dict, List, Tuple
-from asyncio import create_task
+from typing import Any, Dict, List
 
 import pytest
 from packaging.version import parse as ver  # noqa: F401  # used in skipif
@@ -11,15 +8,19 @@ import psycopg
 from psycopg.pq import TransactionStatus
 from psycopg.rows import class_row, Row, TupleRow
 from psycopg._compat import assert_type
-from .test_pool_async import delay_connection, ensure_waiting
 
-pytestmark = [pytest.mark.anyio]
+from ..utils import AEvent, asleep, spawn, gather, is_async
+from .test_pool_common_async import delay_connection, ensure_waiting
 
 try:
     import psycopg_pool as pool
 except ImportError:
+    # Tests should have been skipped if the package is not available
     pass
 
+if True:  # ASYNC
+    pytestmark = [pytest.mark.anyio]
+
 
 async def test_default_sizes(dsn):
     async with pool.AsyncNullConnectionPool(dsn) as p:
@@ -135,18 +136,18 @@ async def test_configure(dsn):
         async with p.connection() as conn:
             assert inits == 1
             res = await conn.execute("show default_transaction_read_only")
-            assert (await res.fetchone())[0] == "on"  # type: ignore[index]
+            assert (await res.fetchone()) == ("on",)
 
         async with p.connection() as conn:
             assert inits == 2
             res = await conn.execute("show default_transaction_read_only")
-            assert (await res.fetchone())[0] == "on"  # type: ignore[index]
+            assert (await res.fetchone()) == ("on",)
             await conn.close()
 
         async with p.connection() as conn:
             assert inits == 3
             res = await conn.execute("show default_transaction_read_only")
-            assert (await res.fetchone())[0] == "on"  # type: ignore[index]
+            assert (await res.fetchone()) == ("on",)
 
 
 @pytest.mark.crdb_skip("backend pid")
@@ -176,14 +177,14 @@ async def test_reset(dsn):
         async with p.connection() as conn:
             # Queue the worker so it will take the same connection a second time
             # instead of making a new one.
-            t = create_task(worker())
+            t = spawn(worker)
             await ensure_waiting(p)
 
             assert resets == 0
             await conn.execute("set timezone to '+2:00'")
             pids.append(conn.info.backend_pid)
 
-        await asyncio.gather(t)
+        await gather(t)
         await p.wait()
 
     assert resets == 1
@@ -206,13 +207,13 @@ async def test_reset_badstate(dsn, caplog):
 
     async with pool.AsyncNullConnectionPool(dsn, max_size=1, reset=reset) as p:
         async with p.connection() as conn:
-            t = create_task(worker())
+            t = spawn(worker)
             await ensure_waiting(p)
 
             await conn.execute("select 1")
             pids.append(conn.info.backend_pid)
 
-        await asyncio.gather(t)
+        await gather(t)
 
     assert pids[0] != pids[1]
     assert caplog.records
@@ -236,13 +237,13 @@ async def test_reset_broken(dsn, caplog):
 
     async with pool.AsyncNullConnectionPool(dsn, max_size=1, reset=reset) as p:
         async with p.connection() as conn:
-            t = create_task(worker())
+            t = spawn(worker)
             await ensure_waiting(p)
 
             await conn.execute("select 1")
             pids.append(conn.info.backend_pid)
 
-        await asyncio.gather(t)
+        await gather(t)
 
     assert pids[0] != pids[1]
     assert caplog.records
@@ -279,14 +280,14 @@ async def test_intrans_rollback(dsn, caplog):
 
         # Queue the worker so it will take the connection a second time instead
         # of making a new one.
-        t = create_task(worker())
+        t = spawn(worker)
         await ensure_waiting(p)
 
         pids.append(conn.info.backend_pid)
         await conn.execute("create table test_intrans_rollback ()")
         assert conn.info.transaction_status == TransactionStatus.INTRANS
         await p.putconn(conn)
-        await asyncio.gather(t)
+        await gather(t)
 
     assert pids[0] == pids[1]
     assert len(caplog.records) == 1
@@ -306,7 +307,9 @@ async def test_inerror_rollback(dsn, caplog):
     async with pool.AsyncNullConnectionPool(dsn, max_size=1) as p:
         conn = await p.getconn()
 
-        t = create_task(worker())
+        # Queue the worker so it will take the connection a second time instead
+        # of making a new one.
+        t = spawn(worker)
         await ensure_waiting(p)
 
         pids.append(conn.info.backend_pid)
@@ -314,7 +317,7 @@ async def test_inerror_rollback(dsn, caplog):
             await conn.execute("wat")
         assert conn.info.transaction_status == TransactionStatus.INERROR
         await p.putconn(conn)
-        await asyncio.gather(t)
+        await gather(t)
 
     assert pids[0] == pids[1]
     assert len(caplog.records) == 1
@@ -335,14 +338,14 @@ async def test_active_close(dsn, caplog):
     async with pool.AsyncNullConnectionPool(dsn, max_size=1) as p:
         conn = await p.getconn()
 
-        t = create_task(worker())
+        t = spawn(worker)
         await ensure_waiting(p)
 
         pids.append(conn.info.backend_pid)
         conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout")
         assert conn.info.transaction_status == TransactionStatus.ACTIVE
         await p.putconn(conn)
-        await asyncio.gather(t)
+        await gather(t)
 
     assert pids[0] != pids[1]
     assert len(caplog.records) == 2
@@ -362,7 +365,7 @@ async def test_fail_rollback_close(dsn, caplog, monkeypatch):
 
     async with pool.AsyncNullConnectionPool(dsn, max_size=1) as p:
         conn = await p.getconn()
-        t = create_task(worker())
+        t = spawn(worker)
         await ensure_waiting(p)
 
         async def bad_rollback():
@@ -378,7 +381,7 @@ async def test_fail_rollback_close(dsn, caplog, monkeypatch):
             await conn.execute("wat")
         assert conn.info.transaction_status == TransactionStatus.INERROR
         await p.putconn(conn)
-        await asyncio.gather(t)
+        await gather(t)
 
     assert pids[0] != pids[1]
     assert len(caplog.records) == 3
@@ -410,11 +413,11 @@ async def test_max_lifetime(dsn):
     async def worker():
         async with p.connection() as conn:
             pids.append(conn.info.backend_pid)
-            await asyncio.sleep(0.1)
+            await asleep(0.1)
 
     async with pool.AsyncNullConnectionPool(dsn, max_size=1, max_lifetime=0.2) as p:
-        ts = [create_task(worker()) for i in range(5)]
-        await asyncio.gather(*ts)
+        ts = [spawn(worker) for i in range(5)]
+        await gather(*ts)
 
     assert pids[0] == pids[1] != pids[4], pids
 
@@ -438,6 +441,7 @@ async def test_stats_connect(dsn, proxy, monkeypatch):
         assert 200 <= stats["connections_ms"] < 300
 
 
+@pytest.mark.skipif(not is_async(__name__), reason="async test only")
 async def test_cancellation_in_queue(dsn):
     # https://github.com/psycopg/psycopg/issues/509
 
@@ -449,7 +453,7 @@ async def test_cancellation_in_queue(dsn):
         await p.wait()
 
         got_conns = []
-        ev = asyncio.Event()
+        ev = AEvent()
 
         async def worker(i):
             try:
@@ -465,27 +469,27 @@ async def test_cancellation_in_queue(dsn):
                     if len(got_conns) >= nconns:
                         ev.set()
 
-                    await asyncio.sleep(5)
+                    await asleep(5)
 
             except BaseException as ex:
                 logging.info("worker %s stopped: %r", i, ex)
                 raise
 
         # Start tasks taking up all the connections and getting in the queue
-        tasks = [asyncio.ensure_future(worker(i)) for i in range(nconns * 3)]
+        tasks = [spawn(worker, (i,)) for i in range(nconns * 3)]
 
         # wait until the pool has served all the connections and clients are queued.
-        await asyncio.wait_for(ev.wait(), 3.0)
+        await ev.wait_timeout(3.0)
         for i in range(10):
             if p.get_stats().get("requests_queued", 0):
                 break
             else:
-                await asyncio.sleep(0.1)
+                await asleep(0.1)
         else:
             pytest.fail("no client got in the queue")
 
         [task.cancel() for task in reversed(tasks)]
-        await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), 1.0)
+        await gather(*tasks, return_exceptions=True, timeout=1.0)
 
         stats = p.get_stats()
         assert stats.get("requests_waiting", 0) == 0
index 3ede5997faffde076bd825ab7044108aeb312f3d..e0ba2260765a122ba4225475076a4e7af040cd29 100755 (executable)
@@ -191,6 +191,7 @@ class RenameAsyncToSync(ast.NodeTransformer):
         "find_insert_problem_async": "find_insert_problem",
         "psycopg_pool.pool_async": "psycopg_pool.pool",
         "psycopg_pool.sched_async": "psycopg_pool.sched",
+        "test_pool_common_async": "test_pool_common",
         "wait_async": "wait",
         "wait_conn_async": "wait_conn",
         "wait_timeout": "wait",
index 1b05fda1f65bd1f270240acdbd321ae14c233967..2017d21cafcf7b60fe1c4e450c555a72ac2963b3 100755 (executable)
@@ -23,6 +23,7 @@ for async in \
     psycopg_pool/psycopg_pool/sched_async.py \
     tests/pool/test_pool_async.py \
     tests/pool/test_pool_common_async.py \
+    tests/pool/test_pool_null_async.py \
     tests/pool/test_sched_async.py \
     tests/test_connection_async.py \
     tests/test_copy_async.py \