]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Refactor Transaction adding an internal push method
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 9 Dec 2021 13:30:33 +0000 (14:30 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 9 Dec 2021 14:30:24 +0000 (15:30 +0100)
Nice and symmetric w.r.t. the pop method recently introduced.

psycopg/psycopg/transaction.py

index 63aa01e34e93553e7c93f1bd53f608efc7b66ba0..fd27a7507416d691fb6befb5b46a80f3fcd84352 100644 (file)
@@ -60,6 +60,7 @@ class BaseTransaction(Generic[ConnectionType]):
         self._savepoint_name = savepoint_name or ""
         self.force_rollback = force_rollback
         self._entered = self._exited = False
+        self._outer_transaction = False
 
     @property
     def savepoint_name(self) -> Optional[str]:
@@ -88,23 +89,10 @@ class BaseTransaction(Generic[ConnectionType]):
             raise TypeError("transaction blocks can be used only once")
         self._entered = True
 
-        self._outer_transaction = (
-            self._conn.pgconn.transaction_status == TransactionStatus.IDLE
-        )
-        if self._outer_transaction:
-            # outer transaction: if no name it's only a begin, else
-            # there will be an additional savepoint
-            assert not self._conn._savepoints
-        else:
-            # inner transaction: it always has a name
-            if not self._savepoint_name:
-                self._savepoint_name = (
-                    f"_pg3_{len(self._conn._savepoints) + 1}"
-                )
+        self._push_savepoint()
 
         commands = []
         if self._outer_transaction:
-            assert not self._conn._savepoints, self._conn._savepoints
             commands.append(self._conn._get_tx_start_command())
 
         if self._savepoint_name:
@@ -114,7 +102,6 @@ class BaseTransaction(Generic[ConnectionType]):
                 .as_bytes(self._conn)
             )
 
-        self._conn._savepoints.append(self._savepoint_name)
         return self._conn._exec_command(b"; ".join(commands))
 
     def _exit_gen(
@@ -200,7 +187,33 @@ class BaseTransaction(Generic[ConnectionType]):
 
         return False
 
+    def _push_savepoint(self) -> None:
+        """
+        Push the transaction on the connection transactions stack.
+
+        Also set the internal state of the object and verify consistency.
+        """
+        self._outer_transaction = (
+            self._conn.pgconn.transaction_status == TransactionStatus.IDLE
+        )
+        if self._outer_transaction:
+            # outer transaction: if no name it's only a begin, else
+            # there will be an additional savepoint
+            assert not self._conn._savepoints
+        else:
+            # inner transaction: it always has a name
+            if not self._savepoint_name:
+                self._savepoint_name = (
+                    f"_pg3_{len(self._conn._savepoints) + 1}"
+                )
+        self._conn._savepoints.append(self._savepoint_name)
+
     def _pop_savepoint(self, action: str) -> Optional[Exception]:
+        """
+        Pop the transaction from the connection transactions stack.
+
+        Also verify the state consistency.
+        """
         sp = self._conn._savepoints.pop()
         if sp == self._savepoint_name:
             return None