import logging
from types import TracebackType
-from typing import Generic, List, Optional, Type, Union, TYPE_CHECKING
+from typing import Generic, Iterator, Optional, Type, Union, TYPE_CHECKING
from . import pq
from . import sql
from . import errors as e
from .pq import TransactionStatus, ConnStatus
from .abc import ConnectionType, PQGen
-from .pq.abc import PGresult
if TYPE_CHECKING:
from typing import Any
sp = f"{self.savepoint_name!r} " if self.savepoint_name else ""
return f"<{cls} {sp}({status}) {info} at 0x{id(self):x}>"
- def _enter_gen(self) -> PQGen[PGresult]:
+ def _enter_gen(self) -> PQGen[None]:
if self._entered:
raise TypeError("transaction blocks can be used only once")
self._entered = True
self._push_savepoint()
- commands = self._get_enter_commands()
- return self._conn._exec_command(b"; ".join(commands))
+ for command in self._get_enter_commands():
+ yield from self._conn._exec_command(command)
def _exit_gen(
self,
logger.warning("error ignored in rollback of %s: %s", self, exc2)
return False
- def _commit_gen(self) -> PQGen[PGresult]:
+ def _commit_gen(self) -> PQGen[None]:
ex = self._pop_savepoint("commit")
self._exited = True
if ex:
raise ex
- commands = self._get_commit_commands()
- return self._conn._exec_command(b"; ".join(commands))
+ for command in self._get_commit_commands():
+ yield from self._conn._exec_command(command)
def _rollback_gen(self, exc_val: Optional[BaseException]) -> PQGen[bool]:
if isinstance(exc_val, Rollback):
if ex:
raise ex
- commands = self._get_rollback_commands()
- yield from self._conn._exec_command(b"; ".join(commands))
+ for command in self._get_rollback_commands():
+ yield from self._conn._exec_command(command)
if isinstance(exc_val, Rollback):
if not exc_val.transaction or exc_val.transaction is self:
return False
- def _get_enter_commands(self) -> List[bytes]:
- commands = []
+ def _get_enter_commands(self) -> Iterator[bytes]:
if self._outer_transaction:
- commands.append(self._conn._get_tx_start_command())
+ yield self._conn._get_tx_start_command()
if self._savepoint_name:
- commands.append(
+ yield (
sql.SQL("SAVEPOINT {}")
.format(sql.Identifier(self._savepoint_name))
.as_bytes(self._conn)
)
- return commands
- def _get_commit_commands(self) -> List[bytes]:
- commands = []
+ def _get_commit_commands(self) -> Iterator[bytes]:
if self._savepoint_name and not self._outer_transaction:
- commands.append(
+ yield (
sql.SQL("RELEASE {}")
.format(sql.Identifier(self._savepoint_name))
.as_bytes(self._conn)
if self._outer_transaction:
assert not self._conn._num_transactions
- commands.append(b"COMMIT")
+ yield b"COMMIT"
- return commands
-
- def _get_rollback_commands(self) -> List[bytes]:
- commands = []
+ def _get_rollback_commands(self) -> Iterator[bytes]:
if self._savepoint_name and not self._outer_transaction:
- commands.append(
- sql.SQL("ROLLBACK TO {n}; RELEASE {n}")
+ yield (
+ sql.SQL("ROLLBACK TO {n}")
+ .format(n=sql.Identifier(self._savepoint_name))
+ .as_bytes(self._conn)
+ )
+ yield (
+ sql.SQL("RELEASE {n}")
.format(n=sql.Identifier(self._savepoint_name))
.as_bytes(self._conn)
)
if self._outer_transaction:
assert not self._conn._num_transactions
- commands.append(b"ROLLBACK")
+ yield b"ROLLBACK"
# Also clear the prepared statements cache.
if self._conn._prepared.clear():
- for cmd in self._conn._prepared.get_maintenance_commands():
- commands.append(cmd)
-
- return commands
+ yield from self._conn._prepared.get_maintenance_commands()
def _push_savepoint(self) -> None:
"""
# Case 2
with conn.transaction(savepoint_name="foo") as tx:
- assert commands.popall() == ['BEGIN; SAVEPOINT "foo"']
+ assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"']
assert tx.savepoint_name == "foo"
assert commands.popall() == ["COMMIT"]
# Case 2
with pytest.raises(ExpectedException):
with conn.transaction(savepoint_name="foo") as tx:
- assert commands.popall() == ['BEGIN; SAVEPOINT "foo"']
+ assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"']
assert tx.savepoint_name == "foo"
raise ExpectedException
assert commands.popall() == ["ROLLBACK"]
assert commands.popall() == ['SAVEPOINT "bar"']
assert tx.savepoint_name == "bar"
raise ExpectedException
- assert commands.popall() == ['ROLLBACK TO "bar"; RELEASE "bar"']
+ assert commands.popall() == ['ROLLBACK TO "bar"', 'RELEASE "bar"']
assert commands.popall() == ["COMMIT"]
# Case 3 (with savepoint name auto-generated)
assert commands.popall() == ['SAVEPOINT "_pg3_2"']
assert tx.savepoint_name == "_pg3_2"
raise ExpectedException
- assert commands.popall() == ['ROLLBACK TO "_pg3_2"; RELEASE "_pg3_2"']
+ assert commands.popall() == [
+ 'ROLLBACK TO "_pg3_2"',
+ 'RELEASE "_pg3_2"',
+ ]
assert commands.popall() == ["COMMIT"]
# Case 2
async with aconn.transaction(savepoint_name="foo") as tx:
- assert commands.popall() == ['BEGIN; SAVEPOINT "foo"']
+ assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"']
assert tx.savepoint_name == "foo"
assert commands.popall() == ["COMMIT"]
# Case 2
with pytest.raises(ExpectedException):
async with aconn.transaction(savepoint_name="foo") as tx:
- assert commands.popall() == ['BEGIN; SAVEPOINT "foo"']
+ assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"']
assert tx.savepoint_name == "foo"
raise ExpectedException
assert commands.popall() == ["ROLLBACK"]
assert commands.popall() == ['SAVEPOINT "bar"']
assert tx.savepoint_name == "bar"
raise ExpectedException
- assert commands.popall() == ['ROLLBACK TO "bar"; RELEASE "bar"']
+ assert commands.popall() == ['ROLLBACK TO "bar"', 'RELEASE "bar"']
assert commands.popall() == ["COMMIT"]
# Case 3 (with savepoint name auto-generated)
assert commands.popall() == ['SAVEPOINT "_pg3_2"']
assert tx.savepoint_name == "_pg3_2"
raise ExpectedException
- assert commands.popall() == ['ROLLBACK TO "_pg3_2"; RELEASE "_pg3_2"']
+ assert commands.popall() == [
+ 'ROLLBACK TO "_pg3_2"',
+ 'RELEASE "_pg3_2"',
+ ]
assert commands.popall() == ["COMMIT"]