]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Update transaction / connection handling
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 14 May 2020 16:50:11 +0000 (12:50 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 May 2020 20:32:22 +0000 (16:32 -0400)
step one, do away with __connection attribute and using
awkward AttributeError logic

step two, move all management of "connection._transaction"
into the transaction objects themselves where it's easier
to follow.

build MarkerTransaction that takes the role of
"do-nothing block"

new connection datamodel is: connection._transaction, always
a root, connection._nested_transaction, always a nested.

nested transactions still chain to each other as this
is still sort of necessary but they consider the root
transaction separately, and the marker transactions
not at all.

introduce new InvalidRequestError subclass
PendingRollbackError.  Apply to connection and session
for all cases where a transaction needs to be rolled
back before continuing.   Within Connection,
both PendingRollbackError as well as ResourceClosedError
are now raised directly without being handled by
handle_dbapi_error();  this removes these two exception
cases from the handle_error event handler as well as
from StatementError wrapping, as these two exceptions are
not statement oriented and are instead programmatic
issues, that the application is failing to handle database
errors properly.

Revise savepoints so that when a release fails, they set
themselves as inactive so that their rollback() method
does not throw another exception.

Give savepoints another go on MySQL, can't get release working
however get support for basic round trip going

Fixes: #5327
Change-Id: Ia3cbbf56d4882fcc7980f90519412f1711fae74d

13 files changed:
doc/build/errors.rst
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/cursor.py
lib/sqlalchemy/engine/events.py
lib/sqlalchemy/exc.py
lib/sqlalchemy/future/engine.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/testing/assertions.py
lib/sqlalchemy/testing/plugin/plugin_base.py
test/engine/test_reconnect.py
test/engine/test_transaction.py
test/orm/test_transaction.py
test/requirements.py

index 599b91e2630b2cd99d1531f9de5c74d2f489e7ec..9af96f9b9a1c8812924188074215b7dc5491ae07 100644 (file)
@@ -225,47 +225,69 @@ sooner.
  :ref:`connections_toplevel`
 
 
+.. _error_8s2b:
+
+Can't reconnect until invalid transaction is rolled back
+----------------------------------------------------------
+
+This error condition refers to the case where a :class:`_engine.Connection` was
+invalidated, either due to a database disconnect detection or due to an
+explicit call to :meth:`_engine.Connection.invalidate`, but there is still a
+transaction present that was initiated by the :meth:`_engine.Connection.begin`
+method.  When a connection is invalidated, any :class:`_engine.Transaction`
+that was in progress is now in an invalid state, and must be explicitly rolled
+back in order to remove it from the :class:`_engine.Connection`.
+
 .. _error_8s2a:
 
 This connection is on an inactive transaction.  Please rollback() fully before proceeding
 ------------------------------------------------------------------------------------------
 
 This error condition was added to SQLAlchemy as of version 1.4.    The error
-refers to the state where a :class:`_engine.Connection` is placed into a transaction
-using a method like :meth:`_engine.Connection.begin`, and then a further "sub" transaction
-is created within that scope; the "sub" transaction is then rolled back using
-:meth:`.Transaction.rollback`, however the outer transaction is not rolled back.
+refers to the state where a :class:`_engine.Connection` is placed into a
+transaction using a method like :meth:`_engine.Connection.begin`, and then a
+further "marker" transaction is created within that scope; the "marker"
+transaction is then rolled back using :meth:`.Transaction.rollback` or closed
+using :meth:`.Transaction.close`, however the outer transaction is still
+present in an "inactive" state and must be rolled back.
 
 The pattern looks like::
 
     engine = create_engine(...)
 
     connection = engine.connect()
-    transaction = connection.begin()
+    transaction1 = connection.begin()
 
+    # this is a "sub" or "marker" transaction, a logical nesting
+    # structure based on "real" transaction transaction1
     transaction2 = connection.begin()
     transaction2.rollback()
 
-    connection.execute(text("select 1"))  # we are rolled back; will now raise
+    # transaction1 is still present and needs explicit rollback,
+    # so this will raise
+    connection.execute(text("select 1"))
 
-    transaction.rollback()
+Above, ``transaction2`` is a "marker" transaction, which indicates a logical
+nesting of transactions within an outer one; while the inner transaction
+can roll back the whole transaction via its rollback() method, its commit()
+method has no effect except to close the scope of the "marker" transaction
+itself.   The call to ``transaction2.rollback()`` has the effect of
+**deactivating** transaction1 which means it is essentially rolled back
+at the database level, however is still present in order to accommodate
+a consistent nesting pattern of transactions.
 
+The correct resolution is to ensure the outer transaction is also
+rolled back::
 
-Above, ``transaction2`` is a "sub" transaction, which indicates a logical
-nesting of transactions within an outer one.   SQLAlchemy makes great use of
-this pattern more commonly in the ORM :class:`.Session`, where the FAQ entry
-:ref:`faq_session_rollback` describes the rationale within the ORM.
+    transaction1.rollback()
 
-The "subtransaction" pattern in Core comes into play often when using the ORM
-pattern described at :ref:`session_external_transaction`.   As this pattern
-involves a behavior called "connection branching", where a :class:`_engine.Connection`
-serves a "branched" :class:`_engine.Connection` object to the :class:`.Session` via
-its :meth:`_engine.Connection.connect` method, the same transaction behavior comes
-into play; if the :class:`.Session` rolls back the transaction, and savepoints
-have not been used to prevent a rollback of the entire transaction, the
-outermost transaction started on the :class:`_engine.Connection` is now in an inactive
-state.
+This pattern is not commonly used in Core.  Within the ORM, a similar issue can
+occur which is the product of the ORM's "logical" transaction structure; this
+is described in the FAQ entry at :ref:`faq_session_rollback`.
 
+The "subtransaction" pattern is to be removed in SQLAlchemy 2.0 so that this
+particular programming pattern will no longer be available and this
+error message will no longer occur in Core.
 
 .. _error_dbapi:
 
index e617f0fadfb03325c99f8cdce21a4f3eb54e2b4f..f169655e09dbd40e3925c230c3a0a5512c703ec0 100644 (file)
@@ -72,10 +72,11 @@ class Connection(Connectable):
         self.engine = engine
         self.dialect = engine.dialect
         self.__branch_from = _branch_from
-        self.__branch = _branch_from is not None
 
         if _branch_from:
-            self.__connection = connection
+            # branching is always "from" the root connection
+            assert _branch_from.__branch_from is None
+            self._dbapi_connection = connection
             self._execution_options = _execution_options
             self._echo = _branch_from._echo
             self.should_close_with_result = False
@@ -83,16 +84,16 @@ class Connection(Connectable):
             self._has_events = _branch_from._has_events
             self._schema_translate_map = _branch_from._schema_translate_map
         else:
-            self.__connection = (
+            self._dbapi_connection = (
                 connection
                 if connection is not None
                 else engine.raw_connection()
             )
-            self._transaction = None
+            self._transaction = self._nested_transaction = None
             self.__savepoint_seq = 0
+            self.__in_begin = False
             self.should_close_with_result = close_with_result
 
-            self.__invalid = False
             self.__can_reconnect = True
             self._echo = self.engine._should_log_info()
 
@@ -109,7 +110,7 @@ class Connection(Connectable):
             self._execution_options = engine._execution_options
 
         if self._has_events or self.engine._has_events:
-            self.dispatch.engine_connect(self, self.__branch)
+            self.dispatch.engine_connect(self, _branch_from is not None)
 
     def schema_for_object(self, obj):
         """return the schema name for the given schema item taking into
@@ -134,6 +135,10 @@ class Connection(Connectable):
         engine and connection; but does not have close_with_result enabled,
         and also whose close() method does nothing.
 
+        .. deprecated:: 1.4 the "branching" concept will be removed in
+           SQLAlchemy 2.0 as well as the "Connection.connect()" method which
+           is the only consumer for this.
+
         The Core uses this very sparingly, only in the case of
         custom SQL default functions that are to be INSERTed as the
         primary key of a row where we need to get the value back, so we have
@@ -145,31 +150,14 @@ class Connection(Connectable):
         connected when a close() event occurs.
 
         """
-        if self.__branch_from:
-            return self.__branch_from._branch()
-        else:
-            return self.engine._connection_cls(
-                self.engine,
-                self.__connection,
-                _branch_from=self,
-                _execution_options=self._execution_options,
-                _has_events=self._has_events,
-                _dispatch=self.dispatch,
-            )
-
-    @property
-    def _root(self):
-        """return the 'root' connection.
-
-        Returns 'self' if this connection is not a branch, else
-        returns the root connection from which we ultimately branched.
-
-        """
-
-        if self.__branch_from:
-            return self.__branch_from
-        else:
-            return self
+        return self.engine._connection_cls(
+            self.engine,
+            self._dbapi_connection,
+            _branch_from=self.__branch_from if self.__branch_from else self,
+            _execution_options=self._execution_options,
+            _has_events=self._has_events,
+            _dispatch=self.dispatch,
+        )
 
     def _generate_for_options(self):
         """define connection method chaining behavior for execution_options"""
@@ -367,16 +355,28 @@ class Connection(Connectable):
     def closed(self):
         """Return True if this connection is closed."""
 
-        return (
-            "_Connection__connection" not in self.__dict__
-            and not self.__can_reconnect
-        )
+        # note this is independent for a "branched" connection vs.
+        # the base
+
+        return self._dbapi_connection is None and not self.__can_reconnect
 
     @property
     def invalidated(self):
         """Return True if this connection was invalidated."""
 
-        return self._root.__invalid
+        # prior to 1.4, "invalid" was stored as a state independent of
+        # "closed", meaning an invalidated connection could be "closed",
+        # the _dbapi_connection would be None and closed=True, yet the
+        # "invalid" flag would stay True.  This meant that there were
+        # three separate states (open/valid, closed/valid, closed/invalid)
+        # when there is really no reason for that; a connection that's
+        # "closed" does not need to be "invalid".  So the state is now
+        # represented by the two facts alone.
+
+        if self.__branch_from:
+            return self.__branch_from.invalidated
+
+        return self._dbapi_connection is None and not self.closed
 
     @property
     def connection(self):
@@ -389,16 +389,15 @@ class Connection(Connectable):
 
         """
 
-        try:
-            return self.__connection
-        except AttributeError:
-            # escape "except AttributeError" before revalidating
-            # to prevent misleading stacktraces in Py3K
-            pass
-        try:
-            return self._revalidate_connection()
-        except BaseException as e:
-            self._handle_dbapi_exception(e, None, None, None, None)
+        if self._dbapi_connection is None:
+            try:
+                return self._revalidate_connection()
+            except (exc.PendingRollbackError, exc.ResourceClosedError):
+                raise
+            except BaseException as e:
+                self._handle_dbapi_exception(e, None, None, None, None)
+        else:
+            return self._dbapi_connection
 
     def get_isolation_level(self):
         """Return the current isolation level assigned to this
@@ -470,34 +469,46 @@ class Connection(Connectable):
         """
         return self.dialect.default_isolation_level
 
+    def _invalid_transaction(self):
+        if self.invalidated:
+            raise exc.PendingRollbackError(
+                "Can't reconnect until invalid %stransaction is rolled "
+                "back."
+                % (
+                    "savepoint "
+                    if self._nested_transaction is not None
+                    else ""
+                ),
+                code="8s2b",
+            )
+        else:
+            raise exc.PendingRollbackError(
+                "This connection is on an inactive %stransaction.  "
+                "Please rollback() fully before proceeding."
+                % (
+                    "savepoint "
+                    if self._nested_transaction is not None
+                    else ""
+                ),
+                code="8s2a",
+            )
+
     def _revalidate_connection(self):
         if self.__branch_from:
             return self.__branch_from._revalidate_connection()
-        if self.__can_reconnect and self.__invalid:
+        if self.__can_reconnect and self.invalidated:
             if self._transaction is not None:
-                raise exc.InvalidRequestError(
-                    "Can't reconnect until invalid "
-                    "transaction is rolled back"
-                )
-            self.__connection = self.engine.raw_connection(_connection=self)
-            self.__invalid = False
-            return self.__connection
+                self._invalid_transaction()
+            self._dbapi_connection = self.engine.raw_connection(
+                _connection=self
+            )
+            return self._dbapi_connection
         raise exc.ResourceClosedError("This Connection is closed")
 
     @property
-    def _connection_is_valid(self):
-        # use getattr() for is_valid to support exceptions raised in
-        # dialect initializer, where the connection is not wrapped in
-        # _ConnectionFairy
-
-        return getattr(self.__connection, "is_valid", False)
-
-    @property
-    def _still_open_and_connection_is_valid(self):
-        return (
-            not self.closed
-            and not self.invalidated
-            and getattr(self.__connection, "is_valid", False)
+    def _still_open_and_dbapi_connection_is_valid(self):
+        return self._dbapi_connection is not None and getattr(
+            self._dbapi_connection, "is_valid", False
         )
 
     @property
@@ -571,16 +582,18 @@ class Connection(Connectable):
 
         """
 
+        if self.__branch_from:
+            return self.__branch_from.invalidate(exception=exception)
+
         if self.invalidated:
             return
 
         if self.closed:
             raise exc.ResourceClosedError("This Connection is closed")
 
-        if self._root._connection_is_valid:
-            self._root.__connection.invalidate(exception)
-        del self._root.__connection
-        self._root.__invalid = True
+        if self._still_open_and_dbapi_connection_is_valid:
+            self._dbapi_connection.invalidate(exception)
+        self._dbapi_connection = None
 
     def detach(self):
         """Detach the underlying DB-API connection from its connection pool.
@@ -608,7 +621,7 @@ class Connection(Connectable):
 
         """
 
-        self.__connection.detach()
+        self._dbapi_connection.detach()
 
     def begin(self):
         """Begin a transaction and return a transaction handle.
@@ -650,7 +663,14 @@ class Connection(Connectable):
         elif self.__branch_from:
             return self.__branch_from.begin()
 
-        if self._transaction is None:
+        if self.__in_begin:
+            # for dialects that emit SQL within the process of
+            # dialect.do_begin() or dialect.do_begin_twophase(), this
+            # flag prevents "autobegin" from being emitted within that
+            # process, while allowing self._transaction to remain at None
+            # until it's complete.
+            return
+        elif self._transaction is None:
             self._transaction = RootTransaction(self)
             return self._transaction
         else:
@@ -659,7 +679,7 @@ class Connection(Connectable):
                     "a transaction is already begun for this connection"
                 )
             else:
-                return Transaction(self, self._transaction)
+                return MarkerTransaction(self)
 
     def begin_nested(self):
         """Begin a nested transaction and return a transaction handle.
@@ -685,17 +705,9 @@ class Connection(Connectable):
             return self.__branch_from.begin_nested()
 
         if self._transaction is None:
-            if self._is_future:
-                self._autobegin()
-            else:
-                self._transaction = RootTransaction(self)
-                self.connection._reset_agent = self._transaction
-                return self._transaction
+            self.begin()
 
-        trans = NestedTransaction(self, self._transaction)
-        if not self._is_future:
-            self._transaction = trans
-        return trans
+        return NestedTransaction(self)
 
     def begin_twophase(self, xid=None):
         """Begin a two-phase or XA transaction and return a transaction
@@ -727,8 +739,7 @@ class Connection(Connectable):
             )
         if xid is None:
             xid = self.engine.dialect.create_xid()
-        self._transaction = TwoPhaseTransaction(self, xid)
-        return self._transaction
+        return TwoPhaseTransaction(self, xid)
 
     def recover_twophase(self):
         return self.engine.dialect.do_recover_twophase(self)
@@ -741,10 +752,10 @@ class Connection(Connectable):
 
     def in_transaction(self):
         """Return True if a transaction is in progress."""
-        return (
-            self._root._transaction is not None
-            and self._root._transaction.is_active
-        )
+        if self.__branch_from is not None:
+            return self.__branch_from.in_transaction()
+
+        return self._transaction is not None and self._transaction.is_active
 
     def _begin_impl(self, transaction):
         assert not self.__branch_from
@@ -755,32 +766,27 @@ class Connection(Connectable):
         if self._has_events or self.engine._has_events:
             self.dispatch.begin(self)
 
+        self.__in_begin = True
         try:
             self.engine.dialect.do_begin(self.connection)
-            if not self._is_future and self.connection._reset_agent is None:
-                self.connection._reset_agent = transaction
         except BaseException as e:
             self._handle_dbapi_exception(e, None, None, None, None)
+        finally:
+            self.__in_begin = False
 
-    def _rollback_impl(self, deactivate_only=False):
+    def _rollback_impl(self):
         assert not self.__branch_from
 
         if self._has_events or self.engine._has_events:
             self.dispatch.rollback(self)
 
-        if self._still_open_and_connection_is_valid:
+        if self._still_open_and_dbapi_connection_is_valid:
             if self._echo:
                 self.engine.logger.info("ROLLBACK")
             try:
                 self.engine.dialect.do_rollback(self.connection)
             except BaseException as e:
                 self._handle_dbapi_exception(e, None, None, None, None)
-            finally:
-                if (
-                    not self.__invalid
-                    and self.connection._reset_agent is self._transaction
-                ):
-                    self.connection._reset_agent = None
 
     def _commit_impl(self, autocommit=False):
         assert not self.__branch_from
@@ -794,13 +800,6 @@ class Connection(Connectable):
             self.engine.dialect.do_commit(self.connection)
         except BaseException as e:
             self._handle_dbapi_exception(e, None, None, None, None)
-        finally:
-            if (
-                not self.__invalid
-                and self.connection._reset_agent is self._transaction
-            ):
-                self.connection._reset_agent = None
-            self._transaction = None
 
     def _savepoint_impl(self, name=None):
         assert not self.__branch_from
@@ -811,44 +810,27 @@ class Connection(Connectable):
         if name is None:
             self.__savepoint_seq += 1
             name = "sa_savepoint_%s" % self.__savepoint_seq
-        if self._still_open_and_connection_is_valid:
+        if self._still_open_and_dbapi_connection_is_valid:
             self.engine.dialect.do_savepoint(self, name)
             return name
 
-    def _discard_transaction(self, trans):
-        if trans is self._transaction:
-            if trans._is_root:
-                assert trans._parent is trans
-                self._transaction = None
-
-            else:
-                assert trans._parent is not trans
-                self._transaction = trans._parent
-
-        if not self._is_future and self._still_open_and_connection_is_valid:
-            if self.__connection._reset_agent is trans:
-                self.__connection._reset_agent = None
-
-    def _rollback_to_savepoint_impl(
-        self, name, context, deactivate_only=False
-    ):
+    def _rollback_to_savepoint_impl(self, name):
         assert not self.__branch_from
 
         if self._has_events or self.engine._has_events:
-            self.dispatch.rollback_savepoint(self, name, context)
+            self.dispatch.rollback_savepoint(self, name, None)
 
-        if self._still_open_and_connection_is_valid:
+        if self._still_open_and_dbapi_connection_is_valid:
             self.engine.dialect.do_rollback_to_savepoint(self, name)
 
-    def _release_savepoint_impl(self, name, context):
+    def _release_savepoint_impl(self, name):
         assert not self.__branch_from
 
         if self._has_events or self.engine._has_events:
-            self.dispatch.release_savepoint(self, name, context)
+            self.dispatch.release_savepoint(self, name, None)
 
-        if self._still_open_and_connection_is_valid:
+        if self._still_open_and_dbapi_connection_is_valid:
             self.engine.dialect.do_release_savepoint(self, name)
-        self._transaction = context
 
     def _begin_twophase_impl(self, transaction):
         assert not self.__branch_from
@@ -858,11 +840,14 @@ class Connection(Connectable):
         if self._has_events or self.engine._has_events:
             self.dispatch.begin_twophase(self, transaction.xid)
 
-        if self._still_open_and_connection_is_valid:
-            self.engine.dialect.do_begin_twophase(self, transaction.xid)
-
-            if not self._is_future and self.connection._reset_agent is None:
-                self.connection._reset_agent = transaction
+        if self._still_open_and_dbapi_connection_is_valid:
+            self.__in_begin = True
+            try:
+                self.engine.dialect.do_begin_twophase(self, transaction.xid)
+            except BaseException as e:
+                self._handle_dbapi_exception(e, None, None, None, None)
+            finally:
+                self.__in_begin = False
 
     def _prepare_twophase_impl(self, xid):
         assert not self.__branch_from
@@ -870,9 +855,12 @@ class Connection(Connectable):
         if self._has_events or self.engine._has_events:
             self.dispatch.prepare_twophase(self, xid)
 
-        if self._still_open_and_connection_is_valid:
+        if self._still_open_and_dbapi_connection_is_valid:
             assert isinstance(self._transaction, TwoPhaseTransaction)
-            self.engine.dialect.do_prepare_twophase(self, xid)
+            try:
+                self.engine.dialect.do_prepare_twophase(self, xid)
+            except BaseException as e:
+                self._handle_dbapi_exception(e, None, None, None, None)
 
     def _rollback_twophase_impl(self, xid, is_prepared):
         assert not self.__branch_from
@@ -880,18 +868,14 @@ class Connection(Connectable):
         if self._has_events or self.engine._has_events:
             self.dispatch.rollback_twophase(self, xid, is_prepared)
 
-        if self._still_open_and_connection_is_valid:
+        if self._still_open_and_dbapi_connection_is_valid:
             assert isinstance(self._transaction, TwoPhaseTransaction)
             try:
                 self.engine.dialect.do_rollback_twophase(
                     self, xid, is_prepared
                 )
-            finally:
-                if self.connection._reset_agent is self._transaction:
-                    self.connection._reset_agent = None
-                self._transaction = None
-        else:
-            self._transaction = None
+            except BaseException as e:
+                self._handle_dbapi_exception(e, None, None, None, None)
 
     def _commit_twophase_impl(self, xid, is_prepared):
         assert not self.__branch_from
@@ -899,25 +883,19 @@ class Connection(Connectable):
         if self._has_events or self.engine._has_events:
             self.dispatch.commit_twophase(self, xid, is_prepared)
 
-        if self._still_open_and_connection_is_valid:
+        if self._still_open_and_dbapi_connection_is_valid:
             assert isinstance(self._transaction, TwoPhaseTransaction)
             try:
                 self.engine.dialect.do_commit_twophase(self, xid, is_prepared)
-            finally:
-                if self.connection._reset_agent is self._transaction:
-                    self.connection._reset_agent = None
-                self._transaction = None
-        else:
-            self._transaction = None
-
-    def _autobegin(self):
-        assert self._is_future
-
-        return self.begin()
+            except BaseException as e:
+                self._handle_dbapi_exception(e, None, None, None, None)
 
     def _autorollback(self):
-        if not self._root.in_transaction():
-            self._root._rollback_impl()
+        if self.__branch_from:
+            self.__branch_from._autorollback()
+
+        if not self.in_transaction():
+            self._rollback_impl()
 
     def close(self):
         """Close this :class:`_engine.Connection`.
@@ -938,40 +916,34 @@ class Connection(Connectable):
         and will allow no further operations.
 
         """
-        assert not self._is_future
 
         if self.__branch_from:
+            assert not self._is_future
             util.warn_deprecated_20(
                 "The .close() method on a so-called 'branched' connection is "
                 "deprecated as of 1.4, as are 'branched' connections overall, "
                 "and will be removed in a future release.  If this is a "
                 "default-handling function, don't close the connection."
             )
+            self._dbapi_connection = None
+            self.__can_reconnect = False
+            return
 
-            try:
-                del self.__connection
-            except AttributeError:
-                pass
-            finally:
-                self.__can_reconnect = False
-                return
-        try:
-            conn = self.__connection
-        except AttributeError:
-            pass
-        else:
+        if self._transaction:
+            self._transaction.close()
 
+        if self._dbapi_connection is not None:
+            conn = self._dbapi_connection
             conn.close()
             if conn._reset_agent is self._transaction:
                 conn._reset_agent = None
 
-            # the close() process can end up invalidating us,
-            # as the pool will call our transaction as the "reset_agent"
-            # for rollback(), which can then cause an invalidation
-            if not self.__invalid:
-                del self.__connection
+            # There is a slight chance that conn.close() may have
+            # triggered an invalidation here in which case
+            # _dbapi_connection would already be None, however usually
+            # it will be non-None here and in a "closed" state.
+            self._dbapi_connection = None
         self.__can_reconnect = False
-        self._transaction = None
 
     def scalar(self, object_, *multiparams, **params):
         """Executes and returns the first column of the first row.
@@ -1100,12 +1072,7 @@ class Connection(Connectable):
                 )
 
         try:
-            try:
-                conn = self.__connection
-            except AttributeError:
-                # escape "except AttributeError" before revalidating
-                # to prevent misleading stacktraces in Py3K
-                conn = None
+            conn = self._dbapi_connection
             if conn is None:
                 conn = self._revalidate_connection()
 
@@ -1113,6 +1080,8 @@ class Connection(Connectable):
             ctx = dialect.execution_ctx_cls._init_default(
                 dialect, self, conn, execution_options
             )
+        except (exc.PendingRollbackError, exc.ResourceClosedError):
+            raise
         except BaseException as e:
             self._handle_dbapi_exception(e, None, None, None, None)
 
@@ -1388,41 +1357,43 @@ class Connection(Connectable):
         """Create an :class:`.ExecutionContext` and execute, returning
         a :class:`_engine.CursorResult`."""
 
+        branched = self
+        if self.__branch_from:
+            # if this is a "branched" connection, do everything in terms
+            # of the "root" connection, *except* for .close(), which is
+            # the only feature that branching provides
+            self = self.__branch_from
+
         if execution_options:
             dialect.set_exec_execution_options(self, execution_options)
 
         try:
-            try:
-                conn = self.__connection
-            except AttributeError:
-                # escape "except AttributeError" before revalidating
-                # to prevent misleading stacktraces in Py3K
-                conn = None
+            conn = self._dbapi_connection
             if conn is None:
                 conn = self._revalidate_connection()
 
             context = constructor(
                 dialect, self, conn, execution_options, *args
             )
+        except (exc.PendingRollbackError, exc.ResourceClosedError):
+            raise
         except BaseException as e:
             self._handle_dbapi_exception(
                 e, util.text_type(statement), parameters, None, None
             )
 
-        if self._root._transaction and not self._root._transaction.is_active:
-            raise exc.InvalidRequestError(
-                "This connection is on an inactive %stransaction.  "
-                "Please rollback() fully before proceeding."
-                % (
-                    "savepoint "
-                    if isinstance(self._transaction, NestedTransaction)
-                    else ""
-                ),
-                code="8s2a",
+        if (
+            self._transaction
+            and not self._transaction.is_active
+            or (
+                self._nested_transaction
+                and not self._nested_transaction.is_active
             )
+        ):
+            self._invalid_transaction()
 
-        if self._is_future and self._root._transaction is None:
-            self._autobegin()
+        if self._is_future and self._transaction is None:
+            self.begin()
 
         if context.compiled:
             context.pre_exec()
@@ -1512,20 +1483,21 @@ class Connection(Connectable):
             if (
                 not self._is_future
                 and context.should_autocommit
-                and self._root._transaction is None
+                and self._transaction is None
             ):
-                self._root._commit_impl(autocommit=True)
+                self._commit_impl(autocommit=True)
 
             # for "connectionless" execution, we have to close this
             # Connection after the statement is complete.
-            if self.should_close_with_result:
+            if branched.should_close_with_result:
                 assert not self._is_future
                 assert not context._is_future_result
 
                 # CursorResult already exhausted rows / has no rows.
-                # close us now
+                # close us now.  note this is where we call .close()
+                # on the "branched" connection if we're doing that.
                 if result._soft_closed:
-                    self.close()
+                    branched.close()
                 else:
                     # CursorResult will close this Connection when no more
                     # rows to fetch.
@@ -1606,7 +1578,7 @@ class Connection(Connectable):
                 and not self.closed
                 and self.dialect.is_disconnect(
                     e,
-                    self.__connection if not self.invalidated else None,
+                    self._dbapi_connection if not self.invalidated else None,
                     cursor,
                 )
             ) or (is_exit_exception and not self.closed)
@@ -1723,7 +1695,7 @@ class Connection(Connectable):
             if self._is_disconnect:
                 del self._is_disconnect
                 if not self.invalidated:
-                    dbapi_conn_wrapper = self.__connection
+                    dbapi_conn_wrapper = self._dbapi_connection
                     if invalidate_pool_on_disconnect:
                         self.engine.pool._invalidate(dbapi_conn_wrapper, e)
                     self.invalidate(e)
@@ -1946,19 +1918,42 @@ class Transaction(object):
       single: thread safety; Transaction
     """
 
+    __slots__ = ()
+
     _is_root = False
 
-    def __init__(self, connection, parent):
-        self.connection = connection
-        self._actual_parent = parent
-        self.is_active = True
+    def __init__(self, connection):
+        raise NotImplementedError()
 
-    def _deactivate(self):
-        self.is_active = False
+    def _do_deactivate(self):
+        """do whatever steps are necessary to set this transaction as
+        "deactive", however leave this transaction object in place as far
+        as the connection's state.
+
+        for a "real" transaction this should roll back the transction
+        and ensure this transaction is no longer a reset agent.
+
+        this is used for nesting of marker transactions where the marker
+        can set the "real" transaction as rolled back, however it stays
+        in place.
+
+        for 2.0 we hope to remove this nesting feature.
+
+        """
+        raise NotImplementedError()
+
+    def _do_close(self):
+        raise NotImplementedError()
+
+    def _do_rollback(self):
+        raise NotImplementedError()
+
+    def _do_commit(self):
+        raise NotImplementedError()
 
     @property
-    def _parent(self):
-        return self._actual_parent or self
+    def is_valid(self):
+        return self.is_active and not self.connection.invalidated
 
     def close(self):
         """Close this :class:`.Transaction`.
@@ -1971,34 +1966,27 @@ class Transaction(object):
         an enclosing transaction.
 
         """
-
-        if self._parent.is_active and self._parent is self:
-            self.rollback()
-        self.connection._discard_transaction(self)
+        try:
+            self._do_close()
+        finally:
+            assert not self.is_active
 
     def rollback(self):
         """Roll back this :class:`.Transaction`.
 
         """
-
-        if self._parent.is_active:
+        try:
             self._do_rollback()
-            self.is_active = False
-        self.connection._discard_transaction(self)
-
-    def _do_rollback(self):
-        self._parent._deactivate()
+        finally:
+            assert not self.is_active
 
     def commit(self):
         """Commit this :class:`.Transaction`."""
 
-        if not self._parent.is_active:
-            raise exc.InvalidRequestError("This transaction is inactive")
-        self._do_commit()
-        self.is_active = False
-
-    def _do_commit(self):
-        pass
+        try:
+            self._do_commit()
+        finally:
+            assert not self.is_active
 
     def __enter__(self):
         return self
@@ -2014,24 +2002,172 @@ class Transaction(object):
             self.rollback()
 
 
+class MarkerTransaction(Transaction):
+    """A 'marker' transaction that is used for nested begin() calls.
+
+    .. deprecated:: 1.4 future connection for 2.0 won't support this pattern.
+
+    """
+
+    __slots__ = ("connection", "_is_active", "_transaction")
+
+    def __init__(self, connection):
+        assert connection._transaction is not None
+        if not connection._transaction.is_active:
+            raise exc.InvalidRequestError(
+                "the current transaction on this connection is inactive.  "
+                "Please issue a rollback first."
+            )
+
+        self.connection = connection
+        if connection._nested_transaction is not None:
+            self._transaction = connection._nested_transaction
+        else:
+            self._transaction = connection._transaction
+        self._is_active = True
+
+    @property
+    def is_active(self):
+        return self._is_active and self._transaction.is_active
+
+    def _deactivate(self):
+        self._is_active = False
+
+    def _do_close(self):
+        # does not actually roll back the root
+        self._deactivate()
+
+    def _do_rollback(self):
+        # does roll back the root
+        if self._is_active:
+            try:
+                self._transaction._do_deactivate()
+            finally:
+                self._deactivate()
+
+    def _do_commit(self):
+        self._deactivate()
+
+
 class RootTransaction(Transaction):
     _is_root = True
 
+    __slots__ = ("connection", "is_active")
+
     def __init__(self, connection):
-        super(RootTransaction, self).__init__(connection, None)
-        self.connection._begin_impl(self)
+        assert connection._transaction is None
+        self.connection = connection
+        self._connection_begin_impl()
+        connection._transaction = self
 
-    def _deactivate(self):
-        self._do_rollback(deactivate_only=True)
-        self.is_active = False
+        self.is_active = True
+
+        # the SingletonThreadPool used with sqlite memory can share the same
+        # DBAPI connection / fairy among multiple Connection objects.  while
+        # this is not ideal, it is a still-supported use case which at the
+        # moment occurs in the test suite due to how some of pytest fixtures
+        # work out
+        if connection._dbapi_connection._reset_agent is None:
+            connection._dbapi_connection._reset_agent = self
 
-    def _do_rollback(self, deactivate_only=False):
+    def _deactivate_from_connection(self):
         if self.is_active:
-            self.connection._rollback_impl(deactivate_only=deactivate_only)
+            assert self.connection._transaction is self
+            self.is_active = False
+
+            if (
+                self.connection._dbapi_connection is not None
+                and self.connection._dbapi_connection._reset_agent is self
+            ):
+                self.connection._dbapi_connection._reset_agent = None
+
+        # we have tests that want to make sure the pool handles this
+        # correctly.  TODO: how to disable internal assertions cleanly?
+        # else:
+        #    if self.connection._dbapi_connection is not None:
+        #        assert (
+        #            self.connection._dbapi_connection._reset_agent is not self
+        #        )
+
+    def _do_deactivate(self):
+        # called from a MarkerTransaction to cancel this root transaction.
+        # the transaction stays in place as connection._transaction, but
+        # is no longer active and is no longer the reset agent for the
+        # pooled connection.   the connection won't support a new begin()
+        # until this transaction is explicitly closed, rolled back,
+        # or committed.
+
+        assert self.connection._transaction is self
+
+        if self.is_active:
+            self._connection_rollback_impl()
+
+        # handle case where a savepoint was created inside of a marker
+        # transaction that refers to a root.  nested has to be cancelled
+        # also.
+        if self.connection._nested_transaction:
+            self.connection._nested_transaction._cancel()
+
+        self._deactivate_from_connection()
+
+    def _connection_begin_impl(self):
+        self.connection._begin_impl(self)
+
+    def _connection_rollback_impl(self):
+        self.connection._rollback_impl()
+
+    def _connection_commit_impl(self):
+        self.connection._commit_impl()
+
+    def _close_impl(self):
+        try:
+            if self.is_active:
+                self._connection_rollback_impl()
+
+            if self.connection._nested_transaction:
+                self.connection._nested_transaction._cancel()
+        finally:
+            if self.is_active:
+                self._deactivate_from_connection()
+            if self.connection._transaction is self:
+                self.connection._transaction = None
+
+        assert not self.is_active
+        assert self.connection._transaction is not self
+
+    def _do_close(self):
+        self._close_impl()
+
+    def _do_rollback(self):
+        self._close_impl()
 
     def _do_commit(self):
         if self.is_active:
-            self.connection._commit_impl()
+            assert self.connection._transaction is self
+
+            try:
+                self._connection_commit_impl()
+            finally:
+                # whether or not commit succeeds, cancel any
+                # nested transactions, make this transaction "inactive"
+                # and remove it as a reset agent
+                if self.connection._nested_transaction:
+                    self.connection._nested_transaction._cancel()
+
+                self._deactivate_from_connection()
+
+            # ...however only remove as the connection's current transaction
+            # if commit succeeded.  otherwise it stays on so that a rollback
+            # needs to occur.
+            self.connection._transaction = None
+        else:
+            if self.connection._transaction is self:
+                self.connection._invalid_transaction()
+            else:
+                raise exc.InvalidRequestError("This transaction is inactive")
+
+        assert not self.is_active
+        assert self.connection._transaction is not self
 
 
 class NestedTransaction(Transaction):
@@ -2044,28 +2180,73 @@ class NestedTransaction(Transaction):
 
     """
 
-    def __init__(self, connection, parent):
-        super(NestedTransaction, self).__init__(connection, parent)
+    __slots__ = ("connection", "is_active", "_savepoint", "_previous_nested")
+
+    def __init__(self, connection):
+        assert connection._transaction is not None
+        self.connection = connection
         self._savepoint = self.connection._savepoint_impl()
+        self.is_active = True
+        self._previous_nested = connection._nested_transaction
+        connection._nested_transaction = self
 
-    def _deactivate(self):
-        self._do_rollback(deactivate_only=True)
+    def _deactivate_from_connection(self):
+        if self.connection._nested_transaction is self:
+            self.connection._nested_transaction = self._previous_nested
+        else:
+            util.warn(
+                "nested transaction already deassociated from connection"
+            )
+
+    def _cancel(self):
+        # called by RootTransaction when the outer transaction is
+        # committed, rolled back, or closed to cancel all savepoints
+        # without any action being taken
         self.is_active = False
+        self._deactivate_from_connection()
+        if self._previous_nested:
+            self._previous_nested._cancel()
 
-    def _do_rollback(self, deactivate_only=False):
-        if self.is_active:
-            self.connection._rollback_to_savepoint_impl(
-                self._savepoint, self._parent
-            )
+    def _close_impl(self, deactivate_from_connection):
+        try:
+            if self.is_active and self.connection._transaction.is_active:
+                self.connection._rollback_to_savepoint_impl(self._savepoint)
+        finally:
+            self.is_active = False
+            if deactivate_from_connection:
+                self._deactivate_from_connection()
+
+    def _do_deactivate(self):
+        self._close_impl(False)
+
+    def _do_close(self):
+        self._close_impl(True)
+
+    def _do_rollback(self):
+        self._close_impl(True)
 
     def _do_commit(self):
         if self.is_active:
-            self.connection._release_savepoint_impl(
-                self._savepoint, self._parent
-            )
+            try:
+                self.connection._release_savepoint_impl(self._savepoint)
+            finally:
+                # nested trans becomes inactive on failed release
+                # unconditionally.  this prevents it from trying to
+                # emit SQL when it rolls back.
+                self.is_active = False
+
+            # but only de-associate from connection if it succeeded
+            self._deactivate_from_connection()
+        else:
+            if self.connection._nested_transaction is self:
+                self.connection._invalid_transaction()
+            else:
+                raise exc.InvalidRequestError(
+                    "This nested transaction is inactive"
+                )
 
 
-class TwoPhaseTransaction(Transaction):
+class TwoPhaseTransaction(RootTransaction):
     """Represent a two-phase transaction.
 
     A new :class:`.TwoPhaseTransaction` object may be procured
@@ -2076,11 +2257,12 @@ class TwoPhaseTransaction(Transaction):
 
     """
 
+    __slots__ = ("connection", "is_active", "xid", "_is_prepared")
+
     def __init__(self, connection, xid):
-        super(TwoPhaseTransaction, self).__init__(connection, None)
         self._is_prepared = False
         self.xid = xid
-        self.connection._begin_twophase_impl(self)
+        super(TwoPhaseTransaction, self).__init__(connection)
 
     def prepare(self):
         """Prepare this :class:`.TwoPhaseTransaction`.
@@ -2088,15 +2270,18 @@ class TwoPhaseTransaction(Transaction):
         After a PREPARE, the transaction can be committed.
 
         """
-        if not self._parent.is_active:
+        if not self.is_active:
             raise exc.InvalidRequestError("This transaction is inactive")
         self.connection._prepare_twophase_impl(self.xid)
         self._is_prepared = True
 
-    def _do_rollback(self):
+    def _connection_begin_impl(self):
+        self.connection._begin_twophase_impl(self)
+
+    def _connection_rollback_impl(self):
         self.connection._rollback_twophase_impl(self.xid, self._is_prepared)
 
-    def _do_commit(self):
+    def _connection_commit_impl(self):
         self.connection._commit_twophase_impl(self.xid, self._is_prepared)
 
 
index 55462f0bfd81d4fe666bc67b202b523e8ca7c3d7..a886d2025856b7d8ebc7e9376db24d5b140b2685 100644 (file)
@@ -96,8 +96,15 @@ class CursorResultMetaData(ResultMetaData):
                 }
             )
 
+        # TODO: need unit test for:
+        # result = connection.execute("raw sql, no columns").scalars()
+        # without the "or ()" it's failing because MD_OBJECTS is None
         new_metadata._keymap.update(
-            {e: new_rec for new_rec in new_recs for e in new_rec[MD_OBJECTS]}
+            {
+                e: new_rec
+                for new_rec in new_recs
+                for e in new_rec[MD_OBJECTS] or ()
+            }
         )
 
         return new_metadata
index af271c56c06f2180d901f0e8711f902625c1b937..759f7f2bd910369d59073d73781dfd41f7626454 100644 (file)
@@ -640,18 +640,20 @@ class ConnectionEvents(event.Events):
 
         :param conn: :class:`_engine.Connection` object
         :param name: specified name used for the savepoint.
-        :param context: :class:`.ExecutionContext` in use.  May be ``None``.
+        :param context: not used
 
         """
+        # TODO: deprecate "context"
 
     def release_savepoint(self, conn, name, context):
         """Intercept release_savepoint() events.
 
         :param conn: :class:`_engine.Connection` object
         :param name: specified name used for the savepoint.
-        :param context: :class:`.ExecutionContext` in use.  May be ``None``.
+        :param context: not used
 
         """
+        # TODO: deprecate "context"
 
     def begin_twophase(self, conn, xid):
         """Intercept begin_twophase() events.
index 94cc25eabfb795cbfa5b162dbeb72b329151b894..92322fb903862f6e8a0d66f6df79d67def94d5ec 100644 (file)
@@ -225,6 +225,15 @@ class NoInspectionAvailable(InvalidRequestError):
     no context for inspection."""
 
 
+class PendingRollbackError(InvalidRequestError):
+    """A transaction has failed and needs to be rolled back before
+    continuing.
+
+    .. versionadded:: 1.4
+
+    """
+
+
 class ResourceClosedError(InvalidRequestError):
     """An operation was requested from a connection, cursor, or other
     object that's in a closed state."""
index b96716978e8a55707796564c543672f6370402be..d3b13b51077df1abe20a17d1a5a1f7c0636abdef 100644 (file)
@@ -249,25 +249,7 @@ class Connection(_LegacyConnection):
         if any transaction is in place.
 
         """
-
-        try:
-            conn = self.__connection
-        except AttributeError:
-            pass
-        else:
-            # TODO: can we do away with "_reset_agent" stuff now?
-            if self._transaction:
-                self._transaction.rollback()
-
-            conn.close()
-
-            # the close() process can end up invalidating us,
-            # as the pool will call our transaction as the "reset_agent"
-            # for rollback(), which can then cause an invalidation
-            if not self.__invalid:
-                del self.__connection
-        self.__can_reconnect = False
-        self._transaction = None
+        super(Connection, self).close()
 
     def execute(self, statement, parameters=None, execution_options=None):
         r"""Executes a SQL statement construct and returns a
index b053b8d9602f5fa887a075146554499d483e89e5..450e5d02394d6483ba0c332924a20e5cc767c856 100644 (file)
@@ -291,7 +291,7 @@ class SessionTransaction(object):
         elif self._state is DEACTIVE:
             if not deactive_ok and not rollback_ok:
                 if self._rollback_exception:
-                    raise sa_exc.InvalidRequestError(
+                    raise sa_exc.PendingRollbackError(
                         "This Session's transaction has been rolled back "
                         "due to a previous exception during flush."
                         " To begin a new transaction with this Session, "
index 05dcf230b243acac0b51ce5b90f0061bc2a6f24b..87e5ba0d2912396d72854e44797bf2a3a2568a3d 100644 (file)
@@ -285,13 +285,11 @@ def _assert_proper_exception_context(exception):
 
 
 def assert_raises(except_cls, callable_, *args, **kw):
-    _assert_raises(except_cls, callable_, args, kw, check_context=True)
+    return _assert_raises(except_cls, callable_, args, kw, check_context=True)
 
 
 def assert_raises_context_ok(except_cls, callable_, *args, **kw):
-    _assert_raises(
-        except_cls, callable_, args, kw,
-    )
+    return _assert_raises(except_cls, callable_, args, kw,)
 
 
 def assert_raises_return(except_cls, callable_, *args, **kw):
@@ -299,7 +297,7 @@ def assert_raises_return(except_cls, callable_, *args, **kw):
 
 
 def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
-    _assert_raises(
+    return _assert_raises(
         except_cls, callable_, args, kwargs, msg=msg, check_context=True
     )
 
@@ -307,7 +305,7 @@ def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
 def assert_raises_message_context_ok(
     except_cls, msg, callable_, *args, **kwargs
 ):
-    _assert_raises(except_cls, callable_, args, kwargs, msg=msg)
+    return _assert_raises(except_cls, callable_, args, kwargs, msg=msg)
 
 
 def _assert_raises(
index 9b2f6911da85a482e498da7cb6c1989b095d9b1b..f7d0dd3ea45370f00523a74640992f323d11b490 100644 (file)
@@ -412,6 +412,17 @@ def _prep_testing_database(options, file_config):
     if options.dropfirst:
         for cfg in config.Config.all_configs():
             e = cfg.db
+
+            # TODO: this has to be part of provision.py in postgresql
+            if against(cfg, "postgresql"):
+                with e.connect().execution_options(
+                    isolation_level="AUTOCOMMIT"
+                ) as conn:
+                    for xid in conn.execute(
+                        "select gid from pg_prepared_xacts"
+                    ).scalars():
+                        conn.execute("ROLLBACK PREPARED '%s'" % xid)
+
             inspector = inspect(e)
             try:
                 view_names = inspector.get_view_names()
@@ -447,6 +458,7 @@ def _prep_testing_database(options, file_config):
             if config.requirements.schemas.enabled_for_config(cfg):
                 util.drop_all_tables(e, inspector, schema=cfg.test_schema)
 
+            # TODO: this has to be part of provision.py in postgresql
             if against(cfg, "postgresql"):
                 from sqlalchemy.dialects import postgresql
 
index a09b047481ca1710e8548ffa345d0fbce100092d..f0d0a9b2fba7f09196bdc61793bb9ef5a384f90d 100644 (file)
@@ -103,8 +103,21 @@ def mock_connection():
         else:
             return
 
+    def commit():
+        if conn.explode == "commit":
+            raise MockDisconnect("Lost the DB connection on commit")
+        elif conn.explode == "commit_no_disconnect":
+            raise MockError(
+                "something broke on commit but we didn't lose the "
+                "connection"
+            )
+        else:
+            return
+
     conn = Mock(
-        rollback=Mock(side_effect=rollback), cursor=Mock(side_effect=cursor())
+        rollback=Mock(side_effect=rollback),
+        commit=Mock(side_effect=commit),
+        cursor=Mock(side_effect=cursor()),
     )
     return conn
 
@@ -420,7 +433,7 @@ class MockReconnectTest(fixtures.TestBase):
             [[call()], [call()], []],
         )
 
-    def test_invalidate_trans(self):
+    def test_invalidate_on_execute_trans(self):
         conn = self.db.connect()
         trans = conn.begin()
         self.dbapi.shutdown()
@@ -432,7 +445,7 @@ class MockReconnectTest(fixtures.TestBase):
         assert conn.invalidated
         assert trans.is_active
         assert_raises_message(
-            tsa.exc.StatementError,
+            tsa.exc.PendingRollbackError,
             "Can't reconnect until invalid transaction is rolled back",
             conn.execute,
             select([1]),
@@ -440,12 +453,30 @@ class MockReconnectTest(fixtures.TestBase):
         assert trans.is_active
 
         assert_raises_message(
-            tsa.exc.InvalidRequestError,
+            tsa.exc.PendingRollbackError,
+            "Can't reconnect until invalid transaction is rolled back",
+            trans.commit,
+        )
+
+        # now it's inactive...
+        assert not trans.is_active
+
+        # but still associated with the connection
+        assert_raises_message(
+            tsa.exc.PendingRollbackError,
+            "Can't reconnect until invalid transaction is rolled back",
+            conn.execute,
+            select([1]),
+        )
+        assert not trans.is_active
+
+        # still can't commit... error stays the same
+        assert_raises_message(
+            tsa.exc.PendingRollbackError,
             "Can't reconnect until invalid transaction is rolled back",
             trans.commit,
         )
 
-        assert trans.is_active
         trans.rollback()
         assert not trans.is_active
         conn.execute(select([1]))
@@ -455,6 +486,104 @@ class MockReconnectTest(fixtures.TestBase):
             [[call()], []],
         )
 
+    def test_invalidate_on_commit_trans(self):
+        conn = self.db.connect()
+        trans = conn.begin()
+        self.dbapi.shutdown("commit")
+
+        assert_raises(tsa.exc.DBAPIError, trans.commit)
+
+        assert not conn.closed
+        assert conn.invalidated
+        assert not trans.is_active
+
+        # error stays consistent
+        assert_raises_message(
+            tsa.exc.PendingRollbackError,
+            "Can't reconnect until invalid transaction is rolled back",
+            conn.execute,
+            select([1]),
+        )
+        assert not trans.is_active
+
+        assert_raises_message(
+            tsa.exc.PendingRollbackError,
+            "Can't reconnect until invalid transaction is rolled back",
+            trans.commit,
+        )
+
+        assert not trans.is_active
+
+        assert_raises_message(
+            tsa.exc.PendingRollbackError,
+            "Can't reconnect until invalid transaction is rolled back",
+            conn.execute,
+            select([1]),
+        )
+        assert not trans.is_active
+
+        trans.rollback()
+        assert not trans.is_active
+        conn.execute(select([1]))
+        assert not conn.invalidated
+
+    def test_commit_fails_contextmanager(self):
+        # this test is also performed in test/engine/test_transaction.py
+        # using real connections
+        conn = self.db.connect()
+
+        def go():
+            with conn.begin():
+                self.dbapi.shutdown("commit_no_disconnect")
+
+        assert_raises(tsa.exc.DBAPIError, go)
+
+        assert not conn.in_transaction()
+
+    def test_commit_fails_trans(self):
+        # this test is also performed in test/engine/test_transaction.py
+        # using real connections
+
+        conn = self.db.connect()
+        trans = conn.begin()
+        self.dbapi.shutdown("commit_no_disconnect")
+
+        assert_raises(tsa.exc.DBAPIError, trans.commit)
+
+        assert not conn.closed
+        assert not conn.invalidated
+        assert not trans.is_active
+
+        # error stays consistent
+        assert_raises_message(
+            tsa.exc.PendingRollbackError,
+            "This connection is on an inactive transaction.  Please rollback",
+            conn.execute,
+            select([1]),
+        )
+        assert not trans.is_active
+
+        assert_raises_message(
+            tsa.exc.PendingRollbackError,
+            "This connection is on an inactive transaction.  Please rollback",
+            trans.commit,
+        )
+
+        assert not trans.is_active
+
+        assert_raises_message(
+            tsa.exc.PendingRollbackError,
+            "This connection is on an inactive transaction.  Please rollback",
+            conn.execute,
+            select([1]),
+        )
+        assert not trans.is_active
+
+        trans.rollback()
+        assert not trans.is_active
+        conn.execute(select([1]))
+        assert not conn.invalidated
+
     def test_invalidate_dont_call_finalizer(self):
         conn = self.db.connect()
         finalizer = mock.Mock()
@@ -497,9 +626,9 @@ class MockReconnectTest(fixtures.TestBase):
 
         conn.close()
         assert conn.closed
-        assert conn.invalidated
+        assert not conn.invalidated
         assert_raises_message(
-            tsa.exc.StatementError,
+            tsa.exc.ResourceClosedError,
             "This Connection is closed",
             conn.execute,
             select([1]),
@@ -544,7 +673,7 @@ class MockReconnectTest(fixtures.TestBase):
         assert not conn.invalidated
 
         assert_raises_message(
-            tsa.exc.StatementError,
+            tsa.exc.ResourceClosedError,
             "This Connection is closed",
             conn.execute,
             select([1]),
@@ -594,10 +723,10 @@ class MockReconnectTest(fixtures.TestBase):
             )
 
         assert conn.closed
-        assert conn.invalidated
+        assert not conn.invalidated
 
         assert_raises_message(
-            tsa.exc.StatementError,
+            tsa.exc.ResourceClosedError,
             "This Connection is closed",
             conn.execute,
             select([1]),
@@ -955,7 +1084,7 @@ class RealReconnectTest(fixtures.TestBase):
 
         _assert_invalidated(c1_branch.execute, select([1]))
         assert not c1_branch.closed
-        assert not c1_branch._connection_is_valid
+        assert not c1_branch._still_open_and_dbapi_connection_is_valid
 
     def test_ensure_is_disconnect_gets_connection(self):
         def is_disconnect(e, conn, cursor):
@@ -1062,6 +1191,7 @@ class RealReconnectTest(fixtures.TestBase):
     def test_with_transaction(self):
         conn = self.engine.connect()
         trans = conn.begin()
+        assert trans.is_valid
         eq_(conn.execute(select([1])).scalar(), 1)
         assert not conn.closed
         self.engine.test_shutdown()
@@ -1069,21 +1199,56 @@ class RealReconnectTest(fixtures.TestBase):
         assert not conn.closed
         assert conn.invalidated
         assert trans.is_active
+        assert not trans.is_valid
+
         assert_raises_message(
-            tsa.exc.StatementError,
+            tsa.exc.PendingRollbackError,
             "Can't reconnect until invalid transaction is rolled back",
             conn.execute,
             select([1]),
         )
         assert trans.is_active
+        assert not trans.is_valid
+
         assert_raises_message(
-            tsa.exc.InvalidRequestError,
+            tsa.exc.PendingRollbackError,
             "Can't reconnect until invalid transaction is rolled back",
             trans.commit,
         )
-        assert trans.is_active
+
+        # becomes inactive
+        assert not trans.is_active
+        assert not trans.is_valid
+
+        # still asks us to rollback
+        assert_raises_message(
+            tsa.exc.PendingRollbackError,
+            "Can't reconnect until invalid transaction is rolled back",
+            conn.execute,
+            select([1]),
+        )
+
+        # still asks us..
+        assert_raises_message(
+            tsa.exc.PendingRollbackError,
+            "Can't reconnect until invalid transaction is rolled back",
+            trans.commit,
+        )
+
+        # still...it's being consistent in what it is asking.
+        assert_raises_message(
+            tsa.exc.PendingRollbackError,
+            "Can't reconnect until invalid transaction is rolled back",
+            conn.execute,
+            select([1]),
+        )
+
+        #  OK!
         trans.rollback()
         assert not trans.is_active
+        assert not trans.is_valid
+
+        # conn still invalid but we can reconnect
         assert conn.invalidated
         eq_(conn.execute(select([1])).scalar(), 1)
         assert not conn.invalidated
index fbc1ffd8399c368bcc5a6ef0522f9567be62e183..164604cd650f458dad88060708bb015531a66d30 100644 (file)
@@ -11,8 +11,10 @@ from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy import text
+from sqlalchemy import util
 from sqlalchemy import VARCHAR
 from sqlalchemy.future import select as future_select
+from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_warnings
@@ -49,8 +51,13 @@ class TransactionTest(fixtures.TestBase):
     def teardown_class(cls):
         users.drop(testing.db)
 
-    def test_commits(self):
-        connection = testing.db.connect()
+    @testing.fixture
+    def local_connection(self):
+        with testing.db.connect() as conn:
+            yield conn
+
+    def test_commits(self, local_connection):
+        connection = local_connection
         transaction = connection.begin()
         connection.execute(users.insert(), user_id=1, user_name="user1")
         transaction.commit()
@@ -66,10 +73,10 @@ class TransactionTest(fixtures.TestBase):
         transaction.commit()
         connection.close()
 
-    def test_rollback(self):
+    def test_rollback(self, local_connection):
         """test a basic rollback"""
 
-        connection = testing.db.connect()
+        connection = local_connection
         transaction = connection.begin()
         connection.execute(users.insert(), user_id=1, user_name="user1")
         connection.execute(users.insert(), user_id=2, user_name="user2")
@@ -77,10 +84,9 @@ class TransactionTest(fixtures.TestBase):
         transaction.rollback()
         result = connection.exec_driver_sql("select * from query_users")
         assert len(result.fetchall()) == 0
-        connection.close()
 
-    def test_raise(self):
-        connection = testing.db.connect()
+    def test_raise(self, local_connection):
+        connection = local_connection
 
         transaction = connection.begin()
         try:
@@ -95,10 +101,9 @@ class TransactionTest(fixtures.TestBase):
 
         result = connection.exec_driver_sql("select * from query_users")
         assert len(result.fetchall()) == 0
-        connection.close()
 
-    def test_nested_rollback(self):
-        connection = testing.db.connect()
+    def test_nested_rollback(self, local_connection):
+        connection = local_connection
         try:
             transaction = connection.begin()
             try:
@@ -129,176 +134,338 @@ class TransactionTest(fixtures.TestBase):
                 transaction.rollback()
                 raise
         except Exception as e:
-            try:
-                # and not "This transaction is inactive"
-                # comment moved here to fix pep8
-                assert str(e) == "uh oh"
-            finally:
-                connection.close()
+            # and not "This transaction is inactive"
+            # comment moved here to fix pep8
+            assert str(e) == "uh oh"
+        else:
+            assert False
 
-    def test_branch_nested_rollback(self):
-        connection = testing.db.connect()
-        try:
-            connection.begin()
-            branched = connection.connect()
-            assert branched.in_transaction()
-            branched.execute(users.insert(), user_id=1, user_name="user1")
-            nested = branched.begin()
-            branched.execute(users.insert(), user_id=2, user_name="user2")
-            nested.rollback()
-            assert not connection.in_transaction()
+    def test_branch_nested_rollback(self, local_connection):
+        connection = local_connection
+        connection.begin()
+        branched = connection.connect()
+        assert branched.in_transaction()
+        branched.execute(users.insert(), user_id=1, user_name="user1")
+        nested = branched.begin()
+        branched.execute(users.insert(), user_id=2, user_name="user2")
+        nested.rollback()
+        assert not connection.in_transaction()
 
-            assert_raises_message(
-                exc.InvalidRequestError,
-                "This connection is on an inactive transaction.  Please",
-                connection.exec_driver_sql,
-                "select 1",
-            )
+        assert_raises_message(
+            exc.InvalidRequestError,
+            "This connection is on an inactive transaction.  Please",
+            connection.exec_driver_sql,
+            "select 1",
+        )
 
-        finally:
-            connection.close()
+    def test_no_marker_on_inactive_trans(self, local_connection):
+        conn = local_connection
+        conn.begin()
 
-    def test_inactive_due_to_subtransaction_no_commit(self):
-        connection = testing.db.connect()
+        mk1 = conn.begin()
+
+        mk1.rollback()
+
+        assert_raises_message(
+            exc.InvalidRequestError,
+            "the current transaction on this connection is inactive.",
+            conn.begin,
+        )
+
+    @testing.requires.savepoints
+    def test_savepoint_cancelled_by_toplevel_marker(self, local_connection):
+        conn = local_connection
+        trans = conn.begin()
+        conn.execute(users.insert(), {"user_id": 1, "user_name": "name"})
+
+        mk1 = conn.begin()
+
+        sp1 = conn.begin_nested()
+        conn.execute(users.insert(), {"user_id": 2, "user_name": "name2"})
+
+        mk1.rollback()
+
+        assert not sp1.is_active
+        assert not trans.is_active
+        assert conn._transaction is trans
+        assert conn._nested_transaction is None
+
+        with testing.db.connect() as conn:
+            eq_(
+                conn.scalar(future_select(func.count(1)).select_from(users)),
+                0,
+            )
+
+    def test_inactive_due_to_subtransaction_no_commit(self, local_connection):
+        connection = local_connection
         trans = connection.begin()
         trans2 = connection.begin()
         trans2.rollback()
+        assert_raises_message(
+            exc.InvalidRequestError,
+            "This connection is on an inactive transaction.  Please rollback",
+            trans.commit,
+        )
+
+        trans.rollback()
+
         assert_raises_message(
             exc.InvalidRequestError,
             "This transaction is inactive",
             trans.commit,
         )
 
-    def test_branch_autorollback(self):
-        connection = testing.db.connect()
-        try:
-            branched = connection.connect()
-            branched.execute(users.insert(), user_id=1, user_name="user1")
-            try:
-                branched.execute(users.insert(), user_id=1, user_name="user1")
-            except exc.DBAPIError:
-                pass
-        finally:
-            connection.close()
+    @testing.requires.savepoints
+    def test_inactive_due_to_subtransaction_on_nested_no_commit(
+        self, local_connection
+    ):
+        connection = local_connection
+        trans = connection.begin()
 
-    def test_branch_orig_rollback(self):
-        connection = testing.db.connect()
-        try:
-            branched = connection.connect()
-            branched.execute(users.insert(), user_id=1, user_name="user1")
-            nested = branched.begin()
-            assert branched.in_transaction()
-            branched.execute(users.insert(), user_id=2, user_name="user2")
-            nested.rollback()
-            eq_(
-                connection.exec_driver_sql(
-                    "select count(*) from query_users"
-                ).scalar(),
-                1,
-            )
+        nested = connection.begin_nested()
 
-        finally:
-            connection.close()
+        trans2 = connection.begin()
+        trans2.rollback()
 
-    def test_branch_autocommit(self):
-        connection = testing.db.connect()
-        try:
-            branched = connection.connect()
-            branched.execute(users.insert(), user_id=1, user_name="user1")
-        finally:
-            connection.close()
+        assert_raises_message(
+            exc.InvalidRequestError,
+            "This connection is on an inactive savepoint transaction.  "
+            "Please rollback",
+            nested.commit,
+        )
+        trans.commit()
+
+        assert_raises_message(
+            exc.InvalidRequestError,
+            "This nested transaction is inactive",
+            nested.commit,
+        )
+
+    def test_branch_autorollback(self, local_connection):
+        connection = local_connection
+        branched = connection.connect()
+        branched.execute(users.insert(), dict(user_id=1, user_name="user1"))
+        assert_raises(
+            exc.DBAPIError,
+            branched.execute,
+            users.insert(),
+            dict(user_id=1, user_name="user1"),
+        )
+        # can continue w/o issue
+        branched.execute(users.insert(), dict(user_id=2, user_name="user2"))
+
+    def test_branch_orig_rollback(self, local_connection):
+        connection = local_connection
+        branched = connection.connect()
+        branched.execute(users.insert(), dict(user_id=1, user_name="user1"))
+        nested = branched.begin()
+        assert branched.in_transaction()
+        branched.execute(users.insert(), dict(user_id=2, user_name="user2"))
+        nested.rollback()
         eq_(
-            testing.db.execute(
-                text("select count(*) from query_users")
+            connection.exec_driver_sql(
+                "select count(*) from query_users"
             ).scalar(),
             1,
         )
 
-    @testing.requires.savepoints
-    def test_branch_savepoint_rollback(self):
-        connection = testing.db.connect()
-        try:
-            trans = connection.begin()
+    @testing.requires.independent_connections
+    def test_branch_autocommit(self, local_connection):
+        with testing.db.connect() as connection:
             branched = connection.connect()
-            assert branched.in_transaction()
-            branched.execute(users.insert(), user_id=1, user_name="user1")
-            nested = branched.begin_nested()
-            branched.execute(users.insert(), user_id=2, user_name="user2")
-            nested.rollback()
-            assert connection.in_transaction()
-            trans.commit()
-            eq_(
-                connection.exec_driver_sql(
-                    "select count(*) from query_users"
-                ).scalar(),
-                1,
+            branched.execute(
+                users.insert(), dict(user_id=1, user_name="user1")
             )
 
-        finally:
-            connection.close()
+        eq_(
+            local_connection.execute(
+                text("select count(*) from query_users")
+            ).scalar(),
+            1,
+        )
+
+    @testing.requires.savepoints
+    def test_branch_savepoint_rollback(self, local_connection):
+        connection = local_connection
+        trans = connection.begin()
+        branched = connection.connect()
+        assert branched.in_transaction()
+        branched.execute(users.insert(), user_id=1, user_name="user1")
+        nested = branched.begin_nested()
+        branched.execute(users.insert(), user_id=2, user_name="user2")
+        nested.rollback()
+        assert connection.in_transaction()
+        trans.commit()
+        eq_(
+            connection.exec_driver_sql(
+                "select count(*) from query_users"
+            ).scalar(),
+            1,
+        )
 
     @testing.requires.two_phase_transactions
-    def test_branch_twophase_rollback(self):
-        connection = testing.db.connect()
-        try:
-            branched = connection.connect()
-            assert not branched.in_transaction()
-            branched.execute(users.insert(), user_id=1, user_name="user1")
-            nested = branched.begin_twophase()
-            branched.execute(users.insert(), user_id=2, user_name="user2")
-            nested.rollback()
-            assert not connection.in_transaction()
-            eq_(
-                connection.exec_driver_sql(
-                    "select count(*) from query_users"
-                ).scalar(),
-                1,
-            )
+    def test_branch_twophase_rollback(self, local_connection):
+        connection = local_connection
+        branched = connection.connect()
+        assert not branched.in_transaction()
+        branched.execute(users.insert(), user_id=1, user_name="user1")
+        nested = branched.begin_twophase()
+        branched.execute(users.insert(), user_id=2, user_name="user2")
+        nested.rollback()
+        assert not connection.in_transaction()
+        eq_(
+            connection.exec_driver_sql(
+                "select count(*) from query_users"
+            ).scalar(),
+            1,
+        )
+
+    def test_commit_fails_flat(self, local_connection):
+        connection = local_connection
+
+        t1 = connection.begin()
+
+        with mock.patch.object(
+            connection,
+            "_commit_impl",
+            mock.Mock(side_effect=exc.DBAPIError("failure", None, None, None)),
+        ):
+            assert_raises_message(exc.DBAPIError, r"failure", t1.commit)
+
+        assert not t1.is_active
+        t1.rollback()  # no error
+
+    def test_commit_fails_ctxmanager(self, local_connection):
+        connection = local_connection
+
+        transaction = [None]
+
+        def go():
+            with mock.patch.object(
+                connection,
+                "_commit_impl",
+                mock.Mock(
+                    side_effect=exc.DBAPIError("failure", None, None, None)
+                ),
+            ):
+                with connection.begin() as t1:
+                    transaction[0] = t1
+
+        assert_raises_message(exc.DBAPIError, r"failure", go)
+
+        t1 = transaction[0]
+        assert not t1.is_active
+        t1.rollback()  # no error
+
+    @testing.requires.savepoints_w_release
+    def test_savepoint_rollback_fails_flat(self, local_connection):
+        connection = local_connection
+        t1 = connection.begin()
+
+        s1 = connection.begin_nested()
+
+        # force the "commit" of the savepoint that occurs
+        # when the "with" block fails, e.g.
+        # the RELEASE, to fail, because the savepoint is already
+        # released.
+        connection.dialect.do_release_savepoint(connection, s1._savepoint)
+
+        assert_raises_message(
+            exc.DBAPIError, r".*SQL\:.*ROLLBACK TO SAVEPOINT", s1.rollback
+        )
 
-        finally:
-            connection.close()
+        assert not s1.is_active
+
+        with testing.expect_warnings("nested transaction already"):
+            s1.rollback()  # no error (though it warns)
+
+        t1.commit()  # no error
 
-    @testing.requires.python2
     @testing.requires.savepoints_w_release
-    def test_savepoint_release_fails_warning(self):
+    def test_savepoint_release_fails_flat(self):
         with testing.db.connect() as connection:
-            connection.begin()
+            t1 = connection.begin()
 
-            with expect_warnings(
-                "An exception has occurred during handling of a previous "
-                "exception.  The previous exception "
-                r"is:.*..SQL\:.*RELEASE SAVEPOINT"
-            ):
+            s1 = connection.begin_nested()
+
+            # force the "commit" of the savepoint that occurs
+            # when the "with" block fails, e.g.
+            # the RELEASE, to fail, because the savepoint is already
+            # released.
+            connection.dialect.do_release_savepoint(connection, s1._savepoint)
+
+            assert_raises_message(
+                exc.DBAPIError, r".*SQL\:.*RELEASE SAVEPOINT", s1.commit
+            )
 
-                def go():
-                    with connection.begin_nested() as savepoint:
-                        connection.dialect.do_release_savepoint(
-                            connection, savepoint._savepoint
-                        )
+            assert not s1.is_active
+            s1.rollback()  # no error.  prior to 1.4 this would try to rollback
 
-                assert_raises_message(
-                    exc.DBAPIError, r".*SQL\:.*ROLLBACK TO SAVEPOINT", go
+            t1.commit()  # no error
+
+    @testing.requires.savepoints_w_release
+    def test_savepoint_release_fails_ctxmanager(self, local_connection):
+        connection = local_connection
+        connection.begin()
+
+        savepoint = [None]
+
+        def go():
+
+            with connection.begin_nested() as sp:
+                savepoint[0] = sp
+                # force the "commit" of the savepoint that occurs
+                # when the "with" block fails, e.g.
+                # the RELEASE, to fail, because the savepoint is already
+                # released.
+                connection.dialect.do_release_savepoint(
+                    connection, sp._savepoint
                 )
 
-    def test_retains_through_options(self):
-        connection = testing.db.connect()
-        try:
-            transaction = connection.begin()
-            connection.execute(users.insert(), user_id=1, user_name="user1")
-            conn2 = connection.execution_options(dummy=True)
-            conn2.execute(users.insert(), user_id=2, user_name="user2")
-            transaction.rollback()
-            eq_(
-                connection.exec_driver_sql(
-                    "select count(*) from query_users"
-                ).scalar(),
-                0,
-            )
-        finally:
-            connection.close()
+        # prior to SQLAlchemy 1.4, the above release would fail
+        # and then the savepoint would try to rollback, and that failed
+        # also, causing a long exception chain that under Python 2
+        # was particularly hard to diagnose, leading to issue
+        # #2696 which eventually impacted Openstack, and we
+        # had to add warnings that show what the "context" for an
+        # exception was.   The SQL for the exception was
+        # ROLLBACK TO SAVEPOINT, and up the exception chain would be
+        # the RELEASE failing.
+        #
+        # now, when the savepoint "commit" fails, it sets itself as
+        # inactive.   so it does not try to rollback and it cleans
+        # itself out appropriately.
+        #
+
+        exc_ = assert_raises_message(
+            exc.DBAPIError, r".*SQL\:.*RELEASE SAVEPOINT", go
+        )
+        savepoint = savepoint[0]
+        assert not savepoint.is_active
 
-    def test_nesting(self):
-        connection = testing.db.connect()
+        if util.py3k:
+            # driver error
+            assert exc_.__cause__
+
+            # and that's it, no other context
+            assert not exc_.__cause__.__context__
+
+    def test_retains_through_options(self, local_connection):
+        connection = local_connection
+        transaction = connection.begin()
+        connection.execute(users.insert(), user_id=1, user_name="user1")
+        conn2 = connection.execution_options(dummy=True)
+        conn2.execute(users.insert(), user_id=2, user_name="user2")
+        transaction.rollback()
+        eq_(
+            connection.exec_driver_sql(
+                "select count(*) from query_users"
+            ).scalar(),
+            0,
+        )
+
+    def test_nesting(self, local_connection):
+        connection = local_connection
         transaction = connection.begin()
         connection.execute(users.insert(), user_id=1, user_name="user1")
         connection.execute(users.insert(), user_id=2, user_name="user2")
@@ -316,10 +483,9 @@ class TransactionTest(fixtures.TestBase):
         )
         result = connection.exec_driver_sql("select * from query_users")
         assert len(result.fetchall()) == 0
-        connection.close()
 
-    def test_with_interface(self):
-        connection = testing.db.connect()
+    def test_with_interface(self, local_connection):
+        connection = local_connection
         trans = connection.begin()
         connection.execute(users.insert(), user_id=1, user_name="user1")
         connection.execute(users.insert(), user_id=2, user_name="user2")
@@ -346,10 +512,9 @@ class TransactionTest(fixtures.TestBase):
             ).scalar()
             == 1
         )
-        connection.close()
 
-    def test_close(self):
-        connection = testing.db.connect()
+    def test_close(self, local_connection):
+        connection = local_connection
         transaction = connection.begin()
         connection.execute(users.insert(), user_id=1, user_name="user1")
         connection.execute(users.insert(), user_id=2, user_name="user2")
@@ -370,10 +535,9 @@ class TransactionTest(fixtures.TestBase):
         )
         result = connection.exec_driver_sql("select * from query_users")
         assert len(result.fetchall()) == 5
-        connection.close()
 
-    def test_close2(self):
-        connection = testing.db.connect()
+    def test_close2(self, local_connection):
+        connection = local_connection
         transaction = connection.begin()
         connection.execute(users.insert(), user_id=1, user_name="user1")
         connection.execute(users.insert(), user_id=2, user_name="user2")
@@ -394,11 +558,10 @@ class TransactionTest(fixtures.TestBase):
         )
         result = connection.exec_driver_sql("select * from query_users")
         assert len(result.fetchall()) == 0
-        connection.close()
 
     @testing.requires.savepoints
-    def test_nested_subtransaction_rollback(self):
-        connection = testing.db.connect()
+    def test_nested_subtransaction_rollback(self, local_connection):
+        connection = local_connection
         transaction = connection.begin()
         connection.execute(users.insert(), user_id=1, user_name="user1")
         trans2 = connection.begin_nested()
@@ -412,11 +575,10 @@ class TransactionTest(fixtures.TestBase):
             ).fetchall(),
             [(1,), (3,)],
         )
-        connection.close()
 
     @testing.requires.savepoints
-    def test_nested_subtransaction_commit(self):
-        connection = testing.db.connect()
+    def test_nested_subtransaction_commit(self, local_connection):
+        connection = local_connection
         transaction = connection.begin()
         connection.execute(users.insert(), user_id=1, user_name="user1")
         trans2 = connection.begin_nested()
@@ -430,11 +592,10 @@ class TransactionTest(fixtures.TestBase):
             ).fetchall(),
             [(1,), (2,), (3,)],
         )
-        connection.close()
 
     @testing.requires.savepoints
-    def test_rollback_to_subtransaction(self):
-        connection = testing.db.connect()
+    def test_rollback_to_subtransaction(self, local_connection):
+        connection = local_connection
         transaction = connection.begin()
         connection.execute(users.insert(), user_id=1, user_name="user1")
         trans2 = connection.begin_nested()
@@ -451,6 +612,7 @@ class TransactionTest(fixtures.TestBase):
             "select 1",
         )
         trans2.rollback()
+        assert connection._nested_transaction is None
 
         connection.execute(users.insert(), user_id=4, user_name="user4")
         transaction.commit()
@@ -460,11 +622,10 @@ class TransactionTest(fixtures.TestBase):
             ).fetchall(),
             [(1,), (4,)],
         )
-        connection.close()
 
     @testing.requires.two_phase_transactions
-    def test_two_phase_transaction(self):
-        connection = testing.db.connect()
+    def test_two_phase_transaction(self, local_connection):
+        connection = local_connection
         transaction = connection.begin_twophase()
         connection.execute(users.insert(), user_id=1, user_name="user1")
         transaction.prepare()
@@ -487,7 +648,6 @@ class TransactionTest(fixtures.TestBase):
             ).fetchall(),
             [(1,), (2,)],
         )
-        connection.close()
 
     # PG emergency shutdown:
     # select * from pg_prepared_xacts
@@ -495,12 +655,11 @@ class TransactionTest(fixtures.TestBase):
     # MySQL emergency shutdown:
     # for arg in `mysql -u root -e "xa recover" | cut -c 8-100 |
     #     grep sa`; do mysql -u root -e "xa rollback '$arg'"; done
-    @testing.crashes("mysql", "Crashing on 5.5, not worth it")
     @testing.requires.skip_mysql_on_windows
     @testing.requires.two_phase_transactions
     @testing.requires.savepoints
-    def test_mixed_two_phase_transaction(self):
-        connection = testing.db.connect()
+    def test_mixed_two_phase_transaction(self, local_connection):
+        connection = local_connection
         transaction = connection.begin_twophase()
         connection.execute(users.insert(), user_id=1, user_name="user1")
         transaction2 = connection.begin()
@@ -521,44 +680,46 @@ class TransactionTest(fixtures.TestBase):
             ).fetchall(),
             [(1,), (2,), (5,)],
         )
-        connection.close()
 
     @testing.requires.two_phase_transactions
     @testing.requires.two_phase_recovery
     def test_two_phase_recover(self):
 
-        # MySQL recovery doesn't currently seem to work correctly
-        # Prepared transactions disappear when connections are closed
-        # and even when they aren't it doesn't seem possible to use the
-        # recovery id.
+        # 2020, still can't get this to work w/ modern MySQL or MariaDB.
+        # the XA RECOVER comes back as bytes, OK, convert to string,
+        # XA COMMIT then says Unknown XID. Also, the drivers seem to be
+        # killing off the XID if I use the connection.invalidate() before
+        # trying to access in another connection.    Not really worth it
+        # unless someone wants to step through how mysqlclient / pymysql
+        # support this correctly.
 
         connection = testing.db.connect()
+
         transaction = connection.begin_twophase()
-        connection.execute(users.insert(), user_id=1, user_name="user1")
+        connection.execute(users.insert(), dict(user_id=1, user_name="user1"))
         transaction.prepare()
         connection.invalidate()
 
-        connection2 = testing.db.connect()
-        eq_(
-            connection2.execution_options(autocommit=True)
-            .execute(select([users.c.user_id]).order_by(users.c.user_id))
-            .fetchall(),
-            [],
-        )
-        recoverables = connection2.recover_twophase()
-        assert transaction.xid in recoverables
-        connection2.commit_prepared(transaction.xid, recover=True)
-        eq_(
-            connection2.execute(
-                select([users.c.user_id]).order_by(users.c.user_id)
-            ).fetchall(),
-            [(1,)],
-        )
-        connection2.close()
+        with testing.db.connect() as connection2:
+            eq_(
+                connection2.execution_options(autocommit=True)
+                .execute(select([users.c.user_id]).order_by(users.c.user_id))
+                .fetchall(),
+                [],
+            )
+            recoverables = connection2.recover_twophase()
+            assert transaction.xid in recoverables
+            connection2.commit_prepared(transaction.xid, recover=True)
+            eq_(
+                connection2.execute(
+                    select([users.c.user_id]).order_by(users.c.user_id)
+                ).fetchall(),
+                [(1,)],
+            )
 
     @testing.requires.two_phase_transactions
-    def test_multiple_two_phase(self):
-        conn = testing.db.connect()
+    def test_multiple_two_phase(self, local_connection):
+        conn = local_connection
         xa = conn.begin_twophase()
         conn.execute(users.insert(), user_id=1, user_name="user1")
         xa.prepare()
@@ -578,7 +739,6 @@ class TransactionTest(fixtures.TestBase):
             select([users.c.user_name]).order_by(users.c.user_id)
         )
         eq_(result.fetchall(), [("user1",), ("user4",)])
-        conn.close()
 
     @testing.requires.two_phase_transactions
     def test_reset_rollback_two_phase_no_rollback(self):
@@ -652,7 +812,7 @@ class ResetAgentTest(fixtures.TestBase):
         with expect_warnings("Reset agent is not active"):
             conn.close()
 
-    def test_trans_commit_reset_agent_broken_ensure(self):
+    def test_trans_commit_reset_agent_broken_ensure_pool(self):
         eng = testing_engine(options={"pool_reset_on_return": "commit"})
         conn = eng.connect()
         trans = conn.begin()
@@ -669,8 +829,10 @@ class ResetAgentTest(fixtures.TestBase):
             assert connection.connection._reset_agent is t1
             t2 = connection.begin_nested()
             assert connection.connection._reset_agent is t1
-            assert connection._transaction is t2
+            assert connection._nested_transaction is t2
+            assert connection._transaction is t1
             t2.close()
+            assert connection._nested_transaction is None
             assert connection._transaction is t1
             assert connection.connection._reset_agent is t1
             t1.close()
@@ -684,10 +846,15 @@ class ResetAgentTest(fixtures.TestBase):
             assert connection.connection._reset_agent is t1
             t2 = connection.begin_nested()
             assert connection.connection._reset_agent is t1
-            assert connection._transaction is t2
+            assert connection._nested_transaction is t2
+            assert connection._transaction is t1
 
             assert connection.connection._reset_agent is t1
             t1.close()
+
+            assert connection._nested_transaction is None
+            assert connection._transaction is None
+
             assert connection.connection._reset_agent is None
         assert not t1.is_active
 
@@ -698,19 +865,25 @@ class ResetAgentTest(fixtures.TestBase):
             assert connection.connection._reset_agent is t1
             t2 = connection.begin_nested()
             assert connection.connection._reset_agent is t1
-            assert connection._transaction is t2
+            assert connection._nested_transaction is t2
+            assert connection._transaction is t1
             t2.close()
+            assert connection._nested_transaction is None
             assert connection._transaction is t1
             assert connection.connection._reset_agent is t1
             t1.rollback()
+            assert connection._transaction is None
             assert connection.connection._reset_agent is None
+        assert not t2.is_active
         assert not t1.is_active
 
     @testing.requires.savepoints
     def test_begin_nested_close(self):
         with testing.db.connect() as connection:
             trans = connection.begin_nested()
-            assert connection.connection._reset_agent is trans
+            assert (
+                connection.connection._reset_agent is connection._transaction
+            )
         assert not trans.is_active
 
     @testing.requires.savepoints
@@ -719,7 +892,7 @@ class ResetAgentTest(fixtures.TestBase):
             trans = connection.begin()
             trans2 = connection.begin_nested()
             assert connection.connection._reset_agent is trans
-        assert trans2.is_active  # was never closed
+        assert not trans2.is_active
         assert not trans.is_active
 
     @testing.requires.savepoints
@@ -1177,11 +1350,9 @@ class IsolationLevelTest(fixtures.TestBase):
 
 
 class FutureResetAgentTest(fixtures.FutureEngineMixin, fixtures.TestBase):
-    """The SQLAlchemy 2.0 Connection ensures its own transaction is rolled
-    back upon close.  Therefore the whole "reset agent" thing can go away.
-    this suite runs through all the reset agent tests to ensure the state
-    of the transaction is maintained while the "reset agent" feature is not
-    needed at all.
+    """Still some debate over if the "reset agent" should apply to the
+    future connection or not.
+
 
     """
 
@@ -1192,7 +1363,8 @@ class FutureResetAgentTest(fixtures.FutureEngineMixin, fixtures.TestBase):
         with testing.db.connect() as connection:
             event.listen(connection, "rollback", canary)
             trans = connection.begin()
-            assert connection.connection._reset_agent is None
+            assert connection.connection._reset_agent is trans
+
         assert not trans.is_active
         eq_(canary.mock_calls, [mock.call(connection)])
 
@@ -1201,7 +1373,7 @@ class FutureResetAgentTest(fixtures.FutureEngineMixin, fixtures.TestBase):
         with testing.db.connect() as connection:
             event.listen(connection, "rollback", canary)
             trans = connection.begin()
-            assert connection.connection._reset_agent is None
+            assert connection.connection._reset_agent is trans
             trans.rollback()
             assert connection.connection._reset_agent is None
         assert not trans.is_active
@@ -1213,7 +1385,7 @@ class FutureResetAgentTest(fixtures.FutureEngineMixin, fixtures.TestBase):
             event.listen(connection, "rollback", canary.rollback)
             event.listen(connection, "commit", canary.commit)
             trans = connection.begin()
-            assert connection.connection._reset_agent is None
+            assert connection.connection._reset_agent is trans
             trans.commit()
             assert connection.connection._reset_agent is None
         assert not trans.is_active
@@ -1226,8 +1398,11 @@ class FutureResetAgentTest(fixtures.FutureEngineMixin, fixtures.TestBase):
             event.listen(connection, "rollback", canary.rollback)
             event.listen(connection, "commit", canary.commit)
             trans = connection.begin_nested()
-            assert connection.connection._reset_agent is None
-        assert trans.is_active  # it's a savepoint
+            assert (
+                connection.connection._reset_agent is connection._transaction
+            )
+        # it's a savepoint, but root made sure it closed
+        assert not trans.is_active
         eq_(canary.mock_calls, [mock.call.rollback(connection)])
 
     @testing.requires.savepoints
@@ -1238,8 +1413,8 @@ class FutureResetAgentTest(fixtures.FutureEngineMixin, fixtures.TestBase):
             event.listen(connection, "commit", canary.commit)
             trans = connection.begin()
             trans2 = connection.begin_nested()
-            assert connection.connection._reset_agent is None
-        assert trans2.is_active  # was never closed
+            assert connection.connection._reset_agent is trans
+        assert not trans2.is_active
         assert not trans.is_active
         eq_(canary.mock_calls, [mock.call.rollback(connection)])
 
@@ -1254,15 +1429,15 @@ class FutureResetAgentTest(fixtures.FutureEngineMixin, fixtures.TestBase):
             event.listen(connection, "commit", canary.commit)
             trans = connection.begin()
             trans2 = connection.begin_nested()
-            assert connection.connection._reset_agent is None
+            assert connection.connection._reset_agent is trans
             trans2.rollback()  # this is not a connection level event
-            assert connection.connection._reset_agent is None
+            assert connection.connection._reset_agent is trans
             trans.commit()
             assert connection.connection._reset_agent is None
         eq_(
             canary.mock_calls,
             [
-                mock.call.rollback_savepoint(connection, mock.ANY, trans),
+                mock.call.rollback_savepoint(connection, mock.ANY, None),
                 mock.call.commit(connection),
             ],
         )
@@ -1275,9 +1450,9 @@ class FutureResetAgentTest(fixtures.FutureEngineMixin, fixtures.TestBase):
             event.listen(connection, "commit", canary.commit)
             trans = connection.begin()
             trans2 = connection.begin_nested()
-            assert connection.connection._reset_agent is None
+            assert connection.connection._reset_agent is trans
             trans2.rollback()
-            assert connection.connection._reset_agent is None
+            assert connection.connection._reset_agent is trans
             trans.rollback()
             assert connection.connection._reset_agent is None
         eq_(canary.mock_calls, [mock.call.rollback(connection)])
@@ -1292,7 +1467,7 @@ class FutureResetAgentTest(fixtures.FutureEngineMixin, fixtures.TestBase):
             )
             event.listen(connection, "commit", canary.commit)
             trans = connection.begin_twophase()
-            assert connection.connection._reset_agent is None
+            assert connection.connection._reset_agent is trans
         assert not trans.is_active
         eq_(
             canary.mock_calls,
@@ -1307,7 +1482,7 @@ class FutureResetAgentTest(fixtures.FutureEngineMixin, fixtures.TestBase):
             event.listen(connection, "commit", canary.commit)
             event.listen(connection, "commit_twophase", canary.commit_twophase)
             trans = connection.begin_twophase()
-            assert connection.connection._reset_agent is None
+            assert connection.connection._reset_agent is trans
             trans.commit()
             assert connection.connection._reset_agent is None
         eq_(
@@ -1325,7 +1500,7 @@ class FutureResetAgentTest(fixtures.FutureEngineMixin, fixtures.TestBase):
             )
             event.listen(connection, "commit", canary.commit)
             trans = connection.begin_twophase()
-            assert connection.connection._reset_agent is None
+            assert connection.connection._reset_agent is trans
             trans.rollback()
             assert connection.connection._reset_agent is None
         eq_(
@@ -1520,7 +1695,7 @@ class FutureTransactionTest(fixtures.FutureEngineMixin, fixtures.TablesTest):
             conn.invalidate()
 
             assert_raises_message(
-                exc.StatementError,
+                exc.PendingRollbackError,
                 "Can't reconnect",
                 conn.execute,
                 select([1]),
@@ -1672,7 +1847,7 @@ class FutureTransactionTest(fixtures.FutureEngineMixin, fixtures.TablesTest):
         with testing.db.begin() as conn:
             conn.execute(users.insert(), {"user_id": 1, "user_name": "name"})
 
-            conn.begin_nested()
+            sp1 = conn.begin_nested()
             conn.execute(users.insert(), {"user_id": 2, "user_name": "name2"})
 
             sp2 = conn.begin_nested()
@@ -1680,8 +1855,12 @@ class FutureTransactionTest(fixtures.FutureEngineMixin, fixtures.TablesTest):
 
             sp2.rollback()
 
+            assert not sp2.is_active
+            assert sp1.is_active
             assert conn.in_transaction()
 
+        assert not sp1.is_active
+
         with testing.db.connect() as conn:
             eq_(
                 conn.scalar(future_select(func.count(1)).select_from(users)),
@@ -1721,13 +1900,21 @@ class FutureTransactionTest(fixtures.FutureEngineMixin, fixtures.TablesTest):
             sp1 = conn.begin_nested()
             conn.execute(users.insert(), {"user_id": 2, "user_name": "name2"})
 
+            assert conn._nested_transaction is sp1
+
             sp2 = conn.begin_nested()
             conn.execute(users.insert(), {"user_id": 3, "user_name": "name3"})
 
+            assert conn._nested_transaction is sp2
+
             sp2.commit()
 
+            assert conn._nested_transaction is sp1
+
             sp1.rollback()
 
+            assert conn._nested_transaction is None
+
             assert conn.in_transaction()
 
         with testing.db.connect() as conn:
@@ -1735,3 +1922,33 @@ class FutureTransactionTest(fixtures.FutureEngineMixin, fixtures.TablesTest):
                 conn.scalar(future_select(func.count(1)).select_from(users)),
                 1,
             )
+
+    @testing.requires.savepoints
+    def test_savepoint_seven(self):
+        users = self.tables.users
+
+        conn = testing.db.connect()
+        trans = conn.begin()
+        conn.execute(users.insert(), {"user_id": 1, "user_name": "name"})
+
+        sp1 = conn.begin_nested()
+        conn.execute(users.insert(), {"user_id": 2, "user_name": "name2"})
+
+        sp2 = conn.begin_nested()
+        conn.execute(users.insert(), {"user_id": 3, "user_name": "name3"})
+
+        assert conn.in_transaction()
+
+        trans.close()
+
+        assert not sp1.is_active
+        assert not sp2.is_active
+        assert not trans.is_active
+        assert conn._transaction is None
+        assert conn._nested_transaction is None
+
+        with testing.db.connect() as conn:
+            eq_(
+                conn.scalar(future_select(func.count(1)).select_from(users)),
+                0,
+            )
index 78a62199aa439c5acdfc8d527977e1fb724d21b4..22e7363b00b655816861252a1ab8771ed3995a6a 100644 (file)
@@ -367,13 +367,25 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
         sess.add(u)
         sess.flush()
         c1 = sess.connection(User)
+        dbapi_conn = c1.connection
+        assert dbapi_conn.is_valid
 
         sess.invalidate()
-        assert c1.invalidated
+
+        # Connection object is closed
+        assert c1.closed
+
+        # "invalidated" is not part of "closed" state
+        assert not c1.invalidated
+
+        # but the DBAPI conn (really ConnectionFairy)
+        # is invalidated
+        assert not dbapi_conn.is_valid
 
         eq_(sess.query(User).all(), [])
         c2 = sess.connection(User)
         assert not c2.invalidated
+        assert c2.connection.is_valid
 
     def test_subtransaction_on_noautocommit(self):
         User, users = self.classes.User, self.tables.users
@@ -859,7 +871,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
         except Exception:
             trans2.rollback(_capture_exception=True)
         assert_raises_message(
-            sa_exc.InvalidRequestError,
+            sa_exc.PendingRollbackError,
             r"This Session's transaction has been rolled back due to a "
             r"previous exception during flush. To begin a new transaction "
             r"with this Session, first issue Session.rollback\(\). "
@@ -1001,7 +1013,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
 
         for i in range(5):
             assert_raises_message(
-                sa_exc.InvalidRequestError,
+                sa_exc.PendingRollbackError,
                 "^This Session's transaction has been "
                 r"rolled back due to a previous exception "
                 "during flush. To "
@@ -1037,7 +1049,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
 
         with expect_warnings(".*during handling of a previous exception.*"):
             session.begin_nested()
-            savepoint = session.connection()._transaction._savepoint
+            savepoint = session.connection()._nested_transaction._savepoint
 
             # force the savepoint to disappear
             session.connection().dialect.do_release_savepoint(
@@ -1708,7 +1720,12 @@ class SavepointTest(_LocalFixture):
         nested_trans._do_commit()
 
         is_(s.transaction, trans)
-        assert_raises(sa_exc.DBAPIError, s.rollback)
+
+        with expect_warnings("nested transaction already deassociated"):
+            # this previously would raise
+            # "savepoint "sa_savepoint_1" does not exist", however as of
+            # #5327 the savepoint already knows it's inactive
+            s.rollback()
 
         assert u1 not in s.new
 
index ed047d790739f345c6bf461629947cbf37f34008..c07717aa8a6f2738dfe9c64e42a8f8d693f27cec 100644 (file)
@@ -720,7 +720,7 @@ class DefaultRequirements(SuiteRequirements):
 
         def pg_prepared_transaction(config):
             if not against(config, "postgresql"):
-                return False
+                return True
 
             with config.db.connect() as conn:
                 try:
@@ -742,20 +742,20 @@ class DefaultRequirements(SuiteRequirements):
                 no_support(
                     "oracle", "two-phase xact not implemented in SQLA/oracle"
                 ),
-                no_support(
-                    "drizzle", "two-phase xact not supported by database"
-                ),
                 no_support(
                     "sqlite", "two-phase xact not supported by database"
                 ),
                 no_support(
                     "sybase", "two-phase xact not supported by drivers/SQLA"
                 ),
-                no_support(
-                    "mysql",
-                    "recent MySQL communiity editions have too many issues "
-                    "(late 2016), disabling for now",
-                ),
+                # in Ia3cbbf56d4882fcc7980f90519412f1711fae74d
+                # we are evaluating which modern MySQL / MariaDB versions
+                # can handle two-phase testing without too many problems
+                # no_support(
+                #     "mysql",
+                #    "recent MySQL communiity editions have too many issues "
+                #    "(late 2016), disabling for now",
+                # ),
                 NotPredicate(
                     LambdaPredicate(
                         pg_prepared_transaction,
@@ -768,7 +768,9 @@ class DefaultRequirements(SuiteRequirements):
     @property
     def two_phase_recovery(self):
         return self.two_phase_transactions + (
-            skip_if("mysql", "crashes on most mariadb and mysql versions")
+            skip_if(
+                "mysql", "still can't get recover to work w/ MariaDB / MySQL"
+            )
         )
 
     @property