import logging
from types import TracebackType
-from typing import Generic, Optional, Type, Union, TYPE_CHECKING
+from typing import Generic, List, Optional, Type, Union, TYPE_CHECKING
from . import sql
from .pq import TransactionStatus
-from .proto import ConnectionType
+from .proto import ConnectionType, Query
from .errors import ProgrammingError
if TYPE_CHECKING:
"calling __exit__() manually and getting it wrong?"
)
+ def _enter_commands(self) -> List[Query]:
+ commands: List[Query] = []
+
+ if self._outer_transaction:
+ assert self._conn._savepoints is None, self._conn._savepoints
+ self._conn._savepoints = []
+ commands.append(b"begin")
+ else:
+ if self._conn._savepoints is None:
+ self._conn._savepoints = []
+ if not self._savepoint_name:
+ self._savepoint_name = f"s{len(self._conn._savepoints) + 1}"
+
+ if self._savepoint_name:
+ commands.append(
+ sql.SQL("savepoint {}").format(
+ sql.Identifier(self._savepoint_name)
+ )
+ )
+ self._conn._savepoints.append(self._savepoint_name)
+
+ return commands
+
+ def _commit_commands(self) -> List[Query]:
+ commands: List[Query] = []
+
+ self._pop_savepoint()
+ if self._savepoint_name:
+ commands.append(
+ sql.SQL("release savepoint {}").format(
+ sql.Identifier(self._savepoint_name)
+ )
+ )
+ if self._outer_transaction:
+ commands.append(b"commit")
+
+ return commands
+
+ def _rollback_commands(self) -> List[Query]:
+ commands: List[Query] = []
+
+ self._pop_savepoint()
+ if self._savepoint_name:
+ commands.append(
+ sql.SQL(
+ "rollback to savepoint {n}; release savepoint {n}"
+ ).format(n=sql.Identifier(self._savepoint_name))
+ )
+ if self._outer_transaction:
+ commands.append(b"rollback")
+
+ return commands
+
def _pop_savepoint(self) -> None:
if self._savepoint_name:
if self._conn._savepoints is None:
class Transaction(BaseTransaction["Connection"]):
def __enter__(self) -> "Transaction":
with self._conn.lock:
- if self._outer_transaction:
- assert self._conn._savepoints is None, self._conn._savepoints
- self._conn._savepoints = []
- self._conn._exec_command(b"begin")
- else:
- if self._conn._savepoints is None:
- self._conn._savepoints = []
- if not self._savepoint_name:
- self._savepoint_name = (
- f"s{len(self._conn._savepoints) + 1}"
- )
-
- if self._savepoint_name:
- self._conn._exec_command(
- sql.SQL("savepoint {}").format(
- sql.Identifier(self._savepoint_name)
- )
- )
- self._conn._savepoints.append(self._savepoint_name)
+ self._execute(self._enter_commands())
return self
def __exit__(
) -> bool:
with self._conn.lock:
if not exc_val and not self.force_rollback:
- return self._commit()
+ self._commit()
+ return False
else:
return self._rollback(exc_val)
- def _commit(self) -> bool:
+ def _commit(self) -> None:
"""Commit changes made in the transaction context."""
- self._pop_savepoint()
- if self._savepoint_name:
- self._conn._exec_command(
- sql.SQL("release savepoint {}").format(
- sql.Identifier(self._savepoint_name)
- )
- )
- if self._outer_transaction:
- self._conn._exec_command(b"commit")
-
- return False # discarded
+ self._execute(self._commit_commands())
def _rollback(self, exc_val: Optional[BaseException]) -> bool:
# Rollback changes made in the transaction context
f"{self._conn}: Explicit rollback from: ", exc_info=True
)
- self._pop_savepoint()
- if self._savepoint_name:
- self._conn._exec_command(
- sql.SQL(
- "rollback to savepoint {n}; release savepoint {n}"
- ).format(n=sql.Identifier(self._savepoint_name))
- )
- if self._outer_transaction:
- self._conn._exec_command(b"rollback")
+ self._execute(self._rollback_commands())
if isinstance(exc_val, Rollback):
if exc_val.transaction in (self, None):
return False
+ def _execute(self, commands: List[Query]) -> None:
+ for command in commands:
+ self._conn._exec_command(command)
+
class AsyncTransaction(BaseTransaction["AsyncConnection"]):
async def __aenter__(self) -> "AsyncTransaction":
async with self._conn.lock:
- if self._outer_transaction:
- assert self._conn._savepoints is None, self._conn._savepoints
- self._conn._savepoints = []
- await self._conn._exec_command(b"begin")
- else:
- if self._conn._savepoints is None:
- self._conn._savepoints = []
- if not self._savepoint_name:
- self._savepoint_name = (
- f"s{len(self._conn._savepoints) + 1}"
- )
-
- if self._savepoint_name:
- await self._conn._exec_command(
- sql.SQL("savepoint {}").format(
- sql.Identifier(self._savepoint_name)
- )
- )
- self._conn._savepoints.append(self._savepoint_name)
+ await self._execute(self._enter_commands())
+
return self
async def __aexit__(
) -> bool:
async with self._conn.lock:
if not exc_val and not self.force_rollback:
- return await self._commit()
+ await self._commit()
+ return False
else:
return await self._rollback(exc_val)
- async def _commit(self) -> bool:
+ async def _commit(self) -> None:
"""Commit changes made in the transaction context."""
- self._pop_savepoint()
- if self._savepoint_name:
- await self._conn._exec_command(
- sql.SQL("release savepoint {}").format(
- sql.Identifier(self._savepoint_name)
- )
- )
- if self._outer_transaction:
- await self._conn._exec_command(b"commit")
-
- return False # discarded
+ await self._execute(self._commit_commands())
async def _rollback(self, exc_val: Optional[BaseException]) -> bool:
# Rollback changes made in the transaction context
f"{self._conn}: Explicit rollback from: ", exc_info=True
)
- self._pop_savepoint()
- if self._savepoint_name:
- await self._conn._exec_command(
- sql.SQL(
- "rollback to savepoint {n}; release savepoint {n}"
- ).format(n=sql.Identifier(self._savepoint_name))
- )
- if self._outer_transaction:
- await self._conn._exec_command(b"rollback")
+ await self._execute(self._rollback_commands())
if isinstance(exc_val, Rollback):
if exc_val.transaction in (self, None):
return True # Swallow the exception
return False
+
+ async def _execute(self, commands: List[Query]) -> None:
+ for command in commands:
+ await self._conn._exec_command(command)