]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Commands generation separated from execution in transactions
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 16 Nov 2020 03:52:29 +0000 (03:52 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 16 Nov 2020 04:02:30 +0000 (04:02 +0000)
Commands and state change are independent from sync/async. Only the
execution and the interface is in the different sync/async subclasses
now.

psycopg3/psycopg3/transaction.py

index b159b65b04e2005b76521e49770e6220e6cda458..a24d7874fdfb3e7ca24ba57dc8b058e6207bf9e0 100644 (file)
@@ -7,11 +7,11 @@ Transaction context managers returned by Connection.transaction()
 import logging
 
 from types import TracebackType
-from typing import Generic, Optional, Type, Union, TYPE_CHECKING
+from typing import Generic, List, Optional, Type, Union, TYPE_CHECKING
 
 from . import sql
 from .pq import TransactionStatus
-from .proto import ConnectionType
+from .proto import ConnectionType, Query
 from .errors import ProgrammingError
 
 if TYPE_CHECKING:
@@ -74,6 +74,59 @@ class BaseTransaction(Generic[ConnectionType]):
         "calling __exit__() manually and getting it wrong?"
     )
 
+    def _enter_commands(self) -> List[Query]:
+        commands: List[Query] = []
+
+        if self._outer_transaction:
+            assert self._conn._savepoints is None, self._conn._savepoints
+            self._conn._savepoints = []
+            commands.append(b"begin")
+        else:
+            if self._conn._savepoints is None:
+                self._conn._savepoints = []
+            if not self._savepoint_name:
+                self._savepoint_name = f"s{len(self._conn._savepoints) + 1}"
+
+        if self._savepoint_name:
+            commands.append(
+                sql.SQL("savepoint {}").format(
+                    sql.Identifier(self._savepoint_name)
+                )
+            )
+            self._conn._savepoints.append(self._savepoint_name)
+
+        return commands
+
+    def _commit_commands(self) -> List[Query]:
+        commands: List[Query] = []
+
+        self._pop_savepoint()
+        if self._savepoint_name:
+            commands.append(
+                sql.SQL("release savepoint {}").format(
+                    sql.Identifier(self._savepoint_name)
+                )
+            )
+        if self._outer_transaction:
+            commands.append(b"commit")
+
+        return commands
+
+    def _rollback_commands(self) -> List[Query]:
+        commands: List[Query] = []
+
+        self._pop_savepoint()
+        if self._savepoint_name:
+            commands.append(
+                sql.SQL(
+                    "rollback to savepoint {n}; release savepoint {n}"
+                ).format(n=sql.Identifier(self._savepoint_name))
+            )
+        if self._outer_transaction:
+            commands.append(b"rollback")
+
+        return commands
+
     def _pop_savepoint(self) -> None:
         if self._savepoint_name:
             if self._conn._savepoints is None:
@@ -90,25 +143,7 @@ class BaseTransaction(Generic[ConnectionType]):
 class Transaction(BaseTransaction["Connection"]):
     def __enter__(self) -> "Transaction":
         with self._conn.lock:
-            if self._outer_transaction:
-                assert self._conn._savepoints is None, self._conn._savepoints
-                self._conn._savepoints = []
-                self._conn._exec_command(b"begin")
-            else:
-                if self._conn._savepoints is None:
-                    self._conn._savepoints = []
-                if not self._savepoint_name:
-                    self._savepoint_name = (
-                        f"s{len(self._conn._savepoints) + 1}"
-                    )
-
-            if self._savepoint_name:
-                self._conn._exec_command(
-                    sql.SQL("savepoint {}").format(
-                        sql.Identifier(self._savepoint_name)
-                    )
-                )
-                self._conn._savepoints.append(self._savepoint_name)
+            self._execute(self._enter_commands())
         return self
 
     def __exit__(
@@ -119,23 +154,14 @@ class Transaction(BaseTransaction["Connection"]):
     ) -> bool:
         with self._conn.lock:
             if not exc_val and not self.force_rollback:
-                return self._commit()
+                self._commit()
+                return False
             else:
                 return self._rollback(exc_val)
 
-    def _commit(self) -> bool:
+    def _commit(self) -> None:
         """Commit changes made in the transaction context."""
-        self._pop_savepoint()
-        if self._savepoint_name:
-            self._conn._exec_command(
-                sql.SQL("release savepoint {}").format(
-                    sql.Identifier(self._savepoint_name)
-                )
-            )
-        if self._outer_transaction:
-            self._conn._exec_command(b"commit")
-
-        return False  # discarded
+        self._execute(self._commit_commands())
 
     def _rollback(self, exc_val: Optional[BaseException]) -> bool:
         # Rollback changes made in the transaction context
@@ -144,15 +170,7 @@ class Transaction(BaseTransaction["Connection"]):
                 f"{self._conn}: Explicit rollback from: ", exc_info=True
             )
 
-        self._pop_savepoint()
-        if self._savepoint_name:
-            self._conn._exec_command(
-                sql.SQL(
-                    "rollback to savepoint {n}; release savepoint {n}"
-                ).format(n=sql.Identifier(self._savepoint_name))
-            )
-        if self._outer_transaction:
-            self._conn._exec_command(b"rollback")
+        self._execute(self._rollback_commands())
 
         if isinstance(exc_val, Rollback):
             if exc_val.transaction in (self, None):
@@ -160,29 +178,16 @@ class Transaction(BaseTransaction["Connection"]):
 
         return False
 
+    def _execute(self, commands: List[Query]) -> None:
+        for command in commands:
+            self._conn._exec_command(command)
+
 
 class AsyncTransaction(BaseTransaction["AsyncConnection"]):
     async def __aenter__(self) -> "AsyncTransaction":
         async with self._conn.lock:
-            if self._outer_transaction:
-                assert self._conn._savepoints is None, self._conn._savepoints
-                self._conn._savepoints = []
-                await self._conn._exec_command(b"begin")
-            else:
-                if self._conn._savepoints is None:
-                    self._conn._savepoints = []
-                if not self._savepoint_name:
-                    self._savepoint_name = (
-                        f"s{len(self._conn._savepoints) + 1}"
-                    )
-
-            if self._savepoint_name:
-                await self._conn._exec_command(
-                    sql.SQL("savepoint {}").format(
-                        sql.Identifier(self._savepoint_name)
-                    )
-                )
-                self._conn._savepoints.append(self._savepoint_name)
+            await self._execute(self._enter_commands())
+
         return self
 
     async def __aexit__(
@@ -193,23 +198,14 @@ class AsyncTransaction(BaseTransaction["AsyncConnection"]):
     ) -> bool:
         async with self._conn.lock:
             if not exc_val and not self.force_rollback:
-                return await self._commit()
+                await self._commit()
+                return False
             else:
                 return await self._rollback(exc_val)
 
-    async def _commit(self) -> bool:
+    async def _commit(self) -> None:
         """Commit changes made in the transaction context."""
-        self._pop_savepoint()
-        if self._savepoint_name:
-            await self._conn._exec_command(
-                sql.SQL("release savepoint {}").format(
-                    sql.Identifier(self._savepoint_name)
-                )
-            )
-        if self._outer_transaction:
-            await self._conn._exec_command(b"commit")
-
-        return False  # discarded
+        await self._execute(self._commit_commands())
 
     async def _rollback(self, exc_val: Optional[BaseException]) -> bool:
         # Rollback changes made in the transaction context
@@ -218,18 +214,14 @@ class AsyncTransaction(BaseTransaction["AsyncConnection"]):
                 f"{self._conn}: Explicit rollback from: ", exc_info=True
             )
 
-        self._pop_savepoint()
-        if self._savepoint_name:
-            await self._conn._exec_command(
-                sql.SQL(
-                    "rollback to savepoint {n}; release savepoint {n}"
-                ).format(n=sql.Identifier(self._savepoint_name))
-            )
-        if self._outer_transaction:
-            await self._conn._exec_command(b"rollback")
+        await self._execute(self._rollback_commands())
 
         if isinstance(exc_val, Rollback):
             if exc_val.transaction in (self, None):
                 return True  # Swallow the exception
 
         return False
+
+    async def _execute(self, commands: List[Query]) -> None:
+        for command in commands:
+            await self._conn._exec_command(command)