from . import sql
from .pq import TransactionStatus
-from .proto import ConnectionType, Query
+from .proto import ConnectionType
from .errors import ProgrammingError
if TYPE_CHECKING:
"calling __exit__() manually and getting it wrong?"
)
- def _enter_commands(self) -> List[Query]:
- commands: List[Query] = []
+ def _enter_commands(self) -> List[str]:
+ commands = []
if self._outer_transaction:
assert self._conn._savepoints is None, self._conn._savepoints
self._conn._savepoints = []
- commands.append(b"begin")
+ commands.append("begin")
else:
if self._conn._savepoints is None:
self._conn._savepoints = []
if self._savepoint_name:
commands.append(
- sql.SQL("savepoint {}").format(
- sql.Identifier(self._savepoint_name)
- )
+ sql.SQL("savepoint {}")
+ .format(sql.Identifier(self._savepoint_name))
+ .as_string(self._conn)
)
self._conn._savepoints.append(self._savepoint_name)
return commands
- def _commit_commands(self) -> List[Query]:
- commands: List[Query] = []
+ def _commit_commands(self) -> List[str]:
+ commands = []
self._pop_savepoint()
- if self._savepoint_name:
+ if self._savepoint_name and not self._outer_transaction:
commands.append(
- sql.SQL("release savepoint {}").format(
- sql.Identifier(self._savepoint_name)
- )
+ sql.SQL("release savepoint {}")
+ .format(sql.Identifier(self._savepoint_name))
+ .as_string(self._conn)
)
if self._outer_transaction:
- commands.append(b"commit")
+ commands.append("commit")
return commands
- def _rollback_commands(self) -> List[Query]:
- commands: List[Query] = []
+ def _rollback_commands(self) -> List[str]:
+ commands = []
self._pop_savepoint()
- if self._savepoint_name:
+ if self._savepoint_name and not self._outer_transaction:
commands.append(
- sql.SQL(
- "rollback to savepoint {n}; release savepoint {n}"
- ).format(n=sql.Identifier(self._savepoint_name))
+ sql.SQL("rollback to savepoint {n}; release savepoint {n}")
+ .format(n=sql.Identifier(self._savepoint_name))
+ .as_string(self._conn)
)
if self._outer_transaction:
- commands.append(b"rollback")
+ commands.append("rollback")
return commands
return False
- def _execute(self, commands: List[Query]) -> None:
- for command in commands:
- self._conn._exec_command(command)
+ def _execute(self, commands: List[str]) -> None:
+ self._conn._exec_command("; ".join(commands))
class AsyncTransaction(BaseTransaction["AsyncConnection"]):
return False
- async def _execute(self, commands: List[Query]) -> None:
- for command in commands:
- await self._conn._exec_command(command)
+ async def _execute(self, commands: List[str]) -> None:
+ await self._conn._exec_command("; ".join(commands))
# Case 2
tx = Transaction(conn, savepoint_name="foo")
tx.__enter__()
- assert commands.pop() == "begin"
- assert commands.pop() == 'savepoint "foo"'
+ assert commands.pop() == 'begin; savepoint "foo"'
assert tx.savepoint_name == "foo"
tx.__exit__(None, None, None)
- assert commands.pop() == 'release savepoint "foo"'
assert commands.pop() == "commit"
# Case 3 (with savepoint name provided)
assert tx.savepoint_name == "bar"
tx.__exit__(None, None, None)
assert commands.pop() == 'release savepoint "bar"'
+ assert not commands
assert commands.pop() == "commit"
# Case 3 (with savepoint name auto-generated)
assert tx.savepoint_name == "s1"
tx.__exit__(None, None, None)
assert commands.pop() == 'release savepoint "s1"'
+ assert not commands
assert commands.pop() == "commit"
assert not commands
# Case 2
tx = Transaction(conn, savepoint_name="foo")
tx.__enter__()
- assert commands.pop() == "begin"
- assert commands.pop() == 'savepoint "foo"'
+ assert commands.pop() == 'begin; savepoint "foo"'
assert tx.savepoint_name == "foo"
tx.__exit__(*some_exc_info())
- assert (
- commands.pop()
- == 'rollback to savepoint "foo"; release savepoint "foo"'
- )
assert commands.pop() == "rollback"
# Case 3 (with savepoint name provided)
commands.pop()
== 'rollback to savepoint "bar"; release savepoint "bar"'
)
+ assert not commands
assert commands.pop() == "commit"
# Case 3 (with savepoint name auto-generated)
commands.pop()
== 'rollback to savepoint "s1"; release savepoint "s1"'
)
+ assert not commands
assert commands.pop() == "commit"
assert not commands
# Case 2
tx = AsyncTransaction(aconn, savepoint_name="foo")
await tx.__aenter__()
- assert commands.pop() == "begin"
- assert commands.pop() == 'savepoint "foo"'
+ assert commands.pop() == 'begin; savepoint "foo"'
assert tx.savepoint_name == "foo"
await tx.__aexit__(None, None, None)
- assert commands.pop() == 'release savepoint "foo"'
assert commands.pop() == "commit"
# Case 3 (with savepoint name provided)
assert tx.savepoint_name == "bar"
await tx.__aexit__(None, None, None)
assert commands.pop() == 'release savepoint "bar"'
+ assert not commands
assert commands.pop() == "commit"
# Case 3 (with savepoint name auto-generated)
assert tx.savepoint_name == "s1"
await tx.__aexit__(None, None, None)
assert commands.pop() == 'release savepoint "s1"'
+ assert not commands
assert commands.pop() == "commit"
assert not commands
# Case 2
tx = AsyncTransaction(aconn, savepoint_name="foo")
await tx.__aenter__()
- assert commands.pop() == "begin"
- assert commands.pop() == 'savepoint "foo"'
+ assert commands.pop() == 'begin; savepoint "foo"'
assert tx.savepoint_name == "foo"
await tx.__aexit__(*some_exc_info())
- assert (
- commands.pop()
- == 'rollback to savepoint "foo"; release savepoint "foo"'
- )
assert commands.pop() == "rollback"
# Case 3 (with savepoint name provided)
commands.pop()
== 'rollback to savepoint "bar"; release savepoint "bar"'
)
+ assert not commands
assert commands.pop() == "commit"
# Case 3 (with savepoint name auto-generated)
commands.pop()
== 'rollback to savepoint "s1"; release savepoint "s1"'
)
+ assert not commands
assert commands.pop() == "commit"
assert not commands