]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Ensure all nested exception throws have a cause
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 29 Feb 2020 19:40:45 +0000 (14:40 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 2 Mar 2020 23:49:12 +0000 (18:49 -0500)
Applied an explicit "cause" to most if not all internally raised exceptions
that are raised from within an internal exception catch, to avoid
misleading stacktraces that suggest an error within the handling of an
exception.  While it would be preferable to suppress the internally caught
exception in the way that the ``__suppress_context__`` attribute would,
there does not as yet seem to be a way to do this without suppressing an
enclosing user constructed context, so for now it exposes the internally
caught exception as the cause so that full information about the context
of the error is maintained.

Fixes: #4849
Change-Id: I55a86b29023675d9e5e49bc7edc5a2dc0bcd4751
(cherry picked from commit 8be0dae77a7e0747f0d0fb4282db4aea7f41e03a)

48 files changed:
doc/build/changelog/unreleased_13/4849.rst [new file with mode: 0644]
lib/sqlalchemy/cextension/resultproxy.c
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/postgresql/psycopg2.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/result.py
lib/sqlalchemy/ext/associationproxy.py
lib/sqlalchemy/ext/baked.py
lib/sqlalchemy/ext/compiler.py
lib/sqlalchemy/ext/declarative/clsregistry.py
lib/sqlalchemy/ext/indexable.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/base.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/loading.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/relationships.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/strategy_options.py
lib/sqlalchemy/orm/sync.py
lib/sqlalchemy/pool/base.py
lib/sqlalchemy/processors.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/ddl.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/schema.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/type_api.py
lib/sqlalchemy/sql/visitors.py
lib/sqlalchemy/testing/__init__.py
lib/sqlalchemy/testing/assertions.py
lib/sqlalchemy/testing/exclusions.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/compat.py
lib/sqlalchemy/util/langhelpers.py
test/aaa_profiling/test_memusage.py
test/aaa_profiling/test_resultset.py
test/base/test_utils.py
test/engine/test_execute.py
test/engine/test_pool.py
test/engine/test_reconnect.py
test/engine/test_reflection.py
test/sql/test_metadata.py

diff --git a/doc/build/changelog/unreleased_13/4849.rst b/doc/build/changelog/unreleased_13/4849.rst
new file mode 100644 (file)
index 0000000..5a649dc
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: bug, general, py3k
+    :tickets: 4849
+
+    Applied an explicit "cause" to most if not all internally raised exceptions
+    that are raised from within an internal exception catch, to avoid
+    misleading stacktraces that suggest an error within the handling of an
+    exception.  While it would be preferable to suppress the internally caught
+    exception in the way that the ``__suppress_context__`` attribute would,
+    there does not as yet seem to be a way to do this without suppressing an
+    enclosing user constructed context, so for now it exposes the internally
+    caught exception as the cause so that full information about the context
+    of the error is maintained.
index 88c57dddee9a2da73ccc52727c95d8b514a3d163..981ae6e2ea71383d57c4e4dbb912c5477ac3faef 100644 (file)
@@ -294,7 +294,7 @@ BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key)
         record = PyDict_GetItem((PyObject *)self->keymap, key);
         if (record == NULL) {
             record = PyObject_CallMethod(self->parent, "_key_fallback",
-                                         "O", key);
+                                         "OO", key, Py_None);
             if (record == NULL)
                 return NULL;
             key_fallback = 1;
index 40b14ad4ab84f088847ddafbde95cff7b93676bb..5c2f114bddd5d1d16d6177894021dae971824bc6 100644 (file)
@@ -2922,7 +2922,7 @@ class MySQLDialect(default.DefaultDialect):
             ).execute(st)
         except exc.DBAPIError as e:
             if self._extract_error_code(e.orig) == 1146:
-                raise exc.NoSuchTableError(full_name)
+                util.raise_(exc.NoSuchTableError(full_name), replace_context=e)
             else:
                 raise
         row = self._compat_first(rp, charset=charset)
@@ -2948,11 +2948,16 @@ class MySQLDialect(default.DefaultDialect):
             except exc.DBAPIError as e:
                 code = self._extract_error_code(e.orig)
                 if code == 1146:
-                    raise exc.NoSuchTableError(full_name)
+                    util.raise_(
+                        exc.NoSuchTableError(full_name), replace_context=e
+                    )
                 elif code == 1356:
-                    raise exc.UnreflectableTableError(
-                        "Table or view named %s could not be "
-                        "reflected: %s" % (full_name, e)
+                    util.raise_(
+                        exc.UnreflectableTableError(
+                            "Table or view named %s could not be "
+                            "reflected: %s" % (full_name, e)
+                        ),
+                        replace_context=e,
                     )
                 else:
                     raise
index 55123bfd40679c683e9e313ee38314c943fd00f3..9ac94d4168c86612f0b01592c74c42a8c529a7ea 100644 (file)
@@ -764,11 +764,14 @@ class PGDialect_psycopg2(PGDialect):
     def set_isolation_level(self, connection, level):
         try:
             level = self._isolation_lookup[level.replace("_", " ")]
-        except KeyError:
-            raise exc.ArgumentError(
-                "Invalid value '%s' for isolation_level. "
-                "Valid isolation levels for %s are %s"
-                % (level, self.name, ", ".join(self._isolation_lookup))
+        except KeyError as err:
+            util.raise_(
+                exc.ArgumentError(
+                    "Invalid value '%s' for isolation_level. "
+                    "Valid isolation levels for %s are %s"
+                    % (level, self.name, ", ".join(self._isolation_lookup))
+                ),
+                replace_context=err,
             )
 
         connection.set_isolation_level(level)
index 904829e9e1edc31d98abf0c5bdc09329739c9d3b..4c630b7d27bef5ba7e4e3fbf35b1f887caca33b5 100644 (file)
@@ -997,9 +997,12 @@ class SQLiteCompiler(compiler.SQLCompiler):
                 self.extract_map[extract.field],
                 self.process(extract.expr, **kw),
             )
-        except KeyError:
-            raise exc.CompileError(
-                "%s is not a valid extract argument." % extract.field
+        except KeyError as err:
+            util.raise_(
+                exc.CompileError(
+                    "%s is not a valid extract argument." % extract.field
+                ),
+                replace_context=err,
             )
 
     def limit_clause(self, select, **kw):
@@ -1531,11 +1534,14 @@ class SQLiteDialect(default.DefaultDialect):
     def set_isolation_level(self, connection, level):
         try:
             isolation_level = self._isolation_lookup[level.replace("_", " ")]
-        except KeyError:
-            raise exc.ArgumentError(
-                "Invalid value '%s' for isolation_level. "
-                "Valid isolation levels for %s are %s"
-                % (level, self.name, ", ".join(self._isolation_lookup))
+        except KeyError as err:
+            util.raise_(
+                exc.ArgumentError(
+                    "Invalid value '%s' for isolation_level. "
+                    "Valid isolation levels for %s are %s"
+                    % (level, self.name, ", ".join(self._isolation_lookup))
+                ),
+                replace_context=err,
             )
         cursor = connection.cursor()
         cursor.execute("PRAGMA read_uncommitted = %d" % isolation_level)
index 7a47fd751a5784395c8aeef22a424f66efbe32eb..aa6c2ee2515f4675dd76dcfa1311bea9662706ef 100644 (file)
@@ -976,8 +976,10 @@ class Connection(Connectable):
             return self._execute_text(object_, multiparams, params)
         try:
             meth = object_._execute_on_connection
-        except AttributeError:
-            raise exc.ObjectNotExecutableError(object_)
+        except AttributeError as err:
+            util.raise_(
+                exc.ObjectNotExecutableError(object_), replace_context=err
+            )
         else:
             return meth(self, multiparams, params)
 
@@ -1369,7 +1371,7 @@ class Connection(Connectable):
         invalidate_pool_on_disconnect = not is_exit_exception
 
         if self._reentrant_error:
-            util.raise_from_cause(
+            util.raise_(
                 exc.DBAPIError.instance(
                     statement,
                     parameters,
@@ -1381,7 +1383,8 @@ class Connection(Connectable):
                     if context is not None
                     else None,
                 ),
-                exc_info,
+                with_traceback=exc_info[2],
+                from_=e,
             )
         self._reentrant_error = True
         try:
@@ -1471,11 +1474,13 @@ class Connection(Connectable):
                     self._autorollback()
 
             if newraise:
-                util.raise_from_cause(newraise, exc_info)
+                util.raise_(newraise, with_traceback=exc_info[2], from_=e)
             elif should_wrap:
-                util.raise_from_cause(sqlalchemy_exception, exc_info)
+                util.raise_(
+                    sqlalchemy_exception, with_traceback=exc_info[2], from_=e
+                )
             else:
-                util.reraise(*exc_info)
+                util.raise_(exc_info[1], with_traceback=exc_info[2])
 
         finally:
             del self._reentrant_error
@@ -1542,11 +1547,13 @@ class Connection(Connectable):
                 ) = ctx.is_disconnect
 
         if newraise:
-            util.raise_from_cause(newraise, exc_info)
+            util.raise_(newraise, with_traceback=exc_info[2], from_=e)
         elif should_wrap:
-            util.raise_from_cause(sqlalchemy_exception, exc_info)
+            util.raise_(
+                sqlalchemy_exception, with_traceback=exc_info[2], from_=e
+            )
         else:
-            util.reraise(*exc_info)
+            util.raise_(exc_info[1], with_traceback=exc_info[2])
 
     def transaction(self, callable_, *args, **kwargs):
         r"""Execute the given function within a transaction boundary.
@@ -2280,7 +2287,9 @@ class Engine(Connectable, log.Identified):
                     e, dialect, self
                 )
             else:
-                util.reraise(*sys.exc_info())
+                util.raise_(
+                    sys.exc_info()[1], with_traceback=sys.exc_info()[2]
+                )
 
     def raw_connection(self, _connection=None):
         """Return a "raw" DBAPI connection from the connection pool.
index e7f937c25d37d4ae100021e5490c8c3147ccdbe9..c142ae8513041e222b19e9d53464def2414a5ac5 100644 (file)
@@ -83,8 +83,8 @@ except ImportError:
         def __getitem__(self, key):
             try:
                 processor, obj, index = self._keymap[key]
-            except KeyError:
-                processor, obj, index = self._parent._key_fallback(key)
+            except KeyError as err:
+                processor, obj, index = self._parent._key_fallback(key, err)
             except TypeError:
                 if isinstance(key, slice):
                     l = []
@@ -112,7 +112,7 @@ except ImportError:
             try:
                 return self[name]
             except KeyError as e:
-                raise AttributeError(e.args[0])
+                util.raise_(AttributeError(e.args[0]), replace_context=e)
 
 
 class RowProxy(BaseRowProxy):
@@ -639,7 +639,7 @@ class ResultMetaData(object):
                 d[key] = rec
         return d
 
-    def _key_fallback(self, key, raiseerr=True):
+    def _key_fallback(self, key, err, raiseerr=True):
         map_ = self._keymap
         result = None
         if isinstance(key, util.string_types):
@@ -678,9 +678,12 @@ class ResultMetaData(object):
                     result = None
         if result is None:
             if raiseerr:
-                raise exc.NoSuchColumnError(
-                    "Could not locate column in row for column '%s'"
-                    % expression._string_or_unprintable(key)
+                util.raise_(
+                    exc.NoSuchColumnError(
+                        "Could not locate column in row for column '%s'"
+                        % expression._string_or_unprintable(key)
+                    ),
+                    replace_context=err,
                 )
             else:
                 return None
@@ -692,21 +695,24 @@ class ResultMetaData(object):
         if key in self._keymap:
             return True
         else:
-            return self._key_fallback(key, False) is not None
+            return self._key_fallback(key, None, False) is not None
 
     def _getter(self, key, raiseerr=True):
         if key in self._keymap:
             processor, obj, index = self._keymap[key]
         else:
-            ret = self._key_fallback(key, raiseerr)
+            ret = self._key_fallback(key, None, raiseerr)
             if ret is None:
                 return None
             processor, obj, index = ret
 
         if index is None:
-            raise exc.InvalidRequestError(
-                "Ambiguous column name '%s' in "
-                "result set column descriptions" % obj
+            util.raise_(
+                exc.InvalidRequestError(
+                    "Ambiguous column name '%s' in "
+                    "result set column descriptions" % obj
+                ),
+                from_=None,
             )
 
         return operator.itemgetter(index)
@@ -771,16 +777,16 @@ class ResultProxy(object):
     def _getter(self, key, raiseerr=True):
         try:
             getter = self._metadata._getter
-        except AttributeError:
-            return self._non_result(None)
+        except AttributeError as err:
+            return self._non_result(None, err)
         else:
             return getter(key, raiseerr)
 
     def _has_key(self, key):
         try:
             has_key = self._metadata._has_key
-        except AttributeError:
-            return self._non_result(None)
+        except AttributeError as err:
+            return self._non_result(None, err)
         else:
             return has_key(key)
 
@@ -1196,8 +1202,8 @@ class ResultProxy(object):
     def _fetchone_impl(self):
         try:
             return self.cursor.fetchone()
-        except AttributeError:
-            return self._non_result(None)
+        except AttributeError as err:
+            return self._non_result(None, err)
 
     def _fetchmany_impl(self, size=None):
         try:
@@ -1205,23 +1211,29 @@ class ResultProxy(object):
                 return self.cursor.fetchmany()
             else:
                 return self.cursor.fetchmany(size)
-        except AttributeError:
-            return self._non_result([])
+        except AttributeError as err:
+            return self._non_result([], err)
 
     def _fetchall_impl(self):
         try:
             return self.cursor.fetchall()
-        except AttributeError:
-            return self._non_result([])
+        except AttributeError as err:
+            return self._non_result([], err)
 
-    def _non_result(self, default):
+    def _non_result(self, default, err=None):
         if self._metadata is None:
-            raise exc.ResourceClosedError(
-                "This result object does not return rows. "
-                "It has been closed automatically."
+            util.raise_(
+                exc.ResourceClosedError(
+                    "This result object does not return rows. "
+                    "It has been closed automatically."
+                ),
+                replace_context=err,
             )
         elif self.closed:
-            raise exc.ResourceClosedError("This result object is closed.")
+            util.raise_(
+                exc.ResourceClosedError("This result object is closed."),
+                replace_context=err,
+            )
         else:
             return default
 
index da9ce12f1b5fef40d610b8703374073f1fb395b5..fb92e3f0ceec3dd432ce5ef23f113608cb717bfa 100644 (file)
@@ -244,6 +244,10 @@ class AssociationProxy(interfaces.InspectionAttrInfo):
         try:
             inst = class_.__dict__[self.key + "_inst"]
         except KeyError:
+            inst = None
+
+        # avoid exception context
+        if inst is None:
             owner = self._calc_owner(class_)
             if owner is not None:
                 inst = AssociationProxyInstance.for_proxy(self, owner, obj)
@@ -358,9 +362,12 @@ class AssociationProxyInstance(object):
 
         # this was never asserted before but this should be made clear.
         if not isinstance(prop, orm.RelationshipProperty):
-            raise NotImplementedError(
-                "association proxy to a non-relationship "
-                "intermediary is not supported"
+            util.raise_(
+                NotImplementedError(
+                    "association proxy to a non-relationship "
+                    "intermediary is not supported"
+                ),
+                replace_context=None,
             )
 
         target_class = prop.mapper.class_
@@ -1323,10 +1330,13 @@ class _AssociationDict(_AssociationCollection):
                 try:
                     for k, v in seq_or_map:
                         self[k] = v
-                except ValueError:
-                    raise ValueError(
-                        "dictionary update sequence "
-                        "requires 2-element tuples"
+                except ValueError as err:
+                    util.raise_(
+                        ValueError(
+                            "dictionary update sequence "
+                            "requires 2-element tuples"
+                        ),
+                        replace_context=err,
                     )
 
         for key, value in kw:
index 77cfa47c5acb9bf079e069ac0411262ffa8f41eb..fc07f83ebb1d9e983258bf0fe9a1686213923dca 100644 (file)
@@ -502,9 +502,12 @@ class Result(object):
         """
         try:
             ret = self.one_or_none()
-        except orm_exc.MultipleResultsFound:
-            raise orm_exc.MultipleResultsFound(
-                "Multiple rows were found for one()"
+        except orm_exc.MultipleResultsFound as err:
+            util.raise_(
+                orm_exc.MultipleResultsFound(
+                    "Multiple rows were found for one()"
+                ),
+                replace_context=err,
             )
         else:
             if ret is None:
index 32f352b184288056915043592f88d62b17a21dc6..70d005b88f99e2b30e6f25c3f0c756a44476029f 100644 (file)
@@ -398,6 +398,7 @@ Example usage::
 
 """
 from .. import exc
+from .. import util
 from ..sql import visitors
 
 
@@ -421,10 +422,13 @@ def compiles(class_, *specs):
                 def _wrap_existing_dispatch(element, compiler, **kw):
                     try:
                         return existing_dispatch(element, compiler, **kw)
-                    except exc.UnsupportedCompilationError:
-                        raise exc.CompileError(
-                            "%s construct has no default "
-                            "compilation handler." % type(element)
+                    except exc.UnsupportedCompilationError as uce:
+                        util.raise_(
+                            exc.CompileError(
+                                "%s construct has no default "
+                                "compilation handler." % type(element)
+                            ),
+                            from_=uce,
                         )
 
                 existing.specs["default"] = _wrap_existing_dispatch
@@ -469,10 +473,13 @@ class _dispatcher(object):
         if not fn:
             try:
                 fn = self.specs["default"]
-            except KeyError:
-                raise exc.CompileError(
-                    "%s construct has no default "
-                    "compilation handler." % type(element)
+            except KeyError as ke:
+                util.raise_(
+                    exc.CompileError(
+                        "%s construct has no default "
+                        "compilation handler." % type(element)
+                    ),
+                    replace_context=ke,
                 )
 
         return fn(element, compiler, **kw)
index 9e253ab39052a7ff6983cd882acfdca4d34e164c..57a6d1a05dabee9f531dacb0d5db85c029b741c4 100644 (file)
@@ -298,12 +298,15 @@ class _class_resolver(object):
             else:
                 return x
         except NameError as n:
-            raise exc.InvalidRequestError(
-                "When initializing mapper %s, expression %r failed to "
-                "locate a name (%r). If this is a class name, consider "
-                "adding this relationship() to the %r class after "
-                "both dependent classes have been defined."
-                % (self.prop.parent, self.arg, n.args[0], self.cls)
+            util.raise_(
+                exc.InvalidRequestError(
+                    "When initializing mapper %s, expression %r failed to "
+                    "locate a name (%r). If this is a class name, consider "
+                    "adding this relationship() to the %r class after "
+                    "both dependent classes have been defined."
+                    % (self.prop.parent, self.arg, n.args[0], self.cls)
+                ),
+                from_=n,
             )
 
 
index f2e0501bb34b87b310b62f73373b6b5f5113422b..6eb7e11850e5d96c89166326c3cf6bc1e35e28cb 100644 (file)
@@ -223,7 +223,8 @@ The above query will render::
 """  # noqa
 from __future__ import absolute_import
 
-from sqlalchemy import inspect
+from .. import inspect
+from .. import util
 from ..ext.hybrid import hybrid_property
 from ..orm.attributes import flag_modified
 
@@ -301,9 +302,9 @@ class index_property(hybrid_property):  # noqa
                 self.datatype = dict
         self.onebased = onebased
 
-    def _fget_default(self):
+    def _fget_default(self, err=None):
         if self.default == self._NO_DEFAULT_ARGUMENT:
-            raise AttributeError(self.attr_name)
+            util.raise_(AttributeError(self.attr_name), replace_context=err)
         else:
             return self.default
 
@@ -314,8 +315,8 @@ class index_property(hybrid_property):  # noqa
             return self._fget_default()
         try:
             value = column_value[self.index]
-        except (KeyError, IndexError):
-            return self._fget_default()
+        except (KeyError, IndexError) as err:
+            return self._fget_default(err)
         else:
             return value
 
@@ -337,8 +338,8 @@ class index_property(hybrid_property):  # noqa
             raise AttributeError(self.attr_name)
         try:
             del column_value[self.index]
-        except KeyError:
-            raise AttributeError(self.attr_name)
+        except KeyError as err:
+            util.raise_(AttributeError(self.attr_name), replace_context=err)
         else:
             setattr(instance, attr_name, column_value)
             flag_modified(instance, attr_name)
index 521775490cbf5522f0e75faefe6d21b3644ef6d2..f7416efb882c0298b625978256614e7496cc5b5d 100644 (file)
@@ -225,16 +225,19 @@ class QueryableAttribute(
     def __getattr__(self, key):
         try:
             return getattr(self.comparator, key)
-        except AttributeError:
-            raise AttributeError(
-                "Neither %r object nor %r object associated with %s "
-                "has an attribute %r"
-                % (
-                    type(self).__name__,
-                    type(self.comparator).__name__,
-                    self,
-                    key,
-                )
+        except AttributeError as err:
+            util.raise_(
+                AttributeError(
+                    "Neither %r object nor %r object associated with %s "
+                    "has an attribute %r"
+                    % (
+                        type(self).__name__,
+                        type(self.comparator).__name__,
+                        self,
+                        key,
+                    )
+                ),
+                replace_context=err,
             )
 
     def __str__(self):
@@ -367,31 +370,39 @@ def create_proxied_attribute(descriptor):
             comparator."""
             try:
                 return getattr(descriptor, attribute)
-            except AttributeError:
+            except AttributeError as err:
                 if attribute == "comparator":
-                    raise AttributeError("comparator")
+                    util.raise_(
+                        AttributeError("comparator"), replace_context=err
+                    )
                 try:
                     # comparator itself might be unreachable
                     comparator = self.comparator
-                except AttributeError:
-                    raise AttributeError(
-                        "Neither %r object nor unconfigured comparator "
-                        "object associated with %s has an attribute %r"
-                        % (type(descriptor).__name__, self, attribute)
+                except AttributeError as err2:
+                    util.raise_(
+                        AttributeError(
+                            "Neither %r object nor unconfigured comparator "
+                            "object associated with %s has an attribute %r"
+                            % (type(descriptor).__name__, self, attribute)
+                        ),
+                        replace_context=err2,
                     )
                 else:
                     try:
                         return getattr(comparator, attribute)
-                    except AttributeError:
-                        raise AttributeError(
-                            "Neither %r object nor %r object "
-                            "associated with %s has an attribute %r"
-                            % (
-                                type(descriptor).__name__,
-                                type(comparator).__name__,
-                                self,
-                                attribute,
-                            )
+                    except AttributeError as err3:
+                        util.raise_(
+                            AttributeError(
+                                "Neither %r object nor %r object "
+                                "associated with %s has an attribute %r"
+                                % (
+                                    type(descriptor).__name__,
+                                    type(comparator).__name__,
+                                    self,
+                                    attribute,
+                                )
+                            ),
+                            replace_context=err3,
                         )
 
     Proxy.__name__ = type(descriptor).__name__ + "Proxy"
@@ -716,12 +727,15 @@ class AttributeImpl(object):
                 elif value is ATTR_WAS_SET:
                     try:
                         return dict_[key]
-                    except KeyError:
+                    except KeyError as err:
                         # TODO: no test coverage here.
-                        raise KeyError(
-                            "Deferred loader for attribute "
-                            "%r failed to populate "
-                            "correctly" % key
+                        util.raise_(
+                            KeyError(
+                                "Deferred loader for attribute "
+                                "%r failed to populate "
+                                "correctly" % key
+                            ),
+                            replace_context=err,
                         )
                 elif value is not ATTR_EMPTY:
                     return self.set_committed_value(state, dict_, value)
index 6c4950310c6b5723659a427aab1f0d6583f67a7b..fab43e4cff5b1fcc1e85525db37c4b31ab857f20 100644 (file)
@@ -397,9 +397,12 @@ def _entity_descriptor(entity, key):
 
     try:
         return getattr(entity, key)
-    except AttributeError:
-        raise sa_exc.InvalidRequestError(
-            "Entity '%s' has no property '%s'" % (description, key)
+    except AttributeError as err:
+        util.raise_(
+            sa_exc.InvalidRequestError(
+                "Entity '%s' has no property '%s'" % (description, key)
+            ),
+            replace_context=err,
         )
 
 
index 848d47da0e3c5cc0f84745d17bbd80b59fdf2092..ff2eaae82f2a262ff00b5b72a8ad9733e8262454 100644 (file)
@@ -536,11 +536,14 @@ class StrategizedProperty(MapperProperty):
         try:
             return self._strategies[key]
         except KeyError:
-            cls = self._strategy_lookup(self, *key)
-            self._strategies[key] = self._strategies[cls] = strategy = cls(
-                self, key
-            )
-            return strategy
+            pass
+
+        # run outside to prevent transfer of exception context
+        cls = self._strategy_lookup(self, *key)
+        self._strategies[key] = self._strategies[cls] = strategy = cls(
+            self, key
+        )
+        return strategy
 
     def setup(self, context, query_entity, path, adapter, **kwargs):
         loader = self._get_context_loader(context, path)
index 4ad6c878e807f22effc340f541b9cc10bc37ed53..23c0c4e49b6fcda988edddac5c3951b4a5b10689 100644 (file)
@@ -96,9 +96,9 @@ def instances(query, cursor, context):
 
             if not query._yield_per:
                 break
-    except Exception as err:
-        cursor.close()
-        util.raise_from_cause(err)
+    except Exception:
+        with util.safe_reraise():
+            cursor.close()
 
 
 @util.dependencies("sqlalchemy.orm.query")
index 3cceb595659784490926d41f0f5c8a7b19269d86..40c6359fef77e2690e37c9e3a3fd96da37643992 100644 (file)
@@ -1525,11 +1525,14 @@ class Mapper(InspectionAttr):
                 # it to mapped ColumnProperty
                 try:
                     self.polymorphic_on = self._props[self.polymorphic_on]
-                except KeyError:
-                    raise sa_exc.ArgumentError(
-                        "Can't determine polymorphic_on "
-                        "value '%s' - no attribute is "
-                        "mapped to this name." % self.polymorphic_on
+                except KeyError as err:
+                    util.raise_(
+                        sa_exc.ArgumentError(
+                            "Can't determine polymorphic_on "
+                            "value '%s' - no attribute is "
+                            "mapped to this name." % self.polymorphic_on
+                        ),
+                        replace_context=err,
                     )
 
             if self.polymorphic_on in self._columntoproperty:
@@ -2041,9 +2044,12 @@ class Mapper(InspectionAttr):
 
         try:
             return self._props[key]
-        except KeyError:
-            raise sa_exc.InvalidRequestError(
-                "Mapper '%s' has no property '%s'" % (self, key)
+        except KeyError as err:
+            util.raise_(
+                sa_exc.InvalidRequestError(
+                    "Mapper '%s' has no property '%s'" % (self, key)
+                ),
+                replace_context=err,
             )
 
     def get_property_by_column(self, column):
index 28a612068380fc309f438296488d86e2262a05b6..40eb8e72c292082692097a6c2138589bad8b758d 100644 (file)
@@ -1630,9 +1630,12 @@ def _sort_states(mapper, states):
             persistent, key=mapper._persistent_sortkey_fn
         )
     except TypeError as err:
-        raise sa_exc.InvalidRequestError(
-            "Could not sort objects by primary key; primary key "
-            "values must be sortable in Python (was: %s)" % err
+        util.raise_(
+            sa_exc.InvalidRequestError(
+                "Could not sort objects by primary key; primary key "
+                "values must be sortable in Python (was: %s)" % err
+            ),
+            replace_context=err,
         )
     return (
         sorted(pending, key=operator.attrgetter("insert_order"))
@@ -1676,10 +1679,13 @@ class BulkUD(object):
     def _factory(cls, lookup, synchronize_session, *arg):
         try:
             klass = lookup[synchronize_session]
-        except KeyError:
-            raise sa_exc.ArgumentError(
-                "Valid strategies for session synchronization "
-                "are %s" % (", ".join(sorted(repr(x) for x in lookup)))
+        except KeyError as err:
+            util.raise_(
+                sa_exc.ArgumentError(
+                    "Valid strategies for session synchronization "
+                    "are %s" % (", ".join(sorted(repr(x) for x in lookup)))
+                ),
+                replace_context=err,
             )
         else:
             return klass(*arg)
@@ -1763,10 +1769,13 @@ class BulkEvaluate(BulkUD):
             self._additional_evaluators(evaluator_compiler)
 
         except evaluator.UnevaluatableError as err:
-            raise sa_exc.InvalidRequestError(
-                'Could not evaluate current criteria in Python: "%s". '
-                "Specify 'fetch' or False for the "
-                "synchronize_session parameter." % err
+            util.raise_(
+                sa_exc.InvalidRequestError(
+                    'Could not evaluate current criteria in Python: "%s". '
+                    "Specify 'fetch' or False for the "
+                    "synchronize_session parameter." % err
+                ),
+                from_=err,
             )
 
         # TODO: detect when the where clause is a trivial primary key match
index 10ab7259613fea6cefb770288d3e7308d0a15a29..4737b16de9e8d7c429ec91962479d63aaf3eabe5 100644 (file)
@@ -1084,15 +1084,18 @@ class Query(object):
                     for prop in mapper._identity_key_props
                 )
 
-            except KeyError:
-                raise sa_exc.InvalidRequestError(
-                    "Incorrect names of values in identifier to formulate "
-                    "primary key for query.get(); primary key attribute names"
-                    " are %s"
-                    % ",".join(
-                        "'%s'" % prop.key
-                        for prop in mapper._identity_key_props
-                    )
+            except KeyError as err:
+                util.raise_(
+                    sa_exc.InvalidRequestError(
+                        "Incorrect names of values in identifier to formulate "
+                        "primary key for query.get(); primary key attribute "
+                        "names are %s"
+                        % ",".join(
+                            "'%s'" % prop.key
+                            for prop in mapper._identity_key_props
+                        )
+                    ),
+                    replace_context=err,
                 )
 
         if (
@@ -3345,9 +3348,12 @@ class Query(object):
         """
         try:
             ret = self.one_or_none()
-        except orm_exc.MultipleResultsFound:
-            raise orm_exc.MultipleResultsFound(
-                "Multiple rows were found for one()"
+        except orm_exc.MultipleResultsFound as err:
+            util.raise_(
+                orm_exc.MultipleResultsFound(
+                    "Multiple rows were found for one()"
+                ),
+                replace_context=err,
             )
         else:
             if ret is None:
index 511e6c3ee77f0ba85929be85ae018a63c07100e9..c4a6ac26e1f161be16ddc6214382c1c827b0bfff 100644 (file)
@@ -2474,50 +2474,64 @@ class JoinCondition(object):
                         a_subset=self.parent_local_selectable,
                         consider_as_foreign_keys=consider_as_foreign_keys,
                     )
-        except sa_exc.NoForeignKeysError:
+        except sa_exc.NoForeignKeysError as nfe:
             if self.secondary is not None:
-                raise sa_exc.NoForeignKeysError(
-                    "Could not determine join "
-                    "condition between parent/child tables on "
-                    "relationship %s - there are no foreign keys "
-                    "linking these tables via secondary table '%s'.  "
-                    "Ensure that referencing columns are associated "
-                    "with a ForeignKey or ForeignKeyConstraint, or "
-                    "specify 'primaryjoin' and 'secondaryjoin' "
-                    "expressions." % (self.prop, self.secondary)
+                util.raise_(
+                    sa_exc.NoForeignKeysError(
+                        "Could not determine join "
+                        "condition between parent/child tables on "
+                        "relationship %s - there are no foreign keys "
+                        "linking these tables via secondary table '%s'.  "
+                        "Ensure that referencing columns are associated "
+                        "with a ForeignKey or ForeignKeyConstraint, or "
+                        "specify 'primaryjoin' and 'secondaryjoin' "
+                        "expressions." % (self.prop, self.secondary)
+                    ),
+                    from_=nfe,
                 )
             else:
-                raise sa_exc.NoForeignKeysError(
-                    "Could not determine join "
-                    "condition between parent/child tables on "
-                    "relationship %s - there are no foreign keys "
-                    "linking these tables.  "
-                    "Ensure that referencing columns are associated "
-                    "with a ForeignKey or ForeignKeyConstraint, or "
-                    "specify a 'primaryjoin' expression." % self.prop
+                util.raise_(
+                    sa_exc.NoForeignKeysError(
+                        "Could not determine join "
+                        "condition between parent/child tables on "
+                        "relationship %s - there are no foreign keys "
+                        "linking these tables.  "
+                        "Ensure that referencing columns are associated "
+                        "with a ForeignKey or ForeignKeyConstraint, or "
+                        "specify a 'primaryjoin' expression." % self.prop
+                    ),
+                    from_=nfe,
                 )
-        except sa_exc.AmbiguousForeignKeysError:
+        except sa_exc.AmbiguousForeignKeysError as afe:
             if self.secondary is not None:
-                raise sa_exc.AmbiguousForeignKeysError(
-                    "Could not determine join "
-                    "condition between parent/child tables on "
-                    "relationship %s - there are multiple foreign key "
-                    "paths linking the tables via secondary table '%s'.  "
-                    "Specify the 'foreign_keys' "
-                    "argument, providing a list of those columns which "
-                    "should be counted as containing a foreign key "
-                    "reference from the secondary table to each of the "
-                    "parent and child tables." % (self.prop, self.secondary)
+                util.raise_(
+                    sa_exc.AmbiguousForeignKeysError(
+                        "Could not determine join "
+                        "condition between parent/child tables on "
+                        "relationship %s - there are multiple foreign key "
+                        "paths linking the tables via secondary table '%s'.  "
+                        "Specify the 'foreign_keys' "
+                        "argument, providing a list of those columns which "
+                        "should be counted as containing a foreign key "
+                        "reference from the secondary table to each of the "
+                        "parent and child tables."
+                        % (self.prop, self.secondary)
+                    ),
+                    from_=afe,
                 )
             else:
-                raise sa_exc.AmbiguousForeignKeysError(
-                    "Could not determine join "
-                    "condition between parent/child tables on "
-                    "relationship %s - there are multiple foreign key "
-                    "paths linking the tables.  Specify the "
-                    "'foreign_keys' argument, providing a list of those "
-                    "columns which should be counted as containing a "
-                    "foreign key reference to the parent table." % self.prop
+                util.raise_(
+                    sa_exc.AmbiguousForeignKeysError(
+                        "Could not determine join "
+                        "condition between parent/child tables on "
+                        "relationship %s - there are multiple foreign key "
+                        "paths linking the tables.  Specify the "
+                        "'foreign_keys' argument, providing a list of those "
+                        "columns which should be counted as containing a "
+                        "foreign key reference to the parent table."
+                        % self.prop
+                    ),
+                    from_=afe,
                 )
 
     @property
index 52cd2ce836275598eb03b21c5993f5ce34ac683e..1c2a15710fcfd2e333b9e308f00e44f000ba4eb4 100644 (file)
@@ -570,7 +570,7 @@ class SessionTransaction(object):
             self._parent._rollback_exception = sys.exc_info()[1]
 
         if rollback_err:
-            util.reraise(*rollback_err)
+            util.raise_(rollback_err[1], with_traceback=rollback_err[2])
 
         sess.dispatch.after_soft_rollback(sess, self)
 
@@ -1354,10 +1354,13 @@ class Session(_SessionClassMethods):
     def _add_bind(self, key, bind):
         try:
             insp = inspect(key)
-        except sa_exc.NoInspectionAvailable:
+        except sa_exc.NoInspectionAvailable as err:
             if not isinstance(key, type):
-                raise sa_exc.ArgumentError(
-                    "Not an acceptable bind target: %s" % key
+                util.raise_(
+                    sa_exc.ArgumentError(
+                        "Not an acceptable bind target: %s" % key
+                    ),
+                    replace_context=err,
                 )
             else:
                 self.__binds[key] = bind
@@ -1507,9 +1510,11 @@ class Session(_SessionClassMethods):
         if mapper is not None:
             try:
                 mapper = inspect(mapper)
-            except sa_exc.NoInspectionAvailable:
+            except sa_exc.NoInspectionAvailable as err:
                 if isinstance(mapper, type):
-                    raise exc.UnmappedClassError(mapper)
+                    util.raise_(
+                        exc.UnmappedClassError(mapper), replace_context=err,
+                    )
                 else:
                     raise
 
@@ -1594,7 +1599,7 @@ class Session(_SessionClassMethods):
                     "consider using a session.no_autoflush block if this "
                     "flush is occurring prematurely"
                 )
-                util.raise_from_cause(e)
+                util.raise_(e, with_traceback=sys.exc_info[2])
 
     def refresh(
         self,
@@ -1649,8 +1654,10 @@ class Session(_SessionClassMethods):
         """
         try:
             state = attributes.instance_state(instance)
-        except exc.NO_STATE:
-            raise exc.UnmappedInstanceError(instance)
+        except exc.NO_STATE as err:
+            util.raise_(
+                exc.UnmappedInstanceError(instance), replace_context=err,
+            )
 
         self._expire_state(state, attribute_names)
 
@@ -1755,8 +1762,10 @@ class Session(_SessionClassMethods):
         """
         try:
             state = attributes.instance_state(instance)
-        except exc.NO_STATE:
-            raise exc.UnmappedInstanceError(instance)
+        except exc.NO_STATE as err:
+            util.raise_(
+                exc.UnmappedInstanceError(instance), replace_context=err,
+            )
         self._expire_state(state, attribute_names)
 
     def _expire_state(self, state, attribute_names):
@@ -1810,8 +1819,10 @@ class Session(_SessionClassMethods):
         """
         try:
             state = attributes.instance_state(instance)
-        except exc.NO_STATE:
-            raise exc.UnmappedInstanceError(instance)
+        except exc.NO_STATE as err:
+            util.raise_(
+                exc.UnmappedInstanceError(instance), replace_context=err,
+            )
         if state.session_id is not self.hash_key:
             raise sa_exc.InvalidRequestError(
                 "Instance %s is not present in this Session" % state_str(state)
@@ -1962,8 +1973,10 @@ class Session(_SessionClassMethods):
 
         try:
             state = attributes.instance_state(instance)
-        except exc.NO_STATE:
-            raise exc.UnmappedInstanceError(instance)
+        except exc.NO_STATE as err:
+            util.raise_(
+                exc.UnmappedInstanceError(instance), replace_context=err,
+            )
 
         self._save_or_update_state(state)
 
@@ -1997,8 +2010,10 @@ class Session(_SessionClassMethods):
 
         try:
             state = attributes.instance_state(instance)
-        except exc.NO_STATE:
-            raise exc.UnmappedInstanceError(instance)
+        except exc.NO_STATE as err:
+            util.raise_(
+                exc.UnmappedInstanceError(instance), replace_context=err,
+            )
 
         self._delete_impl(state, instance, head=True)
 
@@ -2426,8 +2441,10 @@ class Session(_SessionClassMethods):
         """
         try:
             state = attributes.instance_state(instance)
-        except exc.NO_STATE:
-            raise exc.UnmappedInstanceError(instance)
+        except exc.NO_STATE as err:
+            util.raise_(
+                exc.UnmappedInstanceError(instance), replace_context=err,
+            )
         return self._contains_state(state)
 
     def __iter__(self):
@@ -2522,8 +2539,11 @@ class Session(_SessionClassMethods):
             for o in objects:
                 try:
                     state = attributes.instance_state(o)
-                except exc.NO_STATE:
-                    raise exc.UnmappedInstanceError(o)
+
+                except exc.NO_STATE as err:
+                    util.raise_(
+                        exc.UnmappedInstanceError(o), replace_context=err,
+                    )
                 objset.add(state)
         else:
             objset = None
@@ -3385,8 +3405,10 @@ def object_session(instance):
 
     try:
         state = attributes.instance_state(instance)
-    except exc.NO_STATE:
-        raise exc.UnmappedInstanceError(instance)
+    except exc.NO_STATE as err:
+        util.raise_(
+            exc.UnmappedInstanceError(instance), replace_context=err,
+        )
     else:
         return _state_session(state)
 
index b298081864c5e1c19de920e46647322b3debc0a5..934053c4125e1324980bbd5c7cbe283ac1a568bc 100644 (file)
@@ -227,11 +227,14 @@ class Load(Generative, MapperOption):
                 # use getattr on the class to work around
                 # synonyms, hybrids, etc.
                 attr = getattr(ent.class_, attr)
-            except AttributeError:
+            except AttributeError as err:
                 if raiseerr:
-                    raise sa_exc.ArgumentError(
-                        'Can\'t find property named "%s" on '
-                        "%s in this Query." % (attr, ent)
+                    util.raise_(
+                        sa_exc.ArgumentError(
+                            'Can\'t find property named "%s" on '
+                            "%s in this Query." % (attr, ent)
+                        ),
+                        replace_context=err,
                     )
                 else:
                     return None
index 198e64f4f2a42989720a31441240c82adbe9bba9..ceaf54e5d332e1fe3436cd81ce9d0b0013afa2f9 100644 (file)
@@ -13,6 +13,7 @@ between instances based on join conditions.
 from . import attributes
 from . import exc
 from . import util as orm_util
+from .. import util
 
 
 def populate(
@@ -34,15 +35,15 @@ def populate(
             value = source.manager[prop.key].impl.get(
                 source, source_dict, attributes.PASSIVE_OFF
             )
-        except exc.UnmappedColumnError:
-            _raise_col_to_prop(False, source_mapper, l, dest_mapper, r)
+        except exc.UnmappedColumnError as err:
+            _raise_col_to_prop(False, source_mapper, l, dest_mapper, r, err)
 
         try:
             # inline of dest_mapper._set_state_attr_by_column
             prop = dest_mapper._columntoproperty[r]
             dest.manager[prop.key].impl.set(dest, dest_dict, value, None)
-        except exc.UnmappedColumnError:
-            _raise_col_to_prop(True, source_mapper, l, dest_mapper, r)
+        except exc.UnmappedColumnError as err:
+            _raise_col_to_prop(True, source_mapper, l, dest_mapper, r, err)
 
         # technically the "r.primary_key" check isn't
         # needed here, but we check for this condition to limit
@@ -64,8 +65,8 @@ def bulk_populate_inherit_keys(source_dict, source_mapper, synchronize_pairs):
         try:
             prop = source_mapper._columntoproperty[l]
             value = source_dict[prop.key]
-        except exc.UnmappedColumnError:
-            _raise_col_to_prop(False, source_mapper, l, source_mapper, r)
+        except exc.UnmappedColumnError as err:
+            _raise_col_to_prop(False, source_mapper, l, source_mapper, r, err)
 
         try:
             prop = source_mapper._columntoproperty[r]
@@ -88,8 +89,8 @@ def clear(dest, dest_mapper, synchronize_pairs):
             )
         try:
             dest_mapper._set_state_attr_by_column(dest, dest.dict, r, None)
-        except exc.UnmappedColumnError:
-            _raise_col_to_prop(True, None, l, dest_mapper, r)
+        except exc.UnmappedColumnError as err:
+            _raise_col_to_prop(True, None, l, dest_mapper, r, err)
 
 
 def update(source, source_mapper, dest, old_prefix, synchronize_pairs):
@@ -101,8 +102,8 @@ def update(source, source_mapper, dest, old_prefix, synchronize_pairs):
             value = source_mapper._get_state_attr_by_column(
                 source, source.dict, l, passive=attributes.PASSIVE_OFF
             )
-        except exc.UnmappedColumnError:
-            _raise_col_to_prop(False, source_mapper, l, None, r)
+        except exc.UnmappedColumnError as err:
+            _raise_col_to_prop(False, source_mapper, l, None, r, err)
         dest[r.key] = value
         dest[old_prefix + r.key] = oldvalue
 
@@ -113,8 +114,8 @@ def populate_dict(source, source_mapper, dict_, synchronize_pairs):
             value = source_mapper._get_state_attr_by_column(
                 source, source.dict, l, passive=attributes.PASSIVE_OFF
             )
-        except exc.UnmappedColumnError:
-            _raise_col_to_prop(False, source_mapper, l, None, r)
+        except exc.UnmappedColumnError as err:
+            _raise_col_to_prop(False, source_mapper, l, None, r, err)
 
         dict_[r.key] = value
 
@@ -127,8 +128,8 @@ def source_modified(uowcommit, source, source_mapper, synchronize_pairs):
     for l, r in synchronize_pairs:
         try:
             prop = source_mapper._columntoproperty[l]
-        except exc.UnmappedColumnError:
-            _raise_col_to_prop(False, source_mapper, l, None, r)
+        except exc.UnmappedColumnError as err:
+            _raise_col_to_prop(False, source_mapper, l, None, r, err)
         history = uowcommit.get_attribute_history(
             source, prop.key, attributes.PASSIVE_NO_INITIALIZE
         )
@@ -139,22 +140,28 @@ def source_modified(uowcommit, source, source_mapper, synchronize_pairs):
 
 
 def _raise_col_to_prop(
-    isdest, source_mapper, source_column, dest_mapper, dest_column
+    isdest, source_mapper, source_column, dest_mapper, dest_column, err
 ):
     if isdest:
-        raise exc.UnmappedColumnError(
-            "Can't execute sync rule for "
-            "destination column '%s'; mapper '%s' does not map "
-            "this column.  Try using an explicit `foreign_keys` "
-            "collection which does not include this column (or use "
-            "a viewonly=True relation)." % (dest_column, dest_mapper)
+        util.raise_(
+            exc.UnmappedColumnError(
+                "Can't execute sync rule for "
+                "destination column '%s'; mapper '%s' does not map "
+                "this column.  Try using an explicit `foreign_keys` "
+                "collection which does not include this column (or use "
+                "a viewonly=True relation)." % (dest_column, dest_mapper)
+            ),
+            replace_context=err,
         )
     else:
-        raise exc.UnmappedColumnError(
-            "Can't execute sync rule for "
-            "source column '%s'; mapper '%s' does not map this "
-            "column.  Try using an explicit `foreign_keys` "
-            "collection which does not include destination column "
-            "'%s' (or use a viewonly=True relation)."
-            % (source_column, source_mapper, dest_column)
+        util.raise_(
+            exc.UnmappedColumnError(
+                "Can't execute sync rule for "
+                "source column '%s'; mapper '%s' does not map this "
+                "column.  Try using an explicit `foreign_keys` "
+                "collection which does not include destination column "
+                "'%s' (or use a viewonly=True relation)."
+                % (source_column, source_mapper, dest_column)
+            ),
+            replace_context=err,
         )
index ddaff9fc7780da9a67755acb88e54333e2f58565..bbc793d4b9fed3c1604b9bc5f0aa2223cc16c6ce 100644 (file)
@@ -653,8 +653,8 @@ class _ConnectionRecord(object):
             pool.logger.debug("Created new connection %r", connection)
             self.connection = connection
         except Exception as e:
-            pool.logger.debug("Error on connect(): %s", e)
-            raise
+            with util.safe_reraise():
+                pool.logger.debug("Error on connect(): %s", e)
         else:
             if first_connect_check:
                 pool.dispatch.first_connect.for_modify(
index 67f1564ec3c0e0957521c22cd872de39245b112d..8618d5e2aa60a1fb4ca4fa4f33a5443d4a1af6f2 100644 (file)
@@ -32,10 +32,13 @@ def str_to_datetime_processor_factory(regexp, type_):
         else:
             try:
                 m = rmatch(value)
-            except TypeError:
-                raise ValueError(
-                    "Couldn't parse %s string '%r' "
-                    "- value is not a string." % (type_.__name__, value)
+            except TypeError as err:
+                util.raise_(
+                    ValueError(
+                        "Couldn't parse %s string '%r' "
+                        "- value is not a string." % (type_.__name__, value)
+                    ),
+                    from_=err,
                 )
             if m is None:
                 raise ValueError(
index a7e2034dfaad9390da1ccf5ef8ae6e5d896f89ac..4259cecb4b5a08ef5213fe0eaade7b67c376dad7 100644 (file)
@@ -60,8 +60,8 @@ class _DialectArgView(util.collections_abc.MutableMapping):
     def _key(self, key):
         try:
             dialect, value_key = key.split("_", 1)
-        except ValueError:
-            raise KeyError(key)
+        except ValueError as err:
+            util.raise_(KeyError(key), replace_context=err)
         else:
             return dialect, value_key
 
@@ -70,17 +70,20 @@ class _DialectArgView(util.collections_abc.MutableMapping):
 
         try:
             opt = self.obj.dialect_options[dialect]
-        except exc.NoSuchModuleError:
-            raise KeyError(key)
+        except exc.NoSuchModuleError as err:
+            util.raise_(KeyError(key), replace_context=err)
         else:
             return opt[value_key]
 
     def __setitem__(self, key, value):
         try:
             dialect, value_key = self._key(key)
-        except KeyError:
-            raise exc.ArgumentError(
-                "Keys must be of the form <dialectname>_<argname>"
+        except KeyError as err:
+            util.raise_(
+                exc.ArgumentError(
+                    "Keys must be of the form <dialectname>_<argname>"
+                ),
+                replace_context=err,
             )
         else:
             self.obj.dialect_options[dialect][value_key] = value
index e851ad9e77da40f7bbd446c254c362236265dad6..562cd31ea8ca7d89903b1c01895a21c5f5ba6668 100644 (file)
@@ -792,12 +792,13 @@ class SQLCompiler(Compiled):
                 col = only_froms[element.element]
             else:
                 col = with_cols[element.element]
-        except KeyError:
+        except KeyError as ke:
             elements._no_text_coercion(
                 element.element,
                 exc.CompileError,
                 "Can't resolve label reference for ORDER BY / "
                 "GROUP BY / DISTINCT etc.",
+                err=ke,
             )
         else:
             kwargs["render_label_as_label"] = col
@@ -1313,8 +1314,11 @@ class SQLCompiler(Compiled):
         else:
             try:
                 opstring = OPERATORS[operator_]
-            except KeyError:
-                raise exc.UnsupportedCompilationError(self, operator_)
+            except KeyError as err:
+                util.raise_(
+                    exc.UnsupportedCompilationError(self, operator_),
+                    replace_context=err,
+                )
             else:
                 return self._generate_generic_binary(binary, opstring, **kw)
 
@@ -2896,11 +2900,12 @@ class DDLCompiler(Compiled):
                 if column.primary_key:
                     first_pk = True
             except exc.CompileError as ce:
-                util.raise_from_cause(
+                util.raise_(
                     exc.CompileError(
                         util.u("(in table '%s', column '%s'): %s")
                         % (table.description, column.name, ce.args[0])
-                    )
+                    ),
+                    from_=ce,
                 )
 
         const = self.create_table_constraints(
index 3fffd818469a96084e0bd6d33e707466aaeb7e6e..238f7c4d4e1af9e36cc9072ef2c4d554110926fd 100644 (file)
@@ -907,7 +907,7 @@ class SchemaDropper(DDLBase):
                 )
                 collection = [(t, ()) for t in unsorted_tables]
             else:
-                util.raise_from_cause(
+                util.raise_(
                     exc.CircularDependencyError(
                         err2.args[0],
                         err2.cycles,
@@ -924,7 +924,8 @@ class SchemaDropper(DDLBase):
                                 sorted([t.fullname for t in err2.cycles])
                             )
                         ),
-                    )
+                    ),
+                    from_=err2,
                 )
 
         seq_coll = [
index 52f7ce010d905943b9ea6231e4582824998429c5..9a42ef341a71122c6eea6cef19b5cd61f20a6f3a 100644 (file)
@@ -710,10 +710,13 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
     def comparator(self):
         try:
             comparator_factory = self.type.comparator_factory
-        except AttributeError:
-            raise TypeError(
-                "Object %r associated with '.type' attribute "
-                "is not a TypeEngine class or object" % self.type
+        except AttributeError as err:
+            util.raise_(
+                TypeError(
+                    "Object %r associated with '.type' attribute "
+                    "is not a TypeEngine class or object" % self.type
+                ),
+                replace_context=err,
             )
         else:
             return comparator_factory(self)
@@ -721,10 +724,17 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
     def __getattr__(self, key):
         try:
             return getattr(self.comparator, key)
-        except AttributeError:
-            raise AttributeError(
-                "Neither %r object nor %r object has an attribute %r"
-                % (type(self).__name__, type(self.comparator).__name__, key)
+        except AttributeError as err:
+            util.raise_(
+                AttributeError(
+                    "Neither %r object nor %r object has an attribute %r"
+                    % (
+                        type(self).__name__,
+                        type(self.comparator).__name__,
+                        key,
+                    )
+                ),
+                replace_context=err,
             )
 
     def operate(self, op, *other, **kwargs):
@@ -1614,10 +1624,13 @@ class TextClause(Executable, ClauseElement):
                 # a unique/anonymous key in any case, so use the _orig_key
                 # so that a text() construct can support unique parameters
                 existing = new_params[bind._orig_key]
-            except KeyError:
-                raise exc.ArgumentError(
-                    "This text() construct doesn't define a "
-                    "bound parameter named %r" % bind._orig_key
+            except KeyError as err:
+                util.raise_(
+                    exc.ArgumentError(
+                        "This text() construct doesn't define a "
+                        "bound parameter named %r" % bind._orig_key
+                    ),
+                    replace_context=err,
                 )
             else:
                 new_params[existing._orig_key] = bind
@@ -1625,10 +1638,13 @@ class TextClause(Executable, ClauseElement):
         for key, value in names_to_values.items():
             try:
                 existing = new_params[key]
-            except KeyError:
-                raise exc.ArgumentError(
-                    "This text() construct doesn't define a "
-                    "bound parameter named %r" % key
+            except KeyError as err:
+                util.raise_(
+                    exc.ArgumentError(
+                        "This text() construct doesn't define a "
+                        "bound parameter named %r" % key
+                    ),
+                    replace_context=err,
                 )
             else:
                 new_params[key] = existing._with_value(value)
@@ -3450,9 +3466,12 @@ class Over(ColumnElement):
         else:
             try:
                 lower = int(range_[0])
-            except ValueError:
-                raise exc.ArgumentError(
-                    "Integer or None expected for range value"
+            except ValueError as err:
+                util.raise_(
+                    exc.ArgumentError(
+                        "Integer or None expected for range value"
+                    ),
+                    replace_context=err,
                 )
             else:
                 if lower == 0:
@@ -3463,9 +3482,12 @@ class Over(ColumnElement):
         else:
             try:
                 upper = int(range_[1])
-            except ValueError:
-                raise exc.ArgumentError(
-                    "Integer or None expected for range value"
+            except ValueError as err:
+                util.raise_(
+                    exc.ArgumentError(
+                        "Integer or None expected for range value"
+                    ),
+                    replace_context=err,
                 )
             else:
                 if upper == 0:
@@ -4613,14 +4635,19 @@ def _no_column_coercion(element):
     )
 
 
-def _no_text_coercion(element, exc_cls=exc.ArgumentError, extra=None):
-    raise exc_cls(
-        "%(extra)sTextual SQL expression %(expr)r should be "
-        "explicitly declared as text(%(expr)r)"
-        % {
-            "expr": util.ellipses_string(element),
-            "extra": "%s " % extra if extra else "",
-        }
+def _no_text_coercion(
+    element, exc_cls=exc.ArgumentError, extra=None, err=None
+):
+    util.raise_(
+        exc_cls(
+            "%(extra)sTextual SQL expression %(expr)r should be "
+            "explicitly declared as text(%(expr)r)"
+            % {
+                "expr": util.ellipses_string(element),
+                "extra": "%s " % extra if extra else "",
+            }
+        ),
+        replace_context=err,
     )
 
 
index b074af944ca96027d5d4fd617e231b3b87c01191..4d2cc3fee7435a2e59bf49ccf4e727bdb5d9fecc 100644 (file)
@@ -106,12 +106,13 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable):
             if item is not None:
                 try:
                     spwd = item._set_parent_with_dispatch
-                except AttributeError:
-                    util.raise_from_cause(
+                except AttributeError as err:
+                    util.raise_(
                         exc.ArgumentError(
                             "'SchemaItem' object, such as a 'Column' or a "
                             "'Constraint' expected, got %r" % item
-                        )
+                        ),
+                        replace_context=err,
                     )
                 else:
                     spwd(self)
@@ -1615,15 +1616,16 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
                 _proxies=[self],
                 *fk
             )
-        except TypeError:
-            util.raise_from_cause(
+        except TypeError as err:
+            util.raise_(
                 TypeError(
                     "Could not create a copy of this %r object.  "
                     "Ensure the class includes a _constructor() "
                     "attribute or method which accepts the "
                     "standard Column constructor arguments, or "
                     "references the Column class itself." % self.__class__
-                )
+                ),
+                from_=err,
             )
 
         c.table = selectable
@@ -3280,10 +3282,13 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
         try:
             ColumnCollectionConstraint._set_parent(self, table)
         except KeyError as ke:
-            raise exc.ArgumentError(
-                "Can't create ForeignKeyConstraint "
-                "on table '%s': no column "
-                "named '%s' is present." % (table.description, ke.args[0])
+            util.raise_(
+                exc.ArgumentError(
+                    "Can't create ForeignKeyConstraint "
+                    "on table '%s': no column "
+                    "named '%s' is present." % (table.description, ke.args[0])
+                ),
+                from_=ke,
             )
 
         for col, fk in zip(self.columns, self.elements):
index 72a1cfbd58d77c3911ffe388278e5a0bd1a29c61..0e82b830f0a58fee9f9f95757a71a2a795ca760c 100644 (file)
@@ -59,8 +59,10 @@ def _interpret_as_from(element):
             _no_text_coercion(element)
     try:
         return insp.selectable
-    except AttributeError:
-        raise exc.ArgumentError("FROM expression expected")
+    except AttributeError as err:
+        util.raise_(
+            exc.ArgumentError("FROM expression expected"), replace_context=err
+        )
 
 
 def _interpret_as_select(element):
@@ -108,10 +110,13 @@ def _offset_or_limit_clause_asint(clause, attrname):
         return None
     try:
         value = clause._limit_offset_value
-    except AttributeError:
-        raise exc.CompileError(
-            "This SELECT structure does not use a simple "
-            "integer value for %s" % attrname
+    except AttributeError as err:
+        util.raise_(
+            exc.CompileError(
+                "This SELECT structure does not use a simple "
+                "integer value for %s" % attrname
+            ),
+            replace_context=err,
         )
     else:
         return util.asint(value)
@@ -1330,9 +1335,13 @@ class Alias(FromClause):
     def as_scalar(self):
         try:
             return self.element.as_scalar()
-        except AttributeError:
-            raise AttributeError(
-                "Element %s does not support " "'as_scalar()'" % self.element
+        except AttributeError as err:
+            util.raise_(
+                AttributeError(
+                    "Element %s does not support "
+                    "'as_scalar()'" % self.element
+                ),
+                replace_context=err,
             )
 
     def is_derived_from(self, fromclause):
@@ -2983,10 +2992,13 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
 
         try:
             cols_present = bool(columns)
-        except TypeError:
-            raise exc.ArgumentError(
-                "columns argument to select() must "
-                "be a Python list or other iterable"
+        except TypeError as err:
+            util.raise_(
+                exc.ArgumentError(
+                    "columns argument to select() must "
+                    "be a Python list or other iterable"
+                ),
+                replace_context=err,
             )
 
         if cols_present:
index e134006bc84fe8009d7bba12706354b1391ba230..d25622e33c66e6ca345b4919a0bca5953185880d 100644 (file)
@@ -1477,7 +1477,7 @@ class Enum(Emulated, String, SchemaType):
     def _db_value_for_elem(self, elem):
         try:
             return self._valid_lookup[elem]
-        except KeyError:
+        except KeyError as err:
             # for unknown string values, we return as is.  While we can
             # validate these if we wanted, that does not allow for lesser-used
             # end-user use cases, such as using a LIKE comparison with an enum,
@@ -1491,8 +1491,11 @@ class Enum(Emulated, String, SchemaType):
             ):
                 return elem
             else:
-                raise LookupError(
-                    '"%s" is not among the defined enum values' % elem
+                util.raise_(
+                    LookupError(
+                        '"%s" is not among the defined enum values' % elem
+                    ),
+                    replace_context=err,
                 )
 
     class Comparator(String.Comparator):
@@ -1511,9 +1514,12 @@ class Enum(Emulated, String, SchemaType):
     def _object_value_for_elem(self, elem):
         try:
             return self._object_lookup[elem]
-        except KeyError:
-            raise LookupError(
-                '"%s" is not among the defined enum values' % elem
+        except KeyError as err:
+            util.raise_(
+                LookupError(
+                    '"%s" is not among the defined enum values' % elem
+                ),
+                replace_context=err,
             )
 
     def __repr__(self):
index d80cc27aee1b58a0302b355c5aaeb3af469010da..e23c38534196f807ed66d1898fb9f473cf6d6bab 100644 (file)
@@ -476,9 +476,12 @@ class TypeEngine(Visitable):
         try:
             return dialect._type_memos[self]["literal"]
         except KeyError:
-            d = self._dialect_info(dialect)
-            d["literal"] = lp = d["impl"].literal_processor(dialect)
-            return lp
+            pass
+        # avoid KeyError context coming into literal_processor() function
+        # raises
+        d = self._dialect_info(dialect)
+        d["literal"] = lp = d["impl"].literal_processor(dialect)
+        return lp
 
     def _cached_bind_processor(self, dialect):
         """Return a dialect-specific bind processor for this type."""
@@ -486,9 +489,12 @@ class TypeEngine(Visitable):
         try:
             return dialect._type_memos[self]["bind"]
         except KeyError:
-            d = self._dialect_info(dialect)
-            d["bind"] = bp = d["impl"].bind_processor(dialect)
-            return bp
+            pass
+        # avoid KeyError context coming into bind_processor() function
+        # raises
+        d = self._dialect_info(dialect)
+        d["bind"] = bp = d["impl"].bind_processor(dialect)
+        return bp
 
     def _cached_result_processor(self, dialect, coltype):
         """Return a dialect-specific result processor for this type."""
@@ -496,21 +502,27 @@ class TypeEngine(Visitable):
         try:
             return dialect._type_memos[self][coltype]
         except KeyError:
-            d = self._dialect_info(dialect)
-            # key assumption: DBAPI type codes are
-            # constants.  Else this dictionary would
-            # grow unbounded.
-            d[coltype] = rp = d["impl"].result_processor(dialect, coltype)
-            return rp
+            pass
+        # avoid KeyError context coming into result_processor() function
+        # raises
+        d = self._dialect_info(dialect)
+        # key assumption: DBAPI type codes are
+        # constants.  Else this dictionary would
+        # grow unbounded.
+        d[coltype] = rp = d["impl"].result_processor(dialect, coltype)
+        return rp
 
     def _cached_custom_processor(self, dialect, key, fn):
         try:
             return dialect._type_memos[self][key]
         except KeyError:
-            d = self._dialect_info(dialect)
-            impl = d["impl"]
-            d[key] = result = fn(impl)
-            return result
+            pass
+        # avoid KeyError context coming into fn() function
+        # raises
+        d = self._dialect_info(dialect)
+        impl = d["impl"]
+        d[key] = result = fn(impl)
+        return result
 
     def _dialect_info(self, dialect):
         """Return a dialect-specific registry which
index b77041122f133c696194c1bf4daf9bc4d05d4723..362afc5586eeaea24d1d013a3ed640a04b8f0a31 100644 (file)
@@ -86,8 +86,11 @@ def _generate_dispatch(cls):
             def _compiler_dispatch(self, visitor, **kw):
                 try:
                     meth = getter(visitor)
-                except AttributeError:
-                    raise exc.UnsupportedCompilationError(visitor, cls)
+                except AttributeError as err:
+                    util.raise_(
+                        exc.UnsupportedCompilationError(visitor, cls),
+                        replace_context=err,
+                    )
                 else:
                     return meth(self, **kw)
 
@@ -99,8 +102,11 @@ def _generate_dispatch(cls):
                 visit_attr = "visit_%s" % self.__visit_name__
                 try:
                     meth = getattr(visitor, visit_attr)
-                except AttributeError:
-                    raise exc.UnsupportedCompilationError(visitor, cls)
+                except AttributeError as err:
+                    util.raise_(
+                        exc.UnsupportedCompilationError(visitor, cls),
+                        replace_context=err,
+                    )
                 else:
                     return meth(self, **kw)
 
index ee46fe6ae589306e165f0c87ed4953b73ac4ff70..853a66e9ee72344989f924c374346a2fe349b914 100644 (file)
@@ -9,7 +9,9 @@
 from . import config  # noqa
 from . import mock  # noqa
 from .assertions import assert_raises  # noqa
+from .assertions import assert_raises_context_ok  # noqa
 from .assertions import assert_raises_message  # noqa
+from .assertions import assert_raises_message_context_ok  # noqa
 from .assertions import assert_raises_return  # noqa
 from .assertions import AssertsCompiledSQL  # noqa
 from .assertions import AssertsExecutionResults  # noqa
index 7d38d50844f8a22f0e3e696e2c8d63760b5508fa..f53ec1ba46aaff10c594bd9589c179a24679f07d 100644 (file)
@@ -9,6 +9,7 @@ from __future__ import absolute_import
 
 import contextlib
 import re
+import sys
 import warnings
 
 from . import assertsql
@@ -292,41 +293,80 @@ def eq_ignore_whitespace(a, b, msg=None):
     assert a == b, msg or "%r != %r" % (a, b)
 
 
+def _assert_proper_exception_context(exception):
+    """assert that any exception we're catching does not have a __context__
+    without a __cause__, and that __suppress_context__ is never set.
+
+    Python 3 will report nested as exceptions as "during the handling of
+    error X, error Y occurred". That's not what we want to do.  we want
+    these exceptions in a cause chain.
+
+    """
+
+    if not util.py3k:
+        return
+
+    if (
+        exception.__context__ is not exception.__cause__
+        and not exception.__suppress_context__
+    ):
+        assert False, (
+            "Exception %r was correctly raised but did not set a cause, "
+            "within context %r as its cause."
+            % (exception, exception.__context__)
+        )
+
+
 def assert_raises(except_cls, callable_, *args, **kw):
-    try:
-        callable_(*args, **kw)
-        success = False
-    except except_cls:
-        success = True
+    _assert_raises(except_cls, callable_, args, kw, check_context=True)
 
-    # assert outside the block so it works for AssertionError too !
-    assert success, "Callable did not raise an exception"
+
+def assert_raises_context_ok(except_cls, callable_, *args, **kw):
+    _assert_raises(
+        except_cls, callable_, args, kw,
+    )
 
 
 def assert_raises_return(except_cls, callable_, *args, **kw):
+    return _assert_raises(except_cls, callable_, args, kw, check_context=True)
+
+
+def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
+    _assert_raises(
+        except_cls, callable_, args, kwargs, msg=msg, check_context=True
+    )
+
+
+def assert_raises_message_context_ok(
+    except_cls, msg, callable_, *args, **kwargs
+):
+    _assert_raises(except_cls, callable_, args, kwargs, msg=msg)
+
+
+def _assert_raises(
+    except_cls, callable_, args, kwargs, msg=None, check_context=False
+):
     ret_err = None
+    if check_context:
+        are_we_already_in_a_traceback = sys.exc_info()[0]
     try:
-        callable_(*args, **kw)
+        callable_(*args, **kwargs)
         success = False
     except except_cls as err:
-        success = True
         ret_err = err
+        success = True
+        if msg is not None:
+            assert re.search(
+                msg, util.text_type(err), re.UNICODE
+            ), "%r !~ %s" % (msg, err,)
+        if check_context and not are_we_already_in_a_traceback:
+            _assert_proper_exception_context(err)
+        print(util.text_type(err).encode("utf-8"))
 
     # assert outside the block so it works for AssertionError too !
     assert success, "Callable did not raise an exception"
-    return ret_err
 
-
-def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
-    try:
-        callable_(*args, **kwargs)
-        assert False, "Callable did not raise an exception"
-    except except_cls as e:
-        assert re.search(msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (
-            msg,
-            e,
-        )
-        print(util.text_type(e).encode("utf-8"))
+    return ret_err
 
 
 class AssertsCompiledSQL(object):
index e3e3c06b7331fb7e9d587a1e509b891656178dd9..ad09e9df8fa28f2334ed7c51188f3806314c84cb 100644 (file)
@@ -9,6 +9,7 @@
 import contextlib
 import operator
 import re
+import sys
 
 from . import config
 from .. import util
@@ -143,7 +144,7 @@ class compound(object):
                 )
                 break
         else:
-            util.raise_from_cause(ex)
+            util.raise_(ex, with_traceback=sys.exc_info()[2])
 
     def _expect_success(self, config, name="block"):
         if not self.fails:
index f76f90dfa9d267f7eed4ee6e86293ee707fce8f1..57c17717f2c4ecd3126919f4759e239b3e79891b 100644 (file)
@@ -71,6 +71,7 @@ from .compat import py36  # noqa
 from .compat import py3k  # noqa
 from .compat import pypy  # noqa
 from .compat import quote_plus  # noqa
+from .compat import raise_  # noqa
 from .compat import raise_from_cause  # noqa
 from .compat import reduce  # noqa
 from .compat import reraise  # noqa
index 4d234aafd7e18f0c9bf242fdbb64e5ca44cfb381..5609eddabbe8615bd872ac86ddb25431b75c69c8 100644 (file)
@@ -23,6 +23,7 @@ py2k = sys.version_info < (3, 0)
 py265 = sys.version_info >= (2, 6, 5)
 jython = sys.platform.startswith("java")
 pypy = hasattr(sys, "pypy_version_info")
+
 win32 = sys.platform.startswith("win")
 cpython = not pypy and not jython  # TODO: something better for this ?
 
@@ -144,13 +145,42 @@ if py3k:
     def cmp(a, b):
         return (a > b) - (a < b)
 
-    def reraise(tp, value, tb=None, cause=None):
-        if cause is not None:
-            assert cause is not value, "Same cause emitted"
-            value.__cause__ = cause
-        if value.__traceback__ is not tb:
-            raise value.with_traceback(tb)
-        raise value
+    def raise_(
+        exception, with_traceback=None, replace_context=None, from_=False
+    ):
+        r"""implement "raise" with cause support.
+
+        :param exception: exception to raise
+        :param with_traceback: will call exception.with_traceback()
+        :param replace_context: an as-yet-unsupported feature.  This is
+         an exception object which we are "replacing", e.g., it's our
+         "cause" but we don't want it printed.    Basically just what
+         ``__suppress_context__`` does but we don't want to suppress
+         the enclosing context, if any.  So for now we make it the
+         cause.
+        :param from\_: the cause.  this actually sets the cause and doesn't
+         hope to hide it someday.
+
+        """
+        if with_traceback is not None:
+            exception = exception.with_traceback(with_traceback)
+
+        if from_ is not False:
+            exception.__cause__ = from_
+        elif replace_context is not None:
+            # no good solution here, we would like to have the exception
+            # have only the context of replace_context.__context__ so that the
+            # intermediary exception does not change, but we can't figure
+            # that out.
+            exception.__cause__ = replace_context
+
+        try:
+            raise exception
+        finally:
+            # credit to
+            # https://cosmicpercolator.com/2016/01/13/exception-leaks-in-python-2-and-3/
+            # as the __traceback__ object creates a cycle
+            del exception, replace_context, from_, with_traceback
 
     def u(s):
         return s
@@ -256,13 +286,13 @@ else:
         else:
             return text
 
-    # not as nice as that of Py3K, but at least preserves
-    # the code line where the issue occurred
     exec(
-        "def reraise(tp, value, tb=None, cause=None):\n"
-        "    if cause is not None:\n"
-        "        assert cause is not value, 'Same cause emitted'\n"
-        "    raise tp, value, tb\n"
+        "def raise_(exception, with_traceback=None, replace_context=None, "
+        "from_=False):\n"
+        "    if with_traceback:\n"
+        "        raise type(exception), exception, with_traceback\n"
+        "    else:\n"
+        "        raise exception\n"
     )
 
 
@@ -402,6 +432,8 @@ def nested(*managers):
 
 
 def raise_from_cause(exception, exc_info=None):
+    r"""legacy.  use raise\_()"""
+
     if exc_info is None:
         exc_info = sys.exc_info()
     exc_type, exc_value, exc_tb = exc_info
@@ -409,6 +441,12 @@ def raise_from_cause(exception, exc_info=None):
     reraise(type(exception), exception, tb=exc_tb, cause=cause)
 
 
+def reraise(tp, value, tb=None, cause=None):
+    r"""legacy.  use raise\_()"""
+
+    raise_(value, with_traceback=tb, from_=cause)
+
+
 def with_metaclass(meta, *bases):
     """Create a base class with a metaclass.
 
index 360c91371c0e55737c27faebea47be42e964fb7c..f2341222c107316c703ed99e003f326d4c6bec38 100644 (file)
@@ -65,7 +65,9 @@ class safe_reraise(object):
             exc_type, exc_value, exc_tb = self._exc_info
             self._exc_info = None  # remove potential circular references
             if not self.warn_only:
-                compat.reraise(exc_type, exc_value, exc_tb)
+                compat.raise_(
+                    exc_value, with_traceback=exc_tb,
+                )
         else:
             if not compat.py3k and self._exc_info and self._exc_info[1]:
                 # emulate Py3K's behavior of telling us when an exception
@@ -76,7 +78,7 @@ class safe_reraise(object):
                     "is:\n %s %s\n" % (self._exc_info[0], self._exc_info[1])
                 )
             self._exc_info = None  # remove potential circular references
-            compat.reraise(type_, value, traceback)
+            compat.raise_(value, with_traceback=traceback)
 
 
 def clsname_as_plain_name(cls):
index 0b407c11690a418e9b423987db6830ad45a32ba8..455ff81eb2e904e7ac62b4004b3c6499397eeedb 100644 (file)
@@ -1064,6 +1064,20 @@ class CycleTest(_fixtures.FixtureTest):
 
         go()
 
+    def test_raise_from(self):
+        @assert_cycles()
+        def go():
+            try:
+                try:
+                    raise KeyError("foo")
+                except KeyError as ke:
+
+                    util.raise_(Exception("oops"), from_=ke)
+            except Exception as err:  # noqa
+                pass
+
+        go()
+
     def test_query_alias(self):
         User, Address = self.classes("User", "Address")
         configure_mappers()
index 8085d5e23ac83781d38e2d38132ef8b4dcba77b7..39fbc67574b94a7f54883a8f2134d153ff6b9454 100644 (file)
@@ -110,7 +110,7 @@ class ResultSetTest(fixtures.TestBase, AssertsExecutionResults):
             "some other column", Integer
         )
 
-        @profiling.function_call_count()
+        @profiling.function_call_count(variance=0.10)
         def go():
             c1 in row
 
index 23436a6fd61510efaad15716cf7ba92f34974355..e0944b73ce38b4d9ad0440e097eaacc2dc42f71e 100644 (file)
@@ -3,7 +3,6 @@
 import copy
 import datetime
 import inspect
-import sys
 
 from sqlalchemy import exc
 from sqlalchemy import sql
@@ -2531,20 +2530,29 @@ class ReraiseTest(fixtures.TestBase):
         except MyException as err:
             is_(err.__cause__, None)
 
-    def test_reraise_disallow_same_cause(self):
+    def test_raise_from_cause_legacy(self):
         class MyException(Exception):
             pass
 
+        class MyOtherException(Exception):
+            pass
+
+        me = MyException("exc on")
+
         def go():
             try:
-                raise MyException("exc one")
-            except Exception as err:
-                type_, value, tb = sys.exc_info()
-                util.reraise(type_, err, tb, value)
+                raise me
+            except Exception:
+                util.raise_from_cause(MyOtherException("exc two"))
 
-        assert_raises_message(AssertionError, "Same cause emitted", go)
+        try:
+            go()
+            assert False
+        except MyOtherException as moe:
+            if testing.requires.python3.enabled:
+                is_(moe.__cause__, me)
 
-    def test_raise_from_cause(self):
+    def test_raise_from(self):
         class MyException(Exception):
             pass
 
@@ -2556,8 +2564,8 @@ class ReraiseTest(fixtures.TestBase):
         def go():
             try:
                 raise me
-            except Exception:
-                util.raise_from_cause(MyOtherException("exc two"))
+            except Exception as err:
+                util.raise_(MyOtherException("exc two"), from_=err)
 
         try:
             go()
index d78a669a5c0a7fce94eac4c7bf20ff5d2e7db282..93affccbd0670917c7509afbbe1d1c9400f169f0 100644 (file)
@@ -709,12 +709,13 @@ class ExecuteTest(fixtures.TestBase):
                     return super(MockCursor, self).execute(stmt, params, **kw)
 
         eng = engines.proxying_engine(cursor_cls=MockCursor)
-        assert_raises_message(
-            tsa.exc.SAWarning,
-            "Exception attempting to detect unicode returns",
-            eng.connect,
-        )
-        assert eng.dialect.returns_unicode_strings in (True, False)
+        with testing.expect_warnings(
+            "Exception attempting to detect unicode returns"
+        ):
+            eng.connect()
+
+        # because plain varchar passed, we don't know the correct answer
+        eq_(eng.dialect.returns_unicode_strings, "conditional")
         eng.dispose()
 
     def test_works_after_dispose(self):
index a52e7220fdd7a5278e0de3a0434e0f9009a39380..8b884d1101571e9a815fd3906eb39d84e1815cfe 100644 (file)
@@ -10,6 +10,7 @@ from sqlalchemy import pool
 from sqlalchemy import select
 from sqlalchemy import testing
 from sqlalchemy.testing import assert_raises
+from sqlalchemy.testing import assert_raises_context_ok
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
@@ -1258,7 +1259,7 @@ class QueuePoolTest(PoolTestBase):
             eq_(p.checkedout(), 0)
             eq_(p._overflow, 0)
             dbapi.shutdown(True)
-            assert_raises(Exception, p.connect)
+            assert_raises_context_ok(Exception, p.connect)
             eq_(p._overflow, 0)
             eq_(p.checkedout(), 0)  # and not 1
 
index 45d8827148c638f793b2521b2c5a7e03dadd997d..7c9d479c87fd1f20ca4b985942be240af1010a00 100644 (file)
@@ -14,6 +14,7 @@ from sqlalchemy import util
 from sqlalchemy.engine import url
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
+from sqlalchemy.testing import assert_raises_message_context_ok
 from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_warnings
@@ -193,7 +194,7 @@ class PrePingMockTest(fixtures.TestBase):
 
         self.dbapi.shutdown("execute", stop=True)
 
-        assert_raises_message(
+        assert_raises_message_context_ok(
             MockDisconnect, "database is stopped", pool.connect
         )
 
@@ -757,7 +758,7 @@ class CursorErrTest(fixtures.TestBase):
 
     def test_cursor_shutdown_in_initialize(self):
         db = self._fixture(True, True)
-        assert_raises_message(
+        assert_raises_message_context_ok(
             exc.SAWarning, "Exception attempting to detect", db.connect
         )
         eq_(
index 4959049078c9a1af8137e0b0d8a893a442333c4e..2f943b413376a038a475a579ed156553a1ee0a2c 100644 (file)
@@ -24,6 +24,7 @@ from sqlalchemy.testing import eq_regex
 from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import in_
+from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
 from sqlalchemy.testing import not_in_
@@ -593,13 +594,10 @@ class ReflectionTest(fixtures.TestBase, ComparesTables):
         testing.db.dialect.ischema_names = {}
         try:
             m2 = MetaData(testing.db)
-            assert_raises(sa.exc.SAWarning, Table, "test", m2, autoload=True)
 
-            @testing.emits_warning("Did not recognize type")
-            def warns():
-                m3 = MetaData(testing.db)
-                t3 = Table("test", m3, autoload=True)
-                assert t3.c.foo.type.__class__ == sa.types.NullType
+            with testing.expect_warnings("Did not recognize type"):
+                t3 = Table("test", m2, autoload_with=testing.db)
+                is_(t3.c.foo.type.__class__, sa.types.NullType)
 
         finally:
             testing.db.dialect.ischema_names = ischema_names
index 11ed9d8de54ee1c91791fd6fd32fc3daef6fdc23..9b199c734db260c69e327be9b864243127346390 100644 (file)
@@ -4015,16 +4015,11 @@ class DialectKWArgTest(fixtures.TestBase):
 
     def test_unknown_dialect_warning(self):
         with self._fixture():
-            assert_raises_message(
-                exc.SAWarning,
+            with testing.expect_warnings(
                 "Can't validate argument 'unknown_y'; can't locate "
                 "any SQLAlchemy dialect named 'unknown'",
-                Index,
-                "a",
-                "b",
-                "c",
-                unknown_y=True,
-            )
+            ):
+                Index("a", "b", "c", unknown_y=True)
 
     def test_participating_bad_kw(self):
         with self._fixture():