]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Run trans.close() at end of block if transaction already inactive
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 29 Mar 2021 20:40:16 +0000 (16:40 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 30 Mar 2021 14:00:40 +0000 (10:00 -0400)
Modified the context manager used by :class:`_engine.Transaction` so that
an "already detached" warning is not emitted by the ending of the context
manager itself, if the transaction were already manually rolled back inside
the block. This applies to regular transactions, savepoint transactions,
and legacy "marker" transactions. A warning is still emitted if the
``.rollback()`` method is called explicitly more than once.

Fixes: #6155
Change-Id: Ib9f9d803bf377ec843d4a8a09da8ebef4b441665

doc/build/changelog/unreleased_14/6155.rst [new file with mode: 0644]
lib/sqlalchemy/engine/base.py
test/engine/test_transaction.py

diff --git a/doc/build/changelog/unreleased_14/6155.rst b/doc/build/changelog/unreleased_14/6155.rst
new file mode 100644 (file)
index 0000000..9debc17
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: usecase, engine
+    :tickets: 6155
+
+    Modified the context manager used by :class:`_engine.Transaction` so that
+    an "already detached" warning is not emitted by the ending of the context
+    manager itself, if the transaction were already manually rolled back inside
+    the block. This applies to regular transactions, savepoint transactions,
+    and legacy "marker" transactions. A warning is still emitted if the
+    ``.rollback()`` method is called explicitly more than once.
index c5fefa8d9ae9b415bb19c6fa9c798b150df1a17b..0a3ce3bdc05f1baaa892e6abd75ac841ded86b03 100644 (file)
@@ -2185,6 +2185,14 @@ class Transaction(object):
         """
         raise NotImplementedError()
 
+    @property
+    def _deactivated_from_connection(self):
+        """True if this transaction is totally deactivated from the connection
+        and therefore can no longer affect its state.
+
+        """
+        raise NotImplementedError()
+
     def _do_close(self):
         raise NotImplementedError()
 
@@ -2269,7 +2277,10 @@ class Transaction(object):
                 with util.safe_reraise():
                     self.rollback()
         else:
-            self.rollback()
+            if self._deactivated_from_connection:
+                self.close()
+            else:
+                self.rollback()
 
 
 class MarkerTransaction(Transaction):
@@ -2305,6 +2316,10 @@ class MarkerTransaction(Transaction):
             self._transaction = connection._transaction
         self._is_active = True
 
+    @property
+    def _deactivated_from_connection(self):
+        return not self.is_active
+
     @property
     def is_active(self):
         return self._is_active and self._transaction.is_active
@@ -2361,6 +2376,10 @@ class RootTransaction(Transaction):
         elif self.connection._transaction is not self:
             util.warn("transaction already deassociated from connection")
 
+    @property
+    def _deactivated_from_connection(self):
+        return self.connection._transaction is not self
+
     def _do_deactivate(self):
         # called from a MarkerTransaction to cancel this root transaction.
         # the transaction stays in place as connection._transaction, but
@@ -2484,14 +2503,18 @@ class NestedTransaction(Transaction):
         self._previous_nested = connection._nested_transaction
         connection._nested_transaction = self
 
-    def _deactivate_from_connection(self):
+    def _deactivate_from_connection(self, warn=True):
         if self.connection._nested_transaction is self:
             self.connection._nested_transaction = self._previous_nested
-        else:
+        elif warn:
             util.warn(
                 "nested transaction already deassociated from connection"
             )
 
+    @property
+    def _deactivated_from_connection(self):
+        return self.connection._nested_transaction is not self
+
     def _cancel(self):
         # called by RootTransaction when the outer transaction is
         # committed, rolled back, or closed to cancel all savepoints
@@ -2501,23 +2524,28 @@ class NestedTransaction(Transaction):
         if self._previous_nested:
             self._previous_nested._cancel()
 
-    def _close_impl(self, deactivate_from_connection):
+    def _close_impl(self, deactivate_from_connection, warn_already_deactive):
         try:
             if self.is_active and self.connection._transaction.is_active:
                 self.connection._rollback_to_savepoint_impl(self._savepoint)
         finally:
             self.is_active = False
+
             if deactivate_from_connection:
-                self._deactivate_from_connection()
+                self._deactivate_from_connection(warn=warn_already_deactive)
+
+        assert not self.is_active
+        if deactivate_from_connection:
+            assert self.connection._nested_transaction is not self
 
     def _do_deactivate(self):
-        self._close_impl(False)
+        self._close_impl(False, False)
 
     def _do_close(self):
-        self._close_impl(True)
+        self._close_impl(True, False)
 
     def _do_rollback(self):
-        self._close_impl(True)
+        self._close_impl(True, True)
 
     def _do_commit(self):
         if self.is_active:
index 07408e386eb5540ba26af85bb093e02344cc22cb..d972bcb73a529287c3c525ef697523afc97bfabc 100644 (file)
@@ -20,6 +20,8 @@ from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import mock
 from sqlalchemy.testing import ne_
+from sqlalchemy.testing.assertions import expect_deprecated_20
+from sqlalchemy.testing.assertions import expect_raises_message
 from sqlalchemy.testing.engines import testing_engine
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -98,21 +100,127 @@ class TransactionTest(fixtures.TablesTest):
         result = connection.exec_driver_sql("select * from users")
         assert len(result.fetchall()) == 0
 
-    def test_deactivated_warning_ctxmanager(self, local_connection):
+    def test_rollback_end_ctx_manager_autocommit(self, local_connection):
+        m1 = mock.Mock()
+
+        event.listen(local_connection, "rollback", m1.rollback)
+        event.listen(local_connection, "commit", m1.commit)
+
+        with local_connection.begin() as trans:
+            assert local_connection.in_transaction()
+            trans.rollback()
+            assert not local_connection.in_transaction()
+
+            # would be subject to autocommit
+            local_connection.execute(select(1))
+
+            assert not local_connection.in_transaction()
+
+    @testing.combinations((True,), (False,), argnames="roll_back_in_block")
+    def test_ctxmanager_rolls_back(self, local_connection, roll_back_in_block):
+        m1 = mock.Mock()
+
+        event.listen(local_connection, "rollback", m1.rollback)
+        event.listen(local_connection, "commit", m1.commit)
+
+        with expect_raises_message(Exception, "test"):
+            with local_connection.begin() as trans:
+                if roll_back_in_block:
+                    trans.rollback()
+
+                if 1 == 1:
+                    raise Exception("test")
+
+        assert not trans.is_active
+        assert not local_connection.in_transaction()
+        assert trans._deactivated_from_connection
+
+        eq_(m1.mock_calls, [mock.call.rollback(local_connection)])
+
+    @testing.combinations((True,), (False,), argnames="roll_back_in_block")
+    def test_ctxmanager_rolls_back_legacy_marker(
+        self, local_connection, roll_back_in_block
+    ):
+        m1 = mock.Mock()
+
+        event.listen(local_connection, "rollback", m1.rollback)
+        event.listen(local_connection, "commit", m1.commit)
+
+        with expect_deprecated_20(
+            r"Calling .begin\(\) when a transaction is already begun"
+        ):
+            with local_connection.begin() as trans:
+                with expect_raises_message(Exception, "test"):
+                    with local_connection.begin() as marker_trans:
+                        if roll_back_in_block:
+                            marker_trans.rollback()
+                        if 1 == 1:
+                            raise Exception("test")
+
+                    assert not marker_trans.is_active
+                    assert marker_trans._deactivated_from_connection
+
+                assert not trans._deactivated_from_connection
+                assert not trans.is_active
+                assert not local_connection.in_transaction()
+
+        eq_(m1.mock_calls, [mock.call.rollback(local_connection)])
+
+    @testing.combinations((True,), (False,), argnames="roll_back_in_block")
+    @testing.requires.savepoints
+    def test_ctxmanager_rolls_back_savepoint(
+        self, local_connection, roll_back_in_block
+    ):
+        m1 = mock.Mock()
+
+        event.listen(
+            local_connection, "rollback_savepoint", m1.rollback_savepoint
+        )
+        event.listen(local_connection, "rollback", m1.rollback)
+        event.listen(local_connection, "commit", m1.commit)
+
+        with local_connection.begin() as trans:
+            with expect_raises_message(Exception, "test"):
+                with local_connection.begin_nested() as nested_trans:
+                    if roll_back_in_block:
+                        nested_trans.rollback()
+                    if 1 == 1:
+                        raise Exception("test")
+
+                assert not nested_trans.is_active
+                assert nested_trans._deactivated_from_connection
+
+            assert trans.is_active
+            assert local_connection.in_transaction()
+            assert not trans._deactivated_from_connection
+
+        eq_(
+            m1.mock_calls,
+            [
+                mock.call.rollback_savepoint(
+                    local_connection, mock.ANY, mock.ANY
+                ),
+                mock.call.commit(local_connection),
+            ],
+        )
+
+    def test_deactivated_warning_straight(self, local_connection):
         with expect_warnings(
             "transaction already deassociated from connection"
         ):
-            with local_connection.begin() as trans:
-                trans.rollback()
+            trans = local_connection.begin()
+            trans.rollback()
+            trans.rollback()
 
     @testing.requires.savepoints
-    def test_deactivated_savepoint_warning_ctxmanager(self, local_connection):
+    def test_deactivated_savepoint_warning_straight(self, local_connection):
         with expect_warnings(
             "nested transaction already deassociated from connection"
         ):
             with local_connection.begin():
-                with local_connection.begin_nested() as savepoint:
-                    savepoint.rollback()
+                savepoint = local_connection.begin_nested()
+                savepoint.rollback()
+                savepoint.rollback()
 
     def test_commit_fails_flat(self, local_connection):
         connection = local_connection
@@ -1335,6 +1443,11 @@ class FutureTransactionTest(fixtures.FutureEngineMixin, fixtures.TablesTest):
             test_needs_acid=True,
         )
 
+    @testing.fixture
+    def local_connection(self):
+        with testing.db.connect() as conn:
+            yield conn
+
     def test_autobegin_rollback(self):
         users = self.tables.users
         with testing.db.connect() as conn:
@@ -1518,6 +1631,44 @@ class FutureTransactionTest(fixtures.FutureEngineMixin, fixtures.TablesTest):
         with testing.db.begin() as conn:
             assert conn.in_transaction()
             conn.rollback()
+            assert not conn.in_transaction()
+
+    def test_rollback_end_ctx_manager_autobegin(self, local_connection):
+        m1 = mock.Mock()
+
+        event.listen(local_connection, "rollback", m1.rollback)
+        event.listen(local_connection, "commit", m1.commit)
+
+        with local_connection.begin() as trans:
+            assert local_connection.in_transaction()
+            trans.rollback()
+            assert not local_connection.in_transaction()
+
+            # autobegin
+            local_connection.execute(select(1))
+
+            assert local_connection.in_transaction()
+
+    @testing.combinations((True,), (False,), argnames="roll_back_in_block")
+    def test_ctxmanager_rolls_back(self, local_connection, roll_back_in_block):
+        m1 = mock.Mock()
+
+        event.listen(local_connection, "rollback", m1.rollback)
+        event.listen(local_connection, "commit", m1.commit)
+
+        with expect_raises_message(Exception, "test"):
+            with local_connection.begin() as trans:
+                if roll_back_in_block:
+                    trans.rollback()
+
+                if 1 == 1:
+                    raise Exception("test")
+
+        assert not trans.is_active
+        assert not local_connection.in_transaction()
+        assert trans._deactivated_from_connection
+
+        eq_(m1.mock_calls, [mock.call.rollback(local_connection)])
 
     def test_explicit_begin(self):
         users = self.tables.users