]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Raise ProgrammingError on out-of-order exit from transactions
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 8 Dec 2021 23:28:44 +0000 (00:28 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 8 Dec 2021 23:28:44 +0000 (00:28 +0100)
Previously it would have failed an assert.

Further back in time, the condition was checked and reported as
ProgrammingError already. The check was dropped when
`conn.transaction()` started to return an entered transaction, so the
possibility of calling enter/exit manually was taken out of the public
API. However, we didn't consider the possibility of concurrent
threads operating on transaction independently.

Also fix the Transaction representation, which wouldn't have reported
`(terminated)` exiting on rollback, but only on commit.

Close #176.

docs/news.rst
psycopg/psycopg/transaction.py
tests/test_transaction.py
tests/test_transaction_async.py

index aa22ec6f56d30c4b1cf4214949b146d8ee1b8f4c..c02bf5c9d5dcddc54ec586a004a3d65ddfc4cf9c 100644 (file)
@@ -22,6 +22,8 @@ Psycopg 3.0.6 (unreleased)
 - Don't raise exceptions on `ServerCursor.close()` if the connection is closed
   (:ticket:`#173`).
 - Fail on `Connection.cursor()` if the connection is closed (:ticket:`#174`).
+- Raise `ProgrammingError` if out-of-order exit from transaction contexts is
+  detected (:ticket:`#176`).
 - Add `!CHECK_STANDBY` value to `~pq.ConnStatus` enum.
 
 
index 4d01e836f68a3ad9ffda55aa1729a4718f168494..63aa01e34e93553e7c93f1bd53f608efc7b66ba0 100644 (file)
@@ -11,6 +11,7 @@ from typing import Generic, Optional, Type, Union, TYPE_CHECKING
 
 from . import pq
 from . import sql
+from . import errors as e
 from .pq import TransactionStatus
 from .abc import ConnectionType, PQGen
 from .pq.abc import PGresult
@@ -44,6 +45,10 @@ class Rollback(Exception):
         return f"{self.__class__.__qualname__}({self.transaction!r})"
 
 
+class OutOfOrderTransactionNesting(e.ProgrammingError):
+    """Out-of-order transaction nesting detected"""
+
+
 class BaseTransaction(Generic[ConnectionType]):
     def __init__(
         self,
@@ -126,18 +131,24 @@ class BaseTransaction(Generic[ConnectionType]):
             # state) just warn without clobbering the exception bubbling up.
             try:
                 return (yield from self._rollback_gen(exc_val))
+            except OutOfOrderTransactionNesting:
+                # Clobber an exception happened in the block with the exception
+                # caused by out-of-order transaction detected, so make the
+                # behaviour consistent with _commit_gen and to make sure the
+                # user fixes this condition, which is unrelated from
+                # operational error that might arise in the block.
+                raise
             except Exception as exc2:
                 logger.warning(
-                    "error ignored in rollback of %s: %s",
-                    self,
-                    exc2,
+                    "error ignored in rollback of %s: %s", self, exc2
                 )
                 return False
 
     def _commit_gen(self) -> PQGen[PGresult]:
-        assert self._conn._savepoints[-1] == self._savepoint_name
-        self._conn._savepoints.pop()
+        ex = self._pop_savepoint("commit")
         self._exited = True
+        if ex:
+            raise ex
 
         commands = []
         if self._savepoint_name and not self._outer_transaction:
@@ -159,8 +170,10 @@ class BaseTransaction(Generic[ConnectionType]):
                 f"{self._conn}: Explicit rollback from: ", exc_info=True
             )
 
-        assert self._conn._savepoints[-1] == self._savepoint_name
-        self._conn._savepoints.pop()
+        ex = self._pop_savepoint("roll back")
+        self._exited = True
+        if ex:
+            raise ex
 
         commands = []
         if self._savepoint_name and not self._outer_transaction:
@@ -187,6 +200,17 @@ class BaseTransaction(Generic[ConnectionType]):
 
         return False
 
+    def _pop_savepoint(self, action: str) -> Optional[Exception]:
+        sp = self._conn._savepoints.pop()
+        if sp == self._savepoint_name:
+            return None
+
+        other = f"the savepoint {sp!r}" if sp else "the top-level transaction"
+        return OutOfOrderTransactionNesting(
+            f"transactions not correctly nested: {self} would {action}"
+            f" in the wrong order compared to {other}"
+        )
+
 
 class Transaction(BaseTransaction["Connection[Any]"]):
     """
index f569bfc61eca7dfe53c28027264d56567e03d530..07a05bcbd1f382a4ec8c7905a49275fce5feb0ff 100644 (file)
@@ -650,21 +650,41 @@ def test_str(conn):
     assert "[IDLE]" in str(tx)
     assert "(terminated)" in str(tx)
 
+    with pytest.raises(ZeroDivisionError):
+        with conn.transaction() as tx:
+            1 / 0
+
+    assert "(terminated)" in str(tx)
 
-@pytest.mark.parametrize("fail", [False, True])
-def test_concurrency(conn, fail):
+
+@pytest.mark.parametrize("what", ["commit", "rollback", "error"])
+def test_concurrency(conn, what):
     conn.autocommit = True
 
     e = [Event() for i in range(3)]
 
     def worker(unlock, wait_on):
-        with pytest.raises(ProgrammingError):
+        with pytest.raises(ProgrammingError) as ex:
             with conn.transaction():
                 unlock.set()
                 wait_on.wait()
                 conn.execute("select 1")
-                if fail:
+
+                if what == "error":
                     1 / 0
+                elif what == "rollback":
+                    raise Rollback()
+                else:
+                    assert what == "commit"
+
+        if what == "error":
+            assert "would roll back" in str(ex.value)
+            assert isinstance(ex.value.__context__, ZeroDivisionError)
+        elif what == "rollback":
+            assert "would roll back" in str(ex.value)
+            assert isinstance(ex.value.__context__, Rollback)
+        else:
+            assert "would commit" in str(ex.value)
 
     # Start a first transaction in a thread
     t1 = Thread(target=worker, kwargs={"unlock": e[0], "wait_on": e[1]})
index 6b108d57a1a066cc4e88d1bb733d7541a9cf3585..4bfce1595363df89310c4c99d67a09abb5095626 100644 (file)
@@ -618,21 +618,41 @@ async def test_str(aconn):
     assert "[IDLE]" in str(tx)
     assert "(terminated)" in str(tx)
 
+    with pytest.raises(ZeroDivisionError):
+        async with aconn.transaction() as tx:
+            1 / 0
+
+    assert "(terminated)" in str(tx)
 
-@pytest.mark.parametrize("fail", [False, True])
-async def test_concurrency(aconn, fail):
+
+@pytest.mark.parametrize("what", ["commit", "rollback", "error"])
+async def test_concurrency(aconn, what):
     await aconn.set_autocommit(True)
 
     e = [asyncio.Event() for i in range(3)]
 
     async def worker(unlock, wait_on):
-        with pytest.raises(ProgrammingError):
+        with pytest.raises(ProgrammingError) as ex:
             async with aconn.transaction():
                 unlock.set()
                 await wait_on.wait()
                 await aconn.execute("select 1")
-                if fail:
+
+                if what == "error":
                     1 / 0
+                elif what == "rollback":
+                    raise Rollback()
+                else:
+                    assert what == "commit"
+
+        if what == "error":
+            assert "would roll back" in str(ex.value)
+            assert isinstance(ex.value.__context__, ZeroDivisionError)
+        elif what == "rollback":
+            assert "would roll back" in str(ex.value)
+            assert isinstance(ex.value.__context__, Rollback)
+        else:
+            assert "would commit" in str(ex.value)
 
     # Start a first transaction in a task
     t1 = create_task(worker(unlock=e[0], wait_on=e[1]))