From 6a574f91814878169a6cee46025869caf725d7f2 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Thu, 9 Dec 2021 14:55:28 +0100 Subject: [PATCH] Refactor Transaction queries generation into internal methods --- psycopg/psycopg/transaction.py | 74 +++++++++++++++++++--------------- 1 file changed, 42 insertions(+), 32 deletions(-) diff --git a/psycopg/psycopg/transaction.py b/psycopg/psycopg/transaction.py index fd27a7507..90cc69a22 100644 --- a/psycopg/psycopg/transaction.py +++ b/psycopg/psycopg/transaction.py @@ -7,7 +7,7 @@ Transaction context managers returned by Connection.transaction() import logging from types import TracebackType -from typing import Generic, Optional, Type, Union, TYPE_CHECKING +from typing import Generic, List, Optional, Type, Union, TYPE_CHECKING from . import pq from . import sql @@ -90,18 +90,7 @@ class BaseTransaction(Generic[ConnectionType]): self._entered = True self._push_savepoint() - - commands = [] - if self._outer_transaction: - commands.append(self._conn._get_tx_start_command()) - - if self._savepoint_name: - commands.append( - sql.SQL("SAVEPOINT {}") - .format(sql.Identifier(self._savepoint_name)) - .as_bytes(self._conn) - ) - + commands = self._get_enter_commands() return self._conn._exec_command(b"; ".join(commands)) def _exit_gen( @@ -137,18 +126,7 @@ class BaseTransaction(Generic[ConnectionType]): if ex: raise ex - commands = [] - if self._savepoint_name and not self._outer_transaction: - commands.append( - sql.SQL("RELEASE {}") - .format(sql.Identifier(self._savepoint_name)) - .as_bytes(self._conn) - ) - - if self._outer_transaction: - assert not self._conn._savepoints - commands.append(b"COMMIT") - + commands = self._get_commit_commands() return self._conn._exec_command(b"; ".join(commands)) def _rollback_gen(self, exc_val: Optional[BaseException]) -> PQGen[bool]: @@ -162,6 +140,44 @@ class BaseTransaction(Generic[ConnectionType]): if ex: raise ex + commands = self._get_rollback_commands() + yield from self._conn._exec_command(b"; ".join(commands)) + + if isinstance(exc_val, Rollback): + if not exc_val.transaction or exc_val.transaction is self: + return True # Swallow the exception + + return False + + def _get_enter_commands(self) -> List[bytes]: + commands = [] + if self._outer_transaction: + commands.append(self._conn._get_tx_start_command()) + + if self._savepoint_name: + commands.append( + sql.SQL("SAVEPOINT {}") + .format(sql.Identifier(self._savepoint_name)) + .as_bytes(self._conn) + ) + return commands + + def _get_commit_commands(self) -> List[bytes]: + commands = [] + if self._savepoint_name and not self._outer_transaction: + commands.append( + sql.SQL("RELEASE {}") + .format(sql.Identifier(self._savepoint_name)) + .as_bytes(self._conn) + ) + + if self._outer_transaction: + assert not self._conn._savepoints + commands.append(b"COMMIT") + + return commands + + def _get_rollback_commands(self) -> List[bytes]: commands = [] if self._savepoint_name and not self._outer_transaction: commands.append( @@ -179,13 +195,7 @@ class BaseTransaction(Generic[ConnectionType]): for cmd in self._conn._prepared.get_maintenance_commands(): commands.append(cmd) - yield from self._conn._exec_command(b"; ".join(commands)) - - if isinstance(exc_val, Rollback): - if not exc_val.transaction or exc_val.transaction is self: - return True # Swallow the exception - - return False + return commands def _push_savepoint(self) -> None: """ -- 2.47.2