State management simplified too.
self._notice_handlers: List[NoticeHandler] = []
self._notify_handlers: List[NotifyHandler] = []
- # stack of savepoint names managed by active Transaction() blocks
- self._savepoints: Optional[List[str]] = None
- # (None when there no active Transaction blocks; [] when there is only
- # one Transaction block, with a top-level transaction and no savepoint)
+ # Stack of savepoint names managed by current transaction blocks.
+ # the first item is "" in case the outermost Transaction must manage
+ # only a begin/commit and not a savepoint.
+ self._savepoints: List[str] = []
wself = ref(self)
# subclasses must call it holding a lock
status = self.pgconn.transaction_status
if status != TransactionStatus.IDLE:
- if self._savepoints is not None:
+ if self._savepoints:
raise e.ProgrammingError(
"couldn't change autocommit state: "
"connection.transaction() context in progress"
def commit(self) -> None:
"""Commit any pending transaction to the database."""
with self.lock:
- if self._savepoints is not None:
+ if self._savepoints:
raise e.ProgrammingError(
"Explicit commit() forbidden within a Transaction "
"context. (Transaction will be automatically committed "
def rollback(self) -> None:
"""Roll back to the start of any pending transaction."""
with self.lock:
- if self._savepoints is not None:
+ if self._savepoints:
raise e.ProgrammingError(
"Explicit rollback() forbidden within a Transaction "
"context. (Either raise Transaction.Rollback() or allow "
async with self.lock:
if self.pgconn.transaction_status == TransactionStatus.IDLE:
return
- if self._savepoints is not None:
+ if self._savepoints:
raise e.ProgrammingError(
"Explicit commit() forbidden within a Transaction "
"context. (Transaction will be automatically committed "
async with self.lock:
if self.pgconn.transaction_status == TransactionStatus.IDLE:
return
- if self._savepoints is not None:
+ if self._savepoints:
raise e.ProgrammingError(
"Explicit rollback() forbidden within a Transaction "
"context. (Either raise Transaction.Rollback() or allow "
from . import sql
from .pq import TransactionStatus
from .proto import ConnectionType
-from .errors import ProgrammingError
if TYPE_CHECKING:
from .connection import Connection, AsyncConnection # noqa: F401
force_rollback: bool = False,
):
self._conn = connection
- self._savepoint_name = savepoint_name or ""
self.force_rollback = force_rollback
- self._outer_transaction = (
- connection.pgconn.transaction_status == TransactionStatus.IDLE
- )
+ self._yolo = True
+
+ if connection.pgconn.transaction_status == TransactionStatus.IDLE:
+ # outer transaction: if no name it's only a begin, else
+ # there will be an additional savepoint
+ self._outer_transaction = True
+ assert not connection._savepoints
+ self._savepoint_name = savepoint_name or ""
+ else:
+ # inner transaction: it always has a name
+ self._outer_transaction = False
+ self._savepoint_name = (
+ savepoint_name or f"s{len(self._conn._savepoints) + 1}"
+ )
@property
def connection(self) -> ConnectionType:
return self._conn
@property
- def savepoint_name(self) -> str:
+ def savepoint_name(self) -> Optional[str]:
return self._savepoint_name
def __repr__(self) -> str:
args.append("force_rollback=True")
return f"{self.__class__.__qualname__}({', '.join(args)})"
- _out_of_order_err = ProgrammingError(
- "Out-of-order Transaction context exits. Are you "
- "calling __exit__() manually and getting it wrong?"
- )
-
def _enter_commands(self) -> List[str]:
- commands = []
+ assert self._yolo
+ self._yolo = False
+ commands = []
if self._outer_transaction:
- assert self._conn._savepoints is None, self._conn._savepoints
- self._conn._savepoints = []
+ assert not self._conn._savepoints, self._conn._savepoints
commands.append("begin")
- else:
- if self._conn._savepoints is None:
- self._conn._savepoints = []
- if not self._savepoint_name:
- self._savepoint_name = f"s{len(self._conn._savepoints) + 1}"
if self._savepoint_name:
commands.append(
.format(sql.Identifier(self._savepoint_name))
.as_string(self._conn)
)
- self._conn._savepoints.append(self._savepoint_name)
+ self._conn._savepoints.append(self._savepoint_name)
return commands
def _commit_commands(self) -> List[str]:
- commands = []
+ assert self._conn._savepoints[-1] == self._savepoint_name
+ self._conn._savepoints.pop()
- self._pop_savepoint()
+ commands = []
if self._savepoint_name and not self._outer_transaction:
commands.append(
sql.SQL("release savepoint {}")
.format(sql.Identifier(self._savepoint_name))
.as_string(self._conn)
)
+
if self._outer_transaction:
+ assert not self._conn._savepoints
commands.append("commit")
return commands
def _rollback_commands(self) -> List[str]:
- commands = []
+ assert self._conn._savepoints[-1] == self._savepoint_name
+ self._conn._savepoints.pop()
- self._pop_savepoint()
+ commands = []
if self._savepoint_name and not self._outer_transaction:
commands.append(
sql.SQL("rollback to savepoint {n}; release savepoint {n}")
.format(n=sql.Identifier(self._savepoint_name))
.as_string(self._conn)
)
+
if self._outer_transaction:
+ assert not self._conn._savepoints
commands.append("rollback")
return commands
- def _pop_savepoint(self) -> None:
- if self._savepoint_name:
- if self._conn._savepoints is None:
- raise self._out_of_order_err
- actual = self._conn._savepoints.pop()
- if actual != self._savepoint_name:
- raise self._out_of_order_err
- if self._outer_transaction:
- if self._conn._savepoints is None or self._conn._savepoints:
- raise self._out_of_order_err
- self._conn._savepoints = None
-
class Transaction(BaseTransaction["Connection"]):
def __enter__(self) -> "Transaction":
assert not inserted(conn)
+def test_interaction_dbapi_transaction(conn):
+ insert_row(conn, "foo")
+
+ with conn.transaction():
+ insert_row(conn, "bar")
+ raise Rollback
+
+ with conn.transaction():
+ insert_row(conn, "baz")
+
+ assert in_transaction(conn)
+ conn.commit()
+ assert inserted(conn) == {"foo", "baz"}
+
+
def test_prohibits_use_of_commit_rollback_autocommit(conn):
"""
Within a Transaction block, it is forbidden to touch commit, rollback,
tx.__exit__(None, None, None)
assert commands.pop() == "commit"
+ # Case 1 (with a transaction already started)
+ conn.cursor().execute("select 1")
+ assert commands.pop() == "begin"
+ tx = Transaction(conn)
+ tx.__enter__()
+ assert commands.pop() == 'savepoint "s1"'
+ assert tx.savepoint_name == "s1"
+ tx.__exit__(None, None, None)
+ assert commands.pop() == 'release savepoint "s1"'
+ assert not commands
+ conn.rollback()
+ assert commands.pop() == "rollback"
+ assert not commands
+
# Case 2
tx = Transaction(conn, savepoint_name="foo")
tx.__enter__()
assert commands.pop() == "begin"
tx = Transaction(conn)
tx.__enter__()
- assert commands.pop() == 'savepoint "s1"'
- assert tx.savepoint_name == "s1"
+ assert commands.pop() == 'savepoint "s2"'
+ assert tx.savepoint_name == "s2"
tx.__exit__(None, None, None)
- assert commands.pop() == 'release savepoint "s1"'
+ assert commands.pop() == 'release savepoint "s2"'
assert not commands
assert commands.pop() == "commit"
assert commands.pop() == "begin"
tx = Transaction(conn)
tx.__enter__()
- assert commands.pop() == 'savepoint "s1"'
- assert tx.savepoint_name == "s1"
+ assert commands.pop() == 'savepoint "s2"'
+ assert tx.savepoint_name == "s2"
tx.__exit__(*some_exc_info())
assert (
commands.pop()
- == 'rollback to savepoint "s1"; release savepoint "s1"'
+ == 'rollback to savepoint "s2"; release savepoint "s2"'
)
assert not commands
assert commands.pop() == "commit"
assert not inserted(svcconn) # Not yet committed
# Changes committed
assert inserted(svcconn) == {"outer-before", "outer-after"}
-
-
-@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()])
-@pytest.mark.parametrize("name", [None, "s1"])
-def test_manual_exit_without_enter_asserts(conn, name, exc_info):
- """
- When user is calling __enter__() and __exit__() manually for some reason,
- provide a helpful error message if they call __exit__() without first
- having called __enter__()
- """
- tx = Transaction(conn, name)
- with pytest.raises(ProgrammingError, match="Out-of-order"):
- tx.__exit__(*exc_info)
-
-
-@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()])
-@pytest.mark.parametrize("name", [None, "s1"])
-def test_manual_exit_twice_asserts(conn, name, exc_info):
- """
- When user is calling __enter__() and __exit__() manually for some reason,
- provide a helpful error message if they accidentally call __exit__() twice.
- """
- tx = Transaction(conn, name)
- tx.__enter__()
- tx.__exit__(*exc_info)
- with pytest.raises(ProgrammingError, match="Out-of-order"):
- tx.__exit__(*exc_info)
assert not await inserted(aconn)
+async def test_interaction_dbapi_transaction(aconn):
+ await insert_row(aconn, "foo")
+
+ async with aconn.transaction():
+ await insert_row(aconn, "bar")
+ raise Rollback
+
+ async with aconn.transaction():
+ await insert_row(aconn, "baz")
+
+ assert in_transaction(aconn)
+ await aconn.commit()
+ assert await inserted(aconn) == {"foo", "baz"}
+
+
async def test_prohibits_use_of_commit_rollback_autocommit(aconn):
"""
Within a Transaction block, it is forbidden to touch commit, rollback,
await tx.__aexit__(None, None, None)
assert commands.pop() == "commit"
+ # Case 1 (with a transaction already started)
+ await (await aconn.cursor()).execute("select 1")
+ assert commands.pop() == "begin"
+ tx = AsyncTransaction(aconn)
+ await tx.__aenter__()
+ assert commands.pop() == 'savepoint "s1"'
+ assert tx.savepoint_name == "s1"
+ await tx.__aexit__(None, None, None)
+ assert commands.pop() == 'release savepoint "s1"'
+ assert not commands
+ await aconn.rollback()
+ assert commands.pop() == "rollback"
+ assert not commands
+
# Case 2
tx = AsyncTransaction(aconn, savepoint_name="foo")
await tx.__aenter__()
assert commands.pop() == "begin"
tx = AsyncTransaction(aconn)
await tx.__aenter__()
- assert commands.pop() == 'savepoint "s1"'
- assert tx.savepoint_name == "s1"
+ assert commands.pop() == 'savepoint "s2"'
+ assert tx.savepoint_name == "s2"
await tx.__aexit__(None, None, None)
- assert commands.pop() == 'release savepoint "s1"'
+ assert commands.pop() == 'release savepoint "s2"'
assert not commands
assert commands.pop() == "commit"
assert commands.pop() == "begin"
tx = AsyncTransaction(aconn)
await tx.__aenter__()
- assert commands.pop() == 'savepoint "s1"'
- assert tx.savepoint_name == "s1"
+ assert commands.pop() == 'savepoint "s2"'
+ assert tx.savepoint_name == "s2"
await tx.__aexit__(*some_exc_info())
assert (
commands.pop()
- == 'rollback to savepoint "s1"; release savepoint "s1"'
+ == 'rollback to savepoint "s2"; release savepoint "s2"'
)
assert not commands
assert commands.pop() == "commit"
assert not inserted(svcconn) # Not yet committed
# Changes committed
assert inserted(svcconn) == {"outer-before", "outer-after"}
-
-
-@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()])
-@pytest.mark.parametrize("name", [None, "s1"])
-async def test_manual_exit_without_enter_asserts(aconn, name, exc_info):
- """
- When user is calling __enter__() and __exit__() manually for some reason,
- provide a helpful error message if they call __exit__() without first
- having called __enter__()
- """
- tx = AsyncTransaction(aconn, name)
- with pytest.raises(ProgrammingError, match="Out-of-order"):
- await tx.__aexit__(*exc_info)
-
-
-@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()])
-@pytest.mark.parametrize("name", [None, "s1"])
-async def test_manual_exit_twice_asserts(aconn, name, exc_info):
- """
- When user is calling __enter__() and __exit__() manually for some reason,
- provide a helpful error message if they accidentally call __exit__() twice.
- """
- tx = AsyncTransaction(aconn, name)
- await tx.__aenter__()
- await tx.__aexit__(*exc_info)
- with pytest.raises(ProgrammingError, match="Out-of-order"):
- await tx.__aexit__(*exc_info)