From bd8b269a34153c29c7f05e4acacccc6b07b47fb5 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 29 Mar 2021 16:40:16 -0400 Subject: [PATCH] Run trans.close() at end of block if transaction already inactive Modified the context manager used by :class:`_engine.Transaction` so that an "already detached" warning is not emitted by the ending of the context manager itself, if the transaction were already manually rolled back inside the block. This applies to regular transactions, savepoint transactions, and legacy "marker" transactions. A warning is still emitted if the ``.rollback()`` method is called explicitly more than once. Fixes: #6155 Change-Id: Ib9f9d803bf377ec843d4a8a09da8ebef4b441665 --- doc/build/changelog/unreleased_14/6155.rst | 10 ++ lib/sqlalchemy/engine/base.py | 44 +++++- test/engine/test_transaction.py | 163 ++++++++++++++++++++- 3 files changed, 203 insertions(+), 14 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/6155.rst diff --git a/doc/build/changelog/unreleased_14/6155.rst b/doc/build/changelog/unreleased_14/6155.rst new file mode 100644 index 0000000000..9debc17a81 --- /dev/null +++ b/doc/build/changelog/unreleased_14/6155.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: usecase, engine + :tickets: 6155 + + Modified the context manager used by :class:`_engine.Transaction` so that + an "already detached" warning is not emitted by the ending of the context + manager itself, if the transaction were already manually rolled back inside + the block. This applies to regular transactions, savepoint transactions, + and legacy "marker" transactions. A warning is still emitted if the + ``.rollback()`` method is called explicitly more than once. diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index c5fefa8d9a..0a3ce3bdc0 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -2185,6 +2185,14 @@ class Transaction(object): """ raise NotImplementedError() + @property + def _deactivated_from_connection(self): + """True if this transaction is totally deactivated from the connection + and therefore can no longer affect its state. + + """ + raise NotImplementedError() + def _do_close(self): raise NotImplementedError() @@ -2269,7 +2277,10 @@ class Transaction(object): with util.safe_reraise(): self.rollback() else: - self.rollback() + if self._deactivated_from_connection: + self.close() + else: + self.rollback() class MarkerTransaction(Transaction): @@ -2305,6 +2316,10 @@ class MarkerTransaction(Transaction): self._transaction = connection._transaction self._is_active = True + @property + def _deactivated_from_connection(self): + return not self.is_active + @property def is_active(self): return self._is_active and self._transaction.is_active @@ -2361,6 +2376,10 @@ class RootTransaction(Transaction): elif self.connection._transaction is not self: util.warn("transaction already deassociated from connection") + @property + def _deactivated_from_connection(self): + return self.connection._transaction is not self + def _do_deactivate(self): # called from a MarkerTransaction to cancel this root transaction. # the transaction stays in place as connection._transaction, but @@ -2484,14 +2503,18 @@ class NestedTransaction(Transaction): self._previous_nested = connection._nested_transaction connection._nested_transaction = self - def _deactivate_from_connection(self): + def _deactivate_from_connection(self, warn=True): if self.connection._nested_transaction is self: self.connection._nested_transaction = self._previous_nested - else: + elif warn: util.warn( "nested transaction already deassociated from connection" ) + @property + def _deactivated_from_connection(self): + return self.connection._nested_transaction is not self + def _cancel(self): # called by RootTransaction when the outer transaction is # committed, rolled back, or closed to cancel all savepoints @@ -2501,23 +2524,28 @@ class NestedTransaction(Transaction): if self._previous_nested: self._previous_nested._cancel() - def _close_impl(self, deactivate_from_connection): + def _close_impl(self, deactivate_from_connection, warn_already_deactive): try: if self.is_active and self.connection._transaction.is_active: self.connection._rollback_to_savepoint_impl(self._savepoint) finally: self.is_active = False + if deactivate_from_connection: - self._deactivate_from_connection() + self._deactivate_from_connection(warn=warn_already_deactive) + + assert not self.is_active + if deactivate_from_connection: + assert self.connection._nested_transaction is not self def _do_deactivate(self): - self._close_impl(False) + self._close_impl(False, False) def _do_close(self): - self._close_impl(True) + self._close_impl(True, False) def _do_rollback(self): - self._close_impl(True) + self._close_impl(True, True) def _do_commit(self): if self.is_active: diff --git a/test/engine/test_transaction.py b/test/engine/test_transaction.py index 07408e386e..d972bcb73a 100644 --- a/test/engine/test_transaction.py +++ b/test/engine/test_transaction.py @@ -20,6 +20,8 @@ from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock from sqlalchemy.testing import ne_ +from sqlalchemy.testing.assertions import expect_deprecated_20 +from sqlalchemy.testing.assertions import expect_raises_message from sqlalchemy.testing.engines import testing_engine from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -98,21 +100,127 @@ class TransactionTest(fixtures.TablesTest): result = connection.exec_driver_sql("select * from users") assert len(result.fetchall()) == 0 - def test_deactivated_warning_ctxmanager(self, local_connection): + def test_rollback_end_ctx_manager_autocommit(self, local_connection): + m1 = mock.Mock() + + event.listen(local_connection, "rollback", m1.rollback) + event.listen(local_connection, "commit", m1.commit) + + with local_connection.begin() as trans: + assert local_connection.in_transaction() + trans.rollback() + assert not local_connection.in_transaction() + + # would be subject to autocommit + local_connection.execute(select(1)) + + assert not local_connection.in_transaction() + + @testing.combinations((True,), (False,), argnames="roll_back_in_block") + def test_ctxmanager_rolls_back(self, local_connection, roll_back_in_block): + m1 = mock.Mock() + + event.listen(local_connection, "rollback", m1.rollback) + event.listen(local_connection, "commit", m1.commit) + + with expect_raises_message(Exception, "test"): + with local_connection.begin() as trans: + if roll_back_in_block: + trans.rollback() + + if 1 == 1: + raise Exception("test") + + assert not trans.is_active + assert not local_connection.in_transaction() + assert trans._deactivated_from_connection + + eq_(m1.mock_calls, [mock.call.rollback(local_connection)]) + + @testing.combinations((True,), (False,), argnames="roll_back_in_block") + def test_ctxmanager_rolls_back_legacy_marker( + self, local_connection, roll_back_in_block + ): + m1 = mock.Mock() + + event.listen(local_connection, "rollback", m1.rollback) + event.listen(local_connection, "commit", m1.commit) + + with expect_deprecated_20( + r"Calling .begin\(\) when a transaction is already begun" + ): + with local_connection.begin() as trans: + with expect_raises_message(Exception, "test"): + with local_connection.begin() as marker_trans: + if roll_back_in_block: + marker_trans.rollback() + if 1 == 1: + raise Exception("test") + + assert not marker_trans.is_active + assert marker_trans._deactivated_from_connection + + assert not trans._deactivated_from_connection + assert not trans.is_active + assert not local_connection.in_transaction() + + eq_(m1.mock_calls, [mock.call.rollback(local_connection)]) + + @testing.combinations((True,), (False,), argnames="roll_back_in_block") + @testing.requires.savepoints + def test_ctxmanager_rolls_back_savepoint( + self, local_connection, roll_back_in_block + ): + m1 = mock.Mock() + + event.listen( + local_connection, "rollback_savepoint", m1.rollback_savepoint + ) + event.listen(local_connection, "rollback", m1.rollback) + event.listen(local_connection, "commit", m1.commit) + + with local_connection.begin() as trans: + with expect_raises_message(Exception, "test"): + with local_connection.begin_nested() as nested_trans: + if roll_back_in_block: + nested_trans.rollback() + if 1 == 1: + raise Exception("test") + + assert not nested_trans.is_active + assert nested_trans._deactivated_from_connection + + assert trans.is_active + assert local_connection.in_transaction() + assert not trans._deactivated_from_connection + + eq_( + m1.mock_calls, + [ + mock.call.rollback_savepoint( + local_connection, mock.ANY, mock.ANY + ), + mock.call.commit(local_connection), + ], + ) + + def test_deactivated_warning_straight(self, local_connection): with expect_warnings( "transaction already deassociated from connection" ): - with local_connection.begin() as trans: - trans.rollback() + trans = local_connection.begin() + trans.rollback() + trans.rollback() @testing.requires.savepoints - def test_deactivated_savepoint_warning_ctxmanager(self, local_connection): + def test_deactivated_savepoint_warning_straight(self, local_connection): with expect_warnings( "nested transaction already deassociated from connection" ): with local_connection.begin(): - with local_connection.begin_nested() as savepoint: - savepoint.rollback() + savepoint = local_connection.begin_nested() + savepoint.rollback() + savepoint.rollback() def test_commit_fails_flat(self, local_connection): connection = local_connection @@ -1335,6 +1443,11 @@ class FutureTransactionTest(fixtures.FutureEngineMixin, fixtures.TablesTest): test_needs_acid=True, ) + @testing.fixture + def local_connection(self): + with testing.db.connect() as conn: + yield conn + def test_autobegin_rollback(self): users = self.tables.users with testing.db.connect() as conn: @@ -1518,6 +1631,44 @@ class FutureTransactionTest(fixtures.FutureEngineMixin, fixtures.TablesTest): with testing.db.begin() as conn: assert conn.in_transaction() conn.rollback() + assert not conn.in_transaction() + + def test_rollback_end_ctx_manager_autobegin(self, local_connection): + m1 = mock.Mock() + + event.listen(local_connection, "rollback", m1.rollback) + event.listen(local_connection, "commit", m1.commit) + + with local_connection.begin() as trans: + assert local_connection.in_transaction() + trans.rollback() + assert not local_connection.in_transaction() + + # autobegin + local_connection.execute(select(1)) + + assert local_connection.in_transaction() + + @testing.combinations((True,), (False,), argnames="roll_back_in_block") + def test_ctxmanager_rolls_back(self, local_connection, roll_back_in_block): + m1 = mock.Mock() + + event.listen(local_connection, "rollback", m1.rollback) + event.listen(local_connection, "commit", m1.commit) + + with expect_raises_message(Exception, "test"): + with local_connection.begin() as trans: + if roll_back_in_block: + trans.rollback() + + if 1 == 1: + raise Exception("test") + + assert not trans.is_active + assert not local_connection.in_transaction() + assert trans._deactivated_from_connection + + eq_(m1.mock_calls, [mock.call.rollback(local_connection)]) def test_explicit_begin(self): users = self.tables.users -- 2.47.2