]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Reworked internal exception raises that emit
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 18 Apr 2013 15:00:12 +0000 (11:00 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 18 Apr 2013 15:00:12 +0000 (11:00 -0400)
a rollback() before re-raising, so that the stack
trace is preserved from sys.exc_info() before entering
the rollback.  This so that the traceback is preserved
when using coroutine frameworks which may have switched
contexts before the rollback function returns.
[ticket:2703]

doc/build/changelog/changelog_08.rst
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/result.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/compat.py
lib/sqlalchemy/util/langhelpers.py

index 60d61cb8dadc49f6dbd3e6f7c42330c15b971919..a23d05fe799e350e47a33016b15851244caa4c85 100644 (file)
@@ -6,6 +6,17 @@
 .. changelog::
     :version: 0.8.1
 
+    .. change::
+      :tags: bug, sql
+      :tickets: 2703
+
+      Reworked internal exception raises that emit
+      a rollback() before re-raising, so that the stack
+      trace is preserved from sys.exc_info() before entering
+      the rollback.  This so that the traceback is preserved
+      when using coroutine frameworks which may have switched
+      contexts before the rollback function returns.
+
     .. change::
       :tags: bug, orm
       :tickets: 2697
index e40af621990ed978f7945d47a28e28c63d2ca259..b4c9b1e1c7122c91e8ce31a6c9b40d3b3d2f6313 100644 (file)
@@ -462,7 +462,6 @@ class Connection(Connectable):
             self.engine.dialect.do_begin(self.connection)
         except Exception, e:
             self._handle_dbapi_exception(e, None, None, None, None)
-            raise
 
     def _rollback_impl(self):
         if self._has_events:
@@ -476,7 +475,6 @@ class Connection(Connectable):
                 self.__transaction = None
             except Exception, e:
                 self._handle_dbapi_exception(e, None, None, None, None)
-                raise
         else:
             self.__transaction = None
 
@@ -491,7 +489,6 @@ class Connection(Connectable):
             self.__transaction = None
         except Exception, e:
             self._handle_dbapi_exception(e, None, None, None, None)
-            raise
 
     def _savepoint_impl(self, name=None):
         if self._has_events:
@@ -693,7 +690,6 @@ class Connection(Connectable):
                                 dialect, self, conn)
         except Exception, e:
             self._handle_dbapi_exception(e, None, None, None, None)
-            raise
 
         ret = ctx._exec_default(default, None)
         if self.should_close_with_result:
@@ -830,7 +826,6 @@ class Connection(Connectable):
             self._handle_dbapi_exception(e,
                         str(statement), parameters,
                         None, None)
-            raise
 
         if context.compiled:
             context.pre_exec()
@@ -877,7 +872,6 @@ class Connection(Connectable):
                                 parameters,
                                 cursor,
                                 context)
-            raise
 
         if self._has_events:
             self.dispatch.after_cursor_execute(self, cursor,
@@ -952,7 +946,6 @@ class Connection(Connectable):
                                 parameters,
                                 cursor,
                                 None)
-            raise
 
     def _safe_close_cursor(self, cursor):
         """Close the given cursor, catching exceptions
@@ -983,22 +976,21 @@ class Connection(Connectable):
                                     cursor,
                                     context):
 
+        exc_info = sys.exc_info()
+
         if not self._is_disconnect:
             self._is_disconnect = isinstance(e, self.dialect.dbapi.Error) and \
                 not self.closed and \
                 self.dialect.is_disconnect(e, self.__connection, cursor)
 
         if self._reentrant_error:
-            # Py3K
-            #raise exc.DBAPIError.instance(statement, parameters, e,
-            #                               self.dialect.dbapi.Error) from e
-            # Py2K
-            raise exc.DBAPIError.instance(statement,
+            util.raise_from_cause(
+                        exc.DBAPIError.instance(statement,
                                             parameters,
                                             e,
-                                            self.dialect.dbapi.Error), \
-                                            None, sys.exc_info()[2]
-            # end Py2K
+                                            self.dialect.dbapi.Error),
+                        exc_info
+                        )
         self._reentrant_error = True
         try:
             # non-DBAPI error - if we already got a context,
@@ -1021,26 +1013,18 @@ class Connection(Connectable):
                     self._safe_close_cursor(cursor)
                 self._autorollback()
 
-            if not should_wrap:
-                return
-
-            # Py3K
-            #raise exc.DBAPIError.instance(
-            #                        statement,
-            #                        parameters,
-            #                        e,
-            #                        self.dialect.dbapi.Error,
-            #                        connection_invalidated=self._is_disconnect) \
-            #                        from e
-            # Py2K
-            raise exc.DBAPIError.instance(
-                                    statement,
-                                    parameters,
-                                    e,
-                                    self.dialect.dbapi.Error,
-                                    connection_invalidated=self._is_disconnect), \
-                                    None, sys.exc_info()[2]
-            # end Py2K
+            if should_wrap:
+                util.raise_from_cause(
+                                    exc.DBAPIError.instance(
+                                        statement,
+                                        parameters,
+                                        e,
+                                        self.dialect.dbapi.Error,
+                                        connection_invalidated=self._is_disconnect),
+                                    exc_info
+                                )
+
+            util.reraise(*exc_info)
 
         finally:
             del self._reentrant_error
@@ -1115,8 +1099,8 @@ class Connection(Connectable):
             trans.commit()
             return ret
         except:
-            trans.rollback()
-            raise
+            with util.safe_reraise():
+                trans.rollback()
 
     def run_callable(self, callable_, *args, **kwargs):
         """Given a callable object or function, execute it, passing
@@ -1222,8 +1206,8 @@ class Transaction(object):
             try:
                 self.commit()
             except:
-                self.rollback()
-                raise
+                with util.safe_reraise():
+                    self.rollback()
         else:
             self.rollback()
 
@@ -1548,8 +1532,8 @@ class Engine(Connectable, log.Identified):
         try:
             trans = conn.begin()
         except:
-            conn.close()
-            raise
+            with util.safe_reraise():
+                conn.close()
         return Engine._trans_ctx(conn, trans, close_with_result)
 
     def transaction(self, callable_, *args, **kwargs):
index abb9f0fc3dc293a62ebb937174889154c0a03291..daa9fe0855a8eff1999e6e8075055f8cd3e68dcf 100644 (file)
@@ -737,7 +737,6 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
             except Exception, e:
                 self.root_connection._handle_dbapi_exception(
                                 e, None, None, None, self)
-                raise
         else:
             inputsizes = {}
             for key in self.compiled.bind_names.values():
@@ -756,7 +755,6 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
             except Exception, e:
                 self.root_connection._handle_dbapi_exception(
                                 e, None, None, None, self)
-                raise
 
     def _exec_default(self, default, type_):
         if default.is_sequence:
index 1c148e1f0e5a02be198372ba1620c0819bfff5b8..88930081e9ffd98992e4ee699e5e28411bd7fc2f 100644 (file)
@@ -443,7 +443,6 @@ class ResultProxy(object):
         except Exception, e:
             self.connection._handle_dbapi_exception(
                               e, None, None, self.cursor, self.context)
-            raise
 
     @property
     def lastrowid(self):
@@ -467,7 +466,6 @@ class ResultProxy(object):
             self.connection._handle_dbapi_exception(
                                  e, None, None,
                                  self._saved_cursor, self.context)
-            raise
 
     @property
     def returns_rows(self):
@@ -752,7 +750,6 @@ class ResultProxy(object):
             self.connection._handle_dbapi_exception(
                                     e, None, None,
                                     self.cursor, self.context)
-            raise
 
     def fetchmany(self, size=None):
         """Fetch many rows, just like DB-API
@@ -772,7 +769,6 @@ class ResultProxy(object):
             self.connection._handle_dbapi_exception(
                                     e, None, None,
                                     self.cursor, self.context)
-            raise
 
     def fetchone(self):
         """Fetch one row, just like DB-API ``cursor.fetchone()``.
@@ -792,7 +788,6 @@ class ResultProxy(object):
             self.connection._handle_dbapi_exception(
                                     e, None, None,
                                     self.cursor, self.context)
-            raise
 
     def first(self):
         """Fetch the first row and then close the result set unconditionally.
@@ -809,7 +804,6 @@ class ResultProxy(object):
             self.connection._handle_dbapi_exception(
                                     e, None, None,
                                     self.cursor, self.context)
-            raise
 
         try:
             if row is not None:
index 71e617e365787dd552e7c0b98d1ca95475c01a51..91e4f9736112bc8c9f6ca9d65afc7ad4a61061a4 100644 (file)
@@ -341,8 +341,8 @@ class SessionTransaction(object):
                 for t in set(self._connections.values()):
                     t[1].prepare()
             except:
-                self.rollback()
-                raise
+                with util.safe_reraise():
+                    self.rollback()
 
         self._state = PREPARED
 
@@ -441,8 +441,8 @@ class SessionTransaction(object):
             try:
                 self.commit()
             except:
-                self.rollback()
-                raise
+                with util.safe_reraise():
+                    self.rollback()
         else:
             self.rollback()
 
@@ -1928,8 +1928,8 @@ class Session(_SessionClassMethods):
             transaction.commit()
 
         except:
-            transaction.rollback(_capture_exception=True)
-            raise
+            with util.safe_reraise():
+                transaction.rollback(_capture_exception=True)
 
     def is_modified(self, instance, include_collections=True,
                             passive=True):
index 57bbdca85bf7aea3e6e541035f4ffe6e18fe252d..3fa06c79325b96c11b4f4c4481eb4e7b303856e6 100644 (file)
@@ -6,7 +6,8 @@
 
 from .compat import callable, cmp, reduce,  \
     threading, py3k, py3k_warning, jython, pypy, cpython, win32, set_types, \
-    pickle, dottedgetter, parse_qsl, namedtuple, next, WeakSet
+    pickle, dottedgetter, parse_qsl, namedtuple, next, WeakSet, reraise, \
+    raise_from_cause
 
 from ._collections import KeyedTuple, ImmutableContainer, immutabledict, \
     Properties, OrderedProperties, ImmutableProperties, OrderedDict, \
@@ -26,7 +27,7 @@ from .langhelpers import iterate_attributes, class_hierarchy, \
     duck_type_collection, assert_arg_type, symbol, dictlike_iteritems,\
     classproperty, set_creation_order, warn_exception, warn, NoneType,\
     constructor_copy, methods_equivalent, chop_traceback, asint,\
-    generic_repr, counter, PluginLoader, hybridmethod
+    generic_repr, counter, PluginLoader, hybridmethod, safe_reraise
 
 from .deprecations import warn_deprecated, warn_pending_deprecation, \
     deprecated, pending_deprecation
index 2a0f06f8ecbbafaa9f3248255b506de2ac2b3b54..0f0a2f6742dc613f707ffb9298ab0afa16190d50 100644 (file)
@@ -140,3 +140,29 @@ else:
     def b(s):
         return s
 
+
+if py3k:
+    def reraise(tp, value, tb=None, cause=None):
+        if cause is not None:
+            value.__cause__ = cause
+        if value.__traceback__ is not tb:
+            raise value.with_traceback(tb)
+        raise value
+
+    def raise_from_cause(exception, exc_info):
+        exc_type, exc_value, exc_tb = exc_info
+        reraise(type(exception), exception, tb=exc_tb, cause=exc_value)
+else:
+    exec("""def reraise(tp, value, tb=None, cause=None):
+        raise tp, value, tb
+    """)
+
+    def raise_from_cause(exception, exc_info):
+        # not as nice as that of Py3K, but at least preserves
+        # the code line where the issue occurred
+        exc_type, exc_value, exc_tb = exc_info
+        reraise(type(exception), exception, tb=exc_tb)
+
+
+
+
index e3aed24d8a487b397bc801bc2c20af9a3a7ef429..bba8ad734e18b3b4ba07594a1e20fd465bd14043 100644 (file)
@@ -20,6 +20,7 @@ from .compat import set_types, threading, \
 from functools import update_wrapper
 from .. import exc
 import hashlib
+from . import compat
 
 def md5_hex(x):
     # Py3K
@@ -28,6 +29,34 @@ def md5_hex(x):
     m.update(x)
     return m.hexdigest()
 
+class safe_reraise(object):
+    """Reraise an exception after invoking some
+    handler code.
+
+    Stores the existing exception info before
+    invoking so that it is maintained across a potential
+    coroutine context switch.
+
+    e.g.::
+
+        try:
+            sess.commit()
+        except:
+            with safe_reraise():
+                sess.rollback()
+
+    """
+
+    def __enter__(self):
+        self._exc_info = sys.exc_info()
+
+    def __exit__(self, type_, value, traceback):
+        if type_ is None:
+            exc_type, exc_value, exc_tb = self._exc_info
+            compat.reraise(exc_type, exc_value, exc_tb)
+        else:
+            compat.reraise(type_, value, traceback)
+
 def decode_slice(slc):
     """decode a slice object as sent to __getitem__.