]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Detect out-of-order transactions exit when they have the same name too
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 9 Dec 2021 14:24:57 +0000 (15:24 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 9 Dec 2021 14:37:51 +0000 (15:37 +0100)
Close #177

docs/news.rst
psycopg/psycopg/connection.py
psycopg/psycopg/transaction.py
tests/test_transaction.py
tests/test_transaction_async.py

index c02bf5c9d5dcddc54ec586a004a3d65ddfc4cf9c..bc266a405e2683112b6757bc7e3871eb48d8cabf 100644 (file)
@@ -23,7 +23,7 @@ Psycopg 3.0.6 (unreleased)
   (:ticket:`#173`).
 - Fail on `Connection.cursor()` if the connection is closed (:ticket:`#174`).
 - Raise `ProgrammingError` if out-of-order exit from transaction contexts is
-  detected (:ticket:`#176`).
+  detected (:tickets:`#176, #177`).
 - Add `!CHECK_STANDBY` value to `~pq.ConnStatus` enum.
 
 
index 7cf508428174731073ee8f439fd4ec808887436c..c0b695df0174ad2bd25035f22710812d696f3444 100644 (file)
@@ -110,10 +110,8 @@ class BaseConnection(Generic[Row]):
         self._notice_handlers: List[NoticeHandler] = []
         self._notify_handlers: List[NotifyHandler] = []
 
-        # Stack of savepoint names managed by current transaction blocks.
-        # the first item is "" in case the outermost Transaction must manage
-        # only a begin/commit and not a savepoint.
-        self._savepoints: List[str] = []
+        # Number of transaction blocks currently entered
+        self._num_transactions = 0
 
         self._closed = False  # closed by an explicit close()
         self._prepared: PrepareManager = PrepareManager()
@@ -249,7 +247,7 @@ class BaseConnection(Generic[Row]):
         # Raise an exception if we are in a transaction
         status = self.pgconn.transaction_status
         if status != TransactionStatus.IDLE:
-            if self._savepoints:
+            if self._num_transactions:
                 raise e.ProgrammingError(
                     f"can't change {attribute!r} now: "
                     "connection.transaction() context in progress"
@@ -483,7 +481,7 @@ class BaseConnection(Generic[Row]):
 
     def _commit_gen(self) -> PQGen[None]:
         """Generator implementing `Connection.commit()`."""
-        if self._savepoints:
+        if self._num_transactions:
             raise e.ProgrammingError(
                 "Explicit commit() forbidden within a Transaction "
                 "context. (Transaction will be automatically committed "
@@ -500,7 +498,7 @@ class BaseConnection(Generic[Row]):
 
     def _rollback_gen(self) -> PQGen[None]:
         """Generator implementing `Connection.rollback()`."""
-        if self._savepoints:
+        if self._num_transactions:
             raise e.ProgrammingError(
                 "Explicit rollback() forbidden within a Transaction "
                 "context. (Either raise Rollback() or allow "
index 3ea0a0b96f3109728606d26b2f923b75a25effdb..b86517e50d991e78377593d5c5f2f6340b43b864 100644 (file)
@@ -61,6 +61,7 @@ class BaseTransaction(Generic[ConnectionType]):
         self.force_rollback = force_rollback
         self._entered = self._exited = False
         self._outer_transaction = False
+        self._stack_index = -1
 
     @property
     def savepoint_name(self) -> Optional[str]:
@@ -135,7 +136,7 @@ class BaseTransaction(Generic[ConnectionType]):
                 f"{self._conn}: Explicit rollback from: ", exc_info=True
             )
 
-        ex = self._pop_savepoint("roll back")
+        ex = self._pop_savepoint("rollback")
         self._exited = True
         if ex:
             raise ex
@@ -172,7 +173,7 @@ class BaseTransaction(Generic[ConnectionType]):
             )
 
         if self._outer_transaction:
-            assert not self._conn._savepoints
+            assert not self._conn._num_transactions
             commands.append(b"COMMIT")
 
         return commands
@@ -187,7 +188,7 @@ class BaseTransaction(Generic[ConnectionType]):
             )
 
         if self._outer_transaction:
-            assert not self._conn._savepoints
+            assert not self._conn._num_transactions
             commands.append(b"ROLLBACK")
 
         # Also clear the prepared statements cache.
@@ -209,14 +210,16 @@ class BaseTransaction(Generic[ConnectionType]):
         if self._outer_transaction:
             # outer transaction: if no name it's only a begin, else
             # there will be an additional savepoint
-            assert not self._conn._savepoints
+            assert not self._conn._num_transactions
         else:
             # inner transaction: it always has a name
             if not self._savepoint_name:
                 self._savepoint_name = (
-                    f"_pg3_{len(self._conn._savepoints) + 1}"
+                    f"_pg3_{self._conn._num_transactions + 1}"
                 )
-        self._conn._savepoints.append(self._savepoint_name)
+
+        self._stack_index = self._conn._num_transactions
+        self._conn._num_transactions += 1
 
     def _pop_savepoint(self, action: str) -> Optional[Exception]:
         """
@@ -224,14 +227,12 @@ class BaseTransaction(Generic[ConnectionType]):
 
         Also verify the state consistency.
         """
-        sp = self._conn._savepoints.pop()
-        if sp == self._savepoint_name:
+        self._conn._num_transactions -= 1
+        if self._conn._num_transactions == self._stack_index:
             return None
 
-        other = f"the savepoint {sp!r}" if sp else "the top-level transaction"
         return OutOfOrderTransactionNesting(
-            f"transactions not correctly nested: {self} would {action}"
-            f" in the wrong order compared to {other}"
+            f"transaction {action} at the wrong nesting level: {self}"
         )
 
 
index bc7c374f9352d668299036cc3d1368ff1bb029da..74dc8c5c3058fe8c09a4aa8228381007646379ac 100644 (file)
@@ -533,15 +533,6 @@ def test_named_savepoints_with_repeated_names_works(conn):
             assert inserted(conn) == {"tx1"}
         assert inserted(conn) == {"tx1"}
 
-    # Will not (always) catch out-of-order exits
-    with conn.transaction(force_rollback=True):
-        tx1 = conn.transaction("s1")
-        tx2 = conn.transaction("s1")
-        tx1.__enter__()
-        tx2.__enter__()
-        tx1.__exit__(None, None, None)
-        tx2.__exit__(None, None, None)
-
 
 def test_force_rollback_successful_exit(conn, svcconn):
     """
@@ -702,6 +693,22 @@ def test_out_of_order_implicit_begin(conn, exit_error):
         t2.__exit__(*get_exc_info(exit_error))
 
 
+@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback])
+def test_out_of_order_exit_same_name(conn, exit_error):
+    conn.autocommit = True
+
+    t1 = conn.transaction("save")
+    t1.__enter__()
+    t2 = conn.transaction("save")
+    t2.__enter__()
+
+    with pytest.raises(ProgrammingError):
+        t1.__exit__(*get_exc_info(exit_error))
+
+    with pytest.raises(ProgrammingError):
+        t2.__exit__(*get_exc_info(exit_error))
+
+
 @pytest.mark.parametrize("what", ["commit", "rollback", "error"])
 def test_concurrency(conn, what):
     conn.autocommit = True
@@ -723,13 +730,13 @@ def test_concurrency(conn, what):
                     assert what == "commit"
 
         if what == "error":
-            assert "would roll back" in str(ex.value)
+            assert "transaction rollback" in str(ex.value)
             assert isinstance(ex.value.__context__, ZeroDivisionError)
         elif what == "rollback":
-            assert "would roll back" in str(ex.value)
+            assert "transaction rollback" in str(ex.value)
             assert isinstance(ex.value.__context__, Rollback)
         else:
-            assert "would commit" in str(ex.value)
+            assert "transaction commit" in str(ex.value)
 
     # Start a first transaction in a thread
     t1 = Thread(target=worker, kwargs={"unlock": e[0], "wait_on": e[1]})
index 82df069bb222614b25612ae6225603faf570f4e7..c36552c11a5498f79b6d5db1fb66cd4746010e05 100644 (file)
@@ -488,15 +488,6 @@ async def test_named_savepoints_with_repeated_names_works(aconn):
             assert await inserted(aconn) == {"tx1"}
         assert await inserted(aconn) == {"tx1"}
 
-    # Will not (always) catch out-of-order exits
-    async with aconn.transaction(force_rollback=True):
-        tx1 = aconn.transaction("s1")
-        tx2 = aconn.transaction("s1")
-        await tx1.__aenter__()
-        await tx2.__aenter__()
-        await tx1.__aexit__(None, None, None)
-        await tx2.__aexit__(None, None, None)
-
 
 async def test_force_rollback_successful_exit(aconn, svcconn):
     """
@@ -659,6 +650,22 @@ async def test_out_of_order_implicit_begin(aconn, exit_error):
         await t2.__aexit__(*get_exc_info(exit_error))
 
 
+@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback])
+async def test_out_of_order_exit_same_name(aconn, exit_error):
+    await aconn.set_autocommit(True)
+
+    t1 = aconn.transaction("save")
+    await t1.__aenter__()
+    t2 = aconn.transaction("save")
+    await t2.__aenter__()
+
+    with pytest.raises(ProgrammingError):
+        await t1.__aexit__(*get_exc_info(exit_error))
+
+    with pytest.raises(ProgrammingError):
+        await t2.__aexit__(*get_exc_info(exit_error))
+
+
 @pytest.mark.parametrize("what", ["commit", "rollback", "error"])
 async def test_concurrency(aconn, what):
     await aconn.set_autocommit(True)
@@ -680,13 +687,13 @@ async def test_concurrency(aconn, what):
                     assert what == "commit"
 
         if what == "error":
-            assert "would roll back" in str(ex.value)
+            assert "transaction rollback" in str(ex.value)
             assert isinstance(ex.value.__context__, ZeroDivisionError)
         elif what == "rollback":
-            assert "would roll back" in str(ex.value)
+            assert "transaction rollback" in str(ex.value)
             assert isinstance(ex.value.__context__, Rollback)
         else:
-            assert "would commit" in str(ex.value)
+            assert "transaction commit" in str(ex.value)
 
     # Start a first transaction in a task
     t1 = create_task(worker(unlock=e[0], wait_on=e[1]))