From: Daniele Varrazzo Date: Sat, 27 Feb 2021 22:47:26 +0000 (+0100) Subject: Use more of the pool context manager in the tests X-Git-Tag: 3.0.dev0~87^2~36 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ff273b6f8f03280318a145431115d2156015fd73;p=thirdparty%2Fpsycopg.git Use more of the pool context manager in the tests --- diff --git a/tests/pool/test_pool.py b/tests/pool/test_pool.py index 44f093ae2..1356f8e08 100644 --- a/tests/pool/test_pool.py +++ b/tests/pool/test_pool.py @@ -12,54 +12,43 @@ from psycopg3.pq import TransactionStatus def test_defaults(dsn): - p = pool.ConnectionPool(dsn) - assert p.minconn == p.maxconn == 4 - assert p.timeout == 30 - assert p.max_idle == 600 - assert p.num_workers == 3 + with pool.ConnectionPool(dsn) as p: + assert p.minconn == p.maxconn == 4 + assert p.timeout == 30 + assert p.max_idle == 600 + assert p.num_workers == 3 def test_minconn_maxconn(dsn): - p = pool.ConnectionPool(dsn, minconn=2) - assert p.minconn == p.maxconn == 2 + with pool.ConnectionPool(dsn, minconn=2) as p: + assert p.minconn == p.maxconn == 2 - p = pool.ConnectionPool(dsn, minconn=2, maxconn=4) - assert p.minconn == 2 - assert p.maxconn == 4 + with pool.ConnectionPool(dsn, minconn=2, maxconn=4) as p: + assert p.minconn == 2 + assert p.maxconn == 4 with pytest.raises(ValueError): pool.ConnectionPool(dsn, minconn=4, maxconn=2) def test_kwargs(dsn): - p = pool.ConnectionPool(dsn, kwargs={"autocommit": True}, minconn=1) - with p.connection() as conn: - assert conn.autocommit + with pool.ConnectionPool(dsn, kwargs={"autocommit": True}, minconn=1) as p: + with p.connection() as conn: + assert conn.autocommit def test_its_really_a_pool(dsn): - p = pool.ConnectionPool(dsn, minconn=2) - with p.connection() as conn: - with conn.execute("select pg_backend_pid()") as cur: - (pid1,) = cur.fetchone() - - with p.connection() as conn2: - with conn2.execute("select pg_backend_pid()") as cur: - (pid2,) = cur.fetchone() - - with p.connection() as conn: - assert conn.pgconn.backend_pid in (pid1, pid2) + with pool.ConnectionPool(dsn, minconn=2) as p: + with p.connection() as conn: + with conn.execute("select pg_backend_pid()") as cur: + (pid1,) = cur.fetchone() + with p.connection() as conn2: + with conn2.execute("select pg_backend_pid()") as cur: + (pid2,) = cur.fetchone() -def test_connection_not_lost(dsn): - p = pool.ConnectionPool(dsn, minconn=1) - with pytest.raises(ZeroDivisionError): with p.connection() as conn: - pid = conn.pgconn.backend_pid - 1 / 0 - - with p.connection() as conn2: - assert conn2.pgconn.backend_pid == pid + assert conn.pgconn.backend_pid in (pid1, pid2) def test_context(dsn): @@ -68,6 +57,17 @@ def test_context(dsn): assert p.closed +def test_connection_not_lost(dsn): + with pool.ConnectionPool(dsn, minconn=1) as p: + with pytest.raises(ZeroDivisionError): + with p.connection() as conn: + pid = conn.pgconn.backend_pid + 1 / 0 + + with p.connection() as conn2: + assert conn2.pgconn.backend_pid == pid + + @pytest.mark.slow def test_concurrent_filling(dsn, monkeypatch, retries): delay_connection(monkeypatch, 0.1) @@ -96,31 +96,32 @@ def test_concurrent_filling(dsn, monkeypatch, retries): def test_wait_ready(dsn, monkeypatch): delay_connection(monkeypatch, 0.1) with pytest.raises(pool.PoolTimeout): - p = pool.ConnectionPool(dsn, minconn=4, num_workers=1) - p.wait_ready(0.3) + with pool.ConnectionPool(dsn, minconn=4, num_workers=1) as p: + p.wait_ready(0.3) - p = pool.ConnectionPool(dsn, minconn=4, num_workers=1) - p.wait_ready(0.5) - p.close() - p = pool.ConnectionPool(dsn, minconn=4, num_workers=2) - p.wait_ready(0.3) - p.wait_ready(0.0001) # idempotent - p.close() + with pool.ConnectionPool(dsn, minconn=4, num_workers=1) as p: + p.wait_ready(0.5) + + with pool.ConnectionPool(dsn, minconn=4, num_workers=2) as p: + p.wait_ready(0.3) + p.wait_ready(0.0001) # idempotent @pytest.mark.slow def test_setup_no_timeout(dsn, proxy): with pytest.raises(pool.PoolTimeout): - p = pool.ConnectionPool(proxy.client_dsn, minconn=1, num_workers=1) - p.wait_ready(0.2) + with pool.ConnectionPool( + proxy.client_dsn, minconn=1, num_workers=1 + ) as p: + p.wait_ready(0.2) - p = pool.ConnectionPool(proxy.client_dsn, minconn=1, num_workers=1) - sleep(0.5) - assert not p._pool - proxy.start() + with pool.ConnectionPool(proxy.client_dsn, minconn=1, num_workers=1) as p: + sleep(0.5) + assert not p._pool + proxy.start() - with p.connection() as conn: - conn.execute("select 1") + with p.connection() as conn: + conn.execute("select 1") @pytest.mark.slow @@ -137,30 +138,21 @@ def test_queue(dsn, retries): for retry in retries: with retry: results = [] - ts = [] with pool.ConnectionPool(dsn, minconn=2) as p: - for i in range(6): - t = Thread(target=worker, args=(i,)) - t.start() - ts.append(t) - - for t in ts: - t.join() + ts = [Thread(target=worker, args=(i,)) for i in range(6)] + [t.start() for t in ts] + [t.join() for t in ts] times = [item[1] for item in results] want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6] for got, want in zip(times, want_times): - assert got == pytest.approx(want, 0.2), times + assert got == pytest.approx(want, 0.1), times - assert len(set(r[2] for r in results)) == 2 + assert len(set(r[2] for r in results)) == 2, results @pytest.mark.slow def test_queue_timeout(dsn): - p = pool.ConnectionPool(dsn, minconn=2, timeout=0.1) - results = [] - errors = [] - def worker(n): t0 = time() try: @@ -175,14 +167,13 @@ def test_queue_timeout(dsn): t1 = time() results.append((n, t1 - t0, pid)) - ts = [] - for i in range(4): - t = Thread(target=worker, args=(i,)) - t.start() - ts.append(t) + results = [] + errors = [] - for t in ts: - t.join() + with pool.ConnectionPool(dsn, minconn=2, timeout=0.1) as p: + ts = [Thread(target=worker, args=(i,)) for i in range(4)] + [t.start() for t in ts] + [t.join() for t in ts] assert len(results) == 2 assert len(errors) == 2 @@ -192,10 +183,6 @@ def test_queue_timeout(dsn): @pytest.mark.slow def test_dead_client(dsn): - p = pool.ConnectionPool(dsn, minconn=2) - - results = [] - def worker(i, timeout): try: with p.connection(timeout=timeout) as conn: @@ -205,26 +192,22 @@ def test_dead_client(dsn): if timeout > 0.2: raise - ts = [] - for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4]): - t = Thread(target=worker, args=(i, timeout)) - t.start() - ts.append(t) - - for t in ts: - t.join() + results = [] - sleep(0.2) - assert set(results) == set([0, 1, 3, 4]) - assert len(p._pool) == 2 # no connection was lost + with pool.ConnectionPool(dsn, minconn=2) as p: + ts = [ + Thread(target=worker, args=(i, timeout)) + for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4]) + ] + [t.start() for t in ts] + [t.join() for t in ts] + sleep(0.2) + assert set(results) == set([0, 1, 3, 4]) + assert len(p._pool) == 2 # no connection was lost @pytest.mark.slow def test_queue_timeout_override(dsn): - p = pool.ConnectionPool(dsn, minconn=2, timeout=0.1) - results = [] - errors = [] - def worker(n): t0 = time() timeout = 0.25 if n == 3 else None @@ -240,14 +223,13 @@ def test_queue_timeout_override(dsn): t1 = time() results.append((n, t1 - t0, pid)) - ts = [] - for i in range(4): - t = Thread(target=worker, args=(i,)) - t.start() - ts.append(t) + results = [] + errors = [] - for t in ts: - t.join() + with pool.ConnectionPool(dsn, minconn=2, timeout=0.1) as p: + ts = [Thread(target=worker, args=(i,)) for i in range(4)] + [t.start() for t in ts] + [t.join() for t in ts] assert len(results) == 3 assert len(errors) == 1 @@ -255,37 +237,37 @@ def test_queue_timeout_override(dsn): assert 0.1 < e[1] < 0.15 -def test_broken_reconnect(dsn, caplog): - caplog.set_level(logging.WARNING, logger="psycopg3.pool") - p = pool.ConnectionPool(dsn, minconn=1) - with pytest.raises(psycopg3.OperationalError): - with p.connection() as conn: - with conn.execute("select pg_backend_pid()") as cur: - (pid1,) = cur.fetchone() - conn.close() +def test_broken_reconnect(dsn): + with pool.ConnectionPool(dsn, minconn=1) as p: + with pytest.raises(psycopg3.OperationalError): + with p.connection() as conn: + with conn.execute("select pg_backend_pid()") as cur: + (pid1,) = cur.fetchone() + conn.close() - with p.connection() as conn2: - with conn2.execute("select pg_backend_pid()") as cur: - (pid2,) = cur.fetchone() + with p.connection() as conn2: + with conn2.execute("select pg_backend_pid()") as cur: + (pid2,) = cur.fetchone() assert pid1 != pid2 def test_intrans_rollback(dsn, caplog): caplog.set_level(logging.WARNING, logger="psycopg3.pool") - p = pool.ConnectionPool(dsn, minconn=1) - conn = p.getconn() - pid = conn.pgconn.backend_pid - conn.execute("create table test_intrans_rollback ()") - assert conn.pgconn.transaction_status == TransactionStatus.INTRANS - p.putconn(conn) - - with p.connection() as conn2: - assert conn2.pgconn.backend_pid == pid - assert conn2.pgconn.transaction_status == TransactionStatus.IDLE - assert not conn.execute( - "select 1 from pg_class where relname = 'test_intrans_rollback'" - ).fetchone() + + with pool.ConnectionPool(dsn, minconn=1) as p: + conn = p.getconn() + pid = conn.pgconn.backend_pid + conn.execute("create table test_intrans_rollback ()") + assert conn.pgconn.transaction_status == TransactionStatus.INTRANS + p.putconn(conn) + + with p.connection() as conn2: + assert conn2.pgconn.backend_pid == pid + assert conn2.pgconn.transaction_status == TransactionStatus.IDLE + assert not conn.execute( + "select 1 from pg_class where relname = 'test_intrans_rollback'" + ).fetchone() assert len(caplog.records) == 1 assert "INTRANS" in caplog.records[0].message @@ -293,17 +275,18 @@ def test_intrans_rollback(dsn, caplog): def test_inerror_rollback(dsn, caplog): caplog.set_level(logging.WARNING, logger="psycopg3.pool") - p = pool.ConnectionPool(dsn, minconn=1) - conn = p.getconn() - pid = conn.pgconn.backend_pid - with pytest.raises(psycopg3.ProgrammingError): - conn.execute("wat") - assert conn.pgconn.transaction_status == TransactionStatus.INERROR - p.putconn(conn) - with p.connection() as conn2: - assert conn2.pgconn.backend_pid == pid - assert conn2.pgconn.transaction_status == TransactionStatus.IDLE + with pool.ConnectionPool(dsn, minconn=1) as p: + conn = p.getconn() + pid = conn.pgconn.backend_pid + with pytest.raises(psycopg3.ProgrammingError): + conn.execute("wat") + assert conn.pgconn.transaction_status == TransactionStatus.INERROR + p.putconn(conn) + + with p.connection() as conn2: + assert conn2.pgconn.backend_pid == pid + assert conn2.pgconn.transaction_status == TransactionStatus.IDLE assert len(caplog.records) == 1 assert "INERROR" in caplog.records[0].message @@ -311,18 +294,19 @@ def test_inerror_rollback(dsn, caplog): def test_active_close(dsn, caplog): caplog.set_level(logging.WARNING, logger="psycopg3.pool") - p = pool.ConnectionPool(dsn, minconn=1) - conn = p.getconn() - pid = conn.pgconn.backend_pid - cur = conn.cursor() - with cur.copy("copy (select * from generate_series(1, 10)) to stdout"): - pass - assert conn.pgconn.transaction_status == TransactionStatus.ACTIVE - p.putconn(conn) - with p.connection() as conn2: - assert conn2.pgconn.backend_pid != pid - assert conn2.pgconn.transaction_status == TransactionStatus.IDLE + with pool.ConnectionPool(dsn, minconn=1) as p: + conn = p.getconn() + pid = conn.pgconn.backend_pid + cur = conn.cursor() + with cur.copy("copy (select * from generate_series(1, 10)) to stdout"): + pass + assert conn.pgconn.transaction_status == TransactionStatus.ACTIVE + p.putconn(conn) + + with p.connection() as conn2: + assert conn2.pgconn.backend_pid != pid + assert conn2.pgconn.transaction_status == TransactionStatus.IDLE assert len(caplog.records) == 2 assert "ACTIVE" in caplog.records[0].message @@ -331,26 +315,27 @@ def test_active_close(dsn, caplog): def test_fail_rollback_close(dsn, caplog, monkeypatch): caplog.set_level(logging.WARNING, logger="psycopg3.pool") - p = pool.ConnectionPool(dsn, minconn=1) - conn = p.getconn() - def bad_rollback(): - conn.pgconn.finish() - orig_rollback() + with pool.ConnectionPool(dsn, minconn=1) as p: + conn = p.getconn() + + def bad_rollback(): + conn.pgconn.finish() + orig_rollback() - # Make the rollback fail - orig_rollback = conn.rollback - monkeypatch.setattr(conn, "rollback", bad_rollback) + # Make the rollback fail + orig_rollback = conn.rollback + monkeypatch.setattr(conn, "rollback", bad_rollback) - pid = conn.pgconn.backend_pid - with pytest.raises(psycopg3.ProgrammingError): - conn.execute("wat") - assert conn.pgconn.transaction_status == TransactionStatus.INERROR - p.putconn(conn) + pid = conn.pgconn.backend_pid + with pytest.raises(psycopg3.ProgrammingError): + conn.execute("wat") + assert conn.pgconn.transaction_status == TransactionStatus.INERROR + p.putconn(conn) - with p.connection() as conn2: - assert conn2.pgconn.backend_pid != pid - assert conn2.pgconn.transaction_status == TransactionStatus.IDLE + with p.connection() as conn2: + assert conn2.pgconn.backend_pid != pid + assert conn2.pgconn.transaction_status == TransactionStatus.IDLE assert len(caplog.records) == 3 assert "INERROR" in caplog.records[0].message @@ -371,18 +356,18 @@ def test_close_no_threads(dsn): def test_putconn_no_pool(dsn): - p = pool.ConnectionPool(dsn, minconn=1) - conn = psycopg3.connect(dsn) - with pytest.raises(ValueError): - p.putconn(conn) + with pool.ConnectionPool(dsn, minconn=1) as p: + conn = psycopg3.connect(dsn) + with pytest.raises(ValueError): + p.putconn(conn) def test_putconn_wrong_pool(dsn): - p1 = pool.ConnectionPool(dsn, minconn=1) - p2 = pool.ConnectionPool(dsn, minconn=1) - conn = p1.getconn() - with pytest.raises(ValueError): - p2.putconn(conn) + with pool.ConnectionPool(dsn, minconn=1) as p1: + with pool.ConnectionPool(dsn, minconn=1) as p2: + conn = p1.getconn() + with pytest.raises(ValueError): + p2.putconn(conn) def test_del_no_warning(dsn, recwarn): @@ -402,7 +387,7 @@ def test_del_stop_threads(dsn): p = pool.ConnectionPool(dsn) ts = [p._sched_runner] + p._workers del p - sleep(0.2) + sleep(0.1) for t in ts: assert not t.is_alive() @@ -479,21 +464,16 @@ def test_grow(dsn, monkeypatch, retries): dsn, minconn=2, maxconn=4, num_workers=3 ) as p: p.wait_ready(1.0) - ts = [] results = [] - for i in range(6): - t = Thread(target=worker, args=(i,)) - t.start() - ts.append(t) - - for t in ts: - t.join() + ts = [Thread(target=worker, args=(i,)) for i in range(6)] + [t.start() for t in ts] + [t.join() for t in ts] - want_times = [0.2, 0.2, 0.3, 0.3, 0.4, 0.4] - times = [item[1] for item in results] - for got, want in zip(times, want_times): - assert got == pytest.approx(want, 0.1), times + want_times = [0.2, 0.2, 0.3, 0.3, 0.4, 0.4] + times = [item[1] for item in results] + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.1), times @pytest.mark.slow @@ -512,24 +492,19 @@ def test_shrink(dsn, monkeypatch): orig_run = ShrinkPool._run monkeypatch.setattr(ShrinkPool, "_run", run_hacked) - p = pool.ConnectionPool(dsn, minconn=2, maxconn=4, max_idle=0.2) - p.wait_ready(5.0) - assert p.max_idle == 0.2 - def worker(n): with p.connection() as conn: conn.execute("select pg_sleep(0.1)") - ts = [] - for i in range(4): - t = Thread(target=worker, args=(i,)) - t.start() - ts.append(t) + with pool.ConnectionPool(dsn, minconn=2, maxconn=4, max_idle=0.2) as p: + p.wait_ready(5.0) + assert p.max_idle == 0.2 - for t in ts: - t.join() + ts = [Thread(target=worker, args=(i,)) for i in range(4)] + [t.start() for t in ts] + [t.join() for t in ts] + sleep(1) - sleep(1) assert results == [(4, 4), (4, 3), (3, 2), (2, 2), (2, 2)] @@ -543,20 +518,20 @@ def test_reconnect(proxy, caplog, monkeypatch): monkeypatch.setattr(pool.base.ConnectionAttempt, "DELAY_JITTER", 0.0) proxy.start() - p = pool.ConnectionPool(proxy.client_dsn, minconn=1) - p.wait_ready(2.0) - proxy.stop() + with pool.ConnectionPool(proxy.client_dsn, minconn=1) as p: + p.wait_ready(2.0) + proxy.stop() - with pytest.raises(psycopg3.OperationalError): - with p.connection() as conn: - conn.execute("select 1") + with pytest.raises(psycopg3.OperationalError): + with p.connection() as conn: + conn.execute("select 1") - sleep(1.0) - proxy.start() - p.wait_ready() + sleep(1.0) + proxy.start() + p.wait_ready() - with p.connection() as conn: - conn.execute("select 1") + with p.connection() as conn: + conn.execute("select 1") assert "BAD" in caplog.messages[0] times = [rec.created for rec in caplog.records] @@ -580,42 +555,42 @@ def test_reconnect_failure(proxy): nonlocal t1 t1 = time() - p = pool.ConnectionPool( + with pool.ConnectionPool( proxy.client_dsn, name="this-one", minconn=1, reconnect_timeout=1.0, reconnect_failed=failed, - ) - p.wait_ready(2.0) - proxy.stop() + ) as p: + p.wait_ready(2.0) + proxy.stop() - with pytest.raises(psycopg3.OperationalError): - with p.connection() as conn: - conn.execute("select 1") + with pytest.raises(psycopg3.OperationalError): + with p.connection() as conn: + conn.execute("select 1") - t0 = time() - sleep(1.5) - assert t1 - assert t1 - t0 == pytest.approx(1.0, 0.1) - assert p._nconns == 0 + t0 = time() + sleep(1.5) + assert t1 + assert t1 - t0 == pytest.approx(1.0, 0.1) + assert p._nconns == 0 - proxy.start() - t0 = time() - with p.connection() as conn: - conn.execute("select 1") - t1 = time() - assert t1 - t0 < 0.2 + proxy.start() + t0 = time() + with p.connection() as conn: + conn.execute("select 1") + t1 = time() + assert t1 - t0 < 0.2 @pytest.mark.slow def test_uniform_use(dsn): - p = pool.ConnectionPool(dsn, minconn=4) - counts = Counter() - for i in range(8): - with p.connection() as conn: - sleep(0.1) - counts[id(conn)] += 1 + with pool.ConnectionPool(dsn, minconn=4) as p: + counts = Counter() + for i in range(8): + with p.connection() as conn: + sleep(0.1) + counts[id(conn)] += 1 assert len(counts) == 4 assert set(counts.values()) == set([2]) @@ -623,9 +598,6 @@ def test_uniform_use(dsn): @pytest.mark.slow def test_resize(dsn): - p = pool.ConnectionPool(dsn, minconn=2, max_idle=0.2) - size = [] - def sampler(): sleep(0.05) # ensure sampling happens after shrink check while True: @@ -638,28 +610,29 @@ def test_resize(dsn): with p.connection() as conn: conn.execute("select pg_sleep(%s)", [t]) - s = Thread(target=sampler) - s.start() + size = [] - sleep(0.3) + with pool.ConnectionPool(dsn, minconn=2, max_idle=0.2) as p: + s = Thread(target=sampler) + s.start() - c = Thread(target=client, args=(0.4,)) - c.start() + sleep(0.3) + c = Thread(target=client, args=(0.4,)) + c.start() - sleep(0.2) - p.resize(4) - assert p.minconn == 4 - assert p.maxconn == 4 + sleep(0.2) + p.resize(4) + assert p.minconn == 4 + assert p.maxconn == 4 - sleep(0.4) - p.resize(2) - assert p.minconn == 2 - assert p.maxconn == 2 + sleep(0.4) + p.resize(2) + assert p.minconn == 2 + assert p.maxconn == 2 - sleep(0.6) - p.close() - s.join() + sleep(0.6) + s.join() assert size == [2, 1, 3, 4, 3, 2, 2] diff --git a/tests/pool/test_pool_async.py b/tests/pool/test_pool_async.py index f4c282ff5..0425947d3 100644 --- a/tests/pool/test_pool_async.py +++ b/tests/pool/test_pool_async.py @@ -21,50 +21,45 @@ pytestmark = pytest.mark.asyncio async def test_defaults(dsn): - p = pool.AsyncConnectionPool(dsn) - assert p.minconn == p.maxconn == 4 - assert p.timeout == 30 - assert p.max_idle == 600 - assert p.num_workers == 3 - await p.close() + async with pool.AsyncConnectionPool(dsn) as p: + assert p.minconn == p.maxconn == 4 + assert p.timeout == 30 + assert p.max_idle == 600 + assert p.num_workers == 3 async def test_minconn_maxconn(dsn): - p = pool.AsyncConnectionPool(dsn, minconn=2) - assert p.minconn == p.maxconn == 2 - await p.close() + async with pool.AsyncConnectionPool(dsn, minconn=2) as p: + assert p.minconn == p.maxconn == 2 - p = pool.AsyncConnectionPool(dsn, minconn=2, maxconn=4) - assert p.minconn == 2 - assert p.maxconn == 4 - await p.close() + async with pool.AsyncConnectionPool(dsn, minconn=2, maxconn=4) as p: + assert p.minconn == 2 + assert p.maxconn == 4 with pytest.raises(ValueError): pool.AsyncConnectionPool(dsn, minconn=4, maxconn=2) async def test_kwargs(dsn): - p = pool.AsyncConnectionPool(dsn, kwargs={"autocommit": True}, minconn=1) - async with p.connection() as conn: - assert conn.autocommit - - await p.close() + async with pool.AsyncConnectionPool( + dsn, kwargs={"autocommit": True}, minconn=1 + ) as p: + async with p.connection() as conn: + assert conn.autocommit async def test_its_really_a_pool(dsn): - p = pool.AsyncConnectionPool(dsn, minconn=2) - async with p.connection() as conn: - cur = await conn.execute("select pg_backend_pid()") - (pid1,) = await cur.fetchone() - - async with p.connection() as conn2: - cur = await conn2.execute("select pg_backend_pid()") - (pid2,) = await cur.fetchone() + async with pool.AsyncConnectionPool(dsn, minconn=2) as p: + async with p.connection() as conn: + cur = await conn.execute("select pg_backend_pid()") + (pid1,) = await cur.fetchone() - async with p.connection() as conn: - assert conn.pgconn.backend_pid in (pid1, pid2) + async with p.connection() as conn2: + cur = await conn2.execute("select pg_backend_pid()") + (pid2,) = await cur.fetchone() - await p.close() + async with p.connection() as conn: + assert conn.pgconn.backend_pid in (pid1, pid2) async def test_context(dsn): @@ -74,16 +69,14 @@ async def test_context(dsn): async def test_connection_not_lost(dsn): - p = pool.AsyncConnectionPool(dsn, minconn=1) - with pytest.raises(ZeroDivisionError): - async with p.connection() as conn: - pid = conn.pgconn.backend_pid - 1 / 0 - - async with p.connection() as conn2: - assert conn2.pgconn.backend_pid == pid + async with pool.AsyncConnectionPool(dsn, minconn=1) as p: + with pytest.raises(ZeroDivisionError): + async with p.connection() as conn: + pid = conn.pgconn.backend_pid + 1 / 0 - await p.close() + async with p.connection() as conn2: + assert conn2.pgconn.backend_pid == pid @pytest.mark.slow @@ -116,35 +109,36 @@ async def test_concurrent_filling(dsn, monkeypatch, retries): async def test_wait_ready(dsn, monkeypatch): delay_connection(monkeypatch, 0.1) with pytest.raises(pool.PoolTimeout): - p = pool.AsyncConnectionPool(dsn, minconn=4, num_workers=1) - await p.wait_ready(0.3) + async with pool.AsyncConnectionPool( + dsn, minconn=4, num_workers=1 + ) as p: + await p.wait_ready(0.3) - p = pool.AsyncConnectionPool(dsn, minconn=4, num_workers=1) - await p.wait_ready(0.5) - await p.close() - p = pool.AsyncConnectionPool(dsn, minconn=4, num_workers=2) - await p.wait_ready(0.3) - await p.wait_ready(0.0001) # idempotent - await p.close() + async with pool.AsyncConnectionPool(dsn, minconn=4, num_workers=1) as p: + await p.wait_ready(0.5) + + async with pool.AsyncConnectionPool(dsn, minconn=4, num_workers=2) as p: + await p.wait_ready(0.3) + await p.wait_ready(0.0001) # idempotent @pytest.mark.slow async def test_setup_no_timeout(dsn, proxy): with pytest.raises(pool.PoolTimeout): - p = pool.AsyncConnectionPool( + async with pool.AsyncConnectionPool( proxy.client_dsn, minconn=1, num_workers=1 - ) - await p.wait_ready(0.2) + ) as p: + await p.wait_ready(0.2) - p = pool.AsyncConnectionPool(proxy.client_dsn, minconn=1, num_workers=1) - await asyncio.sleep(0.5) - assert not p._pool - proxy.start() - - async with p.connection() as conn: - await conn.execute("select 1") + async with pool.AsyncConnectionPool( + proxy.client_dsn, minconn=1, num_workers=1 + ) as p: + await asyncio.sleep(0.5) + assert not p._pool + proxy.start() - await p.close() + async with p.connection() as conn: + await conn.execute("select 1") @pytest.mark.slow @@ -176,10 +170,6 @@ async def test_queue(dsn, retries): @pytest.mark.slow async def test_queue_timeout(dsn): - p = pool.AsyncConnectionPool(dsn, minconn=2, timeout=0.1) - results = [] - errors = [] - async def worker(n): t0 = time() try: @@ -195,23 +185,21 @@ async def test_queue_timeout(dsn): t1 = time() results.append((n, t1 - t0, pid)) - ts = [create_task(worker(i)) for i in range(4)] - await asyncio.gather(*ts) + results = [] + errors = [] + + async with pool.AsyncConnectionPool(dsn, minconn=2, timeout=0.1) as p: + ts = [create_task(worker(i)) for i in range(4)] + await asyncio.gather(*ts) assert len(results) == 2 assert len(errors) == 2 for e in errors: assert 0.1 < e[1] < 0.15 - await p.close() - @pytest.mark.slow async def test_dead_client(dsn): - p = pool.AsyncConnectionPool(dsn, minconn=2) - - results = [] - async def worker(i, timeout): try: async with p.connection(timeout=timeout) as conn: @@ -221,24 +209,21 @@ async def test_dead_client(dsn): if timeout > 0.2: raise - ts = [ - create_task(worker(i, timeout)) - for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4]) - ] - await asyncio.gather(*ts) + async with pool.AsyncConnectionPool(dsn, minconn=2) as p: + results = [] + ts = [ + create_task(worker(i, timeout)) + for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4]) + ] + await asyncio.gather(*ts) - await asyncio.sleep(0.2) - assert set(results) == set([0, 1, 3, 4]) - assert len(p._pool) == 2 # no connection was lost - await p.close() + await asyncio.sleep(0.2) + assert set(results) == set([0, 1, 3, 4]) + assert len(p._pool) == 2 # no connection was lost @pytest.mark.slow async def test_queue_timeout_override(dsn): - p = pool.AsyncConnectionPool(dsn, minconn=2, timeout=0.1) - results = [] - errors = [] - async def worker(n): t0 = time() timeout = 0.25 if n == 3 else None @@ -255,9 +240,12 @@ async def test_queue_timeout_override(dsn): t1 = time() results.append((n, t1 - t0, pid)) - ts = [create_task(worker(i)) for i in range(4)] - await asyncio.gather(*ts) - await p.close() + results = [] + errors = [] + + async with pool.AsyncConnectionPool(dsn, minconn=2, timeout=0.1) as p: + ts = [create_task(worker(i)) for i in range(4)] + await asyncio.gather(*ts) assert len(results) == 3 assert len(errors) == 1 @@ -266,38 +254,36 @@ async def test_queue_timeout_override(dsn): async def test_broken_reconnect(dsn): - p = pool.AsyncConnectionPool(dsn, minconn=1) - with pytest.raises(psycopg3.OperationalError): - async with p.connection() as conn: - cur = await conn.execute("select pg_backend_pid()") - (pid1,) = await cur.fetchone() - await conn.close() + async with pool.AsyncConnectionPool(dsn, minconn=1) as p: + with pytest.raises(psycopg3.OperationalError): + async with p.connection() as conn: + cur = await conn.execute("select pg_backend_pid()") + (pid1,) = await cur.fetchone() + await conn.close() - async with p.connection() as conn2: - cur = await conn2.execute("select pg_backend_pid()") - (pid2,) = await cur.fetchone() + async with p.connection() as conn2: + cur = await conn2.execute("select pg_backend_pid()") + (pid2,) = await cur.fetchone() - await p.close() assert pid1 != pid2 async def test_intrans_rollback(dsn, caplog): - p = pool.AsyncConnectionPool(dsn, minconn=1) - conn = await p.getconn() - pid = conn.pgconn.backend_pid - await conn.execute("create table test_intrans_rollback ()") - assert conn.pgconn.transaction_status == TransactionStatus.INTRANS - await p.putconn(conn) - - async with p.connection() as conn2: - assert conn2.pgconn.backend_pid == pid - assert conn2.pgconn.transaction_status == TransactionStatus.IDLE - cur = await conn.execute( - "select 1 from pg_class where relname = 'test_intrans_rollback'" - ) - assert not await cur.fetchone() + async with pool.AsyncConnectionPool(dsn, minconn=1) as p: + conn = await p.getconn() + pid = conn.pgconn.backend_pid + await conn.execute("create table test_intrans_rollback ()") + assert conn.pgconn.transaction_status == TransactionStatus.INTRANS + await p.putconn(conn) + + async with p.connection() as conn2: + assert conn2.pgconn.backend_pid == pid + assert conn2.pgconn.transaction_status == TransactionStatus.IDLE + cur = await conn.execute( + "select 1 from pg_class where relname = 'test_intrans_rollback'" + ) + assert not await cur.fetchone() - await p.close() recs = [ r for r in caplog.records @@ -308,17 +294,17 @@ async def test_intrans_rollback(dsn, caplog): async def test_inerror_rollback(dsn, caplog): - p = pool.AsyncConnectionPool(dsn, minconn=1) - conn = await p.getconn() - pid = conn.pgconn.backend_pid - with pytest.raises(psycopg3.ProgrammingError): - await conn.execute("wat") - assert conn.pgconn.transaction_status == TransactionStatus.INERROR - await p.putconn(conn) + async with pool.AsyncConnectionPool(dsn, minconn=1) as p: + conn = await p.getconn() + pid = conn.pgconn.backend_pid + with pytest.raises(psycopg3.ProgrammingError): + await conn.execute("wat") + assert conn.pgconn.transaction_status == TransactionStatus.INERROR + await p.putconn(conn) - async with p.connection() as conn2: - assert conn2.pgconn.backend_pid == pid - assert conn2.pgconn.transaction_status == TransactionStatus.IDLE + async with p.connection() as conn2: + assert conn2.pgconn.backend_pid == pid + assert conn2.pgconn.transaction_status == TransactionStatus.IDLE recs = [ r @@ -328,26 +314,23 @@ async def test_inerror_rollback(dsn, caplog): assert len(recs) == 1 assert "INERROR" in recs[0].message - await p.close() - async def test_active_close(dsn, caplog): - p = pool.AsyncConnectionPool(dsn, minconn=1) - conn = await p.getconn() - pid = conn.pgconn.backend_pid - cur = conn.cursor() - async with cur.copy( - "copy (select * from generate_series(1, 10)) to stdout" - ): - pass - assert conn.pgconn.transaction_status == TransactionStatus.ACTIVE - await p.putconn(conn) + async with pool.AsyncConnectionPool(dsn, minconn=1) as p: + conn = await p.getconn() + pid = conn.pgconn.backend_pid + cur = conn.cursor() + async with cur.copy( + "copy (select * from generate_series(1, 10)) to stdout" + ): + pass + assert conn.pgconn.transaction_status == TransactionStatus.ACTIVE + await p.putconn(conn) - async with p.connection() as conn2: - assert conn2.pgconn.backend_pid != pid - assert conn2.pgconn.transaction_status == TransactionStatus.IDLE + async with p.connection() as conn2: + assert conn2.pgconn.backend_pid != pid + assert conn2.pgconn.transaction_status == TransactionStatus.IDLE - await p.close() recs = [ r for r in caplog.records @@ -359,28 +342,26 @@ async def test_active_close(dsn, caplog): async def test_fail_rollback_close(dsn, caplog, monkeypatch): - p = pool.AsyncConnectionPool(dsn, minconn=1) - conn = await p.getconn() - - async def bad_rollback(): - conn.pgconn.finish() - await orig_rollback() + async with pool.AsyncConnectionPool(dsn, minconn=1) as p: + conn = await p.getconn() - # Make the rollback fail - orig_rollback = conn.rollback - monkeypatch.setattr(conn, "rollback", bad_rollback) + async def bad_rollback(): + conn.pgconn.finish() + await orig_rollback() - pid = conn.pgconn.backend_pid - with pytest.raises(psycopg3.ProgrammingError): - await conn.execute("wat") - assert conn.pgconn.transaction_status == TransactionStatus.INERROR - await p.putconn(conn) + # Make the rollback fail + orig_rollback = conn.rollback + monkeypatch.setattr(conn, "rollback", bad_rollback) - async with p.connection() as conn2: - assert conn2.pgconn.backend_pid != pid - assert conn2.pgconn.transaction_status == TransactionStatus.IDLE + pid = conn.pgconn.backend_pid + with pytest.raises(psycopg3.ProgrammingError): + await conn.execute("wat") + assert conn.pgconn.transaction_status == TransactionStatus.INERROR + await p.putconn(conn) - await p.close() + async with p.connection() as conn2: + assert conn2.pgconn.backend_pid != pid + assert conn2.pgconn.transaction_status == TransactionStatus.IDLE recs = [ r @@ -406,21 +387,18 @@ async def test_close_no_threads(dsn): async def test_putconn_no_pool(dsn): - p = pool.AsyncConnectionPool(dsn, minconn=1) - conn = psycopg3.connect(dsn) - with pytest.raises(ValueError): - await p.putconn(conn) - await p.close() + async with pool.AsyncConnectionPool(dsn, minconn=1) as p: + conn = psycopg3.connect(dsn) + with pytest.raises(ValueError): + await p.putconn(conn) async def test_putconn_wrong_pool(dsn): - p1 = pool.AsyncConnectionPool(dsn, minconn=1) - p2 = pool.AsyncConnectionPool(dsn, minconn=1) - conn = await p1.getconn() - with pytest.raises(ValueError): - await p2.putconn(conn) - await p1.close() - await p2.close() + async with pool.AsyncConnectionPool(dsn, minconn=1) as p1: + async with pool.AsyncConnectionPool(dsn, minconn=1) as p2: + conn = await p1.getconn() + with pytest.raises(ValueError): + await p2.putconn(conn) async def test_del_no_warning(dsn, recwarn): @@ -519,12 +497,11 @@ async def test_grow(dsn, monkeypatch, retries): ts = [create_task(worker(i)) for i in range(6)] await asyncio.gather(*ts) - await p.close() - want_times = [0.2, 0.2, 0.3, 0.3, 0.4, 0.4] - times = [item[1] for item in results] - for got, want in zip(times, want_times): - assert got == pytest.approx(want, 0.1), times + want_times = [0.2, 0.2, 0.3, 0.3, 0.4, 0.4] + times = [item[1] for item in results] + for got, want in zip(times, want_times): + assert got == pytest.approx(want, 0.1), times @pytest.mark.slow @@ -543,19 +520,21 @@ async def test_shrink(dsn, monkeypatch): orig_run = ShrinkPool._run_async monkeypatch.setattr(ShrinkPool, "_run_async", run_async_hacked) - p = pool.AsyncConnectionPool(dsn, minconn=2, maxconn=4, max_idle=0.2) - await p.wait_ready(5.0) - assert p.max_idle == 0.2 - async def worker(n): async with p.connection() as conn: await conn.execute("select pg_sleep(0.1)") - ts = [create_task(worker(i)) for i in range(4)] - await asyncio.gather(*ts) + async with pool.AsyncConnectionPool( + dsn, minconn=2, maxconn=4, max_idle=0.2 + ) as p: + await p.wait_ready(5.0) + assert p.max_idle == 0.2 + + ts = [create_task(worker(i)) for i in range(4)] + await asyncio.gather(*ts) + + await asyncio.sleep(1) - await asyncio.sleep(1) - await p.close() assert results == [(4, 4), (4, 3), (3, 2), (2, 2), (2, 2)] @@ -567,22 +546,20 @@ async def test_reconnect(proxy, caplog, monkeypatch): monkeypatch.setattr(pool.base.ConnectionAttempt, "DELAY_JITTER", 0.0) proxy.start() - p = pool.AsyncConnectionPool(proxy.client_dsn, minconn=1) - await p.wait_ready(2.0) - proxy.stop() + async with pool.AsyncConnectionPool(proxy.client_dsn, minconn=1) as p: + await p.wait_ready(2.0) + proxy.stop() - with pytest.raises(psycopg3.OperationalError): - async with p.connection() as conn: - await conn.execute("select 1") - - await asyncio.sleep(1.0) - proxy.start() - await p.wait_ready() + with pytest.raises(psycopg3.OperationalError): + async with p.connection() as conn: + await conn.execute("select 1") - async with p.connection() as conn: - await conn.execute("select 1") + await asyncio.sleep(1.0) + proxy.start() + await p.wait_ready() - await p.close() + async with p.connection() as conn: + await conn.execute("select 1") recs = [ r @@ -611,54 +588,49 @@ async def test_reconnect_failure(proxy): nonlocal t1 t1 = time() - p = pool.AsyncConnectionPool( + async with pool.AsyncConnectionPool( proxy.client_dsn, name="this-one", minconn=1, reconnect_timeout=1.0, reconnect_failed=failed, - ) - await p.wait_ready(2.0) - proxy.stop() + ) as p: + await p.wait_ready(2.0) + proxy.stop() - with pytest.raises(psycopg3.OperationalError): - async with p.connection() as conn: - await conn.execute("select 1") + with pytest.raises(psycopg3.OperationalError): + async with p.connection() as conn: + await conn.execute("select 1") - t0 = time() - await asyncio.sleep(1.5) - assert t1 - assert t1 - t0 == pytest.approx(1.0, 0.1) - assert p._nconns == 0 + t0 = time() + await asyncio.sleep(1.5) + assert t1 + assert t1 - t0 == pytest.approx(1.0, 0.1) + assert p._nconns == 0 - proxy.start() - t0 = time() - async with p.connection() as conn: - await conn.execute("select 1") - t1 = time() - assert t1 - t0 < 0.2 - await p.close() + proxy.start() + t0 = time() + async with p.connection() as conn: + await conn.execute("select 1") + t1 = time() + assert t1 - t0 < 0.2 @pytest.mark.slow async def test_uniform_use(dsn): - p = pool.AsyncConnectionPool(dsn, minconn=4) - counts = Counter() - for i in range(8): - async with p.connection() as conn: - await asyncio.sleep(0.1) - counts[id(conn)] += 1 + async with pool.AsyncConnectionPool(dsn, minconn=4) as p: + counts = Counter() + for i in range(8): + async with p.connection() as conn: + await asyncio.sleep(0.1) + counts[id(conn)] += 1 - await p.close() assert len(counts) == 4 assert set(counts.values()) == set([2]) @pytest.mark.slow async def test_resize(dsn): - p = pool.AsyncConnectionPool(dsn, minconn=2, max_idle=0.2) - size = [] - async def sampler(): await asyncio.sleep(0.05) # ensure sampling happens after shrink check while True: @@ -671,26 +643,28 @@ async def test_resize(dsn): async with p.connection() as conn: await conn.execute("select pg_sleep(%s)", [t]) - s = create_task(sampler()) + size = [] - await asyncio.sleep(0.3) + async with pool.AsyncConnectionPool(dsn, minconn=2, max_idle=0.2) as p: + s = create_task(sampler()) - c = create_task(client(0.4)) + await asyncio.sleep(0.3) - await asyncio.sleep(0.2) - await p.resize(4) - assert p.minconn == 4 - assert p.maxconn == 4 + c = create_task(client(0.4)) - await asyncio.sleep(0.4) - await p.resize(2) - assert p.minconn == 2 - assert p.maxconn == 2 + await asyncio.sleep(0.2) + await p.resize(4) + assert p.minconn == 4 + assert p.maxconn == 4 - await asyncio.sleep(0.6) - await p.close() - await asyncio.gather(s, c) + await asyncio.sleep(0.4) + await p.resize(2) + assert p.minconn == 2 + assert p.maxconn == 2 + await asyncio.sleep(0.6) + + await asyncio.gather(s, c) assert size == [2, 1, 3, 4, 3, 2, 2]