self._conn = connection
self._savepoint_name = savepoint_name or ""
self.force_rollback = force_rollback
-
- self._outer_transaction: Optional[bool] = None
+ self._outer_transaction = (
+ connection.pgconn.transaction_status == TransactionStatus.IDLE
+ )
@property
def connection(self) -> ConnectionType:
class Transaction(BaseTransaction["Connection"]):
def __enter__(self) -> "Transaction":
with self._conn.lock:
- if self._conn.pgconn.transaction_status == TransactionStatus.IDLE:
+ if self._outer_transaction:
assert self._conn._savepoints is None, self._conn._savepoints
self._conn._savepoints = []
- self._outer_transaction = True
self._conn._exec_command(b"begin")
else:
if self._conn._savepoints is None:
self._conn._savepoints = []
- self._outer_transaction = False
if not self._savepoint_name:
self._savepoint_name = (
f"s{len(self._conn._savepoints) + 1}"
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
- if self._outer_transaction is None:
- raise self._out_of_order_err
with self._conn.lock:
if not exc_val and not self.force_rollback:
return self._commit()
class AsyncTransaction(BaseTransaction["AsyncConnection"]):
async def __aenter__(self) -> "AsyncTransaction":
async with self._conn.lock:
- if self._conn.pgconn.transaction_status == TransactionStatus.IDLE:
+ if self._outer_transaction:
assert self._conn._savepoints is None, self._conn._savepoints
self._conn._savepoints = []
- self._outer_transaction = True
await self._conn._exec_command(b"begin")
else:
if self._conn._savepoints is None:
self._conn._savepoints = []
- self._outer_transaction = False
if not self._savepoint_name:
self._savepoint_name = (
f"s{len(self._conn._savepoints) + 1}"
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
- if self._outer_transaction is None:
- raise self._out_of_order_err
async with self._conn.lock:
if not exc_val and not self.force_rollback:
return await self._commit()
assert inserted(svcconn) == {"outer-before", "outer-after"}
-@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()])
-@pytest.mark.parametrize("name", [None, "s1"])
-def test_manual_enter_and_exit_out_of_order_exit_asserts(conn, name, exc_info):
- """
- When user is calling __enter__() and __exit__() manually for some reason,
- provide a helpful error message if they call __exit__() in the wrong order
- for nested transactions.
- """
- tx1, tx2 = Transaction(conn, name), Transaction(conn)
- tx1.__enter__()
- tx2.__enter__()
- with pytest.raises(ProgrammingError, match="Out-of-order"):
- tx1.__exit__(*exc_info)
-
-
@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()])
@pytest.mark.parametrize("name", [None, "s1"])
def test_manual_exit_without_enter_asserts(conn, name, exc_info):
assert inserted(svcconn) == {"outer-before", "outer-after"}
-@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()])
-@pytest.mark.parametrize("name", [None, "s1"])
-async def test_manual_enter_and_exit_out_of_order_exit_asserts(
- aconn, name, exc_info
-):
- """
- When user is calling __enter__() and __exit__() manually for some reason,
- provide a helpful error message if they call __exit__() in the wrong order
- for nested transactions.
- """
- tx1, tx2 = AsyncTransaction(aconn, name), AsyncTransaction(aconn)
- await tx1.__aenter__()
- await tx2.__aenter__()
- with pytest.raises(ProgrammingError, match="Out-of-order"):
- await tx1.__aexit__(*exc_info)
-
-
@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()])
@pytest.mark.parametrize("name", [None, "s1"])
async def test_manual_exit_without_enter_asserts(aconn, name, exc_info):