From: Daniele Varrazzo Date: Mon, 16 Nov 2020 03:31:29 +0000 (+0000) Subject: Refactoring to minimize sync/async savepoint duplications X-Git-Tag: 3.0.dev0~351^2~7 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=cbfed0d1f75e61c4b3a243111843b385d39b8f9c;p=thirdparty%2Fpsycopg.git Refactoring to minimize sync/async savepoint duplications --- diff --git a/psycopg3/psycopg3/transaction.py b/psycopg3/psycopg3/transaction.py index 22364d7a0..8879d1582 100644 --- a/psycopg3/psycopg3/transaction.py +++ b/psycopg3/psycopg3/transaction.py @@ -74,11 +74,16 @@ class BaseTransaction(Generic[ConnectionType]): ) def _pop_savepoint(self) -> None: - if self._conn._savepoints is None: - raise self._out_of_order_err - actual = self._conn._savepoints.pop() - if actual != self._savepoint_name: - raise self._out_of_order_err + if self._savepoint_name: + if self._conn._savepoints is None: + raise self._out_of_order_err + actual = self._conn._savepoints.pop() + if actual != self._savepoint_name: + raise self._out_of_order_err + if self._outer_transaction: + if self._conn._savepoints is None or self._conn._savepoints: + raise self._out_of_order_err + self._conn._savepoints = None class Transaction(BaseTransaction["Connection"]): @@ -123,18 +128,15 @@ class Transaction(BaseTransaction["Connection"]): def _commit(self) -> bool: """Commit changes made in the transaction context.""" + self._pop_savepoint() if self._savepoint_name: - self._pop_savepoint() self._conn._exec_command( sql.SQL("release savepoint {}").format( sql.Identifier(self._savepoint_name) ) ) if self._outer_transaction: - if self._conn._savepoints is None or self._conn._savepoints: - raise self._out_of_order_err self._conn._exec_command(b"commit") - self._conn._savepoints = None return False # discarded @@ -145,18 +147,15 @@ class Transaction(BaseTransaction["Connection"]): f"{self._conn}: Explicit rollback from: ", exc_info=True ) + self._pop_savepoint() if self._savepoint_name: - self._pop_savepoint() self._conn._exec_command( sql.SQL( "rollback to savepoint {n}; release savepoint {n}" ).format(n=sql.Identifier(self._savepoint_name)) ) if self._outer_transaction: - if self._conn._savepoints is None or self._conn._savepoints: - raise self._out_of_order_err self._conn._exec_command(b"rollback") - self._conn._savepoints = None if isinstance(exc_val, Rollback): if exc_val.transaction in (self, None): @@ -207,18 +206,15 @@ class AsyncTransaction(BaseTransaction["AsyncConnection"]): async def _commit(self) -> bool: """Commit changes made in the transaction context.""" + self._pop_savepoint() if self._savepoint_name: - self._pop_savepoint() await self._conn._exec_command( sql.SQL("release savepoint {}").format( sql.Identifier(self._savepoint_name) ) ) if self._outer_transaction: - if self._conn._savepoints is None or self._conn._savepoints: - raise self._out_of_order_err await self._conn._exec_command(b"commit") - self._conn._savepoints = None return False # discarded @@ -229,18 +225,15 @@ class AsyncTransaction(BaseTransaction["AsyncConnection"]): f"{self._conn}: Explicit rollback from: ", exc_info=True ) + self._pop_savepoint() if self._savepoint_name: - self._pop_savepoint() await self._conn._exec_command( sql.SQL( "rollback to savepoint {n}; release savepoint {n}" ).format(n=sql.Identifier(self._savepoint_name)) ) if self._outer_transaction: - if self._conn._savepoints is None or self._conn._savepoints: - raise self._out_of_order_err await self._conn._exec_command(b"rollback") - self._conn._savepoints = None if isinstance(exc_val, Rollback): if exc_val.transaction in (self, None):