]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Refactoring to minimize sync/async savepoint duplications
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 16 Nov 2020 03:31:29 +0000 (03:31 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 16 Nov 2020 04:02:30 +0000 (04:02 +0000)
psycopg3/psycopg3/transaction.py

index 22364d7a0a8bd498c1f7e28bd4d1c22b94b094b0..8879d158250714b1095433ac8da2b660a8d15604 100644 (file)
@@ -74,11 +74,16 @@ class BaseTransaction(Generic[ConnectionType]):
     )
 
     def _pop_savepoint(self) -> None:
-        if self._conn._savepoints is None:
-            raise self._out_of_order_err
-        actual = self._conn._savepoints.pop()
-        if actual != self._savepoint_name:
-            raise self._out_of_order_err
+        if self._savepoint_name:
+            if self._conn._savepoints is None:
+                raise self._out_of_order_err
+            actual = self._conn._savepoints.pop()
+            if actual != self._savepoint_name:
+                raise self._out_of_order_err
+        if self._outer_transaction:
+            if self._conn._savepoints is None or self._conn._savepoints:
+                raise self._out_of_order_err
+            self._conn._savepoints = None
 
 
 class Transaction(BaseTransaction["Connection"]):
@@ -123,18 +128,15 @@ class Transaction(BaseTransaction["Connection"]):
 
     def _commit(self) -> bool:
         """Commit changes made in the transaction context."""
+        self._pop_savepoint()
         if self._savepoint_name:
-            self._pop_savepoint()
             self._conn._exec_command(
                 sql.SQL("release savepoint {}").format(
                     sql.Identifier(self._savepoint_name)
                 )
             )
         if self._outer_transaction:
-            if self._conn._savepoints is None or self._conn._savepoints:
-                raise self._out_of_order_err
             self._conn._exec_command(b"commit")
-            self._conn._savepoints = None
 
         return False  # discarded
 
@@ -145,18 +147,15 @@ class Transaction(BaseTransaction["Connection"]):
                 f"{self._conn}: Explicit rollback from: ", exc_info=True
             )
 
+        self._pop_savepoint()
         if self._savepoint_name:
-            self._pop_savepoint()
             self._conn._exec_command(
                 sql.SQL(
                     "rollback to savepoint {n}; release savepoint {n}"
                 ).format(n=sql.Identifier(self._savepoint_name))
             )
         if self._outer_transaction:
-            if self._conn._savepoints is None or self._conn._savepoints:
-                raise self._out_of_order_err
             self._conn._exec_command(b"rollback")
-            self._conn._savepoints = None
 
         if isinstance(exc_val, Rollback):
             if exc_val.transaction in (self, None):
@@ -207,18 +206,15 @@ class AsyncTransaction(BaseTransaction["AsyncConnection"]):
 
     async def _commit(self) -> bool:
         """Commit changes made in the transaction context."""
+        self._pop_savepoint()
         if self._savepoint_name:
-            self._pop_savepoint()
             await self._conn._exec_command(
                 sql.SQL("release savepoint {}").format(
                     sql.Identifier(self._savepoint_name)
                 )
             )
         if self._outer_transaction:
-            if self._conn._savepoints is None or self._conn._savepoints:
-                raise self._out_of_order_err
             await self._conn._exec_command(b"commit")
-            self._conn._savepoints = None
 
         return False  # discarded
 
@@ -229,18 +225,15 @@ class AsyncTransaction(BaseTransaction["AsyncConnection"]):
                 f"{self._conn}: Explicit rollback from: ", exc_info=True
             )
 
+        self._pop_savepoint()
         if self._savepoint_name:
-            self._pop_savepoint()
             await self._conn._exec_command(
                 sql.SQL(
                     "rollback to savepoint {n}; release savepoint {n}"
                 ).format(n=sql.Identifier(self._savepoint_name))
             )
         if self._outer_transaction:
-            if self._conn._savepoints is None or self._conn._savepoints:
-                raise self._out_of_order_err
             await self._conn._exec_command(b"rollback")
-            self._conn._savepoints = None
 
         if isinstance(exc_val, Rollback):
             if exc_val.transaction in (self, None):