From: Daniele Varrazzo Date: Thu, 9 Dec 2021 14:24:57 +0000 (+0100) Subject: Detect out-of-order transactions exit when they have the same name too X-Git-Tag: pool-3.1~79^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4364cdd6299efcb0cee5f198177d4363fa2fc93d;p=thirdparty%2Fpsycopg.git Detect out-of-order transactions exit when they have the same name too Close #177 --- diff --git a/docs/news.rst b/docs/news.rst index c02bf5c9d..bc266a405 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -23,7 +23,7 @@ Psycopg 3.0.6 (unreleased) (:ticket:`#173`). - Fail on `Connection.cursor()` if the connection is closed (:ticket:`#174`). - Raise `ProgrammingError` if out-of-order exit from transaction contexts is - detected (:ticket:`#176`). + detected (:tickets:`#176, #177`). - Add `!CHECK_STANDBY` value to `~pq.ConnStatus` enum. diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index 7cf508428..c0b695df0 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -110,10 +110,8 @@ class BaseConnection(Generic[Row]): self._notice_handlers: List[NoticeHandler] = [] self._notify_handlers: List[NotifyHandler] = [] - # Stack of savepoint names managed by current transaction blocks. - # the first item is "" in case the outermost Transaction must manage - # only a begin/commit and not a savepoint. - self._savepoints: List[str] = [] + # Number of transaction blocks currently entered + self._num_transactions = 0 self._closed = False # closed by an explicit close() self._prepared: PrepareManager = PrepareManager() @@ -249,7 +247,7 @@ class BaseConnection(Generic[Row]): # Raise an exception if we are in a transaction status = self.pgconn.transaction_status if status != TransactionStatus.IDLE: - if self._savepoints: + if self._num_transactions: raise e.ProgrammingError( f"can't change {attribute!r} now: " "connection.transaction() context in progress" @@ -483,7 +481,7 @@ class BaseConnection(Generic[Row]): def _commit_gen(self) -> PQGen[None]: """Generator implementing `Connection.commit()`.""" - if self._savepoints: + if self._num_transactions: raise e.ProgrammingError( "Explicit commit() forbidden within a Transaction " "context. (Transaction will be automatically committed " @@ -500,7 +498,7 @@ class BaseConnection(Generic[Row]): def _rollback_gen(self) -> PQGen[None]: """Generator implementing `Connection.rollback()`.""" - if self._savepoints: + if self._num_transactions: raise e.ProgrammingError( "Explicit rollback() forbidden within a Transaction " "context. (Either raise Rollback() or allow " diff --git a/psycopg/psycopg/transaction.py b/psycopg/psycopg/transaction.py index 3ea0a0b96..b86517e50 100644 --- a/psycopg/psycopg/transaction.py +++ b/psycopg/psycopg/transaction.py @@ -61,6 +61,7 @@ class BaseTransaction(Generic[ConnectionType]): self.force_rollback = force_rollback self._entered = self._exited = False self._outer_transaction = False + self._stack_index = -1 @property def savepoint_name(self) -> Optional[str]: @@ -135,7 +136,7 @@ class BaseTransaction(Generic[ConnectionType]): f"{self._conn}: Explicit rollback from: ", exc_info=True ) - ex = self._pop_savepoint("roll back") + ex = self._pop_savepoint("rollback") self._exited = True if ex: raise ex @@ -172,7 +173,7 @@ class BaseTransaction(Generic[ConnectionType]): ) if self._outer_transaction: - assert not self._conn._savepoints + assert not self._conn._num_transactions commands.append(b"COMMIT") return commands @@ -187,7 +188,7 @@ class BaseTransaction(Generic[ConnectionType]): ) if self._outer_transaction: - assert not self._conn._savepoints + assert not self._conn._num_transactions commands.append(b"ROLLBACK") # Also clear the prepared statements cache. @@ -209,14 +210,16 @@ class BaseTransaction(Generic[ConnectionType]): if self._outer_transaction: # outer transaction: if no name it's only a begin, else # there will be an additional savepoint - assert not self._conn._savepoints + assert not self._conn._num_transactions else: # inner transaction: it always has a name if not self._savepoint_name: self._savepoint_name = ( - f"_pg3_{len(self._conn._savepoints) + 1}" + f"_pg3_{self._conn._num_transactions + 1}" ) - self._conn._savepoints.append(self._savepoint_name) + + self._stack_index = self._conn._num_transactions + self._conn._num_transactions += 1 def _pop_savepoint(self, action: str) -> Optional[Exception]: """ @@ -224,14 +227,12 @@ class BaseTransaction(Generic[ConnectionType]): Also verify the state consistency. """ - sp = self._conn._savepoints.pop() - if sp == self._savepoint_name: + self._conn._num_transactions -= 1 + if self._conn._num_transactions == self._stack_index: return None - other = f"the savepoint {sp!r}" if sp else "the top-level transaction" return OutOfOrderTransactionNesting( - f"transactions not correctly nested: {self} would {action}" - f" in the wrong order compared to {other}" + f"transaction {action} at the wrong nesting level: {self}" ) diff --git a/tests/test_transaction.py b/tests/test_transaction.py index bc7c374f9..74dc8c5c3 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -533,15 +533,6 @@ def test_named_savepoints_with_repeated_names_works(conn): assert inserted(conn) == {"tx1"} assert inserted(conn) == {"tx1"} - # Will not (always) catch out-of-order exits - with conn.transaction(force_rollback=True): - tx1 = conn.transaction("s1") - tx2 = conn.transaction("s1") - tx1.__enter__() - tx2.__enter__() - tx1.__exit__(None, None, None) - tx2.__exit__(None, None, None) - def test_force_rollback_successful_exit(conn, svcconn): """ @@ -702,6 +693,22 @@ def test_out_of_order_implicit_begin(conn, exit_error): t2.__exit__(*get_exc_info(exit_error)) +@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback]) +def test_out_of_order_exit_same_name(conn, exit_error): + conn.autocommit = True + + t1 = conn.transaction("save") + t1.__enter__() + t2 = conn.transaction("save") + t2.__enter__() + + with pytest.raises(ProgrammingError): + t1.__exit__(*get_exc_info(exit_error)) + + with pytest.raises(ProgrammingError): + t2.__exit__(*get_exc_info(exit_error)) + + @pytest.mark.parametrize("what", ["commit", "rollback", "error"]) def test_concurrency(conn, what): conn.autocommit = True @@ -723,13 +730,13 @@ def test_concurrency(conn, what): assert what == "commit" if what == "error": - assert "would roll back" in str(ex.value) + assert "transaction rollback" in str(ex.value) assert isinstance(ex.value.__context__, ZeroDivisionError) elif what == "rollback": - assert "would roll back" in str(ex.value) + assert "transaction rollback" in str(ex.value) assert isinstance(ex.value.__context__, Rollback) else: - assert "would commit" in str(ex.value) + assert "transaction commit" in str(ex.value) # Start a first transaction in a thread t1 = Thread(target=worker, kwargs={"unlock": e[0], "wait_on": e[1]}) diff --git a/tests/test_transaction_async.py b/tests/test_transaction_async.py index 82df069bb..c36552c11 100644 --- a/tests/test_transaction_async.py +++ b/tests/test_transaction_async.py @@ -488,15 +488,6 @@ async def test_named_savepoints_with_repeated_names_works(aconn): assert await inserted(aconn) == {"tx1"} assert await inserted(aconn) == {"tx1"} - # Will not (always) catch out-of-order exits - async with aconn.transaction(force_rollback=True): - tx1 = aconn.transaction("s1") - tx2 = aconn.transaction("s1") - await tx1.__aenter__() - await tx2.__aenter__() - await tx1.__aexit__(None, None, None) - await tx2.__aexit__(None, None, None) - async def test_force_rollback_successful_exit(aconn, svcconn): """ @@ -659,6 +650,22 @@ async def test_out_of_order_implicit_begin(aconn, exit_error): await t2.__aexit__(*get_exc_info(exit_error)) +@pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback]) +async def test_out_of_order_exit_same_name(aconn, exit_error): + await aconn.set_autocommit(True) + + t1 = aconn.transaction("save") + await t1.__aenter__() + t2 = aconn.transaction("save") + await t2.__aenter__() + + with pytest.raises(ProgrammingError): + await t1.__aexit__(*get_exc_info(exit_error)) + + with pytest.raises(ProgrammingError): + await t2.__aexit__(*get_exc_info(exit_error)) + + @pytest.mark.parametrize("what", ["commit", "rollback", "error"]) async def test_concurrency(aconn, what): await aconn.set_autocommit(True) @@ -680,13 +687,13 @@ async def test_concurrency(aconn, what): assert what == "commit" if what == "error": - assert "would roll back" in str(ex.value) + assert "transaction rollback" in str(ex.value) assert isinstance(ex.value.__context__, ZeroDivisionError) elif what == "rollback": - assert "would roll back" in str(ex.value) + assert "transaction rollback" in str(ex.value) assert isinstance(ex.value.__context__, Rollback) else: - assert "would commit" in str(ex.value) + assert "transaction commit" in str(ex.value) # Start a first transaction in a task t1 = create_task(worker(unlock=e[0], wait_on=e[1]))