]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
unify transactional context managers
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 2 May 2021 22:31:03 +0000 (18:31 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 6 May 2021 02:21:07 +0000 (22:21 -0400)
Applied consistent behavior to the use case of
calling ``.commit()`` or ``.rollback()`` inside of an existing
``.begin()`` context manager, with the addition of potentially
emitting SQL within the block subsequent to the commit or rollback.
This change continues upon the change first added in
:ticket:`6155` where the use case of calling "rollback" inside of
a ``.begin()`` contextmanager block was proposed:

* calling ``.commit()`` or ``.rollback()`` will now be allowed
without error or warning within all scopes, including
that of legacy and future :class:`_engine.Engine`, ORM
:class:`_orm.Session`, asyncio :class:`.AsyncEngine`.  Previously,
the :class:`_orm.Session` disallowed this.

* The remaining scope of the context manager is then closed;
when the block ends, a check is emitted to see if the transaction
was already ended, and if so the block returns without action.

* It will now raise **an error** if subsequent SQL of any kind
is emitted within the block, **after** ``.commit()`` or
``.rollback()`` is called.   The block should be closed as
the state of the executable object would otherwise be undefined
in this state.

Fixes: #6288
Change-Id: I8b21766ae430f0fa1ac5ef689f4c0fb19fc84336

14 files changed:
doc/build/changelog/unreleased_14/6288.rst [new file with mode: 0644]
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/util.py
lib/sqlalchemy/ext/asyncio/base.py
lib/sqlalchemy/ext/asyncio/engine.py
lib/sqlalchemy/ext/asyncio/session.py
lib/sqlalchemy/future/engine.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/testing/fixtures.py
lib/sqlalchemy/util/concurrency.py
test/engine/test_transaction.py
test/ext/asyncio/test_engine_py3k.py
test/ext/asyncio/test_session_py3k.py
test/orm/test_transaction.py

diff --git a/doc/build/changelog/unreleased_14/6288.rst b/doc/build/changelog/unreleased_14/6288.rst
new file mode 100644 (file)
index 0000000..63b0f6d
--- /dev/null
@@ -0,0 +1,27 @@
+.. change::
+    :tags: usecase, engine, orm
+    :tickets: 6288
+
+    Applied consistent behavior to the use case of
+    calling ``.commit()`` or ``.rollback()`` inside of an existing
+    ``.begin()`` context manager, with the addition of potentially
+    emitting SQL within the block subsequent to the commit or rollback.
+    This change continues upon the change first added in
+    :ticket:`6155` where the use case of calling "rollback" inside of
+    a ``.begin()`` contextmanager block was proposed:
+
+    * calling ``.commit()`` or ``.rollback()`` will now be allowed
+      without error or warning within all scopes, including
+      that of legacy and future :class:`_engine.Engine`, ORM
+      :class:`_orm.Session`, asyncio :class:`.AsyncEngine`.  Previously,
+      the :class:`_orm.Session` disallowed this.
+
+    * The remaining scope of the context manager is then closed;
+      when the block ends, a check is emitted to see if the transaction
+      was already ended, and if so the block returns without action.
+
+    * It will now raise **an error** if subsequent SQL of any kind
+      is emitted within the block, **after** ``.commit()`` or
+      ``.rollback()`` is called.   The block should be closed as
+      the state of the executable object would otherwise be undefined
+      in this state.
index 293dc21b4ea6c09f04e8cf4c256602afc915a1f5..663482b1f5a1962b3eee5dc29fb414cb41e1eb6c 100644 (file)
@@ -13,6 +13,7 @@ from .interfaces import Connectable
 from .interfaces import ExceptionContext
 from .util import _distill_params
 from .util import _distill_params_20
+from .util import TransactionalContext
 from .. import exc
 from .. import inspection
 from .. import log
@@ -60,6 +61,9 @@ class Connection(Connectable):
     _is_future = False
     _sqla_logger_namespace = "sqlalchemy.engine.Connection"
 
+    # used by sqlalchemy.engine.util.TransactionalContext
+    _trans_context_manager = None
+
     def __init__(
         self,
         engine,
@@ -1683,6 +1687,9 @@ class Connection(Connectable):
         ):
             self._invalid_transaction()
 
+        elif self._trans_context_manager:
+            TransactionalContext._trans_ctx_check(self)
+
         if self._is_future and self._transaction is None:
             self._autobegin()
 
@@ -2182,7 +2189,7 @@ class ExceptionContextImpl(ExceptionContext):
         self.invalidate_pool_on_disconnect = invalidate_pool_on_disconnect
 
 
-class Transaction(object):
+class Transaction(TransactionalContext):
     """Represent a database transaction in progress.
 
     The :class:`.Transaction` object is procured by
@@ -2324,21 +2331,14 @@ class Transaction(object):
         finally:
             assert not self.is_active
 
-    def __enter__(self):
-        return self
+    def _get_subject(self):
+        return self.connection
 
-    def __exit__(self, type_, value, traceback):
-        if type_ is None and self.is_active:
-            try:
-                self.commit()
-            except:
-                with util.safe_reraise():
-                    self.rollback()
-        else:
-            if self._deactivated_from_connection:
-                self.close()
-            else:
-                self.rollback()
+    def _transaction_is_active(self):
+        return self.is_active
+
+    def _transaction_is_closed(self):
+        return not self._deactivated_from_connection
 
 
 class MarkerTransaction(Transaction):
@@ -2368,6 +2368,10 @@ class MarkerTransaction(Transaction):
         )
 
         self.connection = connection
+
+        if connection._trans_context_manager:
+            TransactionalContext._trans_ctx_check(connection)
+
         if connection._nested_transaction is not None:
             self._transaction = connection._nested_transaction
         else:
@@ -2429,6 +2433,8 @@ class RootTransaction(Transaction):
 
     def __init__(self, connection):
         assert connection._transaction is None
+        if connection._trans_context_manager:
+            TransactionalContext._trans_ctx_check(connection)
         self.connection = connection
         self._connection_begin_impl()
         connection._transaction = self
@@ -2564,6 +2570,8 @@ class NestedTransaction(Transaction):
 
     def __init__(self, connection):
         assert connection._transaction is not None
+        if connection._trans_context_manager:
+            TransactionalContext._trans_ctx_check(connection)
         self.connection = connection
         self._savepoint = self.connection._savepoint_impl()
         self.is_active = True
@@ -2935,16 +2943,12 @@ class Engine(Connectable, log.Identified):
             self.close_with_result = close_with_result
 
         def __enter__(self):
+            self.transaction.__enter__()
             return self.conn
 
         def __exit__(self, type_, value, traceback):
             try:
-                if type_ is not None:
-                    if self.transaction.is_active:
-                        self.transaction.rollback()
-                else:
-                    if self.transaction.is_active:
-                        self.transaction.commit()
+                self.transaction.__exit__(type_, value, traceback)
             finally:
                 if not self.close_with_result:
                     self.conn.close()
index ede2631985ae17ab2429a4904f425226126b4310..17e3510aad0c096aed51388340708299606bfcff 100644 (file)
@@ -153,3 +153,82 @@ def _distill_params_20(params):
         return (params,), _no_kw
     else:
         raise exc.ArgumentError("mapping or sequence expected for parameters")
+
+
+class TransactionalContext(object):
+    """Apply Python context manager behavior to transaction objects.
+
+    Performs validation to ensure the subject of the transaction is not
+    used if the transaction were ended prematurely.
+
+    """
+
+    _trans_subject = None
+
+    def _transaction_is_active(self):
+        raise NotImplementedError()
+
+    def _transaction_is_closed(self):
+        raise NotImplementedError()
+
+    def _get_subject(self):
+        raise NotImplementedError()
+
+    @classmethod
+    def _trans_ctx_check(cls, subject):
+        trans_context = subject._trans_context_manager
+        if trans_context:
+            if not trans_context._transaction_is_active():
+                raise exc.InvalidRequestError(
+                    "Can't operate on closed transaction inside context "
+                    "manager.  Please complete the context manager "
+                    "before emitting further commands."
+                )
+
+    def __enter__(self):
+        subject = self._get_subject()
+
+        # none for outer transaction, may be non-None for nested
+        # savepoint, legacy nesting cases
+        trans_context = subject._trans_context_manager
+        self._outer_trans_ctx = trans_context
+
+        self._trans_subject = subject
+        subject._trans_context_manager = self
+        return self
+
+    def __exit__(self, type_, value, traceback):
+        subject = self._trans_subject
+
+        # simplistically we could assume that
+        # "subject._trans_context_manager is self".  However, any calling
+        # code that is manipulating __exit__ directly would break this
+        # assumption.  alembic context manager
+        # is an example of partial use that just calls __exit__ and
+        # not __enter__ at the moment.  it's safe to assume this is being done
+        # in the wild also
+        out_of_band_exit = (
+            subject is None or subject._trans_context_manager is not self
+        )
+
+        if type_ is None and self._transaction_is_active():
+            try:
+                self.commit()
+            except:
+                with util.safe_reraise():
+                    self.rollback()
+            finally:
+                if not out_of_band_exit:
+                    subject._trans_context_manager = self._outer_trans_ctx
+                self._trans_subject = self._outer_trans_ctx = None
+        else:
+            try:
+                if not self._transaction_is_active():
+                    if not self._transaction_is_closed():
+                        self.close()
+                else:
+                    self.rollback()
+            finally:
+                if not out_of_band_exit:
+                    subject._trans_context_manager = self._outer_trans_ctx
+                self._trans_subject = self._outer_trans_ctx = None
index fa8c5006ee26caceda36611ed65ea8bc347df9f8..d11b059fd852c3a04f07e2b27035889f9ba9a691 100644 (file)
@@ -5,14 +5,14 @@ from . import exc as async_exc
 
 class StartableContext(abc.ABC):
     @abc.abstractmethod
-    async def start(self) -> "StartableContext":
+    async def start(self, is_ctxmanager=False) -> "StartableContext":
         pass
 
     def __await__(self):
         return self.start().__await__()
 
     async def __aenter__(self):
-        return await self.start()
+        return await self.start(is_ctxmanager=True)
 
     @abc.abstractmethod
     async def __aexit__(self, type_, value, traceback):
index c637b3d9026b2745695a89af1b7828660b809c5f..17ddb614abd6de77526d8ee20c30cf1e8759ec98 100644 (file)
@@ -101,7 +101,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
         self.sync_engine = async_engine.sync_engine
         self.sync_connection = sync_connection
 
-    async def start(self):
+    async def start(self, is_ctxmanager=False):
         """Start this :class:`_asyncio.AsyncConnection` object's context
         outside of using a Python ``with:`` block.
 
@@ -518,19 +518,15 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
         def __init__(self, conn):
             self.conn = conn
 
-        async def start(self):
-            await self.conn.start()
+        async def start(self, is_ctxmanager=False):
+            await self.conn.start(is_ctxmanager=is_ctxmanager)
             self.transaction = self.conn.begin()
             await self.transaction.__aenter__()
 
             return self.conn
 
         async def __aexit__(self, type_, value, traceback):
-            if type_ is not None:
-                await self.transaction.rollback()
-            else:
-                if self.transaction.is_active:
-                    await self.transaction.commit()
+            await self.transaction.__aexit__(type_, value, traceback)
             await self.conn.close()
 
     def __init__(self, sync_engine: Engine):
@@ -678,7 +674,7 @@ class AsyncTransaction(ProxyComparable, StartableContext):
 
         await greenlet_spawn(self._sync_transaction().commit)
 
-    async def start(self):
+    async def start(self, is_ctxmanager=False):
         """Start this :class:`_asyncio.AsyncTransaction` object's context
         outside of using a Python ``with:`` block.
 
@@ -689,17 +685,14 @@ class AsyncTransaction(ProxyComparable, StartableContext):
             if self.nested
             else self.connection._sync_connection().begin
         )
+        if is_ctxmanager:
+            self.sync_transaction.__enter__()
         return self
 
     async def __aexit__(self, type_, value, traceback):
-        if type_ is None and self.is_active:
-            try:
-                await self.commit()
-            except:
-                with util.safe_reraise():
-                    await self.rollback()
-        else:
-            await self.rollback()
+        await greenlet_spawn(
+            self._sync_transaction().__exit__, type_, value, traceback
+        )
 
 
 def _get_sync_engine_or_connection(async_engine):
index d8a5673eb97894c92224cc4c0c6393e789081a17..1b61d6ee305294fb58bf58f619c23664849aa62f 100644 (file)
@@ -399,15 +399,17 @@ class AsyncSessionTransaction(StartableContext):
 
         await greenlet_spawn(self._sync_transaction().commit)
 
-    async def start(self):
+    async def start(self, is_ctxmanager=False):
         self.sync_transaction = await greenlet_spawn(
             self.session.sync_session.begin_nested
             if self.nested
             else self.session.sync_session.begin
         )
+        if is_ctxmanager:
+            self.sync_transaction.__enter__()
         return self
 
     async def __aexit__(self, type_, value, traceback):
-        return await greenlet_spawn(
+        await greenlet_spawn(
             self._sync_transaction().__exit__, type_, value, traceback
         )
index cee17a432d649e85b24e7c81de7673af2e819b56..ab890ca4f4c283ded2d5e582ccb33df09b1364bf 100644 (file)
@@ -359,16 +359,12 @@ class Engine(_LegacyEngine):
 
         def __enter__(self):
             self.transaction = self.conn.begin()
+            self.transaction.__enter__()
             return self.conn
 
         def __exit__(self, type_, value, traceback):
             try:
-                if type_ is not None:
-                    if self.transaction.is_active:
-                        self.transaction.rollback()
-                else:
-                    if self.transaction.is_active:
-                        self.transaction.commit()
+                self.transaction.__exit__(type_, value, traceback)
             finally:
                 self.conn.close()
 
index a3ec360d00e0f4585bf8b8da1a4cc9b14a7a9e6a..cdf3a158565b1310e9715cc61495c2a340298d34 100644 (file)
@@ -31,6 +31,7 @@ from .. import engine
 from .. import exc as sa_exc
 from .. import sql
 from .. import util
+from ..engine.util import TransactionalContext
 from ..inspection import inspect
 from ..sql import coercions
 from ..sql import dml
@@ -475,7 +476,7 @@ class ORMExecuteState(util.MemoizedSlots):
         ]
 
 
-class SessionTransaction(object):
+class SessionTransaction(TransactionalContext):
     """A :class:`.Session`-level transaction.
 
     :class:`.SessionTransaction` is produced from the
@@ -523,6 +524,8 @@ class SessionTransaction(object):
         nested=False,
         autobegin=False,
     ):
+        TransactionalContext._trans_ctx_check(session)
+
         self.session = session
         self._connections = {}
         self._parent = parent
@@ -927,21 +930,14 @@ class SessionTransaction(object):
         self.session = None
         self._connections = None
 
-    def __enter__(self):
-        return self
+    def _get_subject(self):
+        return self.session
 
-    def __exit__(self, type_, value, traceback):
-        self._assert_active(deactive_ok=True, prepared_ok=True)
-        if self.session._transaction is None:
-            return
-        if type_ is None:
-            try:
-                self.commit()
-            except:
-                with util.safe_reraise():
-                    self.rollback()
-        else:
-            self.rollback()
+    def _transaction_is_active(self):
+        return self._state is ACTIVE
+
+    def _transaction_is_closed(self):
+        return self._state is CLOSED
 
 
 class Session(_SessionClassMethods):
@@ -1154,6 +1150,9 @@ class Session(_SessionClassMethods):
 
         _sessions[self.hash_key] = self
 
+    # used by sqlalchemy.engine.util.TransactionalContext
+    _trans_context_manager = None
+
     connection_callable = None
 
     def __enter__(self):
@@ -1252,6 +1251,7 @@ class Session(_SessionClassMethods):
 
     def _autobegin(self):
         if not self.autocommit and self._transaction is None:
+
             trans = SessionTransaction(self, autobegin=True)
             assert self._transaction is trans
             return True
@@ -1520,6 +1520,8 @@ class Session(_SessionClassMethods):
         )
 
     def _connection_for_bind(self, engine, execution_options=None, **kw):
+        TransactionalContext._trans_ctx_check(self)
+
         if self._transaction is not None or self._autobegin():
             return self._transaction._connection_for_bind(
                 engine, execution_options
index c3eb1b36392bc69c54cf578f562bec22b482ae16..581bc8becfee3e284d3b2f948f5cefc67a1501ec 100644 (file)
@@ -147,6 +147,150 @@ class TestBase(object):
         else:
             drop_all_tables_from_metadata(metadata, config.db)
 
+    @config.fixture(
+        params=[
+            (rollback, second_operation, begin_nested)
+            for rollback in (True, False)
+            for second_operation in ("none", "execute", "begin")
+            for begin_nested in (
+                True,
+                False,
+            )
+        ]
+    )
+    def trans_ctx_manager_fixture(self, request, metadata):
+        rollback, second_operation, begin_nested = request.param
+
+        from sqlalchemy import Table, Column, Integer, func, select
+        from . import eq_
+
+        t = Table("test", metadata, Column("data", Integer))
+        eng = getattr(self, "bind", None) or config.db
+
+        t.create(eng)
+
+        def run_test(subject, trans_on_subject, execute_on_subject):
+            with subject.begin() as trans:
+
+                if begin_nested:
+                    if not config.requirements.savepoints.enabled:
+                        config.skip_test("savepoints not enabled")
+                    if execute_on_subject:
+                        nested_trans = subject.begin_nested()
+                    else:
+                        nested_trans = trans.begin_nested()
+
+                    with nested_trans:
+                        if execute_on_subject:
+                            subject.execute(t.insert(), {"data": 10})
+                        else:
+                            trans.execute(t.insert(), {"data": 10})
+
+                        # for nested trans, we always commit/rollback on the
+                        # "nested trans" object itself.
+                        # only Session(future=False) will affect savepoint
+                        # transaction for session.commit/rollback
+
+                        if rollback:
+                            nested_trans.rollback()
+                        else:
+                            nested_trans.commit()
+
+                        if second_operation != "none":
+                            with assertions.expect_raises_message(
+                                sa.exc.InvalidRequestError,
+                                "Can't operate on closed transaction "
+                                "inside context "
+                                "manager.  Please complete the context "
+                                "manager "
+                                "before emitting further commands.",
+                            ):
+                                if second_operation == "execute":
+                                    if execute_on_subject:
+                                        subject.execute(
+                                            t.insert(), {"data": 12}
+                                        )
+                                    else:
+                                        trans.execute(t.insert(), {"data": 12})
+                                elif second_operation == "begin":
+                                    if execute_on_subject:
+                                        subject.begin_nested()
+                                    else:
+                                        trans.begin_nested()
+
+                    # outside the nested trans block, but still inside the
+                    # transaction block, we can run SQL, and it will be
+                    # committed
+                    if execute_on_subject:
+                        subject.execute(t.insert(), {"data": 14})
+                    else:
+                        trans.execute(t.insert(), {"data": 14})
+
+                else:
+                    if execute_on_subject:
+                        subject.execute(t.insert(), {"data": 10})
+                    else:
+                        trans.execute(t.insert(), {"data": 10})
+
+                    if trans_on_subject:
+                        if rollback:
+                            subject.rollback()
+                        else:
+                            subject.commit()
+                    else:
+                        if rollback:
+                            trans.rollback()
+                        else:
+                            trans.commit()
+
+                    if second_operation != "none":
+                        with assertions.expect_raises_message(
+                            sa.exc.InvalidRequestError,
+                            "Can't operate on closed transaction inside "
+                            "context "
+                            "manager.  Please complete the context manager "
+                            "before emitting further commands.",
+                        ):
+                            if second_operation == "execute":
+                                if execute_on_subject:
+                                    subject.execute(t.insert(), {"data": 12})
+                                else:
+                                    trans.execute(t.insert(), {"data": 12})
+                            elif second_operation == "begin":
+                                if hasattr(trans, "begin"):
+                                    trans.begin()
+                                else:
+                                    subject.begin()
+                            elif second_operation == "begin_nested":
+                                if execute_on_subject:
+                                    subject.begin_nested()
+                                else:
+                                    trans.begin_nested()
+
+            expected_committed = 0
+            if begin_nested:
+                # begin_nested variant, we inserted a row after the nested
+                # block
+                expected_committed += 1
+            if not rollback:
+                # not rollback variant, our row inserted in the target
+                # block itself would be committed
+                expected_committed += 1
+
+            if execute_on_subject:
+                eq_(
+                    subject.scalar(select(func.count()).select_from(t)),
+                    expected_committed,
+                )
+            else:
+                with subject.connect() as conn:
+                    eq_(
+                        conn.scalar(select(func.count()).select_from(t)),
+                        expected_committed,
+                    )
+
+        return run_test
+
 
 _connection_fixture_connection = None
 
index e26f305d940f2ce98a5bdd655554c22a4f3a6ceb..60db9cfff534d11a80bcaa7006f7fe8177d64d12 100644 (file)
@@ -24,6 +24,11 @@ if not have_greenlet:
     asyncio = None  # noqa F811
 
     def _not_implemented():
+        # this conditional is to prevent pylance from considering
+        # greenlet_spawn() etc as "no return" and dimming out code below it
+        if have_greenlet:
+            return None
+
         if not compat.py3k:
             raise ValueError("Cannot use this function in py2.")
         else:
index d78ff7beeb632553a8e3dfbe40d7bd80a18e54c7..b8e7edc6522e1ea74895d268053b7cc21eff31e3 100644 (file)
@@ -45,6 +45,30 @@ class TransactionTest(fixtures.TablesTest):
         with testing.db.connect() as conn:
             yield conn
 
+    def test_interrupt_ctxmanager_engine(self, trans_ctx_manager_fixture):
+        fn = trans_ctx_manager_fixture
+
+        # add commit/rollback to the legacy Connection object so that
+        # we can test this less-likely case in use with the legacy
+        # Engine.begin() context manager
+        class ConnWCommitRollback(testing.db._connection_cls):
+            def commit(self):
+                self.get_transaction().commit()
+
+            def rollback(self):
+                self.get_transaction().rollback()
+
+        with mock.patch.object(
+            testing.db, "_connection_cls", ConnWCommitRollback
+        ):
+            fn(testing.db, trans_on_subject=False, execute_on_subject=False)
+
+    def test_interrupt_ctxmanager_connection(self, trans_ctx_manager_fixture):
+        fn = trans_ctx_manager_fixture
+
+        with testing.db.connect() as conn:
+            fn(conn, trans_on_subject=False, execute_on_subject=True)
+
     def test_commits(self, local_connection):
         users = self.tables.users
         connection = local_connection
@@ -111,8 +135,15 @@ class TransactionTest(fixtures.TablesTest):
             trans.rollback()
             assert not local_connection.in_transaction()
 
-            # would be subject to autocommit
-            local_connection.execute(select(1))
+            # previously, would be subject to autocommit.
+            # now it raises
+            with expect_raises_message(
+                exc.InvalidRequestError,
+                "Can't operate on closed transaction inside context manager.  "
+                "Please complete the context manager before emitting "
+                "further commands.",
+            ):
+                local_connection.execute(select(1))
 
             assert not local_connection.in_transaction()
 
@@ -400,6 +431,7 @@ class TransactionTest(fixtures.TablesTest):
         connection = local_connection
         users = self.tables.users
         trans = connection.begin()
+        trans.__enter__()
         connection.execute(users.insert(), dict(user_id=1, user_name="user1"))
         connection.execute(users.insert(), dict(user_id=2, user_name="user2"))
         try:
@@ -418,6 +450,7 @@ class TransactionTest(fixtures.TablesTest):
         )
 
         trans = connection.begin()
+        trans.__enter__()
         connection.execute(users.insert(), dict(user_id=1, user_name="user1"))
         trans.__exit__(None, None, None)
         assert not trans.is_active
@@ -1487,6 +1520,24 @@ class FutureTransactionTest(fixtures.FutureEngineMixin, fixtures.TablesTest):
         with testing.db.connect() as conn:
             yield conn
 
+    def test_interrupt_ctxmanager_engine(self, trans_ctx_manager_fixture):
+        fn = trans_ctx_manager_fixture
+
+        fn(testing.db, trans_on_subject=False, execute_on_subject=False)
+
+    @testing.combinations((True,), (False,), argnames="trans_on_subject")
+    def test_interrupt_ctxmanager_connection(
+        self, trans_ctx_manager_fixture, trans_on_subject
+    ):
+        fn = trans_ctx_manager_fixture
+
+        with testing.db.connect() as conn:
+            fn(
+                conn,
+                trans_on_subject=trans_on_subject,
+                execute_on_subject=True,
+            )
+
     def test_autobegin_rollback(self):
         users = self.tables.users
         with testing.db.connect() as conn:
@@ -1683,10 +1734,17 @@ class FutureTransactionTest(fixtures.FutureEngineMixin, fixtures.TablesTest):
             trans.rollback()
             assert not local_connection.in_transaction()
 
-            # autobegin
-            local_connection.execute(select(1))
+            # previously, would be subject to autocommit.
+            # now it raises
+            with expect_raises_message(
+                exc.InvalidRequestError,
+                "Can't operate on closed transaction inside context manager.  "
+                "Please complete the context manager before emitting "
+                "further commands.",
+            ):
+                local_connection.execute(select(1))
 
-            assert local_connection.in_transaction()
+            assert not local_connection.in_transaction()
 
     @testing.combinations((True,), (False,), argnames="roll_back_in_block")
     def test_ctxmanager_rolls_back(self, local_connection, roll_back_in_block):
index 820c82bca6d1ff7d016e6f5822a7bb35a04a0e69..18e55ff92c7286f21f9ec5123ada81ae3a858bb0 100644 (file)
@@ -17,8 +17,10 @@ from sqlalchemy.ext.asyncio import create_async_engine
 from sqlalchemy.ext.asyncio import engine as _async_engine
 from sqlalchemy.ext.asyncio import exc as asyncio_exc
 from sqlalchemy.pool import AsyncAdaptedQueuePool
+from sqlalchemy.testing import assertions
 from sqlalchemy.testing import async_test
 from sqlalchemy.testing import combinations
+from sqlalchemy.testing import config
 from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_raises
@@ -34,7 +36,133 @@ from sqlalchemy.testing import ne_
 from sqlalchemy.util.concurrency import greenlet_spawn
 
 
-class EngineFixture(fixtures.TablesTest):
+class AsyncFixture:
+    @config.fixture(
+        params=[
+            (rollback, run_second_execute, begin_nested)
+            for rollback in (True, False)
+            for run_second_execute in (True, False)
+            for begin_nested in (True, False)
+        ]
+    )
+    def async_trans_ctx_manager_fixture(self, request, metadata):
+        rollback, run_second_execute, begin_nested = request.param
+
+        from sqlalchemy import Table, Column, Integer, func, select
+
+        t = Table("test", metadata, Column("data", Integer))
+        eng = getattr(self, "bind", None) or config.db
+
+        t.create(eng)
+
+        async def run_test(subject, trans_on_subject, execute_on_subject):
+            async with subject.begin() as trans:
+
+                if begin_nested:
+                    if not config.requirements.savepoints.enabled:
+                        config.skip_test("savepoints not enabled")
+                    if execute_on_subject:
+                        nested_trans = subject.begin_nested()
+                    else:
+                        nested_trans = trans.begin_nested()
+
+                    async with nested_trans:
+                        if execute_on_subject:
+                            await subject.execute(t.insert(), {"data": 10})
+                        else:
+                            await trans.execute(t.insert(), {"data": 10})
+
+                        # for nested trans, we always commit/rollback on the
+                        # "nested trans" object itself.
+                        # only Session(future=False) will affect savepoint
+                        # transaction for session.commit/rollback
+
+                        if rollback:
+                            await nested_trans.rollback()
+                        else:
+                            await nested_trans.commit()
+
+                        if run_second_execute:
+                            with assertions.expect_raises_message(
+                                exc.InvalidRequestError,
+                                "Can't operate on closed transaction "
+                                "inside context manager.  Please complete the "
+                                "context manager "
+                                "before emitting further commands.",
+                            ):
+                                if execute_on_subject:
+                                    await subject.execute(
+                                        t.insert(), {"data": 12}
+                                    )
+                                else:
+                                    await trans.execute(
+                                        t.insert(), {"data": 12}
+                                    )
+
+                    # outside the nested trans block, but still inside the
+                    # transaction block, we can run SQL, and it will be
+                    # committed
+                    if execute_on_subject:
+                        await subject.execute(t.insert(), {"data": 14})
+                    else:
+                        await trans.execute(t.insert(), {"data": 14})
+
+                else:
+                    if execute_on_subject:
+                        await subject.execute(t.insert(), {"data": 10})
+                    else:
+                        await trans.execute(t.insert(), {"data": 10})
+
+                    if trans_on_subject:
+                        if rollback:
+                            await subject.rollback()
+                        else:
+                            await subject.commit()
+                    else:
+                        if rollback:
+                            await trans.rollback()
+                        else:
+                            await trans.commit()
+
+                    if run_second_execute:
+                        with assertions.expect_raises_message(
+                            exc.InvalidRequestError,
+                            "Can't operate on closed transaction inside "
+                            "context "
+                            "manager.  Please complete the context manager "
+                            "before emitting further commands.",
+                        ):
+                            if execute_on_subject:
+                                await subject.execute(t.insert(), {"data": 12})
+                            else:
+                                await trans.execute(t.insert(), {"data": 12})
+
+            expected_committed = 0
+            if begin_nested:
+                # begin_nested variant, we inserted a row after the nested
+                # block
+                expected_committed += 1
+            if not rollback:
+                # not rollback variant, our row inserted in the target
+                # block itself would be committed
+                expected_committed += 1
+
+            if execute_on_subject:
+                eq_(
+                    await subject.scalar(select(func.count()).select_from(t)),
+                    expected_committed,
+                )
+            else:
+                with subject.connect() as conn:
+                    eq_(
+                        await conn.scalar(select(func.count()).select_from(t)),
+                        expected_committed,
+                    )
+
+        return run_test
+
+
+class EngineFixture(AsyncFixture, fixtures.TablesTest):
     __requires__ = ("async_dialect",)
 
     @testing.fixture
@@ -68,6 +196,15 @@ class AsyncEngineTest(EngineFixture):
         async with async_engine.connect() as conn:
             eq_(await conn.scalar(text("select 1")), 2)
 
+    @async_test
+    async def test_interrupt_ctxmanager_connection(
+        self, async_engine, async_trans_ctx_manager_fixture
+    ):
+        fn = async_trans_ctx_manager_fixture
+
+        async with async_engine.connect() as conn:
+            await fn(conn, trans_on_subject=False, execute_on_subject=True)
+
     def test_proxied_attrs_engine(self, async_engine):
         sync_engine = async_engine.sync_engine
 
index feb5574711fd1902088c4a366c3a280378b827a4..e97e2563ab33eccdeacac6480c4a423437d11c53 100644 (file)
@@ -14,10 +14,11 @@ from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import mock
+from .test_engine_py3k import AsyncFixture as _AsyncFixture
 from ...orm import _fixtures
 
 
-class AsyncFixture(_fixtures.FixtureTest):
+class AsyncFixture(_AsyncFixture, _fixtures.FixtureTest):
     __requires__ = ("async_dialect",)
 
     @classmethod
@@ -123,6 +124,14 @@ class AsyncSessionQueryTest(AsyncFixture):
 class AsyncSessionTransactionTest(AsyncFixture):
     run_inserts = None
 
+    @async_test
+    async def test_interrupt_ctxmanager_connection(
+        self, async_trans_ctx_manager_fixture, async_session
+    ):
+        fn = async_trans_ctx_manager_fixture
+
+        await fn(async_session, trans_on_subject=True, execute_on_subject=True)
+
     @async_test
     async def test_sessionmaker_block_one(self, async_engine):
 
index 0e49ff2c349e64e5e4e8e6b196f81a7fae5bcdcf..fbd89616afd3aacc58f31fedf1c8241d49363745 100644 (file)
@@ -27,6 +27,7 @@ from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import assert_warnings
 from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
@@ -2409,6 +2410,67 @@ class ContextManagerPlusFutureTest(FixtureTest):
         eq_(sess.connection().execute(users.select()).all(), [(1, "user1")])
         sess.close()
 
+    @testing.combinations((True,), (False,), argnames="future")
+    def test_interrupt_ctxmanager(self, trans_ctx_manager_fixture, future):
+        fn = trans_ctx_manager_fixture
+
+        session = fixture_session(future=future)
+
+        fn(session, trans_on_subject=True, execute_on_subject=True)
+
+    @testing.combinations((True,), (False,), argnames="future")
+    @testing.combinations((True,), (False,), argnames="rollback")
+    @testing.combinations((True,), (False,), argnames="expire_on_commit")
+    @testing.combinations(
+        ("add",),
+        ("modify",),
+        ("delete",),
+        ("begin",),
+        argnames="check_operation",
+    )
+    def test_interrupt_ctxmanager_ops(
+        self, future, rollback, expire_on_commit, check_operation
+    ):
+        users, User = self.tables.users, self.classes.User
+
+        mapper(User, users)
+
+        session = fixture_session(
+            future=future, expire_on_commit=expire_on_commit
+        )
+
+        with session.begin():
+            u1 = User(id=7, name="u1")
+            session.add(u1)
+
+        with session.begin():
+            u1.name  # unexpire
+            u2 = User(id=8, name="u1")
+            session.add(u2)
+
+            session.flush()
+
+            if rollback:
+                session.rollback()
+            else:
+                session.commit()
+
+            with expect_raises_message(
+                sa_exc.InvalidRequestError,
+                "Can't operate on closed transaction "
+                "inside context manager.  Please complete the context "
+                "manager before emitting further commands.",
+            ):
+                if check_operation == "add":
+                    u3 = User(id=9, name="u2")
+                    session.add(u3)
+                elif check_operation == "begin":
+                    session.begin()
+                elif check_operation == "modify":
+                    u1.name = "newname"
+                elif check_operation == "delete":
+                    session.delete(u1)
+
 
 class TransactionFlagsTest(fixtures.TestBase):
     def test_in_transaction(self):
@@ -2730,32 +2792,6 @@ class JoinIntoAnExternalTransactionFixture(object):
 
         self._assert_count(1)
 
-    @testing.requires.savepoints
-    def test_something_with_rollback(self):
-        A = self.A
-
-        a1 = A()
-        self.session.add(a1)
-        self.session.flush()
-
-        self._assert_count(1)
-        self.session.rollback()
-        self._assert_count(0)
-
-        a1 = A()
-        self.session.add(a1)
-        self.session.commit()
-        self._assert_count(1)
-
-        a2 = A()
-
-        self.session.add(a2)
-        self.session.flush()
-        self._assert_count(2)
-
-        self.session.rollback()
-        self._assert_count(1)
-
     def _assert_count(self, count):
         result = self.connection.scalar(
             select(func.count()).select_from(self.table)
@@ -2801,6 +2837,37 @@ class NewStyleJoinIntoAnExternalTransactionTest(
         if self.trans.is_active:
             self.trans.rollback()
 
+    @testing.requires.savepoints
+    def test_something_with_context_managers(self):
+        A = self.A
+
+        a1 = A()
+
+        with self.session.begin():
+            self.session.add(a1)
+            self.session.flush()
+
+            self._assert_count(1)
+            self.session.rollback()
+
+        self._assert_count(0)
+
+        a1 = A()
+        with self.session.begin():
+            self.session.add(a1)
+
+        self._assert_count(1)
+
+        a2 = A()
+
+        with self.session.begin():
+            self.session.add(a2)
+            self.session.flush()
+            self._assert_count(2)
+
+            self.session.rollback()
+        self._assert_count(1)
+
 
 class FutureJoinIntoAnExternalTransactionTest(
     NewStyleJoinIntoAnExternalTransactionTest,