]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Fixed transaction behaviour when there is a transaction already started
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 16 Nov 2020 16:26:38 +0000 (16:26 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 17 Nov 2020 15:20:16 +0000 (15:20 +0000)
State management simplified too.

psycopg3/psycopg3/connection.py
psycopg3/psycopg3/transaction.py
tests/test_transaction.py
tests/test_transaction_async.py

index 016af7f0114aab5d3e15a4d074b8a997a4ec58f6..2dd262de432a02fce369559165917e09234d5fbb 100644 (file)
@@ -106,10 +106,10 @@ class BaseConnection:
         self._notice_handlers: List[NoticeHandler] = []
         self._notify_handlers: List[NotifyHandler] = []
 
-        # stack of savepoint names managed by active Transaction() blocks
-        self._savepoints: Optional[List[str]] = None
-        # (None when there no active Transaction blocks; [] when there is only
-        # one Transaction block, with a top-level transaction and no savepoint)
+        # Stack of savepoint names managed by current transaction blocks.
+        # the first item is "" in case the outermost Transaction must manage
+        # only a begin/commit and not a savepoint.
+        self._savepoints: List[str] = []
 
         wself = ref(self)
 
@@ -135,7 +135,7 @@ class BaseConnection:
         # subclasses must call it holding a lock
         status = self.pgconn.transaction_status
         if status != TransactionStatus.IDLE:
-            if self._savepoints is not None:
+            if self._savepoints:
                 raise e.ProgrammingError(
                     "couldn't change autocommit state: "
                     "connection.transaction() context in progress"
@@ -295,7 +295,7 @@ class Connection(BaseConnection):
     def commit(self) -> None:
         """Commit any pending transaction to the database."""
         with self.lock:
-            if self._savepoints is not None:
+            if self._savepoints:
                 raise e.ProgrammingError(
                     "Explicit commit() forbidden within a Transaction "
                     "context. (Transaction will be automatically committed "
@@ -308,7 +308,7 @@ class Connection(BaseConnection):
     def rollback(self) -> None:
         """Roll back to the start of any pending transaction."""
         with self.lock:
-            if self._savepoints is not None:
+            if self._savepoints:
                 raise e.ProgrammingError(
                     "Explicit rollback() forbidden within a Transaction "
                     "context. (Either raise Transaction.Rollback() or allow "
@@ -447,7 +447,7 @@ class AsyncConnection(BaseConnection):
         async with self.lock:
             if self.pgconn.transaction_status == TransactionStatus.IDLE:
                 return
-            if self._savepoints is not None:
+            if self._savepoints:
                 raise e.ProgrammingError(
                     "Explicit commit() forbidden within a Transaction "
                     "context. (Transaction will be automatically committed "
@@ -459,7 +459,7 @@ class AsyncConnection(BaseConnection):
         async with self.lock:
             if self.pgconn.transaction_status == TransactionStatus.IDLE:
                 return
-            if self._savepoints is not None:
+            if self._savepoints:
                 raise e.ProgrammingError(
                     "Explicit rollback() forbidden within a Transaction "
                     "context. (Either raise Transaction.Rollback() or allow "
index b02d412bc1a41175551315649e1a655c068ce6d5..8e11ab04bad8d4f5620335b5662764480c026b07 100644 (file)
@@ -12,7 +12,6 @@ from typing import Generic, List, Optional, Type, Union, TYPE_CHECKING
 from . import sql
 from .pq import TransactionStatus
 from .proto import ConnectionType
-from .errors import ProgrammingError
 
 if TYPE_CHECKING:
     from .connection import Connection, AsyncConnection  # noqa: F401
@@ -47,18 +46,28 @@ class BaseTransaction(Generic[ConnectionType]):
         force_rollback: bool = False,
     ):
         self._conn = connection
-        self._savepoint_name = savepoint_name or ""
         self.force_rollback = force_rollback
-        self._outer_transaction = (
-            connection.pgconn.transaction_status == TransactionStatus.IDLE
-        )
+        self._yolo = True
+
+        if connection.pgconn.transaction_status == TransactionStatus.IDLE:
+            # outer transaction: if no name it's only a begin, else
+            # there will be an additional savepoint
+            self._outer_transaction = True
+            assert not connection._savepoints
+            self._savepoint_name = savepoint_name or ""
+        else:
+            # inner transaction: it always has a name
+            self._outer_transaction = False
+            self._savepoint_name = (
+                savepoint_name or f"s{len(self._conn._savepoints) + 1}"
+            )
 
     @property
     def connection(self) -> ConnectionType:
         return self._conn
 
     @property
-    def savepoint_name(self) -> str:
+    def savepoint_name(self) -> Optional[str]:
         return self._savepoint_name
 
     def __repr__(self) -> str:
@@ -69,23 +78,14 @@ class BaseTransaction(Generic[ConnectionType]):
             args.append("force_rollback=True")
         return f"{self.__class__.__qualname__}({', '.join(args)})"
 
-    _out_of_order_err = ProgrammingError(
-        "Out-of-order Transaction context exits. Are you "
-        "calling __exit__() manually and getting it wrong?"
-    )
-
     def _enter_commands(self) -> List[str]:
-        commands = []
+        assert self._yolo
+        self._yolo = False
 
+        commands = []
         if self._outer_transaction:
-            assert self._conn._savepoints is None, self._conn._savepoints
-            self._conn._savepoints = []
+            assert not self._conn._savepoints, self._conn._savepoints
             commands.append("begin")
-        else:
-            if self._conn._savepoints is None:
-                self._conn._savepoints = []
-            if not self._savepoint_name:
-                self._savepoint_name = f"s{len(self._conn._savepoints) + 1}"
 
         if self._savepoint_name:
             commands.append(
@@ -93,52 +93,46 @@ class BaseTransaction(Generic[ConnectionType]):
                 .format(sql.Identifier(self._savepoint_name))
                 .as_string(self._conn)
             )
-            self._conn._savepoints.append(self._savepoint_name)
 
+        self._conn._savepoints.append(self._savepoint_name)
         return commands
 
     def _commit_commands(self) -> List[str]:
-        commands = []
+        assert self._conn._savepoints[-1] == self._savepoint_name
+        self._conn._savepoints.pop()
 
-        self._pop_savepoint()
+        commands = []
         if self._savepoint_name and not self._outer_transaction:
             commands.append(
                 sql.SQL("release savepoint {}")
                 .format(sql.Identifier(self._savepoint_name))
                 .as_string(self._conn)
             )
+
         if self._outer_transaction:
+            assert not self._conn._savepoints
             commands.append("commit")
 
         return commands
 
     def _rollback_commands(self) -> List[str]:
-        commands = []
+        assert self._conn._savepoints[-1] == self._savepoint_name
+        self._conn._savepoints.pop()
 
-        self._pop_savepoint()
+        commands = []
         if self._savepoint_name and not self._outer_transaction:
             commands.append(
                 sql.SQL("rollback to savepoint {n}; release savepoint {n}")
                 .format(n=sql.Identifier(self._savepoint_name))
                 .as_string(self._conn)
             )
+
         if self._outer_transaction:
+            assert not self._conn._savepoints
             commands.append("rollback")
 
         return commands
 
-    def _pop_savepoint(self) -> None:
-        if self._savepoint_name:
-            if self._conn._savepoints is None:
-                raise self._out_of_order_err
-            actual = self._conn._savepoints.pop()
-            if actual != self._savepoint_name:
-                raise self._out_of_order_err
-        if self._outer_transaction:
-            if self._conn._savepoints is None or self._conn._savepoints:
-                raise self._out_of_order_err
-            self._conn._savepoints = None
-
 
 class Transaction(BaseTransaction["Connection"]):
     def __enter__(self) -> "Transaction":
index dfa66ab25f2239a8df7100bb371fa58ed4f1651e..c5d8890bde85e68c701e55cd959ae19aab2a0660 100644 (file)
@@ -135,6 +135,21 @@ def test_rollback_on_exception_exit(conn):
     assert not inserted(conn)
 
 
+def test_interaction_dbapi_transaction(conn):
+    insert_row(conn, "foo")
+
+    with conn.transaction():
+        insert_row(conn, "bar")
+        raise Rollback
+
+    with conn.transaction():
+        insert_row(conn, "baz")
+
+    assert in_transaction(conn)
+    conn.commit()
+    assert inserted(conn) == {"foo", "baz"}
+
+
 def test_prohibits_use_of_commit_rollback_autocommit(conn):
     """
     Within a Transaction block, it is forbidden to touch commit, rollback,
@@ -365,6 +380,20 @@ def test_named_savepoints_successful_exit(conn, commands):
     tx.__exit__(None, None, None)
     assert commands.pop() == "commit"
 
+    # Case 1 (with a transaction already started)
+    conn.cursor().execute("select 1")
+    assert commands.pop() == "begin"
+    tx = Transaction(conn)
+    tx.__enter__()
+    assert commands.pop() == 'savepoint "s1"'
+    assert tx.savepoint_name == "s1"
+    tx.__exit__(None, None, None)
+    assert commands.pop() == 'release savepoint "s1"'
+    assert not commands
+    conn.rollback()
+    assert commands.pop() == "rollback"
+    assert not commands
+
     # Case 2
     tx = Transaction(conn, savepoint_name="foo")
     tx.__enter__()
@@ -390,10 +419,10 @@ def test_named_savepoints_successful_exit(conn, commands):
         assert commands.pop() == "begin"
         tx = Transaction(conn)
         tx.__enter__()
-        assert commands.pop() == 'savepoint "s1"'
-        assert tx.savepoint_name == "s1"
+        assert commands.pop() == 'savepoint "s2"'
+        assert tx.savepoint_name == "s2"
         tx.__exit__(None, None, None)
-        assert commands.pop() == 'release savepoint "s1"'
+        assert commands.pop() == 'release savepoint "s2"'
         assert not commands
     assert commands.pop() == "commit"
 
@@ -442,12 +471,12 @@ def test_named_savepoints_exception_exit(conn, commands):
         assert commands.pop() == "begin"
         tx = Transaction(conn)
         tx.__enter__()
-        assert commands.pop() == 'savepoint "s1"'
-        assert tx.savepoint_name == "s1"
+        assert commands.pop() == 'savepoint "s2"'
+        assert tx.savepoint_name == "s2"
         tx.__exit__(*some_exc_info())
         assert (
             commands.pop()
-            == 'rollback to savepoint "s1"; release savepoint "s1"'
+            == 'rollback to savepoint "s2"; release savepoint "s2"'
         )
         assert not commands
     assert commands.pop() == "commit"
@@ -608,30 +637,3 @@ def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected(conn, svcconn):
         assert not inserted(svcconn)  # Not yet committed
     # Changes committed
     assert inserted(svcconn) == {"outer-before", "outer-after"}
-
-
-@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()])
-@pytest.mark.parametrize("name", [None, "s1"])
-def test_manual_exit_without_enter_asserts(conn, name, exc_info):
-    """
-    When user is calling __enter__() and __exit__() manually for some reason,
-    provide a helpful error message if they call __exit__() without first
-    having called __enter__()
-    """
-    tx = Transaction(conn, name)
-    with pytest.raises(ProgrammingError, match="Out-of-order"):
-        tx.__exit__(*exc_info)
-
-
-@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()])
-@pytest.mark.parametrize("name", [None, "s1"])
-def test_manual_exit_twice_asserts(conn, name, exc_info):
-    """
-    When user is calling __enter__() and __exit__() manually for some reason,
-    provide a helpful error message if they accidentally call __exit__() twice.
-    """
-    tx = Transaction(conn, name)
-    tx.__enter__()
-    tx.__exit__(*exc_info)
-    with pytest.raises(ProgrammingError, match="Out-of-order"):
-        tx.__exit__(*exc_info)
index c434fa6a01c2be80c1dad3fd7496461245c32b75..d4a3889d6cb006bb35b61edd8ab20b5564c4b232 100644 (file)
@@ -97,6 +97,21 @@ async def test_rollback_on_exception_exit(aconn):
     assert not await inserted(aconn)
 
 
+async def test_interaction_dbapi_transaction(aconn):
+    await insert_row(aconn, "foo")
+
+    async with aconn.transaction():
+        await insert_row(aconn, "bar")
+        raise Rollback
+
+    async with aconn.transaction():
+        await insert_row(aconn, "baz")
+
+    assert in_transaction(aconn)
+    await aconn.commit()
+    assert await inserted(aconn) == {"foo", "baz"}
+
+
 async def test_prohibits_use_of_commit_rollback_autocommit(aconn):
     """
     Within a Transaction block, it is forbidden to touch commit, rollback,
@@ -334,6 +349,20 @@ async def test_named_savepoints_successful_exit(aconn, commands):
     await tx.__aexit__(None, None, None)
     assert commands.pop() == "commit"
 
+    # Case 1 (with a transaction already started)
+    await (await aconn.cursor()).execute("select 1")
+    assert commands.pop() == "begin"
+    tx = AsyncTransaction(aconn)
+    await tx.__aenter__()
+    assert commands.pop() == 'savepoint "s1"'
+    assert tx.savepoint_name == "s1"
+    await tx.__aexit__(None, None, None)
+    assert commands.pop() == 'release savepoint "s1"'
+    assert not commands
+    await aconn.rollback()
+    assert commands.pop() == "rollback"
+    assert not commands
+
     # Case 2
     tx = AsyncTransaction(aconn, savepoint_name="foo")
     await tx.__aenter__()
@@ -359,10 +388,10 @@ async def test_named_savepoints_successful_exit(aconn, commands):
         assert commands.pop() == "begin"
         tx = AsyncTransaction(aconn)
         await tx.__aenter__()
-        assert commands.pop() == 'savepoint "s1"'
-        assert tx.savepoint_name == "s1"
+        assert commands.pop() == 'savepoint "s2"'
+        assert tx.savepoint_name == "s2"
         await tx.__aexit__(None, None, None)
-        assert commands.pop() == 'release savepoint "s1"'
+        assert commands.pop() == 'release savepoint "s2"'
         assert not commands
     assert commands.pop() == "commit"
 
@@ -411,12 +440,12 @@ async def test_named_savepoints_exception_exit(aconn, commands):
         assert commands.pop() == "begin"
         tx = AsyncTransaction(aconn)
         await tx.__aenter__()
-        assert commands.pop() == 'savepoint "s1"'
-        assert tx.savepoint_name == "s1"
+        assert commands.pop() == 'savepoint "s2"'
+        assert tx.savepoint_name == "s2"
         await tx.__aexit__(*some_exc_info())
         assert (
             commands.pop()
-            == 'rollback to savepoint "s1"; release savepoint "s1"'
+            == 'rollback to savepoint "s2"; release savepoint "s2"'
         )
         assert not commands
     assert commands.pop() == "commit"
@@ -579,30 +608,3 @@ async def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected(
         assert not inserted(svcconn)  # Not yet committed
     # Changes committed
     assert inserted(svcconn) == {"outer-before", "outer-after"}
-
-
-@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()])
-@pytest.mark.parametrize("name", [None, "s1"])
-async def test_manual_exit_without_enter_asserts(aconn, name, exc_info):
-    """
-    When user is calling __enter__() and __exit__() manually for some reason,
-    provide a helpful error message if they call __exit__() without first
-    having called __enter__()
-    """
-    tx = AsyncTransaction(aconn, name)
-    with pytest.raises(ProgrammingError, match="Out-of-order"):
-        await tx.__aexit__(*exc_info)
-
-
-@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()])
-@pytest.mark.parametrize("name", [None, "s1"])
-async def test_manual_exit_twice_asserts(aconn, name, exc_info):
-    """
-    When user is calling __enter__() and __exit__() manually for some reason,
-    provide a helpful error message if they accidentally call __exit__() twice.
-    """
-    tx = AsyncTransaction(aconn, name)
-    await tx.__aenter__()
-    await tx.__aexit__(*exc_info)
-    with pytest.raises(ProgrammingError, match="Out-of-order"):
-        await tx.__aexit__(*exc_info)