From: Daniele Varrazzo Date: Mon, 16 Nov 2020 03:52:29 +0000 (+0000) Subject: Commands generation separated from execution in transactions X-Git-Tag: 3.0.dev0~351^2~5 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=93c388a12486329bb5db757cc405e5fe9be77eb8;p=thirdparty%2Fpsycopg.git Commands generation separated from execution in transactions Commands and state change are independent from sync/async. Only the execution and the interface is in the different sync/async subclasses now. --- diff --git a/psycopg3/psycopg3/transaction.py b/psycopg3/psycopg3/transaction.py index b159b65b0..a24d7874f 100644 --- a/psycopg3/psycopg3/transaction.py +++ b/psycopg3/psycopg3/transaction.py @@ -7,11 +7,11 @@ Transaction context managers returned by Connection.transaction() import logging from types import TracebackType -from typing import Generic, Optional, Type, Union, TYPE_CHECKING +from typing import Generic, List, Optional, Type, Union, TYPE_CHECKING from . import sql from .pq import TransactionStatus -from .proto import ConnectionType +from .proto import ConnectionType, Query from .errors import ProgrammingError if TYPE_CHECKING: @@ -74,6 +74,59 @@ class BaseTransaction(Generic[ConnectionType]): "calling __exit__() manually and getting it wrong?" ) + def _enter_commands(self) -> List[Query]: + commands: List[Query] = [] + + if self._outer_transaction: + assert self._conn._savepoints is None, self._conn._savepoints + self._conn._savepoints = [] + commands.append(b"begin") + else: + if self._conn._savepoints is None: + self._conn._savepoints = [] + if not self._savepoint_name: + self._savepoint_name = f"s{len(self._conn._savepoints) + 1}" + + if self._savepoint_name: + commands.append( + sql.SQL("savepoint {}").format( + sql.Identifier(self._savepoint_name) + ) + ) + self._conn._savepoints.append(self._savepoint_name) + + return commands + + def _commit_commands(self) -> List[Query]: + commands: List[Query] = [] + + self._pop_savepoint() + if self._savepoint_name: + commands.append( + sql.SQL("release savepoint {}").format( + sql.Identifier(self._savepoint_name) + ) + ) + if self._outer_transaction: + commands.append(b"commit") + + return commands + + def _rollback_commands(self) -> List[Query]: + commands: List[Query] = [] + + self._pop_savepoint() + if self._savepoint_name: + commands.append( + sql.SQL( + "rollback to savepoint {n}; release savepoint {n}" + ).format(n=sql.Identifier(self._savepoint_name)) + ) + if self._outer_transaction: + commands.append(b"rollback") + + return commands + def _pop_savepoint(self) -> None: if self._savepoint_name: if self._conn._savepoints is None: @@ -90,25 +143,7 @@ class BaseTransaction(Generic[ConnectionType]): class Transaction(BaseTransaction["Connection"]): def __enter__(self) -> "Transaction": with self._conn.lock: - if self._outer_transaction: - assert self._conn._savepoints is None, self._conn._savepoints - self._conn._savepoints = [] - self._conn._exec_command(b"begin") - else: - if self._conn._savepoints is None: - self._conn._savepoints = [] - if not self._savepoint_name: - self._savepoint_name = ( - f"s{len(self._conn._savepoints) + 1}" - ) - - if self._savepoint_name: - self._conn._exec_command( - sql.SQL("savepoint {}").format( - sql.Identifier(self._savepoint_name) - ) - ) - self._conn._savepoints.append(self._savepoint_name) + self._execute(self._enter_commands()) return self def __exit__( @@ -119,23 +154,14 @@ class Transaction(BaseTransaction["Connection"]): ) -> bool: with self._conn.lock: if not exc_val and not self.force_rollback: - return self._commit() + self._commit() + return False else: return self._rollback(exc_val) - def _commit(self) -> bool: + def _commit(self) -> None: """Commit changes made in the transaction context.""" - self._pop_savepoint() - if self._savepoint_name: - self._conn._exec_command( - sql.SQL("release savepoint {}").format( - sql.Identifier(self._savepoint_name) - ) - ) - if self._outer_transaction: - self._conn._exec_command(b"commit") - - return False # discarded + self._execute(self._commit_commands()) def _rollback(self, exc_val: Optional[BaseException]) -> bool: # Rollback changes made in the transaction context @@ -144,15 +170,7 @@ class Transaction(BaseTransaction["Connection"]): f"{self._conn}: Explicit rollback from: ", exc_info=True ) - self._pop_savepoint() - if self._savepoint_name: - self._conn._exec_command( - sql.SQL( - "rollback to savepoint {n}; release savepoint {n}" - ).format(n=sql.Identifier(self._savepoint_name)) - ) - if self._outer_transaction: - self._conn._exec_command(b"rollback") + self._execute(self._rollback_commands()) if isinstance(exc_val, Rollback): if exc_val.transaction in (self, None): @@ -160,29 +178,16 @@ class Transaction(BaseTransaction["Connection"]): return False + def _execute(self, commands: List[Query]) -> None: + for command in commands: + self._conn._exec_command(command) + class AsyncTransaction(BaseTransaction["AsyncConnection"]): async def __aenter__(self) -> "AsyncTransaction": async with self._conn.lock: - if self._outer_transaction: - assert self._conn._savepoints is None, self._conn._savepoints - self._conn._savepoints = [] - await self._conn._exec_command(b"begin") - else: - if self._conn._savepoints is None: - self._conn._savepoints = [] - if not self._savepoint_name: - self._savepoint_name = ( - f"s{len(self._conn._savepoints) + 1}" - ) - - if self._savepoint_name: - await self._conn._exec_command( - sql.SQL("savepoint {}").format( - sql.Identifier(self._savepoint_name) - ) - ) - self._conn._savepoints.append(self._savepoint_name) + await self._execute(self._enter_commands()) + return self async def __aexit__( @@ -193,23 +198,14 @@ class AsyncTransaction(BaseTransaction["AsyncConnection"]): ) -> bool: async with self._conn.lock: if not exc_val and not self.force_rollback: - return await self._commit() + await self._commit() + return False else: return await self._rollback(exc_val) - async def _commit(self) -> bool: + async def _commit(self) -> None: """Commit changes made in the transaction context.""" - self._pop_savepoint() - if self._savepoint_name: - await self._conn._exec_command( - sql.SQL("release savepoint {}").format( - sql.Identifier(self._savepoint_name) - ) - ) - if self._outer_transaction: - await self._conn._exec_command(b"commit") - - return False # discarded + await self._execute(self._commit_commands()) async def _rollback(self, exc_val: Optional[BaseException]) -> bool: # Rollback changes made in the transaction context @@ -218,18 +214,14 @@ class AsyncTransaction(BaseTransaction["AsyncConnection"]): f"{self._conn}: Explicit rollback from: ", exc_info=True ) - self._pop_savepoint() - if self._savepoint_name: - await self._conn._exec_command( - sql.SQL( - "rollback to savepoint {n}; release savepoint {n}" - ).format(n=sql.Identifier(self._savepoint_name)) - ) - if self._outer_transaction: - await self._conn._exec_command(b"rollback") + await self._execute(self._rollback_commands()) if isinstance(exc_val, Rollback): if exc_val.transaction in (self, None): return True # Swallow the exception return False + + async def _execute(self, commands: List[Query]) -> None: + for command in commands: + await self._conn._exec_command(command)