From: Daniele Varrazzo Date: Thu, 9 Dec 2021 13:55:28 +0000 (+0100) Subject: Refactor Transaction queries generation into internal methods X-Git-Tag: pool-3.1~79^2~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6a574f91814878169a6cee46025869caf725d7f2;p=thirdparty%2Fpsycopg.git Refactor Transaction queries generation into internal methods --- diff --git a/psycopg/psycopg/transaction.py b/psycopg/psycopg/transaction.py index fd27a7507..90cc69a22 100644 --- a/psycopg/psycopg/transaction.py +++ b/psycopg/psycopg/transaction.py @@ -7,7 +7,7 @@ Transaction context managers returned by Connection.transaction() import logging from types import TracebackType -from typing import Generic, Optional, Type, Union, TYPE_CHECKING +from typing import Generic, List, Optional, Type, Union, TYPE_CHECKING from . import pq from . import sql @@ -90,18 +90,7 @@ class BaseTransaction(Generic[ConnectionType]): self._entered = True self._push_savepoint() - - commands = [] - if self._outer_transaction: - commands.append(self._conn._get_tx_start_command()) - - if self._savepoint_name: - commands.append( - sql.SQL("SAVEPOINT {}") - .format(sql.Identifier(self._savepoint_name)) - .as_bytes(self._conn) - ) - + commands = self._get_enter_commands() return self._conn._exec_command(b"; ".join(commands)) def _exit_gen( @@ -137,18 +126,7 @@ class BaseTransaction(Generic[ConnectionType]): if ex: raise ex - commands = [] - if self._savepoint_name and not self._outer_transaction: - commands.append( - sql.SQL("RELEASE {}") - .format(sql.Identifier(self._savepoint_name)) - .as_bytes(self._conn) - ) - - if self._outer_transaction: - assert not self._conn._savepoints - commands.append(b"COMMIT") - + commands = self._get_commit_commands() return self._conn._exec_command(b"; ".join(commands)) def _rollback_gen(self, exc_val: Optional[BaseException]) -> PQGen[bool]: @@ -162,6 +140,44 @@ class BaseTransaction(Generic[ConnectionType]): if ex: raise ex + commands = self._get_rollback_commands() + yield from self._conn._exec_command(b"; ".join(commands)) + + if isinstance(exc_val, Rollback): + if not exc_val.transaction or exc_val.transaction is self: + return True # Swallow the exception + + return False + + def _get_enter_commands(self) -> List[bytes]: + commands = [] + if self._outer_transaction: + commands.append(self._conn._get_tx_start_command()) + + if self._savepoint_name: + commands.append( + sql.SQL("SAVEPOINT {}") + .format(sql.Identifier(self._savepoint_name)) + .as_bytes(self._conn) + ) + return commands + + def _get_commit_commands(self) -> List[bytes]: + commands = [] + if self._savepoint_name and not self._outer_transaction: + commands.append( + sql.SQL("RELEASE {}") + .format(sql.Identifier(self._savepoint_name)) + .as_bytes(self._conn) + ) + + if self._outer_transaction: + assert not self._conn._savepoints + commands.append(b"COMMIT") + + return commands + + def _get_rollback_commands(self) -> List[bytes]: commands = [] if self._savepoint_name and not self._outer_transaction: commands.append( @@ -179,13 +195,7 @@ class BaseTransaction(Generic[ConnectionType]): for cmd in self._conn._prepared.get_maintenance_commands(): commands.append(cmd) - yield from self._conn._exec_command(b"; ".join(commands)) - - if isinstance(exc_val, Rollback): - if not exc_val.transaction or exc_val.transaction is self: - return True # Swallow the exception - - return False + return commands def _push_savepoint(self) -> None: """