savepoint_name: Optional[str] = None,
force_rollback: bool = False,
) -> Iterator[Transaction]:
- tx = Transaction(self, savepoint_name, force_rollback)
- with tx:
+ with Transaction(self, savepoint_name or "", force_rollback) as tx:
yield tx
@classmethod
def __init__(
self,
connection: "Connection",
- savepoint_name: Optional[str] = None,
+ savepoint_name: str = "",
force_rollback: bool = False,
):
self._conn = connection
- self._savepoint_name: Optional[str] = None
- if savepoint_name is not None:
- if not savepoint_name:
- raise ValueError("savepoint_name must be a non-empty string")
- self._savepoint_name = savepoint_name
+ self._savepoint_name = savepoint_name
self.force_rollback = force_rollback
self._outer_transaction: Optional[bool] = None
return self._conn
@property
- def savepoint_name(self) -> Optional[str]:
+ def savepoint_name(self) -> str:
return self._savepoint_name
def __enter__(self) -> "Transaction":
if self._conn._savepoints is None:
self._conn._savepoints = []
self._outer_transaction = False
- if self._savepoint_name is None:
+ if not self._savepoint_name:
self._savepoint_name = (
f"s{len(self._conn._savepoints) + 1}"
)
- if self._savepoint_name is not None:
+ if self._savepoint_name:
self._conn._exec_command(
sql.SQL("savepoint {}").format(
sql.Identifier(self._savepoint_name)
def __repr__(self) -> str:
args = [f"connection={self.connection}"]
- if self.savepoint_name is not None:
+ if not self.savepoint_name:
args.append(f"savepoint_name={self.savepoint_name!r}")
if self.force_rollback:
args.append("force_rollback=True")
assert_rows(svcconn, {"one", "two", "three"})
-def test_named_savepoint_empty_string_invalid(conn):
- """
- Raise validate savepoint_name up-front (rather than later constructing an
- invalid SQL command and having that fail with an OperationalError).
- """
- with pytest.raises(ValueError):
- with conn.transaction(savepoint_name=""):
- pass
-
-
def test_named_savepoint_escapes_savepoint_name(conn):
with conn.transaction("s-1"):
pass
tx = Transaction(conn)
with assert_commands_issued(conn, "begin"):
tx.__enter__()
- assert tx.savepoint_name is None
+ assert not tx.savepoint_name
with assert_commands_issued(conn, "commit"):
tx.__exit__(None, None, None)
tx = Transaction(conn)
with assert_commands_issued(conn, "begin"):
tx.__enter__()
- assert tx.savepoint_name is None
+ assert not tx.savepoint_name
with assert_commands_issued(conn, "rollback"):
tx.__exit__(*some_exc_info())