]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Solve race conditions in test
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 13 Nov 2021 17:15:58 +0000 (18:15 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 13 Nov 2021 21:54:15 +0000 (22:54 +0100)
tests/pool/test_pool.py
tests/pool/test_pool_async.py
tests/test_concurrency.py

index c0779dbb33b122f79d92c62be8470c3cc0b82146..177044decac8e6db20fefa70cf8d03a11501d807 100644 (file)
@@ -168,7 +168,7 @@ def test_configure(dsn):
             conn.execute("set default_transaction_read_only to on")
 
     with pool.ConnectionPool(dsn, min_size=1, configure=configure) as p:
-        p.wait(timeout=1.0)
+        p.wait()
         with p.connection() as conn:
             assert inits == 1
             res = conn.execute("show default_transaction_read_only")
@@ -636,34 +636,47 @@ def test_closed_putconn(dsn):
     assert conn.closed
 
 
-@pytest.mark.slow
-def test_closed_queue(dsn, retries):
+def test_closed_queue(dsn):
     def w1():
         with p.connection() as conn:
-            cur = conn.execute("select 1 from pg_sleep(0.2)")
-            assert cur.fetchone()[0] == 1  # type: ignore[index]
+            e1.set()  # Tell w0 that w1 got a connection
+            cur = conn.execute("select 1")
+            assert cur.fetchone() == (1,)
+            e2.wait()  # Wait until w0 has tested w2
         success.append("w1")
 
     def w2():
-        with pytest.raises(pool.PoolClosed):
+        try:
             with p.connection():
-                pass
-        success.append("w2")
+                pass  # unexpected
+        except pool.PoolClosed:
+            success.append("w2")
 
-    for retry in retries:
-        with retry:
-            p = pool.ConnectionPool(dsn, min_size=1)
-            success: List[str] = []
-
-            t1 = Thread(target=w1)
-            t2 = Thread(target=w2)
-            t1.start()
-            sleep(0.1)
-            t2.start()
-            p.close()
-            t1.join()
-            t2.join()
-            assert len(success) == 2
+    e1 = Event()
+    e2 = Event()
+
+    p = pool.ConnectionPool(dsn, min_size=1)
+    p.wait()
+    success: List[str] = []
+
+    t1 = Thread(target=w1)
+    t1.start()
+    # Wait until w1 has received a connection
+    e1.wait()
+
+    t2 = Thread(target=w2)
+    t2.start()
+    # Wait until w2 is in the queue
+    while not p._waiting:
+        sleep(0)
+
+    p.close(0)
+
+    # Wait for the workers to finish
+    e2.set()
+    t1.join()
+    t2.join()
+    assert len(success) == 2
 
 
 @pytest.mark.slow
index 96c39ed90cdc7db6e4f0a14a4166f42137a72e22..d7d2fdca4aa5164d7da6d4e42f00ecf0ce4dd8e3 100644 (file)
@@ -629,31 +629,44 @@ async def test_closed_putconn(dsn):
     assert conn.closed
 
 
-@pytest.mark.slow
-async def test_closed_queue(dsn, retries):
+async def test_closed_queue(dsn):
     async def w1():
         async with p.connection() as conn:
-            res = await conn.execute("select 1 from pg_sleep(0.2)")
-            assert await res.fetchone() == (1,)
+            e1.set()  # Tell w0 that w1 got a connection
+            cur = await conn.execute("select 1")
+            assert await cur.fetchone() == (1,)
+            await e2.wait()  # Wait until w0 has tested w2
         success.append("w1")
 
     async def w2():
-        with pytest.raises(pool.PoolClosed):
+        try:
             async with p.connection():
-                pass
-        success.append("w2")
+                pass  # unexpected
+        except pool.PoolClosed:
+            success.append("w2")
 
-    async for retry in retries:
-        with retry:
-            p = pool.AsyncConnectionPool(dsn, min_size=1)
-            success: List[str] = []
-
-            t1 = create_task(w1())
-            await asyncio.sleep(0.1)
-            t2 = create_task(w2())
-            await p.close()
-            await asyncio.gather(t1, t2)
-            assert len(success) == 2
+    e1 = asyncio.Event()
+    e2 = asyncio.Event()
+
+    p = pool.AsyncConnectionPool(dsn, min_size=1)
+    await p.wait()
+    success: List[str] = []
+
+    t1 = create_task(w1())
+    # Wait until w1 has received a connection
+    await e1.wait()
+
+    t2 = create_task(w2())
+    # Wait until w2 is in the queue
+    while not p._waiting:
+        await asyncio.sleep(0)
+
+    await p.close()
+
+    # Wait for the workers to finish
+    e2.set()
+    await asyncio.gather(t1, t2)
+    assert len(success) == 2
 
 
 @pytest.mark.slow
index 625e89d5a70cbab734fd2a473638493f91bbaab7..0fd628ab8f879102ab60e5b37ac7fb617ffbdcf4 100644 (file)
@@ -59,6 +59,7 @@ def test_commit_concurrency(conn):
 
     # Stop the committer thread
     stop = True
+    t1.join()
 
     assert notices.empty(), "%d notices raised" % notices.qsize()