"""
raise NotImplementedError()
+ @property
+ def _deactivated_from_connection(self):
+ """True if this transaction is totally deactivated from the connection
+ and therefore can no longer affect its state.
+
+ """
+ raise NotImplementedError()
+
def _do_close(self):
raise NotImplementedError()
with util.safe_reraise():
self.rollback()
else:
- self.rollback()
+ if self._deactivated_from_connection:
+ self.close()
+ else:
+ self.rollback()
class MarkerTransaction(Transaction):
self._transaction = connection._transaction
self._is_active = True
+ @property
+ def _deactivated_from_connection(self):
+ return not self.is_active
+
@property
def is_active(self):
return self._is_active and self._transaction.is_active
elif self.connection._transaction is not self:
util.warn("transaction already deassociated from connection")
+ @property
+ def _deactivated_from_connection(self):
+ return self.connection._transaction is not self
+
def _do_deactivate(self):
# called from a MarkerTransaction to cancel this root transaction.
# the transaction stays in place as connection._transaction, but
self._previous_nested = connection._nested_transaction
connection._nested_transaction = self
- def _deactivate_from_connection(self):
+ def _deactivate_from_connection(self, warn=True):
if self.connection._nested_transaction is self:
self.connection._nested_transaction = self._previous_nested
- else:
+ elif warn:
util.warn(
"nested transaction already deassociated from connection"
)
+ @property
+ def _deactivated_from_connection(self):
+ return self.connection._nested_transaction is not self
+
def _cancel(self):
# called by RootTransaction when the outer transaction is
# committed, rolled back, or closed to cancel all savepoints
if self._previous_nested:
self._previous_nested._cancel()
- def _close_impl(self, deactivate_from_connection):
+ def _close_impl(self, deactivate_from_connection, warn_already_deactive):
try:
if self.is_active and self.connection._transaction.is_active:
self.connection._rollback_to_savepoint_impl(self._savepoint)
finally:
self.is_active = False
+
if deactivate_from_connection:
- self._deactivate_from_connection()
+ self._deactivate_from_connection(warn=warn_already_deactive)
+
+ assert not self.is_active
+ if deactivate_from_connection:
+ assert self.connection._nested_transaction is not self
def _do_deactivate(self):
- self._close_impl(False)
+ self._close_impl(False, False)
def _do_close(self):
- self._close_impl(True)
+ self._close_impl(True, False)
def _do_rollback(self):
- self._close_impl(True)
+ self._close_impl(True, True)
def _do_commit(self):
if self.is_active:
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import mock
from sqlalchemy.testing import ne_
+from sqlalchemy.testing.assertions import expect_deprecated_20
+from sqlalchemy.testing.assertions import expect_raises_message
from sqlalchemy.testing.engines import testing_engine
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
result = connection.exec_driver_sql("select * from users")
assert len(result.fetchall()) == 0
- def test_deactivated_warning_ctxmanager(self, local_connection):
+ def test_rollback_end_ctx_manager_autocommit(self, local_connection):
+ m1 = mock.Mock()
+
+ event.listen(local_connection, "rollback", m1.rollback)
+ event.listen(local_connection, "commit", m1.commit)
+
+ with local_connection.begin() as trans:
+ assert local_connection.in_transaction()
+ trans.rollback()
+ assert not local_connection.in_transaction()
+
+ # would be subject to autocommit
+ local_connection.execute(select(1))
+
+ assert not local_connection.in_transaction()
+
+ @testing.combinations((True,), (False,), argnames="roll_back_in_block")
+ def test_ctxmanager_rolls_back(self, local_connection, roll_back_in_block):
+ m1 = mock.Mock()
+
+ event.listen(local_connection, "rollback", m1.rollback)
+ event.listen(local_connection, "commit", m1.commit)
+
+ with expect_raises_message(Exception, "test"):
+ with local_connection.begin() as trans:
+ if roll_back_in_block:
+ trans.rollback()
+
+ if 1 == 1:
+ raise Exception("test")
+
+ assert not trans.is_active
+ assert not local_connection.in_transaction()
+ assert trans._deactivated_from_connection
+
+ eq_(m1.mock_calls, [mock.call.rollback(local_connection)])
+
+ @testing.combinations((True,), (False,), argnames="roll_back_in_block")
+ def test_ctxmanager_rolls_back_legacy_marker(
+ self, local_connection, roll_back_in_block
+ ):
+ m1 = mock.Mock()
+
+ event.listen(local_connection, "rollback", m1.rollback)
+ event.listen(local_connection, "commit", m1.commit)
+
+ with expect_deprecated_20(
+ r"Calling .begin\(\) when a transaction is already begun"
+ ):
+ with local_connection.begin() as trans:
+ with expect_raises_message(Exception, "test"):
+ with local_connection.begin() as marker_trans:
+ if roll_back_in_block:
+ marker_trans.rollback()
+ if 1 == 1:
+ raise Exception("test")
+
+ assert not marker_trans.is_active
+ assert marker_trans._deactivated_from_connection
+
+ assert not trans._deactivated_from_connection
+ assert not trans.is_active
+ assert not local_connection.in_transaction()
+
+ eq_(m1.mock_calls, [mock.call.rollback(local_connection)])
+
+ @testing.combinations((True,), (False,), argnames="roll_back_in_block")
+ @testing.requires.savepoints
+ def test_ctxmanager_rolls_back_savepoint(
+ self, local_connection, roll_back_in_block
+ ):
+ m1 = mock.Mock()
+
+ event.listen(
+ local_connection, "rollback_savepoint", m1.rollback_savepoint
+ )
+ event.listen(local_connection, "rollback", m1.rollback)
+ event.listen(local_connection, "commit", m1.commit)
+
+ with local_connection.begin() as trans:
+ with expect_raises_message(Exception, "test"):
+ with local_connection.begin_nested() as nested_trans:
+ if roll_back_in_block:
+ nested_trans.rollback()
+ if 1 == 1:
+ raise Exception("test")
+
+ assert not nested_trans.is_active
+ assert nested_trans._deactivated_from_connection
+
+ assert trans.is_active
+ assert local_connection.in_transaction()
+ assert not trans._deactivated_from_connection
+
+ eq_(
+ m1.mock_calls,
+ [
+ mock.call.rollback_savepoint(
+ local_connection, mock.ANY, mock.ANY
+ ),
+ mock.call.commit(local_connection),
+ ],
+ )
+
+ def test_deactivated_warning_straight(self, local_connection):
with expect_warnings(
"transaction already deassociated from connection"
):
- with local_connection.begin() as trans:
- trans.rollback()
+ trans = local_connection.begin()
+ trans.rollback()
+ trans.rollback()
@testing.requires.savepoints
- def test_deactivated_savepoint_warning_ctxmanager(self, local_connection):
+ def test_deactivated_savepoint_warning_straight(self, local_connection):
with expect_warnings(
"nested transaction already deassociated from connection"
):
with local_connection.begin():
- with local_connection.begin_nested() as savepoint:
- savepoint.rollback()
+ savepoint = local_connection.begin_nested()
+ savepoint.rollback()
+ savepoint.rollback()
def test_commit_fails_flat(self, local_connection):
connection = local_connection
test_needs_acid=True,
)
+ @testing.fixture
+ def local_connection(self):
+ with testing.db.connect() as conn:
+ yield conn
+
def test_autobegin_rollback(self):
users = self.tables.users
with testing.db.connect() as conn:
with testing.db.begin() as conn:
assert conn.in_transaction()
conn.rollback()
+ assert not conn.in_transaction()
+
+ def test_rollback_end_ctx_manager_autobegin(self, local_connection):
+ m1 = mock.Mock()
+
+ event.listen(local_connection, "rollback", m1.rollback)
+ event.listen(local_connection, "commit", m1.commit)
+
+ with local_connection.begin() as trans:
+ assert local_connection.in_transaction()
+ trans.rollback()
+ assert not local_connection.in_transaction()
+
+ # autobegin
+ local_connection.execute(select(1))
+
+ assert local_connection.in_transaction()
+
+ @testing.combinations((True,), (False,), argnames="roll_back_in_block")
+ def test_ctxmanager_rolls_back(self, local_connection, roll_back_in_block):
+ m1 = mock.Mock()
+
+ event.listen(local_connection, "rollback", m1.rollback)
+ event.listen(local_connection, "commit", m1.commit)
+
+ with expect_raises_message(Exception, "test"):
+ with local_connection.begin() as trans:
+ if roll_back_in_block:
+ trans.rollback()
+
+ if 1 == 1:
+ raise Exception("test")
+
+ assert not trans.is_active
+ assert not local_connection.in_transaction()
+ assert trans._deactivated_from_connection
+
+ eq_(m1.mock_calls, [mock.call.rollback(local_connection)])
def test_explicit_begin(self):
users = self.tables.users