From ccd8349cb9229d4c67a564b4496e0e5e90a9fde5 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Sat, 30 Oct 2021 19:44:37 +0200 Subject: [PATCH] Move two-phase transaction fixture to a common place Importing it across test cases requires too silly workarounds. Also do without attrgetter. --- tests/fix_psycopg.py | 56 +++++++++++++++++++++++++++++++++ tests/test_psycopg_dbapi20.py | 2 -- tests/test_tpc.py | 59 +---------------------------------- tests/test_tpc_async.py | 7 +---- 4 files changed, 58 insertions(+), 66 deletions(-) diff --git a/tests/fix_psycopg.py b/tests/fix_psycopg.py index f036c9857..ff52c11f9 100644 --- a/tests/fix_psycopg.py +++ b/tests/fix_psycopg.py @@ -21,3 +21,59 @@ def global_adapters(): adapters.types.clear() for t in types: adapters.types.add(t) + + +@pytest.fixture +def tpc(svcconn): + tpc = Tpc(svcconn) + tpc.check_tpc() + tpc.clear_test_xacts() + tpc.make_test_table() + yield tpc + tpc.clear_test_xacts() + + +class Tpc: + """Helper object to test two-phase transactions""" + + def __init__(self, conn): + assert conn.autocommit + self.conn = conn + + def check_tpc(self): + val = int( + self.conn.execute("show max_prepared_transactions").fetchone()[0] + ) + if not val: + pytest.skip("prepared transactions disabled in the database") + + def clear_test_xacts(self): + """Rollback all the prepared transaction in the testing db.""" + from psycopg import sql + + cur = self.conn.execute( + "select gid from pg_prepared_xacts where database = %s", + (self.conn.info.dbname,), + ) + gids = [r[0] for r in cur] + for gid in gids: + self.conn.execute(sql.SQL("rollback prepared {}").format(gid)) + + def make_test_table(self): + self.conn.execute("CREATE TABLE IF NOT EXISTS test_tpc (data text)") + self.conn.execute("TRUNCATE test_tpc") + + def count_xacts(self): + """Return the number of prepared xacts currently in the test db.""" + cur = self.conn.execute( + """ + select count(*) from pg_prepared_xacts + where database = %s""", + (self.conn.info.dbname,), + ) + return cur.fetchone()[0] + + def count_test_records(self): + """Return the number of records in the test table.""" + cur = self.conn.execute("select count(*) from test_tpc") + return cur.fetchone()[0] diff --git a/tests/test_psycopg_dbapi20.py b/tests/test_psycopg_dbapi20.py index 0a40e4a20..4aa1cc7c7 100644 --- a/tests/test_psycopg_dbapi20.py +++ b/tests/test_psycopg_dbapi20.py @@ -8,8 +8,6 @@ from psycopg.conninfo import conninfo_to_dict from . import dbapi20 from . import dbapi20_tpc -from .test_tpc import tpc # noqa F401 # fixture - @pytest.fixture(scope="class") def with_dsn(request, dsn): diff --git a/tests/test_tpc.py b/tests/test_tpc.py index 127e70a8a..29dd7dd28 100644 --- a/tests/test_tpc.py +++ b/tests/test_tpc.py @@ -1,9 +1,6 @@ -from operator import attrgetter - import pytest import psycopg -from psycopg import sql def test_tpc_disabled(conn): @@ -177,7 +174,7 @@ class TestTPC: xids = conn.tpc_recover() xids = [xid for xid in xids if xid.database == conn.info.dbname] - xids.sort(key=attrgetter("gtrid")) + xids.sort(key=lambda x: x.gtrid) # check the values returned assert len(okvals) == len(xids) @@ -327,57 +324,3 @@ class TestXidObject: x2 = psycopg.Xid.from_string("99_xxx_yyy") str(x2) == "99_xxx_yyy" - - -@pytest.fixture -def tpc(svcconn): - tpc = Tpc(svcconn) - tpc.check_tpc() - tpc.clear_test_xacts() - tpc.make_test_table() - yield tpc - tpc.clear_test_xacts() - - -class Tpc: - """Helper object to test two-phase transactions""" - - def __init__(self, conn): - assert conn.autocommit - self.conn = conn - - def check_tpc(self): - val = int( - self.conn.execute("show max_prepared_transactions").fetchone()[0] - ) - if not val: - pytest.skip("prepared transactions disabled in the database") - - def clear_test_xacts(self): - """Rollback all the prepared transaction in the testing db.""" - cur = self.conn.execute( - "select gid from pg_prepared_xacts where database = %s", - (self.conn.info.dbname,), - ) - gids = [r[0] for r in cur] - for gid in gids: - self.conn.execute(sql.SQL("rollback prepared {}").format(gid)) - - def make_test_table(self): - self.conn.execute("CREATE TABLE IF NOT EXISTS test_tpc (data text)") - self.conn.execute("TRUNCATE test_tpc") - - def count_xacts(self): - """Return the number of prepared xacts currently in the test db.""" - cur = self.conn.execute( - """ - select count(*) from pg_prepared_xacts - where database = %s""", - (self.conn.info.dbname,), - ) - return cur.fetchone()[0] - - def count_test_records(self): - """Return the number of records in the test table.""" - cur = self.conn.execute("select count(*) from test_tpc") - return cur.fetchone()[0] diff --git a/tests/test_tpc_async.py b/tests/test_tpc_async.py index d8fdaf433..76aa6459b 100644 --- a/tests/test_tpc_async.py +++ b/tests/test_tpc_async.py @@ -1,13 +1,8 @@ -from operator import attrgetter - import pytest import psycopg -from .test_tpc import tpc # noqa: F401 # fixture - pytestmark = [pytest.mark.asyncio] -tpc = tpc # Silence F811 in the rest of the file async def test_tpc_disabled(aconn): @@ -188,7 +183,7 @@ class TestTPC: xids = await aconn.tpc_recover() xids = [xid for xid in xids if xid.database == aconn.info.dbname] - xids.sort(key=attrgetter("gtrid")) + xids.sort(key=lambda x: x.gtrid) # check the values returned assert len(okvals) == len(xids) -- 2.47.2