import logging
from types import TracebackType
-from typing import Generic, Optional, Type, Union, TYPE_CHECKING
+from typing import Generic, List, Optional, Type, Union, TYPE_CHECKING
from . import pq
from . import sql
self._entered = True
self._push_savepoint()
-
- commands = []
- if self._outer_transaction:
- commands.append(self._conn._get_tx_start_command())
-
- if self._savepoint_name:
- commands.append(
- sql.SQL("SAVEPOINT {}")
- .format(sql.Identifier(self._savepoint_name))
- .as_bytes(self._conn)
- )
-
+ commands = self._get_enter_commands()
return self._conn._exec_command(b"; ".join(commands))
def _exit_gen(
if ex:
raise ex
- commands = []
- if self._savepoint_name and not self._outer_transaction:
- commands.append(
- sql.SQL("RELEASE {}")
- .format(sql.Identifier(self._savepoint_name))
- .as_bytes(self._conn)
- )
-
- if self._outer_transaction:
- assert not self._conn._savepoints
- commands.append(b"COMMIT")
-
+ commands = self._get_commit_commands()
return self._conn._exec_command(b"; ".join(commands))
def _rollback_gen(self, exc_val: Optional[BaseException]) -> PQGen[bool]:
if ex:
raise ex
+ commands = self._get_rollback_commands()
+ yield from self._conn._exec_command(b"; ".join(commands))
+
+ if isinstance(exc_val, Rollback):
+ if not exc_val.transaction or exc_val.transaction is self:
+ return True # Swallow the exception
+
+ return False
+
+ def _get_enter_commands(self) -> List[bytes]:
+ commands = []
+ if self._outer_transaction:
+ commands.append(self._conn._get_tx_start_command())
+
+ if self._savepoint_name:
+ commands.append(
+ sql.SQL("SAVEPOINT {}")
+ .format(sql.Identifier(self._savepoint_name))
+ .as_bytes(self._conn)
+ )
+ return commands
+
+ def _get_commit_commands(self) -> List[bytes]:
+ commands = []
+ if self._savepoint_name and not self._outer_transaction:
+ commands.append(
+ sql.SQL("RELEASE {}")
+ .format(sql.Identifier(self._savepoint_name))
+ .as_bytes(self._conn)
+ )
+
+ if self._outer_transaction:
+ assert not self._conn._savepoints
+ commands.append(b"COMMIT")
+
+ return commands
+
+ def _get_rollback_commands(self) -> List[bytes]:
commands = []
if self._savepoint_name and not self._outer_transaction:
commands.append(
for cmd in self._conn._prepared.get_maintenance_commands():
commands.append(cmd)
- yield from self._conn._exec_command(b"; ".join(commands))
-
- if isinstance(exc_val, Rollback):
- if not exc_val.transaction or exc_val.transaction is self:
- return True # Swallow the exception
-
- return False
+ return commands
def _push_savepoint(self) -> None:
"""