-import concurrent.futures
import logging
+from threading import Thread, Event
import pytest
assert "(terminated)" in str(tx)
-def test_concurrency(conn):
+@pytest.mark.parametrize("fail", [False, True])
+def test_concurrency(conn, fail):
conn.autocommit = True
- def fn(value):
- with conn.transaction():
- cur = conn.execute("select %s", (value,))
- return cur
+ e = [Event() for i in range(3)]
- values = range(2)
- with concurrent.futures.ThreadPoolExecutor() as e:
- cursors = e.map(fn, values)
- assert sum(cur.fetchone()[0] for cur in cursors) == sum(values)
+ def worker(unlock, wait_on):
+ with pytest.raises(ProgrammingError):
+ with conn.transaction():
+ unlock.set()
+ wait_on.wait()
+ conn.execute("select 1")
+ if fail:
+ 1 / 0
+
+ # Start a first transaction in a thread
+ t1 = Thread(target=worker, kwargs={"unlock": e[0], "wait_on": e[1]})
+ t1.start()
+ e[0].wait()
+
+ # Start a nested transaction in a thread
+ t2 = Thread(target=worker, kwargs={"unlock": e[1], "wait_on": e[2]})
+ t2.start()
+
+ # Terminate the first transaction before the second does
+ t1.join()
+ e[2].set()
+ t2.join()
import pytest
from psycopg import AsyncConnection, ProgrammingError, Rollback
+from psycopg._compat import create_task
from .test_transaction import in_transaction, insert_row, inserted
from .test_transaction import ExpectedException
assert "(terminated)" in str(tx)
-async def test_concurrency(aconn):
+@pytest.mark.parametrize("fail", [False, True])
+async def test_concurrency(aconn, fail):
await aconn.set_autocommit(True)
- async def fn(value):
- async with aconn.transaction():
- cur = await aconn.execute("select %s", (value,))
- return cur
+ e = [asyncio.Event() for i in range(3)]
- values = range(2)
- cursors = await asyncio.gather(*[fn(value) for value in values])
- assert sum([(await cur.fetchone())[0] for cur in cursors]) == sum(values)
+ async def worker(unlock, wait_on):
+ with pytest.raises(ProgrammingError):
+ async with aconn.transaction():
+ unlock.set()
+ await wait_on.wait()
+ await aconn.execute("select 1")
+ if fail:
+ 1 / 0
+
+ # Start a first transaction in a task
+ t1 = create_task(worker(unlock=e[0], wait_on=e[1]))
+ await e[0].wait()
+
+ # Start a nested transaction in a task
+ t2 = create_task(worker(unlock=e[1], wait_on=e[2]))
+
+ # Terminate the first transaction before the second does
+ await asyncio.gather(t1)
+ e[2].set()
+ await asyncio.gather(t2)