]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Move two-phase transaction fixture to a common place
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 30 Oct 2021 17:44:37 +0000 (19:44 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 28 Nov 2021 17:04:31 +0000 (18:04 +0100)
Importing it across test cases requires too silly workarounds.

Also do without attrgetter.

tests/fix_psycopg.py
tests/test_psycopg_dbapi20.py
tests/test_tpc.py
tests/test_tpc_async.py

index f036c9857073397fd46a131a872ab5d45585f59b..ff52c11f9eed118aec80bc568788b28c15bcae7d 100644 (file)
@@ -21,3 +21,59 @@ def global_adapters():
     adapters.types.clear()
     for t in types:
         adapters.types.add(t)
+
+
+@pytest.fixture
+def tpc(svcconn):
+    tpc = Tpc(svcconn)
+    tpc.check_tpc()
+    tpc.clear_test_xacts()
+    tpc.make_test_table()
+    yield tpc
+    tpc.clear_test_xacts()
+
+
+class Tpc:
+    """Helper object to test two-phase transactions"""
+
+    def __init__(self, conn):
+        assert conn.autocommit
+        self.conn = conn
+
+    def check_tpc(self):
+        val = int(
+            self.conn.execute("show max_prepared_transactions").fetchone()[0]
+        )
+        if not val:
+            pytest.skip("prepared transactions disabled in the database")
+
+    def clear_test_xacts(self):
+        """Rollback all the prepared transaction in the testing db."""
+        from psycopg import sql
+
+        cur = self.conn.execute(
+            "select gid from pg_prepared_xacts where database = %s",
+            (self.conn.info.dbname,),
+        )
+        gids = [r[0] for r in cur]
+        for gid in gids:
+            self.conn.execute(sql.SQL("rollback prepared {}").format(gid))
+
+    def make_test_table(self):
+        self.conn.execute("CREATE TABLE IF NOT EXISTS test_tpc (data text)")
+        self.conn.execute("TRUNCATE test_tpc")
+
+    def count_xacts(self):
+        """Return the number of prepared xacts currently in the test db."""
+        cur = self.conn.execute(
+            """
+            select count(*) from pg_prepared_xacts
+            where database = %s""",
+            (self.conn.info.dbname,),
+        )
+        return cur.fetchone()[0]
+
+    def count_test_records(self):
+        """Return the number of records in the test table."""
+        cur = self.conn.execute("select count(*) from test_tpc")
+        return cur.fetchone()[0]
index 0a40e4a2035534b6c305762876c21f6631af19df..4aa1cc7c7d3bd8a2f84282323acb0b71a970750f 100644 (file)
@@ -8,8 +8,6 @@ from psycopg.conninfo import conninfo_to_dict
 from . import dbapi20
 from . import dbapi20_tpc
 
-from .test_tpc import tpc  # noqa F401  # fixture
-
 
 @pytest.fixture(scope="class")
 def with_dsn(request, dsn):
index 127e70a8a770a51550ef70770e18dcab1c302027..29dd7dd280722f07caa463ce05422090f6838d1e 100644 (file)
@@ -1,9 +1,6 @@
-from operator import attrgetter
-
 import pytest
 
 import psycopg
-from psycopg import sql
 
 
 def test_tpc_disabled(conn):
@@ -177,7 +174,7 @@ class TestTPC:
 
         xids = conn.tpc_recover()
         xids = [xid for xid in xids if xid.database == conn.info.dbname]
-        xids.sort(key=attrgetter("gtrid"))
+        xids.sort(key=lambda x: x.gtrid)
 
         # check the values returned
         assert len(okvals) == len(xids)
@@ -327,57 +324,3 @@ class TestXidObject:
 
         x2 = psycopg.Xid.from_string("99_xxx_yyy")
         str(x2) == "99_xxx_yyy"
-
-
-@pytest.fixture
-def tpc(svcconn):
-    tpc = Tpc(svcconn)
-    tpc.check_tpc()
-    tpc.clear_test_xacts()
-    tpc.make_test_table()
-    yield tpc
-    tpc.clear_test_xacts()
-
-
-class Tpc:
-    """Helper object to test two-phase transactions"""
-
-    def __init__(self, conn):
-        assert conn.autocommit
-        self.conn = conn
-
-    def check_tpc(self):
-        val = int(
-            self.conn.execute("show max_prepared_transactions").fetchone()[0]
-        )
-        if not val:
-            pytest.skip("prepared transactions disabled in the database")
-
-    def clear_test_xacts(self):
-        """Rollback all the prepared transaction in the testing db."""
-        cur = self.conn.execute(
-            "select gid from pg_prepared_xacts where database = %s",
-            (self.conn.info.dbname,),
-        )
-        gids = [r[0] for r in cur]
-        for gid in gids:
-            self.conn.execute(sql.SQL("rollback prepared {}").format(gid))
-
-    def make_test_table(self):
-        self.conn.execute("CREATE TABLE IF NOT EXISTS test_tpc (data text)")
-        self.conn.execute("TRUNCATE test_tpc")
-
-    def count_xacts(self):
-        """Return the number of prepared xacts currently in the test db."""
-        cur = self.conn.execute(
-            """
-            select count(*) from pg_prepared_xacts
-            where database = %s""",
-            (self.conn.info.dbname,),
-        )
-        return cur.fetchone()[0]
-
-    def count_test_records(self):
-        """Return the number of records in the test table."""
-        cur = self.conn.execute("select count(*) from test_tpc")
-        return cur.fetchone()[0]
index d8fdaf433bba8e26ea07eef150d9aeeefcd1e77b..76aa6459bbd2d1475eb41e717714c095951ad1c6 100644 (file)
@@ -1,13 +1,8 @@
-from operator import attrgetter
-
 import pytest
 
 import psycopg
 
-from .test_tpc import tpc  # noqa: F401  # fixture
-
 pytestmark = [pytest.mark.asyncio]
-tpc = tpc  # Silence F811 in the rest of the file
 
 
 async def test_tpc_disabled(aconn):
@@ -188,7 +183,7 @@ class TestTPC:
 
         xids = await aconn.tpc_recover()
         xids = [xid for xid in xids if xid.database == aconn.info.dbname]
-        xids.sort(key=attrgetter("gtrid"))
+        xids.sort(key=lambda x: x.gtrid)
 
         # check the values returned
         assert len(okvals) == len(xids)