]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Transaction tests cleanup
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 16 Nov 2020 16:26:38 +0000 (16:26 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 17 Nov 2020 15:27:57 +0000 (15:27 +0000)
tests/test_transaction.py
tests/test_transaction_async.py

index c5d8890bde85e68c701e55cd959ae19aab2a0660..bc90e9faf1a896dc29e3d51503c38e21d474b9a2 100644 (file)
@@ -1,17 +1,11 @@
-import sys
-
 import pytest
 
 from psycopg3 import Connection, ProgrammingError, Rollback
-from psycopg3.sql import Composable
-from psycopg3.transaction import Transaction
 
 
 @pytest.fixture(autouse=True)
 def create_test_table(svcconn):
-    """
-    Creates a table called 'test_table' for use in tests.
-    """
+    """Creates a table called 'test_table' for use in tests."""
     cur = svcconn.cursor()
     cur.execute("drop table if exists test_table")
     cur.execute("create table test_table (id text primary key)")
@@ -20,10 +14,20 @@ def create_test_table(svcconn):
 
 
 def insert_row(conn, value):
-    conn.cursor().execute("INSERT INTO test_table VALUES (%s)", (value,))
+    sql = "INSERT INTO test_table VALUES (%s)"
+    if isinstance(conn, Connection):
+        conn.cursor().execute(sql, (value,))
+    else:
+
+        async def f():
+            cur = await conn.cursor()
+            await cur.execute(sql, (value,))
+
+        return f()
 
 
 def inserted(conn):
+    """Return the values inserted in the test table."""
     sql = "SELECT * FROM test_table"
     if isinstance(conn, Connection):
         rows = conn.cursor().execute(sql).fetchall()
@@ -39,26 +43,29 @@ def inserted(conn):
         return f()
 
 
-@pytest.fixture
-def commands(monkeypatch):
-    """The queue of commands issued internally by connections.
+class ListPopAll(list):
+    """A list, with a popall() method."""
 
-    Not concurrency safe as it mokkeypatches a class method, but good enough
-    for tests.
-    """
-    _orig_exec_command = Connection._exec_command
-    L = []
+    def popall(self):
+        out = self[:]
+        del self[:]
+        return out
 
-    def _exec_command(self, command):
+
+@pytest.fixture
+def commands(conn, monkeypatch):
+    """The queue of commands issued internally by the test connection."""
+    _orig_exec_command = conn._exec_command
+    L = ListPopAll()
+
+    def _exec_command(command):
         if isinstance(command, bytes):
-            command = command.decode(self.client_encoding)
-        elif isinstance(command, Composable):
-            command = command.as_string(self)
+            command = command.decode(conn.client_encoding)
 
         L.insert(0, command)
-        _orig_exec_command(self, command)
+        _orig_exec_command(command)
 
-    monkeypatch.setattr(Connection, "_exec_command", _exec_command)
+    monkeypatch.setattr(conn, "_exec_command", _exec_command)
     yield L
 
 
@@ -75,13 +82,6 @@ class ExpectedException(Exception):
     pass
 
 
-def some_exc_info():
-    try:
-        raise ExpectedException()
-    except ExpectedException:
-        return sys.exc_info()
-
-
 def test_basic(conn):
     """Basic use of transaction() to BEGIN and COMMIT a transaction."""
     assert not in_transaction(conn)
@@ -372,61 +372,45 @@ def test_named_savepoints_successful_exit(conn, commands):
     """
     # Case 1
     # Using Transaction explicitly becase conn.transaction() enters the contetx
-    tx = Transaction(conn)
     assert not commands
-    tx.__enter__()
-    assert commands.pop() == "begin"
-    assert not tx.savepoint_name
-    tx.__exit__(None, None, None)
-    assert commands.pop() == "commit"
+    with conn.transaction() as tx:
+        assert commands.popall() == ["begin"]
+        assert not tx.savepoint_name
+    assert commands.popall() == ["commit"]
 
     # Case 1 (with a transaction already started)
     conn.cursor().execute("select 1")
-    assert commands.pop() == "begin"
-    tx = Transaction(conn)
-    tx.__enter__()
-    assert commands.pop() == 'savepoint "s1"'
-    assert tx.savepoint_name == "s1"
-    tx.__exit__(None, None, None)
-    assert commands.pop() == 'release savepoint "s1"'
-    assert not commands
+    assert commands.popall() == ["begin"]
+    with conn.transaction() as tx:
+        assert commands.popall() == ['savepoint "s1"']
+        assert tx.savepoint_name == "s1"
+    assert commands.popall() == ['release savepoint "s1"']
     conn.rollback()
-    assert commands.pop() == "rollback"
-    assert not commands
+    assert commands.popall() == ["rollback"]
 
     # Case 2
-    tx = Transaction(conn, savepoint_name="foo")
-    tx.__enter__()
-    assert commands.pop() == 'begin; savepoint "foo"'
-    assert tx.savepoint_name == "foo"
-    tx.__exit__(None, None, None)
-    assert commands.pop() == "commit"
+    with conn.transaction(savepoint_name="foo") as tx:
+        assert commands.popall() == ['begin; savepoint "foo"']
+        assert tx.savepoint_name == "foo"
+    assert commands.popall() == ["commit"]
 
     # Case 3 (with savepoint name provided)
     with conn.transaction():
-        assert commands.pop() == "begin"
-        tx = Transaction(conn, savepoint_name="bar")
-        tx.__enter__()
-        assert commands.pop() == 'savepoint "bar"'
-        assert tx.savepoint_name == "bar"
-        tx.__exit__(None, None, None)
-        assert commands.pop() == 'release savepoint "bar"'
-        assert not commands
-    assert commands.pop() == "commit"
+        assert commands.popall() == ["begin"]
+        with conn.transaction(savepoint_name="bar") as tx:
+            assert commands.popall() == ['savepoint "bar"']
+            assert tx.savepoint_name == "bar"
+        assert commands.popall() == ['release savepoint "bar"']
+    assert commands.popall() == ["commit"]
 
     # Case 3 (with savepoint name auto-generated)
     with conn.transaction():
-        assert commands.pop() == "begin"
-        tx = Transaction(conn)
-        tx.__enter__()
-        assert commands.pop() == 'savepoint "s2"'
-        assert tx.savepoint_name == "s2"
-        tx.__exit__(None, None, None)
-        assert commands.pop() == 'release savepoint "s2"'
-        assert not commands
-    assert commands.pop() == "commit"
-
-    assert not commands
+        assert commands.popall() == ["begin"]
+        with conn.transaction() as tx:
+            assert commands.popall() == ['savepoint "s2"']
+            assert tx.savepoint_name == "s2"
+        assert commands.popall() == ['release savepoint "s2"']
+    assert commands.popall() == ["commit"]
 
 
 def test_named_savepoints_exception_exit(conn, commands):
@@ -436,52 +420,46 @@ def test_named_savepoints_exception_exit(conn, commands):
     be rolled-back as appropriate.
     """
     # Case 1
-    tx = Transaction(conn)
-    tx.__enter__()
-    assert commands.pop() == "begin"
-    assert not tx.savepoint_name
-    tx.__exit__(*some_exc_info())
-    assert commands.pop() == "rollback"
+    with pytest.raises(ExpectedException):
+        with conn.transaction() as tx:
+            assert commands.popall() == ["begin"]
+            assert not tx.savepoint_name
+            raise ExpectedException
+    assert commands.popall() == ["rollback"]
 
     # Case 2
-    tx = Transaction(conn, savepoint_name="foo")
-    tx.__enter__()
-    assert commands.pop() == 'begin; savepoint "foo"'
-    assert tx.savepoint_name == "foo"
-    tx.__exit__(*some_exc_info())
-    assert commands.pop() == "rollback"
+    with pytest.raises(ExpectedException):
+        with conn.transaction(savepoint_name="foo") as tx:
+            assert commands.popall() == ['begin; savepoint "foo"']
+            assert tx.savepoint_name == "foo"
+            raise ExpectedException
+    assert commands.popall() == ["rollback"]
 
     # Case 3 (with savepoint name provided)
     with conn.transaction():
-        assert commands.pop() == "begin"
-        tx = Transaction(conn, savepoint_name="bar")
-        tx.__enter__()
-        assert commands.pop() == 'savepoint "bar"'
-        assert tx.savepoint_name == "bar"
-        tx.__exit__(*some_exc_info())
-        assert (
-            commands.pop()
-            == 'rollback to savepoint "bar"; release savepoint "bar"'
-        )
-        assert not commands
-    assert commands.pop() == "commit"
+        assert commands.popall() == ["begin"]
+        with pytest.raises(ExpectedException):
+            with conn.transaction(savepoint_name="bar") as tx:
+                assert commands.popall() == ['savepoint "bar"']
+                assert tx.savepoint_name == "bar"
+                raise ExpectedException
+        assert commands.popall() == [
+            'rollback to savepoint "bar"; release savepoint "bar"'
+        ]
+    assert commands.popall() == ["commit"]
 
     # Case 3 (with savepoint name auto-generated)
     with conn.transaction():
-        assert commands.pop() == "begin"
-        tx = Transaction(conn)
-        tx.__enter__()
-        assert commands.pop() == 'savepoint "s2"'
-        assert tx.savepoint_name == "s2"
-        tx.__exit__(*some_exc_info())
-        assert (
-            commands.pop()
-            == 'rollback to savepoint "s2"; release savepoint "s2"'
-        )
-        assert not commands
-    assert commands.pop() == "commit"
-
-    assert not commands
+        assert commands.popall() == ["begin"]
+        with pytest.raises(ExpectedException):
+            with conn.transaction() as tx:
+                assert commands.popall() == ['savepoint "s2"']
+                assert tx.savepoint_name == "s2"
+                raise ExpectedException
+        assert commands.popall() == [
+            'rollback to savepoint "s2"; release savepoint "s2"'
+        ]
+    assert commands.popall() == ["commit"]
 
 
 def test_named_savepoints_with_repeated_names_works(conn):
index d4a3889d6cb006bb35b61edd8ab20b5564c4b232..09764551461e0990fe5b36ab317e02124e1bbef2 100644 (file)
@@ -1,49 +1,31 @@
 import pytest
 
-from psycopg3 import AsyncConnection, ProgrammingError, Rollback
-from psycopg3.sql import Composable
-from psycopg3.transaction import AsyncTransaction
-
-from .test_transaction import (
-    in_transaction,
-    ExpectedException,
-    some_exc_info,
-    inserted,
-)
+from psycopg3 import ProgrammingError, Rollback
+
+from .test_transaction import in_transaction, insert_row, inserted
+from .test_transaction import ExpectedException, ListPopAll
 from .test_transaction import create_test_table  # noqa  # autouse fixture
 
 pytestmark = pytest.mark.asyncio
 
 
 @pytest.fixture
-async def commands(monkeypatch):
-    """The queue of commands issued internally by connections.
-
-    Not concurrency safe as it mokkeypatches a class method, but good enough
-    for tests.
-    """
-    _orig_exec_command = AsyncConnection._exec_command
-    L = []
+async def commands(aconn, monkeypatch):
+    """The queue of commands issued internally by the test connection."""
+    _orig_exec_command = aconn._exec_command
+    L = ListPopAll()
 
-    async def _exec_command(self, command):
+    async def _exec_command(command):
         if isinstance(command, bytes):
-            command = command.decode(self.client_encoding)
-        elif isinstance(command, Composable):
-            command = command.as_string(self)
+            command = command.decode(aconn.client_encoding)
 
         L.insert(0, command)
-        await _orig_exec_command(self, command)
+        await _orig_exec_command(command)
 
-    monkeypatch.setattr(AsyncConnection, "_exec_command", _exec_command)
+    monkeypatch.setattr(aconn, "_exec_command", _exec_command)
     yield L
 
 
-async def insert_row(aconn, value):
-    await (await aconn.cursor()).execute(
-        "INSERT INTO test_table VALUES (%s)", (value,)
-    )
-
-
 async def test_basic(aconn):
     """Basic use of transaction() to BEGIN and COMMIT a transaction."""
     assert not in_transaction(aconn)
@@ -342,60 +324,45 @@ async def test_named_savepoints_successful_exit(aconn, commands):
     """
     # Case 1
     # Using Transaction explicitly becase conn.transaction() enters the contetx
-    tx = AsyncTransaction(aconn)
-    await tx.__aenter__()
-    assert commands.pop() == "begin"
-    assert not tx.savepoint_name
-    await tx.__aexit__(None, None, None)
-    assert commands.pop() == "commit"
+    async with aconn.transaction() as tx:
+        assert commands.popall() == ["begin"]
+        assert not tx.savepoint_name
+    assert commands.popall() == ["commit"]
 
     # Case 1 (with a transaction already started)
     await (await aconn.cursor()).execute("select 1")
-    assert commands.pop() == "begin"
-    tx = AsyncTransaction(aconn)
-    await tx.__aenter__()
-    assert commands.pop() == 'savepoint "s1"'
-    assert tx.savepoint_name == "s1"
-    await tx.__aexit__(None, None, None)
-    assert commands.pop() == 'release savepoint "s1"'
-    assert not commands
+    assert commands.popall() == ["begin"]
+    async with aconn.transaction() as tx:
+        assert commands.popall() == ['savepoint "s1"']
+        assert tx.savepoint_name == "s1"
+
+    assert commands.popall() == ['release savepoint "s1"']
     await aconn.rollback()
-    assert commands.pop() == "rollback"
-    assert not commands
+    assert commands.popall() == ["rollback"]
 
     # Case 2
-    tx = AsyncTransaction(aconn, savepoint_name="foo")
-    await tx.__aenter__()
-    assert commands.pop() == 'begin; savepoint "foo"'
-    assert tx.savepoint_name == "foo"
-    await tx.__aexit__(None, None, None)
-    assert commands.pop() == "commit"
+    async with aconn.transaction(savepoint_name="foo") as tx:
+        assert commands.popall() == ['begin; savepoint "foo"']
+        assert tx.savepoint_name == "foo"
+    assert commands.popall() == ["commit"]
 
     # Case 3 (with savepoint name provided)
     async with aconn.transaction():
-        assert commands.pop() == "begin"
-        tx = AsyncTransaction(aconn, savepoint_name="bar")
-        await tx.__aenter__()
-        assert commands.pop() == 'savepoint "bar"'
-        assert tx.savepoint_name == "bar"
-        await tx.__aexit__(None, None, None)
-        assert commands.pop() == 'release savepoint "bar"'
-        assert not commands
-    assert commands.pop() == "commit"
+        assert commands.popall() == ["begin"]
+        async with aconn.transaction(savepoint_name="bar") as tx:
+            assert commands.popall() == ['savepoint "bar"']
+            assert tx.savepoint_name == "bar"
+        assert commands.popall() == ['release savepoint "bar"']
+    assert commands.popall() == ["commit"]
 
     # Case 3 (with savepoint name auto-generated)
     async with aconn.transaction():
-        assert commands.pop() == "begin"
-        tx = AsyncTransaction(aconn)
-        await tx.__aenter__()
-        assert commands.pop() == 'savepoint "s2"'
-        assert tx.savepoint_name == "s2"
-        await tx.__aexit__(None, None, None)
-        assert commands.pop() == 'release savepoint "s2"'
-        assert not commands
-    assert commands.pop() == "commit"
-
-    assert not commands
+        assert commands.popall() == ["begin"]
+        async with aconn.transaction() as tx:
+            assert commands.popall() == ['savepoint "s2"']
+            assert tx.savepoint_name == "s2"
+        assert commands.popall() == ['release savepoint "s2"']
+    assert commands.popall() == ["commit"]
 
 
 async def test_named_savepoints_exception_exit(aconn, commands):
@@ -405,52 +372,46 @@ async def test_named_savepoints_exception_exit(aconn, commands):
     be rolled-back as appropriate.
     """
     # Case 1
-    tx = AsyncTransaction(aconn)
-    await tx.__aenter__()
-    assert commands.pop() == "begin"
-    assert not tx.savepoint_name
-    await tx.__aexit__(*some_exc_info())
-    assert commands.pop() == "rollback"
+    with pytest.raises(ExpectedException):
+        async with aconn.transaction() as tx:
+            assert commands.popall() == ["begin"]
+            assert not tx.savepoint_name
+            raise ExpectedException
+    assert commands.popall() == ["rollback"]
 
     # Case 2
-    tx = AsyncTransaction(aconn, savepoint_name="foo")
-    await tx.__aenter__()
-    assert commands.pop() == 'begin; savepoint "foo"'
-    assert tx.savepoint_name == "foo"
-    await tx.__aexit__(*some_exc_info())
-    assert commands.pop() == "rollback"
+    with pytest.raises(ExpectedException):
+        async with aconn.transaction(savepoint_name="foo") as tx:
+            assert commands.popall() == ['begin; savepoint "foo"']
+            assert tx.savepoint_name == "foo"
+            raise ExpectedException
+    assert commands.popall() == ["rollback"]
 
     # Case 3 (with savepoint name provided)
     async with aconn.transaction():
-        assert commands.pop() == "begin"
-        tx = AsyncTransaction(aconn, savepoint_name="bar")
-        await tx.__aenter__()
-        assert commands.pop() == 'savepoint "bar"'
-        assert tx.savepoint_name == "bar"
-        await tx.__aexit__(*some_exc_info())
-        assert (
-            commands.pop()
-            == 'rollback to savepoint "bar"; release savepoint "bar"'
-        )
-        assert not commands
-    assert commands.pop() == "commit"
+        assert commands.popall() == ["begin"]
+        with pytest.raises(ExpectedException):
+            async with aconn.transaction(savepoint_name="bar") as tx:
+                assert commands.popall() == ['savepoint "bar"']
+                assert tx.savepoint_name == "bar"
+                raise ExpectedException
+        assert commands.popall() == [
+            'rollback to savepoint "bar"; release savepoint "bar"'
+        ]
+    assert commands.popall() == ["commit"]
 
     # Case 3 (with savepoint name auto-generated)
     async with aconn.transaction():
-        assert commands.pop() == "begin"
-        tx = AsyncTransaction(aconn)
-        await tx.__aenter__()
-        assert commands.pop() == 'savepoint "s2"'
-        assert tx.savepoint_name == "s2"
-        await tx.__aexit__(*some_exc_info())
-        assert (
-            commands.pop()
-            == 'rollback to savepoint "s2"; release savepoint "s2"'
-        )
-        assert not commands
-    assert commands.pop() == "commit"
-
-    assert not commands
+        assert commands.popall() == ["begin"]
+        with pytest.raises(ExpectedException):
+            async with aconn.transaction() as tx:
+                assert commands.popall() == ['savepoint "s2"']
+                assert tx.savepoint_name == "s2"
+                raise ExpectedException
+        assert commands.popall() == [
+            'rollback to savepoint "s2"; release savepoint "s2"'
+        ]
+    assert commands.popall() == ["commit"]
 
 
 async def test_named_savepoints_with_repeated_names_works(aconn):