]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Make the test to detect out-of-order transaction exiting reliable
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 8 Dec 2021 19:54:34 +0000 (20:54 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 8 Dec 2021 19:54:34 +0000 (20:54 +0100)
tests/test_transaction.py
tests/test_transaction_async.py

index d231c916131eb7fa443c34ccf9f512349426c1bf..f569bfc61eca7dfe53c28027264d56567e03d530 100644 (file)
@@ -1,5 +1,5 @@
-import concurrent.futures
 import logging
+from threading import Thread, Event
 
 import pytest
 
@@ -651,15 +651,31 @@ def test_str(conn):
     assert "(terminated)" in str(tx)
 
 
-def test_concurrency(conn):
+@pytest.mark.parametrize("fail", [False, True])
+def test_concurrency(conn, fail):
     conn.autocommit = True
 
-    def fn(value):
-        with conn.transaction():
-            cur = conn.execute("select %s", (value,))
-        return cur
+    e = [Event() for i in range(3)]
 
-    values = range(2)
-    with concurrent.futures.ThreadPoolExecutor() as e:
-        cursors = e.map(fn, values)
-    assert sum(cur.fetchone()[0] for cur in cursors) == sum(values)
+    def worker(unlock, wait_on):
+        with pytest.raises(ProgrammingError):
+            with conn.transaction():
+                unlock.set()
+                wait_on.wait()
+                conn.execute("select 1")
+                if fail:
+                    1 / 0
+
+    # Start a first transaction in a thread
+    t1 = Thread(target=worker, kwargs={"unlock": e[0], "wait_on": e[1]})
+    t1.start()
+    e[0].wait()
+
+    # Start a nested transaction in a thread
+    t2 = Thread(target=worker, kwargs={"unlock": e[1], "wait_on": e[2]})
+    t2.start()
+
+    # Terminate the first transaction before the second does
+    t1.join()
+    e[2].set()
+    t2.join()
index d09103b90d605bfebc53e85d602ea22cb45dfeec..6b108d57a1a066cc4e88d1bb733d7541a9cf3585 100644 (file)
@@ -4,6 +4,7 @@ import logging
 import pytest
 
 from psycopg import AsyncConnection, ProgrammingError, Rollback
+from psycopg._compat import create_task
 
 from .test_transaction import in_transaction, insert_row, inserted
 from .test_transaction import ExpectedException
@@ -618,14 +619,29 @@ async def test_str(aconn):
     assert "(terminated)" in str(tx)
 
 
-async def test_concurrency(aconn):
+@pytest.mark.parametrize("fail", [False, True])
+async def test_concurrency(aconn, fail):
     await aconn.set_autocommit(True)
 
-    async def fn(value):
-        async with aconn.transaction():
-            cur = await aconn.execute("select %s", (value,))
-        return cur
+    e = [asyncio.Event() for i in range(3)]
 
-    values = range(2)
-    cursors = await asyncio.gather(*[fn(value) for value in values])
-    assert sum([(await cur.fetchone())[0] for cur in cursors]) == sum(values)
+    async def worker(unlock, wait_on):
+        with pytest.raises(ProgrammingError):
+            async with aconn.transaction():
+                unlock.set()
+                await wait_on.wait()
+                await aconn.execute("select 1")
+                if fail:
+                    1 / 0
+
+    # Start a first transaction in a task
+    t1 = create_task(worker(unlock=e[0], wait_on=e[1]))
+    await e[0].wait()
+
+    # Start a nested transaction in a task
+    t2 = create_task(worker(unlock=e[1], wait_on=e[2]))
+
+    # Terminate the first transaction before the second does
+    await asyncio.gather(t1)
+    e[2].set()
+    await asyncio.gather(t2)