From 880874ef6b30c2264b09fb84ac1af11d579d3e4b Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Wed, 8 Dec 2021 20:54:34 +0100 Subject: [PATCH] Make the test to detect out-of-order transaction exiting reliable --- tests/test_transaction.py | 36 ++++++++++++++++++++++++--------- tests/test_transaction_async.py | 32 +++++++++++++++++++++-------- 2 files changed, 50 insertions(+), 18 deletions(-) diff --git a/tests/test_transaction.py b/tests/test_transaction.py index d231c9161..f569bfc61 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -1,5 +1,5 @@ -import concurrent.futures import logging +from threading import Thread, Event import pytest @@ -651,15 +651,31 @@ def test_str(conn): assert "(terminated)" in str(tx) -def test_concurrency(conn): +@pytest.mark.parametrize("fail", [False, True]) +def test_concurrency(conn, fail): conn.autocommit = True - def fn(value): - with conn.transaction(): - cur = conn.execute("select %s", (value,)) - return cur + e = [Event() for i in range(3)] - values = range(2) - with concurrent.futures.ThreadPoolExecutor() as e: - cursors = e.map(fn, values) - assert sum(cur.fetchone()[0] for cur in cursors) == sum(values) + def worker(unlock, wait_on): + with pytest.raises(ProgrammingError): + with conn.transaction(): + unlock.set() + wait_on.wait() + conn.execute("select 1") + if fail: + 1 / 0 + + # Start a first transaction in a thread + t1 = Thread(target=worker, kwargs={"unlock": e[0], "wait_on": e[1]}) + t1.start() + e[0].wait() + + # Start a nested transaction in a thread + t2 = Thread(target=worker, kwargs={"unlock": e[1], "wait_on": e[2]}) + t2.start() + + # Terminate the first transaction before the second does + t1.join() + e[2].set() + t2.join() diff --git a/tests/test_transaction_async.py b/tests/test_transaction_async.py index d09103b90..6b108d57a 100644 --- a/tests/test_transaction_async.py +++ b/tests/test_transaction_async.py @@ -4,6 +4,7 @@ import logging import pytest from psycopg import AsyncConnection, ProgrammingError, Rollback +from psycopg._compat import create_task from .test_transaction import in_transaction, insert_row, inserted from .test_transaction import ExpectedException @@ -618,14 +619,29 @@ async def test_str(aconn): assert "(terminated)" in str(tx) -async def test_concurrency(aconn): +@pytest.mark.parametrize("fail", [False, True]) +async def test_concurrency(aconn, fail): await aconn.set_autocommit(True) - async def fn(value): - async with aconn.transaction(): - cur = await aconn.execute("select %s", (value,)) - return cur + e = [asyncio.Event() for i in range(3)] - values = range(2) - cursors = await asyncio.gather(*[fn(value) for value in values]) - assert sum([(await cur.fetchone())[0] for cur in cursors]) == sum(values) + async def worker(unlock, wait_on): + with pytest.raises(ProgrammingError): + async with aconn.transaction(): + unlock.set() + await wait_on.wait() + await aconn.execute("select 1") + if fail: + 1 / 0 + + # Start a first transaction in a task + t1 = create_task(worker(unlock=e[0], wait_on=e[1])) + await e[0].wait() + + # Start a nested transaction in a task + t2 = create_task(worker(unlock=e[1], wait_on=e[2])) + + # Terminate the first transaction before the second does + await asyncio.gather(t1) + e[2].set() + await asyncio.gather(t2) -- 2.47.2