]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Use more of the pool context manager in the tests
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 27 Feb 2021 22:47:26 +0000 (23:47 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Mar 2021 04:07:25 +0000 (05:07 +0100)
tests/pool/test_pool.py
tests/pool/test_pool_async.py

index 44f093ae2017a4dff175627126c3db1612a0abf9..1356f8e08f2c80e43f26b820970d9b752f5b5c6b 100644 (file)
@@ -12,54 +12,43 @@ from psycopg3.pq import TransactionStatus
 
 
 def test_defaults(dsn):
-    p = pool.ConnectionPool(dsn)
-    assert p.minconn == p.maxconn == 4
-    assert p.timeout == 30
-    assert p.max_idle == 600
-    assert p.num_workers == 3
+    with pool.ConnectionPool(dsn) as p:
+        assert p.minconn == p.maxconn == 4
+        assert p.timeout == 30
+        assert p.max_idle == 600
+        assert p.num_workers == 3
 
 
 def test_minconn_maxconn(dsn):
-    p = pool.ConnectionPool(dsn, minconn=2)
-    assert p.minconn == p.maxconn == 2
+    with pool.ConnectionPool(dsn, minconn=2) as p:
+        assert p.minconn == p.maxconn == 2
 
-    p = pool.ConnectionPool(dsn, minconn=2, maxconn=4)
-    assert p.minconn == 2
-    assert p.maxconn == 4
+    with pool.ConnectionPool(dsn, minconn=2, maxconn=4) as p:
+        assert p.minconn == 2
+        assert p.maxconn == 4
 
     with pytest.raises(ValueError):
         pool.ConnectionPool(dsn, minconn=4, maxconn=2)
 
 
 def test_kwargs(dsn):
-    p = pool.ConnectionPool(dsn, kwargs={"autocommit": True}, minconn=1)
-    with p.connection() as conn:
-        assert conn.autocommit
+    with pool.ConnectionPool(dsn, kwargs={"autocommit": True}, minconn=1) as p:
+        with p.connection() as conn:
+            assert conn.autocommit
 
 
 def test_its_really_a_pool(dsn):
-    p = pool.ConnectionPool(dsn, minconn=2)
-    with p.connection() as conn:
-        with conn.execute("select pg_backend_pid()") as cur:
-            (pid1,) = cur.fetchone()
-
-        with p.connection() as conn2:
-            with conn2.execute("select pg_backend_pid()") as cur:
-                (pid2,) = cur.fetchone()
-
-    with p.connection() as conn:
-        assert conn.pgconn.backend_pid in (pid1, pid2)
+    with pool.ConnectionPool(dsn, minconn=2) as p:
+        with p.connection() as conn:
+            with conn.execute("select pg_backend_pid()") as cur:
+                (pid1,) = cur.fetchone()
 
+            with p.connection() as conn2:
+                with conn2.execute("select pg_backend_pid()") as cur:
+                    (pid2,) = cur.fetchone()
 
-def test_connection_not_lost(dsn):
-    p = pool.ConnectionPool(dsn, minconn=1)
-    with pytest.raises(ZeroDivisionError):
         with p.connection() as conn:
-            pid = conn.pgconn.backend_pid
-            1 / 0
-
-    with p.connection() as conn2:
-        assert conn2.pgconn.backend_pid == pid
+            assert conn.pgconn.backend_pid in (pid1, pid2)
 
 
 def test_context(dsn):
@@ -68,6 +57,17 @@ def test_context(dsn):
     assert p.closed
 
 
+def test_connection_not_lost(dsn):
+    with pool.ConnectionPool(dsn, minconn=1) as p:
+        with pytest.raises(ZeroDivisionError):
+            with p.connection() as conn:
+                pid = conn.pgconn.backend_pid
+                1 / 0
+
+        with p.connection() as conn2:
+            assert conn2.pgconn.backend_pid == pid
+
+
 @pytest.mark.slow
 def test_concurrent_filling(dsn, monkeypatch, retries):
     delay_connection(monkeypatch, 0.1)
@@ -96,31 +96,32 @@ def test_concurrent_filling(dsn, monkeypatch, retries):
 def test_wait_ready(dsn, monkeypatch):
     delay_connection(monkeypatch, 0.1)
     with pytest.raises(pool.PoolTimeout):
-        p = pool.ConnectionPool(dsn, minconn=4, num_workers=1)
-        p.wait_ready(0.3)
+        with pool.ConnectionPool(dsn, minconn=4, num_workers=1) as p:
+            p.wait_ready(0.3)
 
-    p = pool.ConnectionPool(dsn, minconn=4, num_workers=1)
-    p.wait_ready(0.5)
-    p.close()
-    p = pool.ConnectionPool(dsn, minconn=4, num_workers=2)
-    p.wait_ready(0.3)
-    p.wait_ready(0.0001)  # idempotent
-    p.close()
+    with pool.ConnectionPool(dsn, minconn=4, num_workers=1) as p:
+        p.wait_ready(0.5)
+
+    with pool.ConnectionPool(dsn, minconn=4, num_workers=2) as p:
+        p.wait_ready(0.3)
+        p.wait_ready(0.0001)  # idempotent
 
 
 @pytest.mark.slow
 def test_setup_no_timeout(dsn, proxy):
     with pytest.raises(pool.PoolTimeout):
-        p = pool.ConnectionPool(proxy.client_dsn, minconn=1, num_workers=1)
-        p.wait_ready(0.2)
+        with pool.ConnectionPool(
+            proxy.client_dsn, minconn=1, num_workers=1
+        ) as p:
+            p.wait_ready(0.2)
 
-    p = pool.ConnectionPool(proxy.client_dsn, minconn=1, num_workers=1)
-    sleep(0.5)
-    assert not p._pool
-    proxy.start()
+    with pool.ConnectionPool(proxy.client_dsn, minconn=1, num_workers=1) as p:
+        sleep(0.5)
+        assert not p._pool
+        proxy.start()
 
-    with p.connection() as conn:
-        conn.execute("select 1")
+        with p.connection() as conn:
+            conn.execute("select 1")
 
 
 @pytest.mark.slow
@@ -137,30 +138,21 @@ def test_queue(dsn, retries):
     for retry in retries:
         with retry:
             results = []
-            ts = []
             with pool.ConnectionPool(dsn, minconn=2) as p:
-                for i in range(6):
-                    t = Thread(target=worker, args=(i,))
-                    t.start()
-                    ts.append(t)
-
-                for t in ts:
-                    t.join()
+                ts = [Thread(target=worker, args=(i,)) for i in range(6)]
+                [t.start() for t in ts]
+                [t.join() for t in ts]
 
             times = [item[1] for item in results]
             want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6]
             for got, want in zip(times, want_times):
-                assert got == pytest.approx(want, 0.2), times
+                assert got == pytest.approx(want, 0.1), times
 
-            assert len(set(r[2] for r in results)) == 2
+            assert len(set(r[2] for r in results)) == 2, results
 
 
 @pytest.mark.slow
 def test_queue_timeout(dsn):
-    p = pool.ConnectionPool(dsn, minconn=2, timeout=0.1)
-    results = []
-    errors = []
-
     def worker(n):
         t0 = time()
         try:
@@ -175,14 +167,13 @@ def test_queue_timeout(dsn):
             t1 = time()
             results.append((n, t1 - t0, pid))
 
-    ts = []
-    for i in range(4):
-        t = Thread(target=worker, args=(i,))
-        t.start()
-        ts.append(t)
+    results = []
+    errors = []
 
-    for t in ts:
-        t.join()
+    with pool.ConnectionPool(dsn, minconn=2, timeout=0.1) as p:
+        ts = [Thread(target=worker, args=(i,)) for i in range(4)]
+        [t.start() for t in ts]
+        [t.join() for t in ts]
 
     assert len(results) == 2
     assert len(errors) == 2
@@ -192,10 +183,6 @@ def test_queue_timeout(dsn):
 
 @pytest.mark.slow
 def test_dead_client(dsn):
-    p = pool.ConnectionPool(dsn, minconn=2)
-
-    results = []
-
     def worker(i, timeout):
         try:
             with p.connection(timeout=timeout) as conn:
@@ -205,26 +192,22 @@ def test_dead_client(dsn):
             if timeout > 0.2:
                 raise
 
-    ts = []
-    for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4]):
-        t = Thread(target=worker, args=(i, timeout))
-        t.start()
-        ts.append(t)
-
-    for t in ts:
-        t.join()
+    results = []
 
-    sleep(0.2)
-    assert set(results) == set([0, 1, 3, 4])
-    assert len(p._pool) == 2  # no connection was lost
+    with pool.ConnectionPool(dsn, minconn=2) as p:
+        ts = [
+            Thread(target=worker, args=(i, timeout))
+            for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4])
+        ]
+        [t.start() for t in ts]
+        [t.join() for t in ts]
+        sleep(0.2)
+        assert set(results) == set([0, 1, 3, 4])
+        assert len(p._pool) == 2  # no connection was lost
 
 
 @pytest.mark.slow
 def test_queue_timeout_override(dsn):
-    p = pool.ConnectionPool(dsn, minconn=2, timeout=0.1)
-    results = []
-    errors = []
-
     def worker(n):
         t0 = time()
         timeout = 0.25 if n == 3 else None
@@ -240,14 +223,13 @@ def test_queue_timeout_override(dsn):
             t1 = time()
             results.append((n, t1 - t0, pid))
 
-    ts = []
-    for i in range(4):
-        t = Thread(target=worker, args=(i,))
-        t.start()
-        ts.append(t)
+    results = []
+    errors = []
 
-    for t in ts:
-        t.join()
+    with pool.ConnectionPool(dsn, minconn=2, timeout=0.1) as p:
+        ts = [Thread(target=worker, args=(i,)) for i in range(4)]
+        [t.start() for t in ts]
+        [t.join() for t in ts]
 
     assert len(results) == 3
     assert len(errors) == 1
@@ -255,37 +237,37 @@ def test_queue_timeout_override(dsn):
         assert 0.1 < e[1] < 0.15
 
 
-def test_broken_reconnect(dsn, caplog):
-    caplog.set_level(logging.WARNING, logger="psycopg3.pool")
-    p = pool.ConnectionPool(dsn, minconn=1)
-    with pytest.raises(psycopg3.OperationalError):
-        with p.connection() as conn:
-            with conn.execute("select pg_backend_pid()") as cur:
-                (pid1,) = cur.fetchone()
-            conn.close()
+def test_broken_reconnect(dsn):
+    with pool.ConnectionPool(dsn, minconn=1) as p:
+        with pytest.raises(psycopg3.OperationalError):
+            with p.connection() as conn:
+                with conn.execute("select pg_backend_pid()") as cur:
+                    (pid1,) = cur.fetchone()
+                conn.close()
 
-    with p.connection() as conn2:
-        with conn2.execute("select pg_backend_pid()") as cur:
-            (pid2,) = cur.fetchone()
+        with p.connection() as conn2:
+            with conn2.execute("select pg_backend_pid()") as cur:
+                (pid2,) = cur.fetchone()
 
     assert pid1 != pid2
 
 
 def test_intrans_rollback(dsn, caplog):
     caplog.set_level(logging.WARNING, logger="psycopg3.pool")
-    p = pool.ConnectionPool(dsn, minconn=1)
-    conn = p.getconn()
-    pid = conn.pgconn.backend_pid
-    conn.execute("create table test_intrans_rollback ()")
-    assert conn.pgconn.transaction_status == TransactionStatus.INTRANS
-    p.putconn(conn)
-
-    with p.connection() as conn2:
-        assert conn2.pgconn.backend_pid == pid
-        assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
-        assert not conn.execute(
-            "select 1 from pg_class where relname = 'test_intrans_rollback'"
-        ).fetchone()
+
+    with pool.ConnectionPool(dsn, minconn=1) as p:
+        conn = p.getconn()
+        pid = conn.pgconn.backend_pid
+        conn.execute("create table test_intrans_rollback ()")
+        assert conn.pgconn.transaction_status == TransactionStatus.INTRANS
+        p.putconn(conn)
+
+        with p.connection() as conn2:
+            assert conn2.pgconn.backend_pid == pid
+            assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
+            assert not conn.execute(
+                "select 1 from pg_class where relname = 'test_intrans_rollback'"
+            ).fetchone()
 
     assert len(caplog.records) == 1
     assert "INTRANS" in caplog.records[0].message
@@ -293,17 +275,18 @@ def test_intrans_rollback(dsn, caplog):
 
 def test_inerror_rollback(dsn, caplog):
     caplog.set_level(logging.WARNING, logger="psycopg3.pool")
-    p = pool.ConnectionPool(dsn, minconn=1)
-    conn = p.getconn()
-    pid = conn.pgconn.backend_pid
-    with pytest.raises(psycopg3.ProgrammingError):
-        conn.execute("wat")
-    assert conn.pgconn.transaction_status == TransactionStatus.INERROR
-    p.putconn(conn)
 
-    with p.connection() as conn2:
-        assert conn2.pgconn.backend_pid == pid
-        assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
+    with pool.ConnectionPool(dsn, minconn=1) as p:
+        conn = p.getconn()
+        pid = conn.pgconn.backend_pid
+        with pytest.raises(psycopg3.ProgrammingError):
+            conn.execute("wat")
+        assert conn.pgconn.transaction_status == TransactionStatus.INERROR
+        p.putconn(conn)
+
+        with p.connection() as conn2:
+            assert conn2.pgconn.backend_pid == pid
+            assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
 
     assert len(caplog.records) == 1
     assert "INERROR" in caplog.records[0].message
@@ -311,18 +294,19 @@ def test_inerror_rollback(dsn, caplog):
 
 def test_active_close(dsn, caplog):
     caplog.set_level(logging.WARNING, logger="psycopg3.pool")
-    p = pool.ConnectionPool(dsn, minconn=1)
-    conn = p.getconn()
-    pid = conn.pgconn.backend_pid
-    cur = conn.cursor()
-    with cur.copy("copy (select * from generate_series(1, 10)) to stdout"):
-        pass
-    assert conn.pgconn.transaction_status == TransactionStatus.ACTIVE
-    p.putconn(conn)
 
-    with p.connection() as conn2:
-        assert conn2.pgconn.backend_pid != pid
-        assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
+    with pool.ConnectionPool(dsn, minconn=1) as p:
+        conn = p.getconn()
+        pid = conn.pgconn.backend_pid
+        cur = conn.cursor()
+        with cur.copy("copy (select * from generate_series(1, 10)) to stdout"):
+            pass
+        assert conn.pgconn.transaction_status == TransactionStatus.ACTIVE
+        p.putconn(conn)
+
+        with p.connection() as conn2:
+            assert conn2.pgconn.backend_pid != pid
+            assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
 
     assert len(caplog.records) == 2
     assert "ACTIVE" in caplog.records[0].message
@@ -331,26 +315,27 @@ def test_active_close(dsn, caplog):
 
 def test_fail_rollback_close(dsn, caplog, monkeypatch):
     caplog.set_level(logging.WARNING, logger="psycopg3.pool")
-    p = pool.ConnectionPool(dsn, minconn=1)
-    conn = p.getconn()
 
-    def bad_rollback():
-        conn.pgconn.finish()
-        orig_rollback()
+    with pool.ConnectionPool(dsn, minconn=1) as p:
+        conn = p.getconn()
+
+        def bad_rollback():
+            conn.pgconn.finish()
+            orig_rollback()
 
-    # Make the rollback fail
-    orig_rollback = conn.rollback
-    monkeypatch.setattr(conn, "rollback", bad_rollback)
+        # Make the rollback fail
+        orig_rollback = conn.rollback
+        monkeypatch.setattr(conn, "rollback", bad_rollback)
 
-    pid = conn.pgconn.backend_pid
-    with pytest.raises(psycopg3.ProgrammingError):
-        conn.execute("wat")
-    assert conn.pgconn.transaction_status == TransactionStatus.INERROR
-    p.putconn(conn)
+        pid = conn.pgconn.backend_pid
+        with pytest.raises(psycopg3.ProgrammingError):
+            conn.execute("wat")
+        assert conn.pgconn.transaction_status == TransactionStatus.INERROR
+        p.putconn(conn)
 
-    with p.connection() as conn2:
-        assert conn2.pgconn.backend_pid != pid
-        assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
+        with p.connection() as conn2:
+            assert conn2.pgconn.backend_pid != pid
+            assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
 
     assert len(caplog.records) == 3
     assert "INERROR" in caplog.records[0].message
@@ -371,18 +356,18 @@ def test_close_no_threads(dsn):
 
 
 def test_putconn_no_pool(dsn):
-    p = pool.ConnectionPool(dsn, minconn=1)
-    conn = psycopg3.connect(dsn)
-    with pytest.raises(ValueError):
-        p.putconn(conn)
+    with pool.ConnectionPool(dsn, minconn=1) as p:
+        conn = psycopg3.connect(dsn)
+        with pytest.raises(ValueError):
+            p.putconn(conn)
 
 
 def test_putconn_wrong_pool(dsn):
-    p1 = pool.ConnectionPool(dsn, minconn=1)
-    p2 = pool.ConnectionPool(dsn, minconn=1)
-    conn = p1.getconn()
-    with pytest.raises(ValueError):
-        p2.putconn(conn)
+    with pool.ConnectionPool(dsn, minconn=1) as p1:
+        with pool.ConnectionPool(dsn, minconn=1) as p2:
+            conn = p1.getconn()
+            with pytest.raises(ValueError):
+                p2.putconn(conn)
 
 
 def test_del_no_warning(dsn, recwarn):
@@ -402,7 +387,7 @@ def test_del_stop_threads(dsn):
     p = pool.ConnectionPool(dsn)
     ts = [p._sched_runner] + p._workers
     del p
-    sleep(0.2)
+    sleep(0.1)
     for t in ts:
         assert not t.is_alive()
 
@@ -479,21 +464,16 @@ def test_grow(dsn, monkeypatch, retries):
                 dsn, minconn=2, maxconn=4, num_workers=3
             ) as p:
                 p.wait_ready(1.0)
-                ts = []
                 results = []
 
-                for i in range(6):
-                    t = Thread(target=worker, args=(i,))
-                    t.start()
-                    ts.append(t)
-
-                for t in ts:
-                    t.join()
+                ts = [Thread(target=worker, args=(i,)) for i in range(6)]
+                [t.start() for t in ts]
+                [t.join() for t in ts]
 
-                want_times = [0.2, 0.2, 0.3, 0.3, 0.4, 0.4]
-                times = [item[1] for item in results]
-                for got, want in zip(times, want_times):
-                    assert got == pytest.approx(want, 0.1), times
+            want_times = [0.2, 0.2, 0.3, 0.3, 0.4, 0.4]
+            times = [item[1] for item in results]
+            for got, want in zip(times, want_times):
+                assert got == pytest.approx(want, 0.1), times
 
 
 @pytest.mark.slow
@@ -512,24 +492,19 @@ def test_shrink(dsn, monkeypatch):
     orig_run = ShrinkPool._run
     monkeypatch.setattr(ShrinkPool, "_run", run_hacked)
 
-    p = pool.ConnectionPool(dsn, minconn=2, maxconn=4, max_idle=0.2)
-    p.wait_ready(5.0)
-    assert p.max_idle == 0.2
-
     def worker(n):
         with p.connection() as conn:
             conn.execute("select pg_sleep(0.1)")
 
-    ts = []
-    for i in range(4):
-        t = Thread(target=worker, args=(i,))
-        t.start()
-        ts.append(t)
+    with pool.ConnectionPool(dsn, minconn=2, maxconn=4, max_idle=0.2) as p:
+        p.wait_ready(5.0)
+        assert p.max_idle == 0.2
 
-    for t in ts:
-        t.join()
+        ts = [Thread(target=worker, args=(i,)) for i in range(4)]
+        [t.start() for t in ts]
+        [t.join() for t in ts]
+        sleep(1)
 
-    sleep(1)
     assert results == [(4, 4), (4, 3), (3, 2), (2, 2), (2, 2)]
 
 
@@ -543,20 +518,20 @@ def test_reconnect(proxy, caplog, monkeypatch):
     monkeypatch.setattr(pool.base.ConnectionAttempt, "DELAY_JITTER", 0.0)
 
     proxy.start()
-    p = pool.ConnectionPool(proxy.client_dsn, minconn=1)
-    p.wait_ready(2.0)
-    proxy.stop()
+    with pool.ConnectionPool(proxy.client_dsn, minconn=1) as p:
+        p.wait_ready(2.0)
+        proxy.stop()
 
-    with pytest.raises(psycopg3.OperationalError):
-        with p.connection() as conn:
-            conn.execute("select 1")
+        with pytest.raises(psycopg3.OperationalError):
+            with p.connection() as conn:
+                conn.execute("select 1")
 
-    sleep(1.0)
-    proxy.start()
-    p.wait_ready()
+        sleep(1.0)
+        proxy.start()
+        p.wait_ready()
 
-    with p.connection() as conn:
-        conn.execute("select 1")
+        with p.connection() as conn:
+            conn.execute("select 1")
 
     assert "BAD" in caplog.messages[0]
     times = [rec.created for rec in caplog.records]
@@ -580,42 +555,42 @@ def test_reconnect_failure(proxy):
         nonlocal t1
         t1 = time()
 
-    p = pool.ConnectionPool(
+    with pool.ConnectionPool(
         proxy.client_dsn,
         name="this-one",
         minconn=1,
         reconnect_timeout=1.0,
         reconnect_failed=failed,
-    )
-    p.wait_ready(2.0)
-    proxy.stop()
+    ) as p:
+        p.wait_ready(2.0)
+        proxy.stop()
 
-    with pytest.raises(psycopg3.OperationalError):
-        with p.connection() as conn:
-            conn.execute("select 1")
+        with pytest.raises(psycopg3.OperationalError):
+            with p.connection() as conn:
+                conn.execute("select 1")
 
-    t0 = time()
-    sleep(1.5)
-    assert t1
-    assert t1 - t0 == pytest.approx(1.0, 0.1)
-    assert p._nconns == 0
+        t0 = time()
+        sleep(1.5)
+        assert t1
+        assert t1 - t0 == pytest.approx(1.0, 0.1)
+        assert p._nconns == 0
 
-    proxy.start()
-    t0 = time()
-    with p.connection() as conn:
-        conn.execute("select 1")
-    t1 = time()
-    assert t1 - t0 < 0.2
+        proxy.start()
+        t0 = time()
+        with p.connection() as conn:
+            conn.execute("select 1")
+        t1 = time()
+        assert t1 - t0 < 0.2
 
 
 @pytest.mark.slow
 def test_uniform_use(dsn):
-    p = pool.ConnectionPool(dsn, minconn=4)
-    counts = Counter()
-    for i in range(8):
-        with p.connection() as conn:
-            sleep(0.1)
-            counts[id(conn)] += 1
+    with pool.ConnectionPool(dsn, minconn=4) as p:
+        counts = Counter()
+        for i in range(8):
+            with p.connection() as conn:
+                sleep(0.1)
+                counts[id(conn)] += 1
 
     assert len(counts) == 4
     assert set(counts.values()) == set([2])
@@ -623,9 +598,6 @@ def test_uniform_use(dsn):
 
 @pytest.mark.slow
 def test_resize(dsn):
-    p = pool.ConnectionPool(dsn, minconn=2, max_idle=0.2)
-    size = []
-
     def sampler():
         sleep(0.05)  # ensure sampling happens after shrink check
         while True:
@@ -638,28 +610,29 @@ def test_resize(dsn):
         with p.connection() as conn:
             conn.execute("select pg_sleep(%s)", [t])
 
-    s = Thread(target=sampler)
-    s.start()
+    size = []
 
-    sleep(0.3)
+    with pool.ConnectionPool(dsn, minconn=2, max_idle=0.2) as p:
+        s = Thread(target=sampler)
+        s.start()
 
-    c = Thread(target=client, args=(0.4,))
-    c.start()
+        sleep(0.3)
+        c = Thread(target=client, args=(0.4,))
+        c.start()
 
-    sleep(0.2)
-    p.resize(4)
-    assert p.minconn == 4
-    assert p.maxconn == 4
+        sleep(0.2)
+        p.resize(4)
+        assert p.minconn == 4
+        assert p.maxconn == 4
 
-    sleep(0.4)
-    p.resize(2)
-    assert p.minconn == 2
-    assert p.maxconn == 2
+        sleep(0.4)
+        p.resize(2)
+        assert p.minconn == 2
+        assert p.maxconn == 2
 
-    sleep(0.6)
-    p.close()
-    s.join()
+        sleep(0.6)
 
+    s.join()
     assert size == [2, 1, 3, 4, 3, 2, 2]
 
 
index f4c282ff57229d6cd7f4ca6a507eced09545942f..0425947d3dec09cdb69189b52f8c35854ccd9780 100644 (file)
@@ -21,50 +21,45 @@ pytestmark = pytest.mark.asyncio
 
 
 async def test_defaults(dsn):
-    p = pool.AsyncConnectionPool(dsn)
-    assert p.minconn == p.maxconn == 4
-    assert p.timeout == 30
-    assert p.max_idle == 600
-    assert p.num_workers == 3
-    await p.close()
+    async with pool.AsyncConnectionPool(dsn) as p:
+        assert p.minconn == p.maxconn == 4
+        assert p.timeout == 30
+        assert p.max_idle == 600
+        assert p.num_workers == 3
 
 
 async def test_minconn_maxconn(dsn):
-    p = pool.AsyncConnectionPool(dsn, minconn=2)
-    assert p.minconn == p.maxconn == 2
-    await p.close()
+    async with pool.AsyncConnectionPool(dsn, minconn=2) as p:
+        assert p.minconn == p.maxconn == 2
 
-    p = pool.AsyncConnectionPool(dsn, minconn=2, maxconn=4)
-    assert p.minconn == 2
-    assert p.maxconn == 4
-    await p.close()
+    async with pool.AsyncConnectionPool(dsn, minconn=2, maxconn=4) as p:
+        assert p.minconn == 2
+        assert p.maxconn == 4
 
     with pytest.raises(ValueError):
         pool.AsyncConnectionPool(dsn, minconn=4, maxconn=2)
 
 
 async def test_kwargs(dsn):
-    p = pool.AsyncConnectionPool(dsn, kwargs={"autocommit": True}, minconn=1)
-    async with p.connection() as conn:
-        assert conn.autocommit
-
-    await p.close()
+    async with pool.AsyncConnectionPool(
+        dsn, kwargs={"autocommit": True}, minconn=1
+    ) as p:
+        async with p.connection() as conn:
+            assert conn.autocommit
 
 
 async def test_its_really_a_pool(dsn):
-    p = pool.AsyncConnectionPool(dsn, minconn=2)
-    async with p.connection() as conn:
-        cur = await conn.execute("select pg_backend_pid()")
-        (pid1,) = await cur.fetchone()
-
-        async with p.connection() as conn2:
-            cur = await conn2.execute("select pg_backend_pid()")
-            (pid2,) = await cur.fetchone()
+    async with pool.AsyncConnectionPool(dsn, minconn=2) as p:
+        async with p.connection() as conn:
+            cur = await conn.execute("select pg_backend_pid()")
+            (pid1,) = await cur.fetchone()
 
-    async with p.connection() as conn:
-        assert conn.pgconn.backend_pid in (pid1, pid2)
+            async with p.connection() as conn2:
+                cur = await conn2.execute("select pg_backend_pid()")
+                (pid2,) = await cur.fetchone()
 
-    await p.close()
+        async with p.connection() as conn:
+            assert conn.pgconn.backend_pid in (pid1, pid2)
 
 
 async def test_context(dsn):
@@ -74,16 +69,14 @@ async def test_context(dsn):
 
 
 async def test_connection_not_lost(dsn):
-    p = pool.AsyncConnectionPool(dsn, minconn=1)
-    with pytest.raises(ZeroDivisionError):
-        async with p.connection() as conn:
-            pid = conn.pgconn.backend_pid
-            1 / 0
-
-    async with p.connection() as conn2:
-        assert conn2.pgconn.backend_pid == pid
+    async with pool.AsyncConnectionPool(dsn, minconn=1) as p:
+        with pytest.raises(ZeroDivisionError):
+            async with p.connection() as conn:
+                pid = conn.pgconn.backend_pid
+                1 / 0
 
-    await p.close()
+        async with p.connection() as conn2:
+            assert conn2.pgconn.backend_pid == pid
 
 
 @pytest.mark.slow
@@ -116,35 +109,36 @@ async def test_concurrent_filling(dsn, monkeypatch, retries):
 async def test_wait_ready(dsn, monkeypatch):
     delay_connection(monkeypatch, 0.1)
     with pytest.raises(pool.PoolTimeout):
-        p = pool.AsyncConnectionPool(dsn, minconn=4, num_workers=1)
-        await p.wait_ready(0.3)
+        async with pool.AsyncConnectionPool(
+            dsn, minconn=4, num_workers=1
+        ) as p:
+            await p.wait_ready(0.3)
 
-    p = pool.AsyncConnectionPool(dsn, minconn=4, num_workers=1)
-    await p.wait_ready(0.5)
-    await p.close()
-    p = pool.AsyncConnectionPool(dsn, minconn=4, num_workers=2)
-    await p.wait_ready(0.3)
-    await p.wait_ready(0.0001)  # idempotent
-    await p.close()
+    async with pool.AsyncConnectionPool(dsn, minconn=4, num_workers=1) as p:
+        await p.wait_ready(0.5)
+
+    async with pool.AsyncConnectionPool(dsn, minconn=4, num_workers=2) as p:
+        await p.wait_ready(0.3)
+        await p.wait_ready(0.0001)  # idempotent
 
 
 @pytest.mark.slow
 async def test_setup_no_timeout(dsn, proxy):
     with pytest.raises(pool.PoolTimeout):
-        p = pool.AsyncConnectionPool(
+        async with pool.AsyncConnectionPool(
             proxy.client_dsn, minconn=1, num_workers=1
-        )
-        await p.wait_ready(0.2)
+        ) as p:
+            await p.wait_ready(0.2)
 
-    p = pool.AsyncConnectionPool(proxy.client_dsn, minconn=1, num_workers=1)
-    await asyncio.sleep(0.5)
-    assert not p._pool
-    proxy.start()
-
-    async with p.connection() as conn:
-        await conn.execute("select 1")
+    async with pool.AsyncConnectionPool(
+        proxy.client_dsn, minconn=1, num_workers=1
+    ) as p:
+        await asyncio.sleep(0.5)
+        assert not p._pool
+        proxy.start()
 
-    await p.close()
+        async with p.connection() as conn:
+            await conn.execute("select 1")
 
 
 @pytest.mark.slow
@@ -176,10 +170,6 @@ async def test_queue(dsn, retries):
 
 @pytest.mark.slow
 async def test_queue_timeout(dsn):
-    p = pool.AsyncConnectionPool(dsn, minconn=2, timeout=0.1)
-    results = []
-    errors = []
-
     async def worker(n):
         t0 = time()
         try:
@@ -195,23 +185,21 @@ async def test_queue_timeout(dsn):
             t1 = time()
             results.append((n, t1 - t0, pid))
 
-    ts = [create_task(worker(i)) for i in range(4)]
-    await asyncio.gather(*ts)
+    results = []
+    errors = []
+
+    async with pool.AsyncConnectionPool(dsn, minconn=2, timeout=0.1) as p:
+        ts = [create_task(worker(i)) for i in range(4)]
+        await asyncio.gather(*ts)
 
     assert len(results) == 2
     assert len(errors) == 2
     for e in errors:
         assert 0.1 < e[1] < 0.15
 
-    await p.close()
-
 
 @pytest.mark.slow
 async def test_dead_client(dsn):
-    p = pool.AsyncConnectionPool(dsn, minconn=2)
-
-    results = []
-
     async def worker(i, timeout):
         try:
             async with p.connection(timeout=timeout) as conn:
@@ -221,24 +209,21 @@ async def test_dead_client(dsn):
             if timeout > 0.2:
                 raise
 
-    ts = [
-        create_task(worker(i, timeout))
-        for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4])
-    ]
-    await asyncio.gather(*ts)
+    async with pool.AsyncConnectionPool(dsn, minconn=2) as p:
+        results = []
+        ts = [
+            create_task(worker(i, timeout))
+            for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4])
+        ]
+        await asyncio.gather(*ts)
 
-    await asyncio.sleep(0.2)
-    assert set(results) == set([0, 1, 3, 4])
-    assert len(p._pool) == 2  # no connection was lost
-    await p.close()
+        await asyncio.sleep(0.2)
+        assert set(results) == set([0, 1, 3, 4])
+        assert len(p._pool) == 2  # no connection was lost
 
 
 @pytest.mark.slow
 async def test_queue_timeout_override(dsn):
-    p = pool.AsyncConnectionPool(dsn, minconn=2, timeout=0.1)
-    results = []
-    errors = []
-
     async def worker(n):
         t0 = time()
         timeout = 0.25 if n == 3 else None
@@ -255,9 +240,12 @@ async def test_queue_timeout_override(dsn):
             t1 = time()
             results.append((n, t1 - t0, pid))
 
-    ts = [create_task(worker(i)) for i in range(4)]
-    await asyncio.gather(*ts)
-    await p.close()
+    results = []
+    errors = []
+
+    async with pool.AsyncConnectionPool(dsn, minconn=2, timeout=0.1) as p:
+        ts = [create_task(worker(i)) for i in range(4)]
+        await asyncio.gather(*ts)
 
     assert len(results) == 3
     assert len(errors) == 1
@@ -266,38 +254,36 @@ async def test_queue_timeout_override(dsn):
 
 
 async def test_broken_reconnect(dsn):
-    p = pool.AsyncConnectionPool(dsn, minconn=1)
-    with pytest.raises(psycopg3.OperationalError):
-        async with p.connection() as conn:
-            cur = await conn.execute("select pg_backend_pid()")
-            (pid1,) = await cur.fetchone()
-            await conn.close()
+    async with pool.AsyncConnectionPool(dsn, minconn=1) as p:
+        with pytest.raises(psycopg3.OperationalError):
+            async with p.connection() as conn:
+                cur = await conn.execute("select pg_backend_pid()")
+                (pid1,) = await cur.fetchone()
+                await conn.close()
 
-    async with p.connection() as conn2:
-        cur = await conn2.execute("select pg_backend_pid()")
-        (pid2,) = await cur.fetchone()
+        async with p.connection() as conn2:
+            cur = await conn2.execute("select pg_backend_pid()")
+            (pid2,) = await cur.fetchone()
 
-    await p.close()
     assert pid1 != pid2
 
 
 async def test_intrans_rollback(dsn, caplog):
-    p = pool.AsyncConnectionPool(dsn, minconn=1)
-    conn = await p.getconn()
-    pid = conn.pgconn.backend_pid
-    await conn.execute("create table test_intrans_rollback ()")
-    assert conn.pgconn.transaction_status == TransactionStatus.INTRANS
-    await p.putconn(conn)
-
-    async with p.connection() as conn2:
-        assert conn2.pgconn.backend_pid == pid
-        assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
-        cur = await conn.execute(
-            "select 1 from pg_class where relname = 'test_intrans_rollback'"
-        )
-        assert not await cur.fetchone()
+    async with pool.AsyncConnectionPool(dsn, minconn=1) as p:
+        conn = await p.getconn()
+        pid = conn.pgconn.backend_pid
+        await conn.execute("create table test_intrans_rollback ()")
+        assert conn.pgconn.transaction_status == TransactionStatus.INTRANS
+        await p.putconn(conn)
+
+        async with p.connection() as conn2:
+            assert conn2.pgconn.backend_pid == pid
+            assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
+            cur = await conn.execute(
+                "select 1 from pg_class where relname = 'test_intrans_rollback'"
+            )
+            assert not await cur.fetchone()
 
-    await p.close()
     recs = [
         r
         for r in caplog.records
@@ -308,17 +294,17 @@ async def test_intrans_rollback(dsn, caplog):
 
 
 async def test_inerror_rollback(dsn, caplog):
-    p = pool.AsyncConnectionPool(dsn, minconn=1)
-    conn = await p.getconn()
-    pid = conn.pgconn.backend_pid
-    with pytest.raises(psycopg3.ProgrammingError):
-        await conn.execute("wat")
-    assert conn.pgconn.transaction_status == TransactionStatus.INERROR
-    await p.putconn(conn)
+    async with pool.AsyncConnectionPool(dsn, minconn=1) as p:
+        conn = await p.getconn()
+        pid = conn.pgconn.backend_pid
+        with pytest.raises(psycopg3.ProgrammingError):
+            await conn.execute("wat")
+        assert conn.pgconn.transaction_status == TransactionStatus.INERROR
+        await p.putconn(conn)
 
-    async with p.connection() as conn2:
-        assert conn2.pgconn.backend_pid == pid
-        assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
+        async with p.connection() as conn2:
+            assert conn2.pgconn.backend_pid == pid
+            assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
 
     recs = [
         r
@@ -328,26 +314,23 @@ async def test_inerror_rollback(dsn, caplog):
     assert len(recs) == 1
     assert "INERROR" in recs[0].message
 
-    await p.close()
-
 
 async def test_active_close(dsn, caplog):
-    p = pool.AsyncConnectionPool(dsn, minconn=1)
-    conn = await p.getconn()
-    pid = conn.pgconn.backend_pid
-    cur = conn.cursor()
-    async with cur.copy(
-        "copy (select * from generate_series(1, 10)) to stdout"
-    ):
-        pass
-    assert conn.pgconn.transaction_status == TransactionStatus.ACTIVE
-    await p.putconn(conn)
+    async with pool.AsyncConnectionPool(dsn, minconn=1) as p:
+        conn = await p.getconn()
+        pid = conn.pgconn.backend_pid
+        cur = conn.cursor()
+        async with cur.copy(
+            "copy (select * from generate_series(1, 10)) to stdout"
+        ):
+            pass
+        assert conn.pgconn.transaction_status == TransactionStatus.ACTIVE
+        await p.putconn(conn)
 
-    async with p.connection() as conn2:
-        assert conn2.pgconn.backend_pid != pid
-        assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
+        async with p.connection() as conn2:
+            assert conn2.pgconn.backend_pid != pid
+            assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
 
-    await p.close()
     recs = [
         r
         for r in caplog.records
@@ -359,28 +342,26 @@ async def test_active_close(dsn, caplog):
 
 
 async def test_fail_rollback_close(dsn, caplog, monkeypatch):
-    p = pool.AsyncConnectionPool(dsn, minconn=1)
-    conn = await p.getconn()
-
-    async def bad_rollback():
-        conn.pgconn.finish()
-        await orig_rollback()
+    async with pool.AsyncConnectionPool(dsn, minconn=1) as p:
+        conn = await p.getconn()
 
-    # Make the rollback fail
-    orig_rollback = conn.rollback
-    monkeypatch.setattr(conn, "rollback", bad_rollback)
+        async def bad_rollback():
+            conn.pgconn.finish()
+            await orig_rollback()
 
-    pid = conn.pgconn.backend_pid
-    with pytest.raises(psycopg3.ProgrammingError):
-        await conn.execute("wat")
-    assert conn.pgconn.transaction_status == TransactionStatus.INERROR
-    await p.putconn(conn)
+        # Make the rollback fail
+        orig_rollback = conn.rollback
+        monkeypatch.setattr(conn, "rollback", bad_rollback)
 
-    async with p.connection() as conn2:
-        assert conn2.pgconn.backend_pid != pid
-        assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
+        pid = conn.pgconn.backend_pid
+        with pytest.raises(psycopg3.ProgrammingError):
+            await conn.execute("wat")
+        assert conn.pgconn.transaction_status == TransactionStatus.INERROR
+        await p.putconn(conn)
 
-    await p.close()
+        async with p.connection() as conn2:
+            assert conn2.pgconn.backend_pid != pid
+            assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
 
     recs = [
         r
@@ -406,21 +387,18 @@ async def test_close_no_threads(dsn):
 
 
 async def test_putconn_no_pool(dsn):
-    p = pool.AsyncConnectionPool(dsn, minconn=1)
-    conn = psycopg3.connect(dsn)
-    with pytest.raises(ValueError):
-        await p.putconn(conn)
-    await p.close()
+    async with pool.AsyncConnectionPool(dsn, minconn=1) as p:
+        conn = psycopg3.connect(dsn)
+        with pytest.raises(ValueError):
+            await p.putconn(conn)
 
 
 async def test_putconn_wrong_pool(dsn):
-    p1 = pool.AsyncConnectionPool(dsn, minconn=1)
-    p2 = pool.AsyncConnectionPool(dsn, minconn=1)
-    conn = await p1.getconn()
-    with pytest.raises(ValueError):
-        await p2.putconn(conn)
-    await p1.close()
-    await p2.close()
+    async with pool.AsyncConnectionPool(dsn, minconn=1) as p1:
+        async with pool.AsyncConnectionPool(dsn, minconn=1) as p2:
+            conn = await p1.getconn()
+            with pytest.raises(ValueError):
+                await p2.putconn(conn)
 
 
 async def test_del_no_warning(dsn, recwarn):
@@ -519,12 +497,11 @@ async def test_grow(dsn, monkeypatch, retries):
 
                 ts = [create_task(worker(i)) for i in range(6)]
                 await asyncio.gather(*ts)
-                await p.close()
 
-                want_times = [0.2, 0.2, 0.3, 0.3, 0.4, 0.4]
-                times = [item[1] for item in results]
-                for got, want in zip(times, want_times):
-                    assert got == pytest.approx(want, 0.1), times
+            want_times = [0.2, 0.2, 0.3, 0.3, 0.4, 0.4]
+            times = [item[1] for item in results]
+            for got, want in zip(times, want_times):
+                assert got == pytest.approx(want, 0.1), times
 
 
 @pytest.mark.slow
@@ -543,19 +520,21 @@ async def test_shrink(dsn, monkeypatch):
     orig_run = ShrinkPool._run_async
     monkeypatch.setattr(ShrinkPool, "_run_async", run_async_hacked)
 
-    p = pool.AsyncConnectionPool(dsn, minconn=2, maxconn=4, max_idle=0.2)
-    await p.wait_ready(5.0)
-    assert p.max_idle == 0.2
-
     async def worker(n):
         async with p.connection() as conn:
             await conn.execute("select pg_sleep(0.1)")
 
-    ts = [create_task(worker(i)) for i in range(4)]
-    await asyncio.gather(*ts)
+    async with pool.AsyncConnectionPool(
+        dsn, minconn=2, maxconn=4, max_idle=0.2
+    ) as p:
+        await p.wait_ready(5.0)
+        assert p.max_idle == 0.2
+
+        ts = [create_task(worker(i)) for i in range(4)]
+        await asyncio.gather(*ts)
+
+        await asyncio.sleep(1)
 
-    await asyncio.sleep(1)
-    await p.close()
     assert results == [(4, 4), (4, 3), (3, 2), (2, 2), (2, 2)]
 
 
@@ -567,22 +546,20 @@ async def test_reconnect(proxy, caplog, monkeypatch):
     monkeypatch.setattr(pool.base.ConnectionAttempt, "DELAY_JITTER", 0.0)
 
     proxy.start()
-    p = pool.AsyncConnectionPool(proxy.client_dsn, minconn=1)
-    await p.wait_ready(2.0)
-    proxy.stop()
+    async with pool.AsyncConnectionPool(proxy.client_dsn, minconn=1) as p:
+        await p.wait_ready(2.0)
+        proxy.stop()
 
-    with pytest.raises(psycopg3.OperationalError):
-        async with p.connection() as conn:
-            await conn.execute("select 1")
-
-    await asyncio.sleep(1.0)
-    proxy.start()
-    await p.wait_ready()
+        with pytest.raises(psycopg3.OperationalError):
+            async with p.connection() as conn:
+                await conn.execute("select 1")
 
-    async with p.connection() as conn:
-        await conn.execute("select 1")
+        await asyncio.sleep(1.0)
+        proxy.start()
+        await p.wait_ready()
 
-    await p.close()
+        async with p.connection() as conn:
+            await conn.execute("select 1")
 
     recs = [
         r
@@ -611,54 +588,49 @@ async def test_reconnect_failure(proxy):
         nonlocal t1
         t1 = time()
 
-    p = pool.AsyncConnectionPool(
+    async with pool.AsyncConnectionPool(
         proxy.client_dsn,
         name="this-one",
         minconn=1,
         reconnect_timeout=1.0,
         reconnect_failed=failed,
-    )
-    await p.wait_ready(2.0)
-    proxy.stop()
+    ) as p:
+        await p.wait_ready(2.0)
+        proxy.stop()
 
-    with pytest.raises(psycopg3.OperationalError):
-        async with p.connection() as conn:
-            await conn.execute("select 1")
+        with pytest.raises(psycopg3.OperationalError):
+            async with p.connection() as conn:
+                await conn.execute("select 1")
 
-    t0 = time()
-    await asyncio.sleep(1.5)
-    assert t1
-    assert t1 - t0 == pytest.approx(1.0, 0.1)
-    assert p._nconns == 0
+        t0 = time()
+        await asyncio.sleep(1.5)
+        assert t1
+        assert t1 - t0 == pytest.approx(1.0, 0.1)
+        assert p._nconns == 0
 
-    proxy.start()
-    t0 = time()
-    async with p.connection() as conn:
-        await conn.execute("select 1")
-    t1 = time()
-    assert t1 - t0 < 0.2
-    await p.close()
+        proxy.start()
+        t0 = time()
+        async with p.connection() as conn:
+            await conn.execute("select 1")
+        t1 = time()
+        assert t1 - t0 < 0.2
 
 
 @pytest.mark.slow
 async def test_uniform_use(dsn):
-    p = pool.AsyncConnectionPool(dsn, minconn=4)
-    counts = Counter()
-    for i in range(8):
-        async with p.connection() as conn:
-            await asyncio.sleep(0.1)
-            counts[id(conn)] += 1
+    async with pool.AsyncConnectionPool(dsn, minconn=4) as p:
+        counts = Counter()
+        for i in range(8):
+            async with p.connection() as conn:
+                await asyncio.sleep(0.1)
+                counts[id(conn)] += 1
 
-    await p.close()
     assert len(counts) == 4
     assert set(counts.values()) == set([2])
 
 
 @pytest.mark.slow
 async def test_resize(dsn):
-    p = pool.AsyncConnectionPool(dsn, minconn=2, max_idle=0.2)
-    size = []
-
     async def sampler():
         await asyncio.sleep(0.05)  # ensure sampling happens after shrink check
         while True:
@@ -671,26 +643,28 @@ async def test_resize(dsn):
         async with p.connection() as conn:
             await conn.execute("select pg_sleep(%s)", [t])
 
-    s = create_task(sampler())
+    size = []
 
-    await asyncio.sleep(0.3)
+    async with pool.AsyncConnectionPool(dsn, minconn=2, max_idle=0.2) as p:
+        s = create_task(sampler())
 
-    c = create_task(client(0.4))
+        await asyncio.sleep(0.3)
 
-    await asyncio.sleep(0.2)
-    await p.resize(4)
-    assert p.minconn == 4
-    assert p.maxconn == 4
+        c = create_task(client(0.4))
 
-    await asyncio.sleep(0.4)
-    await p.resize(2)
-    assert p.minconn == 2
-    assert p.maxconn == 2
+        await asyncio.sleep(0.2)
+        await p.resize(4)
+        assert p.minconn == 4
+        assert p.maxconn == 4
 
-    await asyncio.sleep(0.6)
-    await p.close()
-    await asyncio.gather(s, c)
+        await asyncio.sleep(0.4)
+        await p.resize(2)
+        assert p.minconn == 2
+        assert p.maxconn == 2
 
+        await asyncio.sleep(0.6)
+
+    await asyncio.gather(s, c)
     assert size == [2, 1, 3, 4, 3, 2, 2]