]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Refactor Transaction queries generation into internal methods
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 9 Dec 2021 13:55:28 +0000 (14:55 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 9 Dec 2021 14:30:24 +0000 (15:30 +0100)
psycopg/psycopg/transaction.py

index fd27a7507416d691fb6befb5b46a80f3fcd84352..90cc69a22a5cd7a8a5b22ffac57d524fc40d54b3 100644 (file)
@@ -7,7 +7,7 @@ Transaction context managers returned by Connection.transaction()
 import logging
 
 from types import TracebackType
-from typing import Generic, Optional, Type, Union, TYPE_CHECKING
+from typing import Generic, List, Optional, Type, Union, TYPE_CHECKING
 
 from . import pq
 from . import sql
@@ -90,18 +90,7 @@ class BaseTransaction(Generic[ConnectionType]):
         self._entered = True
 
         self._push_savepoint()
-
-        commands = []
-        if self._outer_transaction:
-            commands.append(self._conn._get_tx_start_command())
-
-        if self._savepoint_name:
-            commands.append(
-                sql.SQL("SAVEPOINT {}")
-                .format(sql.Identifier(self._savepoint_name))
-                .as_bytes(self._conn)
-            )
-
+        commands = self._get_enter_commands()
         return self._conn._exec_command(b"; ".join(commands))
 
     def _exit_gen(
@@ -137,18 +126,7 @@ class BaseTransaction(Generic[ConnectionType]):
         if ex:
             raise ex
 
-        commands = []
-        if self._savepoint_name and not self._outer_transaction:
-            commands.append(
-                sql.SQL("RELEASE {}")
-                .format(sql.Identifier(self._savepoint_name))
-                .as_bytes(self._conn)
-            )
-
-        if self._outer_transaction:
-            assert not self._conn._savepoints
-            commands.append(b"COMMIT")
-
+        commands = self._get_commit_commands()
         return self._conn._exec_command(b"; ".join(commands))
 
     def _rollback_gen(self, exc_val: Optional[BaseException]) -> PQGen[bool]:
@@ -162,6 +140,44 @@ class BaseTransaction(Generic[ConnectionType]):
         if ex:
             raise ex
 
+        commands = self._get_rollback_commands()
+        yield from self._conn._exec_command(b"; ".join(commands))
+
+        if isinstance(exc_val, Rollback):
+            if not exc_val.transaction or exc_val.transaction is self:
+                return True  # Swallow the exception
+
+        return False
+
+    def _get_enter_commands(self) -> List[bytes]:
+        commands = []
+        if self._outer_transaction:
+            commands.append(self._conn._get_tx_start_command())
+
+        if self._savepoint_name:
+            commands.append(
+                sql.SQL("SAVEPOINT {}")
+                .format(sql.Identifier(self._savepoint_name))
+                .as_bytes(self._conn)
+            )
+        return commands
+
+    def _get_commit_commands(self) -> List[bytes]:
+        commands = []
+        if self._savepoint_name and not self._outer_transaction:
+            commands.append(
+                sql.SQL("RELEASE {}")
+                .format(sql.Identifier(self._savepoint_name))
+                .as_bytes(self._conn)
+            )
+
+        if self._outer_transaction:
+            assert not self._conn._savepoints
+            commands.append(b"COMMIT")
+
+        return commands
+
+    def _get_rollback_commands(self) -> List[bytes]:
         commands = []
         if self._savepoint_name and not self._outer_transaction:
             commands.append(
@@ -179,13 +195,7 @@ class BaseTransaction(Generic[ConnectionType]):
             for cmd in self._conn._prepared.get_maintenance_commands():
                 commands.append(cmd)
 
-        yield from self._conn._exec_command(b"; ".join(commands))
-
-        if isinstance(exc_val, Rollback):
-            if not exc_val.transaction or exc_val.transaction is self:
-                return True  # Swallow the exception
-
-        return False
+        return commands
 
     def _push_savepoint(self) -> None:
         """