]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat: add Transaction attribute to record the transaction state
authorimmortalcodes <21112002mj@gmail.com>
Tue, 14 Oct 2025 12:41:57 +0000 (18:11 +0530)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 22 Nov 2025 18:12:55 +0000 (19:12 +0100)
psycopg/psycopg/transaction.py
tests/test_transaction.py
tests/test_transaction_async.py

index 22d40d6ba6b06b1cf74a603db2c6a055ee1f4edd..e7713aba2c3a8beb297ffc8fc32c352cfa09ca2e 100644 (file)
@@ -7,6 +7,7 @@ Transaction context managers returned by Connection.transaction()
 from __future__ import annotations
 
 import logging
+from enum import Enum
 from types import TracebackType
 from typing import TYPE_CHECKING, Any, Generic
 from collections.abc import Iterator
@@ -51,6 +52,14 @@ class OutOfOrderTransactionNesting(e.ProgrammingError):
 
 
 class BaseTransaction(Generic[ConnectionType]):
+    class Status(Enum):
+        NOT_STARTED = "not_started"
+        ACTIVE = "active"
+        COMMITTED = "committed"
+        FAILED = "failed"
+        ROLLED_BACK_EXPLICITLY = "rolled_back_explicitly"
+        ROLLED_BACK_WITH_ERROR = "rolled_back_with_error"
+
     def __init__(
         self,
         connection: ConnectionType,
@@ -62,6 +71,7 @@ class BaseTransaction(Generic[ConnectionType]):
         self._savepoint_name = savepoint_name or ""
         self.force_rollback = force_rollback
         self._entered = self._exited = False
+        self.status = self.Status.NOT_STARTED
         self._outer_transaction = False
         self._stack_index = -1
 
@@ -77,20 +87,14 @@ class BaseTransaction(Generic[ConnectionType]):
     def __repr__(self) -> str:
         cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
         info = connection_summary(self.pgconn)
-        if not self._entered:
-            status = "inactive"
-        elif not self._exited:
-            status = "active"
-        else:
-            status = "terminated"
-
         sp = f"{self.savepoint_name!r} " if self.savepoint_name else ""
-        return f"<{cls} {sp}({status}) {info} at 0x{id(self):x}>"
+        return f"<{cls} {sp}({self.status.value}) {info} at 0x{id(self):x}>"
 
     def _enter_gen(self) -> PQGen[None]:
         if self._entered:
             raise TypeError("transaction blocks can be used only once")
         self._entered = True
+        self.status = self.Status.ACTIVE
 
         self._push_savepoint()
         for command in self._get_enter_commands():
@@ -124,6 +128,7 @@ class BaseTransaction(Generic[ConnectionType]):
     def _commit_gen(self) -> PQGen[None]:
         ex = self._pop_savepoint("commit")
         self._exited = True
+        self.status = self.Status.COMMITTED
         if ex:
             raise ex
 
@@ -136,6 +141,12 @@ class BaseTransaction(Generic[ConnectionType]):
 
         ex = self._pop_savepoint("rollback")
         self._exited = True
+
+        if isinstance(exc_val, Rollback) or self.force_rollback:
+            self.status = self.Status.ROLLED_BACK_EXPLICITLY
+        else:
+            self.status = self.Status.ROLLED_BACK_WITH_ERROR
+
         if ex:
             raise ex
 
@@ -253,6 +264,7 @@ class Transaction(BaseTransaction["Connection[Any]"]):
             with self._conn.lock:
                 return self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
         else:
+            self.status = self.Status.FAILED
             return False
 
 
@@ -282,4 +294,5 @@ class AsyncTransaction(BaseTransaction["AsyncConnection[Any]"]):
             async with self._conn.lock:
                 return await self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
         else:
+            self.status = self.Status.FAILED
             return False
index 0e94ff4c5536fa2a7ee0db0738f2854122c7b1fe..8ae55d5e03bac49045a823f860f1abaa2bdacd47 100644 (file)
@@ -523,6 +523,128 @@ def test_force_rollback_exception_exit(conn, svcconn):
     assert not inserted(svcconn)
 
 
+def test_transaction_status(conn_cls, dsn):
+    conn = conn_cls.connect(dsn)
+
+    """
+    The Transaction.status property ends up in committed state when no exceptions
+    are raised and force_rollback is False(default).
+    """
+    assert conn.pgconn.transaction_status == pq.TransactionStatus.IDLE
+    with conn.transaction() as tx:
+        assert tx.status == tx.Status.ACTIVE
+        assert conn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+    assert tx.status == tx.Status.COMMITTED
+    assert conn.pgconn.transaction_status == pq.TransactionStatus.IDLE
+
+    """
+    The Transaction.status property ends up in rolled_back_with_error state when an
+    exception is raised within the transaction block.
+    """
+    try:
+        with conn.transaction() as tx:
+            assert tx.status == tx.Status.ACTIVE
+            assert conn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+            1 / 0
+    except ZeroDivisionError:
+        pass
+    assert tx.status == tx.Status.ROLLED_BACK_WITH_ERROR
+    assert conn.pgconn.transaction_status == pq.TransactionStatus.IDLE
+
+    """
+    The Transaction.status property ends up in rolled_back_explicitly state when a
+    Rollback exception is raised within the transaction block.
+    """
+    try:
+        with conn.transaction() as tx:
+            assert tx.status == tx.Status.ACTIVE
+            assert conn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+            raise Rollback()
+    except Rollback:
+        pass
+    assert tx.status == tx.Status.ROLLED_BACK_EXPLICITLY
+    assert conn.pgconn.transaction_status == pq.TransactionStatus.IDLE
+
+    """
+    The Transaction.status property ends up in rolled_back_explicitly state when a
+    Transaction is created with force_rollback=True.
+    """
+    with conn.transaction(force_rollback=True) as tx:
+        assert tx.status == tx.Status.ACTIVE
+        assert conn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+    assert tx.status == tx.Status.ROLLED_BACK_EXPLICITLY
+    assert conn.pgconn.transaction_status == pq.TransactionStatus.IDLE
+
+    """
+    The Transaction.status property ends up in FAILED state when the connection
+    is broken within the transaction block.
+    """
+    with conn.transaction() as tx:
+        assert tx.status == tx.Status.ACTIVE
+        assert conn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+        conn.close()
+        assert conn.pgconn.status == pq.ConnStatus.BAD
+    assert tx.status == tx.Status.FAILED
+
+
+def test_nested_transaction_status(conn_cls, dsn):
+    conn = conn_cls.connect(dsn)
+
+    """
+    Testing nested transactions status property behavior.
+    This is a basic test case where the outer transaction commits successfully.
+    """
+    with conn.transaction() as tx1:
+        assert tx1.status == tx1.Status.ACTIVE
+        assert conn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+        with conn.transaction() as tx2:
+            assert tx2.status == tx2.Status.ACTIVE
+            assert conn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+        assert tx2.status == tx2.Status.COMMITTED
+        assert conn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+
+        try:
+            with conn.transaction() as tx3:
+                assert tx3.status == tx3.Status.ACTIVE
+                assert conn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+                1 / 0
+        except ZeroDivisionError:
+            pass
+        assert tx3.status == tx3.Status.ROLLED_BACK_WITH_ERROR
+        assert conn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+
+        with conn.transaction() as tx4:
+            assert tx4.status == tx4.Status.ACTIVE
+            assert conn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+            raise Rollback()
+        assert tx4.status == tx4.Status.ROLLED_BACK_EXPLICITLY
+        assert conn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+
+        with conn.transaction(force_rollback=True) as tx5:
+            assert tx5.status == tx5.Status.ACTIVE
+            assert conn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+        assert tx5.status == tx5.Status.ROLLED_BACK_EXPLICITLY
+        assert conn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+
+    assert tx1.status == tx1.Status.COMMITTED
+    assert conn.pgconn.transaction_status == pq.TransactionStatus.IDLE
+
+    """
+    Testing nested transactions status property behavior.
+    This test case checks the scenario where the inner transaction fails
+    """
+    with conn.transaction() as tx6:
+        assert tx6.status == tx6.Status.ACTIVE
+        assert conn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+        with conn.transaction() as tx7:
+            assert tx7.status == tx7.Status.ACTIVE
+            assert conn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+            conn.close()
+            assert conn.pgconn.status == pq.ConnStatus.BAD
+    assert tx7.status == tx7.Status.FAILED
+    assert tx6.status == tx6.Status.FAILED
+
+
 @crdb_skip_external_observer
 def test_explicit_rollback_discards_changes(conn, svcconn):
     """
@@ -629,13 +751,13 @@ def test_str(conn, pipeline):
         assert "[IDLE, pipeline=ON]" in str(tx)
     else:
         assert "[IDLE]" in str(tx)
-    assert "(terminated)" in str(tx)
+    assert "(committed)" in str(tx)
 
     with pytest.raises(ZeroDivisionError):
         with conn.transaction() as tx:
             1 / 0
 
-    assert "(terminated)" in str(tx)
+    assert "(rolled_back_with_error)" in str(tx)
 
 
 @pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback])
index a2ad4765b53c75c29d7ccd54f17b8c1e09128da9..97f2b7afb01216ecddf6bee8d511782257a1e2b1 100644 (file)
@@ -530,6 +530,128 @@ async def test_force_rollback_exception_exit(aconn, svcconn):
     assert not inserted(svcconn)
 
 
+async def test_transaction_status(aconn_cls, dsn):
+    aconn = await aconn_cls.connect(dsn)
+
+    """
+    The Transaction.status property ends up in committed state when no exceptions
+    are raised and force_rollback is False(default).
+    """
+    assert aconn.pgconn.transaction_status == pq.TransactionStatus.IDLE
+    async with aconn.transaction() as tx:
+        assert tx.status == tx.Status.ACTIVE
+        assert aconn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+    assert tx.status == tx.Status.COMMITTED
+    assert aconn.pgconn.transaction_status == pq.TransactionStatus.IDLE
+
+    """
+    The Transaction.status property ends up in rolled_back_with_error state when an
+    exception is raised within the transaction block.
+    """
+    try:
+        async with aconn.transaction() as tx:
+            assert tx.status == tx.Status.ACTIVE
+            assert aconn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+            1 / 0
+    except ZeroDivisionError:
+        pass
+    assert tx.status == tx.Status.ROLLED_BACK_WITH_ERROR
+    assert aconn.pgconn.transaction_status == pq.TransactionStatus.IDLE
+
+    """
+    The Transaction.status property ends up in rolled_back_explicitly state when a
+    Rollback exception is raised within the transaction block.
+    """
+    try:
+        async with aconn.transaction() as tx:
+            assert tx.status == tx.Status.ACTIVE
+            assert aconn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+            raise Rollback()
+    except Rollback:
+        pass
+    assert tx.status == tx.Status.ROLLED_BACK_EXPLICITLY
+    assert aconn.pgconn.transaction_status == pq.TransactionStatus.IDLE
+
+    """
+    The Transaction.status property ends up in rolled_back_explicitly state when a
+    Transaction is created with force_rollback=True.
+    """
+    async with aconn.transaction(force_rollback=True) as tx:
+        assert tx.status == tx.Status.ACTIVE
+        assert aconn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+    assert tx.status == tx.Status.ROLLED_BACK_EXPLICITLY
+    assert aconn.pgconn.transaction_status == pq.TransactionStatus.IDLE
+
+    """
+    The Transaction.status property ends up in FAILED state when the connection
+    is broken within the transaction block.
+    """
+    async with aconn.transaction() as tx:
+        assert tx.status == tx.Status.ACTIVE
+        assert aconn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+        await aconn.close()
+        assert aconn.pgconn.status == pq.ConnStatus.BAD
+    assert tx.status == tx.Status.FAILED
+
+
+async def test_nested_transaction_status(aconn_cls, dsn):
+    aconn = await aconn_cls.connect(dsn)
+
+    """
+    Testing nested transactions status property behavior.
+    This is a basic test case where the outer transaction commits successfully.
+    """
+    async with aconn.transaction() as tx1:
+        assert tx1.status == tx1.Status.ACTIVE
+        assert aconn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+        async with aconn.transaction() as tx2:
+            assert tx2.status == tx2.Status.ACTIVE
+            assert aconn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+        assert tx2.status == tx2.Status.COMMITTED
+        assert aconn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+
+        try:
+            async with aconn.transaction() as tx3:
+                assert tx3.status == tx3.Status.ACTIVE
+                assert aconn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+                1 / 0
+        except ZeroDivisionError:
+            pass
+        assert tx3.status == tx3.Status.ROLLED_BACK_WITH_ERROR
+        assert aconn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+
+        async with aconn.transaction() as tx4:
+            assert tx4.status == tx4.Status.ACTIVE
+            assert aconn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+            raise Rollback()
+        assert tx4.status == tx4.Status.ROLLED_BACK_EXPLICITLY
+        assert aconn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+
+        async with aconn.transaction(force_rollback=True) as tx5:
+            assert tx5.status == tx5.Status.ACTIVE
+            assert aconn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+        assert tx5.status == tx5.Status.ROLLED_BACK_EXPLICITLY
+        assert aconn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+
+    assert tx1.status == tx1.Status.COMMITTED
+    assert aconn.pgconn.transaction_status == pq.TransactionStatus.IDLE
+
+    """
+    Testing nested transactions status property behavior.
+    This test case checks the scenario where the inner transaction fails
+    """
+    async with aconn.transaction() as tx6:
+        assert tx6.status == tx6.Status.ACTIVE
+        assert aconn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+        async with aconn.transaction() as tx7:
+            assert tx7.status == tx7.Status.ACTIVE
+            assert aconn.pgconn.transaction_status == pq.TransactionStatus.INTRANS
+            await aconn.close()
+            assert aconn.pgconn.status == pq.ConnStatus.BAD
+    assert tx7.status == tx7.Status.FAILED
+    assert tx6.status == tx6.Status.FAILED
+
+
 @crdb_skip_external_observer
 async def test_explicit_rollback_discards_changes(aconn, svcconn):
     """
@@ -636,13 +758,13 @@ async def test_str(aconn, apipeline):
         assert "[IDLE, pipeline=ON]" in str(tx)
     else:
         assert "[IDLE]" in str(tx)
-    assert "(terminated)" in str(tx)
+    assert "(committed)" in str(tx)
 
     with pytest.raises(ZeroDivisionError):
         async with aconn.transaction() as tx:
             1 / 0
 
-    assert "(terminated)" in str(tx)
+    assert "(rolled_back_with_error)" in str(tx)
 
 
 @pytest.mark.parametrize("exit_error", [None, ZeroDivisionError, Rollback])