]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Ensure all nested exception throws have a cause
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 29 Feb 2020 19:40:45 +0000 (14:40 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 2 Mar 2020 22:24:19 +0000 (17:24 -0500)
Applied an explicit "cause" to most if not all internally raised exceptions
that are raised from within an internal exception catch, to avoid
misleading stacktraces that suggest an error within the handling of an
exception.  While it would be preferable to suppress the internally caught
exception in the way that the ``__suppress_context__`` attribute would,
there does not as yet seem to be a way to do this without suppressing an
enclosing user constructed context, so for now it exposes the internally
caught exception as the cause so that full information about the context
of the error is maintained.

Fixes: #4849
Change-Id: I55a86b29023675d9e5e49bc7edc5a2dc0bcd4751

50 files changed:
doc/build/changelog/unreleased_13/4849.rst [new file with mode: 0644]
lib/sqlalchemy/cextension/resultproxy.c
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/postgresql/psycopg2.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/result.py
lib/sqlalchemy/engine/row.py
lib/sqlalchemy/ext/associationproxy.py
lib/sqlalchemy/ext/baked.py
lib/sqlalchemy/ext/compiler.py
lib/sqlalchemy/ext/declarative/clsregistry.py
lib/sqlalchemy/ext/indexable.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/base.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/loading.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/relationships.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/strategy_options.py
lib/sqlalchemy/orm/sync.py
lib/sqlalchemy/pool/base.py
lib/sqlalchemy/processors.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/ddl.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/schema.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/type_api.py
lib/sqlalchemy/sql/visitors.py
lib/sqlalchemy/testing/__init__.py
lib/sqlalchemy/testing/assertions.py
lib/sqlalchemy/testing/exclusions.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/compat.py
lib/sqlalchemy/util/langhelpers.py
test/aaa_profiling/test_memusage.py
test/aaa_profiling/test_resultset.py
test/base/test_utils.py
test/engine/test_execute.py
test/engine/test_pool.py
test/engine/test_reconnect.py
test/engine/test_reflection.py
test/sql/test_metadata.py

diff --git a/doc/build/changelog/unreleased_13/4849.rst b/doc/build/changelog/unreleased_13/4849.rst
new file mode 100644 (file)
index 0000000..5a649dc
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: bug, general, py3k
+    :tickets: 4849
+
+    Applied an explicit "cause" to most if not all internally raised exceptions
+    that are raised from within an internal exception catch, to avoid
+    misleading stacktraces that suggest an error within the handling of an
+    exception.  While it would be preferable to suppress the internally caught
+    exception in the way that the ``__suppress_context__`` attribute would,
+    there does not as yet seem to be a way to do this without suppressing an
+    enclosing user constructed context, so for now it exposes the internally
+    caught exception as the cause so that full information about the context
+    of the error is maintained.
index 3c44010b8910f9960aa64af63a030e7203afc9f5..f622e6a28848d8de748633ba529d9e614b5db538 100644 (file)
@@ -288,7 +288,7 @@ BaseRow_getitem_by_object(BaseRow *self, PyObject *key, int asmapping)
 
     if (record == NULL) {
         record = PyObject_CallMethod(self->parent, "_key_fallback",
-                                     "O", key);
+                                     "OO", key, Py_None);
         if (record == NULL)
             return NULL;
         key_fallback = 1;  // boolean to indicate record is a new reference
index e0bf16793118d6e7efe63428ec12d16fd1195dbd..6ea8cbcb81967ffb5e5ce99e4e51e85be2c9708b 100644 (file)
@@ -2968,7 +2968,7 @@ class MySQLDialect(default.DefaultDialect):
             ).execute(st)
         except exc.DBAPIError as e:
             if self._extract_error_code(e.orig) == 1146:
-                raise exc.NoSuchTableError(full_name)
+                util.raise_(exc.NoSuchTableError(full_name), replace_context=e)
             else:
                 raise
         row = self._compat_first(rp, charset=charset)
@@ -2992,11 +2992,16 @@ class MySQLDialect(default.DefaultDialect):
             except exc.DBAPIError as e:
                 code = self._extract_error_code(e.orig)
                 if code == 1146:
-                    raise exc.NoSuchTableError(full_name)
+                    util.raise_(
+                        exc.NoSuchTableError(full_name), replace_context=e
+                    )
                 elif code == 1356:
-                    raise exc.UnreflectableTableError(
-                        "Table or view named %s could not be "
-                        "reflected: %s" % (full_name, e)
+                    util.raise_(
+                        exc.UnreflectableTableError(
+                            "Table or view named %s could not be "
+                            "reflected: %s" % (full_name, e)
+                        ),
+                        replace_context=e,
                     )
                 else:
                     raise
index 0b6afc337de93b52dcca681e43084c93cb1f754a..1b1c9b0ba8e65b7211094b5c468de3393b48138b 100644 (file)
@@ -763,11 +763,14 @@ class PGDialect_psycopg2(PGDialect):
     def set_isolation_level(self, connection, level):
         try:
             level = self._isolation_lookup[level.replace("_", " ")]
-        except KeyError:
-            raise exc.ArgumentError(
-                "Invalid value '%s' for isolation_level. "
-                "Valid isolation levels for %s are %s"
-                % (level, self.name, ", ".join(self._isolation_lookup))
+        except KeyError as err:
+            util.raise_(
+                exc.ArgumentError(
+                    "Invalid value '%s' for isolation_level. "
+                    "Valid isolation levels for %s are %s"
+                    % (level, self.name, ", ".join(self._isolation_lookup))
+                ),
+                replace_context=err,
             )
 
         connection.set_isolation_level(level)
index d04b543cd13c9e3c0cdcbb3c855c51ba5a840b03..b1a83bf9212a791e5b82467bb9f386abf4a24c22 100644 (file)
@@ -997,9 +997,12 @@ class SQLiteCompiler(compiler.SQLCompiler):
                 self.extract_map[extract.field],
                 self.process(extract.expr, **kw),
             )
-        except KeyError:
-            raise exc.CompileError(
-                "%s is not a valid extract argument." % extract.field
+        except KeyError as err:
+            util.raise_(
+                exc.CompileError(
+                    "%s is not a valid extract argument." % extract.field
+                ),
+                replace_context=err,
             )
 
     def limit_clause(self, select, **kw):
@@ -1537,11 +1540,14 @@ class SQLiteDialect(default.DefaultDialect):
     def set_isolation_level(self, connection, level):
         try:
             isolation_level = self._isolation_lookup[level.replace("_", " ")]
-        except KeyError:
-            raise exc.ArgumentError(
-                "Invalid value '%s' for isolation_level. "
-                "Valid isolation levels for %s are %s"
-                % (level, self.name, ", ".join(self._isolation_lookup))
+        except KeyError as err:
+            util.raise_(
+                exc.ArgumentError(
+                    "Invalid value '%s' for isolation_level. "
+                    "Valid isolation levels for %s are %s"
+                    % (level, self.name, ", ".join(self._isolation_lookup))
+                ),
+                replace_context=err,
             )
         cursor = connection.cursor()
         cursor.execute("PRAGMA read_uncommitted = %d" % isolation_level)
index ce6c2e9c67bffa18da6eb06991ed288342a8dea0..449f386cea54039a7f93d92ea065095c81dea1c0 100644 (file)
@@ -996,8 +996,10 @@ class Connection(Connectable):
             return self._execute_text(object_, multiparams, params)
         try:
             meth = object_._execute_on_connection
-        except AttributeError:
-            raise exc.ObjectNotExecutableError(object_)
+        except AttributeError as err:
+            util.raise_(
+                exc.ObjectNotExecutableError(object_), replace_context=err
+            )
         else:
             return meth(self, multiparams, params)
 
@@ -1400,7 +1402,7 @@ class Connection(Connectable):
         invalidate_pool_on_disconnect = not is_exit_exception
 
         if self._reentrant_error:
-            util.raise_from_cause(
+            util.raise_(
                 exc.DBAPIError.instance(
                     statement,
                     parameters,
@@ -1412,7 +1414,8 @@ class Connection(Connectable):
                     if context is not None
                     else None,
                 ),
-                exc_info,
+                with_traceback=exc_info[2],
+                from_=e,
             )
         self._reentrant_error = True
         try:
@@ -1502,11 +1505,13 @@ class Connection(Connectable):
                     self._autorollback()
 
             if newraise:
-                util.raise_from_cause(newraise, exc_info)
+                util.raise_(newraise, with_traceback=exc_info[2], from_=e)
             elif should_wrap:
-                util.raise_from_cause(sqlalchemy_exception, exc_info)
+                util.raise_(
+                    sqlalchemy_exception, with_traceback=exc_info[2], from_=e
+                )
             else:
-                util.reraise(*exc_info)
+                util.raise_(exc_info[1], with_traceback=exc_info[2])
 
         finally:
             del self._reentrant_error
@@ -1573,11 +1578,13 @@ class Connection(Connectable):
                 ) = ctx.is_disconnect
 
         if newraise:
-            util.raise_from_cause(newraise, exc_info)
+            util.raise_(newraise, with_traceback=exc_info[2], from_=e)
         elif should_wrap:
-            util.raise_from_cause(sqlalchemy_exception, exc_info)
+            util.raise_(
+                sqlalchemy_exception, with_traceback=exc_info[2], from_=e
+            )
         else:
-            util.reraise(*exc_info)
+            util.raise_(exc_info[1], with_traceback=exc_info[2])
 
     def _run_ddl_visitor(self, visitorcallable, element, **kwargs):
         """run a DDL visitor.
@@ -2329,7 +2336,9 @@ class Engine(Connectable, log.Identified):
                     e, dialect, self
                 )
             else:
-                util.reraise(*sys.exc_info())
+                util.raise_(
+                    sys.exc_info()[1], with_traceback=sys.exc_info()[2]
+                )
 
     def raw_connection(self, _connection=None):
         """Return a "raw" DBAPI connection from the connection pool.
index 1a63c307bce6752a83e2b7d1218aae23d12fcecd..7db9eecaea5d30da4084eaaed349687b6f8a3ac1 100644 (file)
@@ -53,11 +53,11 @@ class ResultMetaData(object):
     def _has_key(self, key):
         return key in self._keymap
 
-    def _key_fallback(self, key):
+    def _key_fallback(self, key, err):
         if isinstance(key, int):
-            raise IndexError(key)
+            util.raise_(IndexError(key), replace_context=err)
         else:
-            raise KeyError(key)
+            util.raise_(KeyError(key), replace_context=err)
 
 
 class SimpleResultMetaData(ResultMetaData):
@@ -546,11 +546,14 @@ class CursorResultMetaData(ResultMetaData):
         ) in self._colnames_from_description(context, cursor_description):
             yield idx, colname, sqltypes.NULLTYPE, coltype, None, untranslated
 
-    def _key_fallback(self, key, raiseerr=True):
+    def _key_fallback(self, key, err, raiseerr=True):
         if raiseerr:
-            raise exc.NoSuchColumnError(
-                "Could not locate column in row for column '%s'"
-                % util.string_or_unprintable(key)
+            util.raise_(
+                exc.NoSuchColumnError(
+                    "Could not locate column in row for column '%s'"
+                    % util.string_or_unprintable(key)
+                ),
+                replace_context=err,
             )
         else:
             return None
@@ -570,8 +573,8 @@ class CursorResultMetaData(ResultMetaData):
     def _getter(self, key, raiseerr=True):
         try:
             rec = self._keymap[key]
-        except KeyError:
-            rec = self._key_fallback(key, raiseerr)
+        except KeyError as ke:
+            rec = self._key_fallback(key, ke, raiseerr)
             if rec is None:
                 return None
 
@@ -598,8 +601,8 @@ class CursorResultMetaData(ResultMetaData):
         for key in keys:
             try:
                 rec = self._keymap[key]
-            except KeyError:
-                rec = self._key_fallback(key, raiseerr)
+            except KeyError as ke:
+                rec = self._key_fallback(key, ke, raiseerr)
                 if rec is None:
                     return None
 
@@ -656,9 +659,9 @@ class LegacyCursorResultMetaData(CursorResultMetaData):
             )
             return True
         else:
-            return self._key_fallback(key, False) is not None
+            return self._key_fallback(key, None, False) is not None
 
-    def _key_fallback(self, key, raiseerr=True):
+    def _key_fallback(self, key, err, raiseerr=True):
         map_ = self._keymap
         result = None
 
@@ -714,9 +717,12 @@ class LegacyCursorResultMetaData(CursorResultMetaData):
                     )
         if result is None:
             if raiseerr:
-                raise exc.NoSuchColumnError(
-                    "Could not locate column in row for column '%s'"
-                    % util.string_or_unprintable(key)
+                util.raise_(
+                    exc.NoSuchColumnError(
+                        "Could not locate column in row for column '%s'"
+                        % util.string_or_unprintable(key)
+                    ),
+                    replace_context=err,
                 )
             else:
                 return None
@@ -736,7 +742,7 @@ class LegacyCursorResultMetaData(CursorResultMetaData):
         if key in self._keymap:
             return True
         else:
-            return self._key_fallback(key, False) is not None
+            return self._key_fallback(key, None, False) is not None
 
 
 class CursorFetchStrategy(object):
@@ -807,9 +813,12 @@ class NoCursorDQLFetchStrategy(CursorFetchStrategy):
     def fetchall(self):
         return self._non_result([])
 
-    def _non_result(self, default):
+    def _non_result(self, default, err=None):
         if self.closed:
-            raise exc.ResourceClosedError("This result object is closed.")
+            util.raise_(
+                exc.ResourceClosedError("This result object is closed."),
+                replace_context=err,
+            )
         else:
             return default
 
@@ -843,10 +852,13 @@ class NoCursorDMLFetchStrategy(CursorFetchStrategy):
     def fetchall(self):
         return self._non_result([])
 
-    def _non_result(self, default):
-        raise exc.ResourceClosedError(
-            "This result object does not return rows. "
-            "It has been closed automatically."
+    def _non_result(self, default, err=None):
+        util.raise_(
+            exc.ResourceClosedError(
+                "This result object does not return rows. "
+                "It has been closed automatically."
+            ),
+            replace_context=err,
         )
 
 
@@ -1123,24 +1135,24 @@ class BaseResult(object):
     def _getter(self, key, raiseerr=True):
         try:
             getter = self._metadata._getter
-        except AttributeError:
-            return self.cursor_strategy._non_result(None)
+        except AttributeError as err:
+            return self.cursor_strategy._non_result(None, err)
         else:
             return getter(key, raiseerr)
 
     def _tuple_getter(self, key, raiseerr=True):
         try:
             getter = self._metadata._tuple_getter
-        except AttributeError:
-            return self.cursor_strategy._non_result(None)
+        except AttributeError as err:
+            return self.cursor_strategy._non_result(None, err)
         else:
             return getter(key, raiseerr)
 
     def _has_key(self, key):
         try:
             has_key = self._metadata._has_key
-        except AttributeError:
-            return self.cursor_strategy._non_result(None)
+        except AttributeError as err:
+            return self.cursor_strategy._non_result(None, err)
         else:
             return has_key(key)
 
index 55d8c2249dce29b02a6e5719e7965381caf4290b..b58b350e25a7f4ff2f4bbbd4ba6b5c507c0afb31 100644 (file)
@@ -84,8 +84,8 @@ except ImportError:
         def _subscript_impl(self, key, ismapping):
             try:
                 rec = self._keymap[key]
-            except KeyError:
-                rec = self._parent._key_fallback(key)
+            except KeyError as ke:
+                rec = self._parent._key_fallback(key, ke)
             except TypeError:
                 # the non-C version detects a slice using TypeError.
                 # this is pretty inefficient for the slice use case
@@ -119,7 +119,7 @@ except ImportError:
             try:
                 return self._get_by_key_impl_mapping(name)
             except KeyError as e:
-                raise AttributeError(e.args[0])
+                util.raise_(AttributeError(e.args[0]), replace_context=e)
 
 
 class Row(BaseRow, collections_abc.Sequence):
index 41346fc4e0c1a8acb8af28f5c0d22dc9acd076f8..f00b642dbdc9a6084528e3f67b4662e04fc3210b 100644 (file)
@@ -244,6 +244,10 @@ class AssociationProxy(interfaces.InspectionAttrInfo):
         try:
             inst = class_.__dict__[self.key + "_inst"]
         except KeyError:
+            inst = None
+
+        # avoid exception context
+        if inst is None:
             owner = self._calc_owner(class_)
             if owner is not None:
                 inst = AssociationProxyInstance.for_proxy(self, owner, obj)
@@ -358,9 +362,12 @@ class AssociationProxyInstance(object):
 
         # this was never asserted before but this should be made clear.
         if not isinstance(prop, orm.RelationshipProperty):
-            raise NotImplementedError(
-                "association proxy to a non-relationship "
-                "intermediary is not supported"
+            util.raise_(
+                NotImplementedError(
+                    "association proxy to a non-relationship "
+                    "intermediary is not supported"
+                ),
+                replace_context=None,
             )
 
         target_class = prop.mapper.class_
@@ -1323,10 +1330,13 @@ class _AssociationDict(_AssociationCollection):
                 try:
                     for k, v in seq_or_map:
                         self[k] = v
-                except ValueError:
-                    raise ValueError(
-                        "dictionary update sequence "
-                        "requires 2-element tuples"
+                except ValueError as err:
+                    util.raise_(
+                        ValueError(
+                            "dictionary update sequence "
+                            "requires 2-element tuples"
+                        ),
+                        replace_context=err,
                     )
 
         for key, value in kw:
index cafe69093a47dd419a4d7187f1f352c408ee3b88..cf67387e4384bcc13c3bafc96e6a5f1197e56c52 100644 (file)
@@ -504,9 +504,12 @@ class Result(object):
         """
         try:
             ret = self.one_or_none()
-        except orm_exc.MultipleResultsFound:
-            raise orm_exc.MultipleResultsFound(
-                "Multiple rows were found for one()"
+        except orm_exc.MultipleResultsFound as err:
+            util.raise_(
+                orm_exc.MultipleResultsFound(
+                    "Multiple rows were found for one()"
+                ),
+                replace_context=err,
             )
         else:
             if ret is None:
index c27907cdcf6bd5fb4341440fc4e23af64f872a4d..b8b6f8dc0dcd2c531540ca635b79a610d509e534 100644 (file)
@@ -398,6 +398,7 @@ Example usage::
 
 """
 from .. import exc
+from .. import util
 from ..sql import sqltypes
 from ..sql import visitors
 
@@ -422,10 +423,13 @@ def compiles(class_, *specs):
                 def _wrap_existing_dispatch(element, compiler, **kw):
                     try:
                         return existing_dispatch(element, compiler, **kw)
-                    except exc.UnsupportedCompilationError:
-                        raise exc.CompileError(
-                            "%s construct has no default "
-                            "compilation handler." % type(element)
+                    except exc.UnsupportedCompilationError as uce:
+                        util.raise_(
+                            exc.CompileError(
+                                "%s construct has no default "
+                                "compilation handler." % type(element)
+                            ),
+                            from_=uce,
                         )
 
                 existing.specs["default"] = _wrap_existing_dispatch
@@ -470,10 +474,13 @@ class _dispatcher(object):
         if not fn:
             try:
                 fn = self.specs["default"]
-            except KeyError:
-                raise exc.CompileError(
-                    "%s construct has no default "
-                    "compilation handler." % type(element)
+            except KeyError as ke:
+                util.raise_(
+                    exc.CompileError(
+                        "%s construct has no default "
+                        "compilation handler." % type(element)
+                    ),
+                    replace_context=ke,
                 )
 
         # if compilation includes add_to_result_map, collect add_to_result_map
index 7ff30b807f549c55158dde35f0d85af2ec483911..93e643cf5c3670cf649c61d8fe08794052f894c9 100644 (file)
@@ -298,12 +298,15 @@ class _class_resolver(object):
             else:
                 return x
         except NameError as n:
-            raise exc.InvalidRequestError(
-                "When initializing mapper %s, expression %r failed to "
-                "locate a name (%r). If this is a class name, consider "
-                "adding this relationship() to the %r class after "
-                "both dependent classes have been defined."
-                % (self.prop.parent, self.arg, n.args[0], self.cls)
+            util.raise_(
+                exc.InvalidRequestError(
+                    "When initializing mapper %s, expression %r failed to "
+                    "locate a name (%r). If this is a class name, consider "
+                    "adding this relationship() to the %r class after "
+                    "both dependent classes have been defined."
+                    % (self.prop.parent, self.arg, n.args[0], self.cls)
+                ),
+                from_=n,
             )
 
 
index f2e0501bb34b87b310b62f73373b6b5f5113422b..6eb7e11850e5d96c89166326c3cf6bc1e35e28cb 100644 (file)
@@ -223,7 +223,8 @@ The above query will render::
 """  # noqa
 from __future__ import absolute_import
 
-from sqlalchemy import inspect
+from .. import inspect
+from .. import util
 from ..ext.hybrid import hybrid_property
 from ..orm.attributes import flag_modified
 
@@ -301,9 +302,9 @@ class index_property(hybrid_property):  # noqa
                 self.datatype = dict
         self.onebased = onebased
 
-    def _fget_default(self):
+    def _fget_default(self, err=None):
         if self.default == self._NO_DEFAULT_ARGUMENT:
-            raise AttributeError(self.attr_name)
+            util.raise_(AttributeError(self.attr_name), replace_context=err)
         else:
             return self.default
 
@@ -314,8 +315,8 @@ class index_property(hybrid_property):  # noqa
             return self._fget_default()
         try:
             value = column_value[self.index]
-        except (KeyError, IndexError):
-            return self._fget_default()
+        except (KeyError, IndexError) as err:
+            return self._fget_default(err)
         else:
             return value
 
@@ -337,8 +338,8 @@ class index_property(hybrid_property):  # noqa
             raise AttributeError(self.attr_name)
         try:
             del column_value[self.index]
-        except KeyError:
-            raise AttributeError(self.attr_name)
+        except KeyError as err:
+            util.raise_(AttributeError(self.attr_name), replace_context=err)
         else:
             setattr(instance, attr_name, column_value)
             flag_modified(instance, attr_name)
index 66a18da9923de39f94da849a714cf6dbc2abd7a9..a959b0a4008da9efd80d3c86f3d6c6c04d957de0 100644 (file)
@@ -231,16 +231,19 @@ class QueryableAttribute(
     def __getattr__(self, key):
         try:
             return getattr(self.comparator, key)
-        except AttributeError:
-            raise AttributeError(
-                "Neither %r object nor %r object associated with %s "
-                "has an attribute %r"
-                % (
-                    type(self).__name__,
-                    type(self.comparator).__name__,
-                    self,
-                    key,
-                )
+        except AttributeError as err:
+            util.raise_(
+                AttributeError(
+                    "Neither %r object nor %r object associated with %s "
+                    "has an attribute %r"
+                    % (
+                        type(self).__name__,
+                        type(self.comparator).__name__,
+                        self,
+                        key,
+                    )
+                ),
+                replace_context=err,
             )
 
     def __str__(self):
@@ -373,31 +376,39 @@ def create_proxied_attribute(descriptor):
             comparator."""
             try:
                 return getattr(descriptor, attribute)
-            except AttributeError:
+            except AttributeError as err:
                 if attribute == "comparator":
-                    raise AttributeError("comparator")
+                    util.raise_(
+                        AttributeError("comparator"), replace_context=err
+                    )
                 try:
                     # comparator itself might be unreachable
                     comparator = self.comparator
-                except AttributeError:
-                    raise AttributeError(
-                        "Neither %r object nor unconfigured comparator "
-                        "object associated with %s has an attribute %r"
-                        % (type(descriptor).__name__, self, attribute)
+                except AttributeError as err2:
+                    util.raise_(
+                        AttributeError(
+                            "Neither %r object nor unconfigured comparator "
+                            "object associated with %s has an attribute %r"
+                            % (type(descriptor).__name__, self, attribute)
+                        ),
+                        replace_context=err2,
                     )
                 else:
                     try:
                         return getattr(comparator, attribute)
-                    except AttributeError:
-                        raise AttributeError(
-                            "Neither %r object nor %r object "
-                            "associated with %s has an attribute %r"
-                            % (
-                                type(descriptor).__name__,
-                                type(comparator).__name__,
-                                self,
-                                attribute,
-                            )
+                    except AttributeError as err3:
+                        util.raise_(
+                            AttributeError(
+                                "Neither %r object nor %r object "
+                                "associated with %s has an attribute %r"
+                                % (
+                                    type(descriptor).__name__,
+                                    type(comparator).__name__,
+                                    self,
+                                    attribute,
+                                )
+                            ),
+                            replace_context=err3,
                         )
 
     Proxy.__name__ = type(descriptor).__name__ + "Proxy"
@@ -713,12 +724,15 @@ class AttributeImpl(object):
                 elif value is ATTR_WAS_SET:
                     try:
                         return dict_[key]
-                    except KeyError:
+                    except KeyError as err:
                         # TODO: no test coverage here.
-                        raise KeyError(
-                            "Deferred loader for attribute "
-                            "%r failed to populate "
-                            "correctly" % key
+                        util.raise_(
+                            KeyError(
+                                "Deferred loader for attribute "
+                                "%r failed to populate "
+                                "correctly" % key
+                            ),
+                            replace_context=err,
                         )
                 elif value is not ATTR_EMPTY:
                     return self.set_committed_value(state, dict_, value)
index 571107a380a2228b3675190a224e621bf6a0ca90..a31745aec278f3cfa0254ce4adf255aa0bed9c3e 100644 (file)
@@ -387,9 +387,12 @@ def _entity_descriptor(entity, key):
 
     try:
         return getattr(entity, key)
-    except AttributeError:
-        raise sa_exc.InvalidRequestError(
-            "Entity '%s' has no property '%s'" % (description, key)
+    except AttributeError as err:
+        util.raise_(
+            sa_exc.InvalidRequestError(
+                "Entity '%s' has no property '%s'" % (description, key)
+            ),
+            replace_context=err,
         )
 
 
index f75c7d3bac27ad5c770139212e74ee632a6e7f46..57c192a5d3d97d4d356a7a0de5cf563b6f564f4d 100644 (file)
@@ -557,12 +557,15 @@ class StrategizedProperty(MapperProperty):
         try:
             return self._strategies[key]
         except KeyError:
-            cls = self._strategy_lookup(self, *key)
-            # this previously was setting self._strategies[cls], that's
-            # a bad idea; should use strategy key at all times because every
-            # strategy has multiple keys at this point
-            self._strategies[key] = strategy = cls(self, key)
-            return strategy
+            pass
+
+        # run outside to prevent transfer of exception context
+        cls = self._strategy_lookup(self, *key)
+        # this previously was setting self._strategies[cls], that's
+        # a bad idea; should use strategy key at all times because every
+        # strategy has multiple keys at this point
+        self._strategies[key] = strategy = cls(self, key)
+        return strategy
 
     def setup(self, context, query_entity, path, adapter, **kwargs):
         loader = self._get_context_loader(context, path)
index 193980e6c32a5a19f581b3b433a58a5eb52c5d15..d943ebb19097c62028ca640e66114858d377aaf9 100644 (file)
@@ -99,9 +99,9 @@ def instances(query, cursor, context):
 
             if not query._yield_per:
                 break
-    except Exception as err:
-        cursor.close()
-        util.raise_from_cause(err)
+    except Exception:
+        with util.safe_reraise():
+            cursor.close()
 
 
 @util.dependencies("sqlalchemy.orm.query")
index 0d87a9c406cb6d3f20e5e400ee73d22043840397..91e3251e2c29e8694cf06c0f2908b8877a4d6ad7 100644 (file)
@@ -1483,11 +1483,14 @@ class Mapper(sql_base.HasCacheKey, InspectionAttr):
                 # it to mapped ColumnProperty
                 try:
                     self.polymorphic_on = self._props[self.polymorphic_on]
-                except KeyError:
-                    raise sa_exc.ArgumentError(
-                        "Can't determine polymorphic_on "
-                        "value '%s' - no attribute is "
-                        "mapped to this name." % self.polymorphic_on
+                except KeyError as err:
+                    util.raise_(
+                        sa_exc.ArgumentError(
+                            "Can't determine polymorphic_on "
+                            "value '%s' - no attribute is "
+                            "mapped to this name." % self.polymorphic_on
+                        ),
+                        replace_context=err,
                     )
 
             if self.polymorphic_on in self._columntoproperty:
@@ -1987,9 +1990,12 @@ class Mapper(sql_base.HasCacheKey, InspectionAttr):
 
         try:
             return self._props[key]
-        except KeyError:
-            raise sa_exc.InvalidRequestError(
-                "Mapper '%s' has no property '%s'" % (self, key)
+        except KeyError as err:
+            util.raise_(
+                sa_exc.InvalidRequestError(
+                    "Mapper '%s' has no property '%s'" % (self, key)
+                ),
+                replace_context=err,
             )
 
     def get_property_by_column(self, column):
index 3b274a3893938fb862aba9b497ba9711dadaec5b..46c84d4bda826e5403f382aa5d4fbdcc9ad6f793 100644 (file)
@@ -1635,9 +1635,12 @@ def _sort_states(mapper, states):
             persistent, key=mapper._persistent_sortkey_fn
         )
     except TypeError as err:
-        raise sa_exc.InvalidRequestError(
-            "Could not sort objects by primary key; primary key "
-            "values must be sortable in Python (was: %s)" % err
+        util.raise_(
+            sa_exc.InvalidRequestError(
+                "Could not sort objects by primary key; primary key "
+                "values must be sortable in Python (was: %s)" % err
+            ),
+            replace_context=err,
         )
     return (
         sorted(pending, key=operator.attrgetter("insert_order"))
@@ -1681,10 +1684,13 @@ class BulkUD(object):
     def _factory(cls, lookup, synchronize_session, *arg):
         try:
             klass = lookup[synchronize_session]
-        except KeyError:
-            raise sa_exc.ArgumentError(
-                "Valid strategies for session synchronization "
-                "are %s" % (", ".join(sorted(repr(x) for x in lookup)))
+        except KeyError as err:
+            util.raise_(
+                sa_exc.ArgumentError(
+                    "Valid strategies for session synchronization "
+                    "are %s" % (", ".join(sorted(repr(x) for x in lookup)))
+                ),
+                replace_context=err,
             )
         else:
             return klass(*arg)
@@ -1768,10 +1774,13 @@ class BulkEvaluate(BulkUD):
             self._additional_evaluators(evaluator_compiler)
 
         except evaluator.UnevaluatableError as err:
-            raise sa_exc.InvalidRequestError(
-                'Could not evaluate current criteria in Python: "%s". '
-                "Specify 'fetch' or False for the "
-                "synchronize_session parameter." % err
+            util.raise_(
+                sa_exc.InvalidRequestError(
+                    'Could not evaluate current criteria in Python: "%s". '
+                    "Specify 'fetch' or False for the "
+                    "synchronize_session parameter." % err
+                ),
+                from_=err,
             )
 
         # TODO: detect when the where clause is a trivial primary key match
index d237aa3bf244a3caa63841e0aff6d8938bf4e4ba..e29e6eeeeb03a480fd12315a0f5813aa9f75e0d1 100644 (file)
@@ -1019,15 +1019,18 @@ class Query(Generative):
                     for prop in mapper._identity_key_props
                 )
 
-            except KeyError:
-                raise sa_exc.InvalidRequestError(
-                    "Incorrect names of values in identifier to formulate "
-                    "primary key for query.get(); primary key attribute names"
-                    " are %s"
-                    % ",".join(
-                        "'%s'" % prop.key
-                        for prop in mapper._identity_key_props
-                    )
+            except KeyError as err:
+                util.raise_(
+                    sa_exc.InvalidRequestError(
+                        "Incorrect names of values in identifier to formulate "
+                        "primary key for query.get(); primary key attribute "
+                        "names are %s"
+                        % ",".join(
+                            "'%s'" % prop.key
+                            for prop in mapper._identity_key_props
+                        )
+                    ),
+                    replace_context=err,
                 )
 
         if (
@@ -3292,9 +3295,12 @@ class Query(Generative):
         """
         try:
             ret = self.one_or_none()
-        except orm_exc.MultipleResultsFound:
-            raise orm_exc.MultipleResultsFound(
-                "Multiple rows were found for one()"
+        except orm_exc.MultipleResultsFound as err:
+            util.raise_(
+                orm_exc.MultipleResultsFound(
+                    "Multiple rows were found for one()"
+                ),
+                replace_context=err,
             )
         else:
             if ret is None:
index b82a3d2712efe78bf2d5a65a8d8152fe35ecc687..2995baf5fc6ecd1111b438fa4c8dfa6e3e1be8ec 100644 (file)
@@ -2484,50 +2484,64 @@ class JoinCondition(object):
                         a_subset=self.parent_local_selectable,
                         consider_as_foreign_keys=consider_as_foreign_keys,
                     )
-        except sa_exc.NoForeignKeysError:
+        except sa_exc.NoForeignKeysError as nfe:
             if self.secondary is not None:
-                raise sa_exc.NoForeignKeysError(
-                    "Could not determine join "
-                    "condition between parent/child tables on "
-                    "relationship %s - there are no foreign keys "
-                    "linking these tables via secondary table '%s'.  "
-                    "Ensure that referencing columns are associated "
-                    "with a ForeignKey or ForeignKeyConstraint, or "
-                    "specify 'primaryjoin' and 'secondaryjoin' "
-                    "expressions." % (self.prop, self.secondary)
+                util.raise_(
+                    sa_exc.NoForeignKeysError(
+                        "Could not determine join "
+                        "condition between parent/child tables on "
+                        "relationship %s - there are no foreign keys "
+                        "linking these tables via secondary table '%s'.  "
+                        "Ensure that referencing columns are associated "
+                        "with a ForeignKey or ForeignKeyConstraint, or "
+                        "specify 'primaryjoin' and 'secondaryjoin' "
+                        "expressions." % (self.prop, self.secondary)
+                    ),
+                    from_=nfe,
                 )
             else:
-                raise sa_exc.NoForeignKeysError(
-                    "Could not determine join "
-                    "condition between parent/child tables on "
-                    "relationship %s - there are no foreign keys "
-                    "linking these tables.  "
-                    "Ensure that referencing columns are associated "
-                    "with a ForeignKey or ForeignKeyConstraint, or "
-                    "specify a 'primaryjoin' expression." % self.prop
+                util.raise_(
+                    sa_exc.NoForeignKeysError(
+                        "Could not determine join "
+                        "condition between parent/child tables on "
+                        "relationship %s - there are no foreign keys "
+                        "linking these tables.  "
+                        "Ensure that referencing columns are associated "
+                        "with a ForeignKey or ForeignKeyConstraint, or "
+                        "specify a 'primaryjoin' expression." % self.prop
+                    ),
+                    from_=nfe,
                 )
-        except sa_exc.AmbiguousForeignKeysError:
+        except sa_exc.AmbiguousForeignKeysError as afe:
             if self.secondary is not None:
-                raise sa_exc.AmbiguousForeignKeysError(
-                    "Could not determine join "
-                    "condition between parent/child tables on "
-                    "relationship %s - there are multiple foreign key "
-                    "paths linking the tables via secondary table '%s'.  "
-                    "Specify the 'foreign_keys' "
-                    "argument, providing a list of those columns which "
-                    "should be counted as containing a foreign key "
-                    "reference from the secondary table to each of the "
-                    "parent and child tables." % (self.prop, self.secondary)
+                util.raise_(
+                    sa_exc.AmbiguousForeignKeysError(
+                        "Could not determine join "
+                        "condition between parent/child tables on "
+                        "relationship %s - there are multiple foreign key "
+                        "paths linking the tables via secondary table '%s'.  "
+                        "Specify the 'foreign_keys' "
+                        "argument, providing a list of those columns which "
+                        "should be counted as containing a foreign key "
+                        "reference from the secondary table to each of the "
+                        "parent and child tables."
+                        % (self.prop, self.secondary)
+                    ),
+                    from_=afe,
                 )
             else:
-                raise sa_exc.AmbiguousForeignKeysError(
-                    "Could not determine join "
-                    "condition between parent/child tables on "
-                    "relationship %s - there are multiple foreign key "
-                    "paths linking the tables.  Specify the "
-                    "'foreign_keys' argument, providing a list of those "
-                    "columns which should be counted as containing a "
-                    "foreign key reference to the parent table." % self.prop
+                util.raise_(
+                    sa_exc.AmbiguousForeignKeysError(
+                        "Could not determine join "
+                        "condition between parent/child tables on "
+                        "relationship %s - there are multiple foreign key "
+                        "paths linking the tables.  Specify the "
+                        "'foreign_keys' argument, providing a list of those "
+                        "columns which should be counted as containing a "
+                        "foreign key reference to the parent table."
+                        % self.prop
+                    ),
+                    from_=afe,
                 )
 
     @property
index 0950339516cf64d77d3c37234ae86e7d2acccbfe..74e5464835899fca7880f462f7019e43c5f81c34 100644 (file)
@@ -575,7 +575,7 @@ class SessionTransaction(object):
             self._parent._rollback_exception = sys.exc_info()[1]
 
         if rollback_err:
-            util.reraise(*rollback_err)
+            util.raise_(rollback_err[1], with_traceback=rollback_err[2])
 
         sess.dispatch.after_soft_rollback(sess, self)
 
@@ -1362,10 +1362,13 @@ class Session(_SessionClassMethods):
     def _add_bind(self, key, bind):
         try:
             insp = inspect(key)
-        except sa_exc.NoInspectionAvailable:
+        except sa_exc.NoInspectionAvailable as err:
             if not isinstance(key, type):
-                raise sa_exc.ArgumentError(
-                    "Not an acceptable bind target: %s" % key
+                util.raise_(
+                    sa_exc.ArgumentError(
+                        "Not an acceptable bind target: %s" % key
+                    ),
+                    replace_context=err,
                 )
             else:
                 self.__binds[key] = bind
@@ -1515,9 +1518,11 @@ class Session(_SessionClassMethods):
         if mapper is not None:
             try:
                 mapper = inspect(mapper)
-            except sa_exc.NoInspectionAvailable:
+            except sa_exc.NoInspectionAvailable as err:
                 if isinstance(mapper, type):
-                    raise exc.UnmappedClassError(mapper)
+                    util.raise_(
+                        exc.UnmappedClassError(mapper), replace_context=err,
+                    )
                 else:
                     raise
 
@@ -1656,7 +1661,7 @@ class Session(_SessionClassMethods):
                     "consider using a session.no_autoflush block if this "
                     "flush is occurring prematurely"
                 )
-                util.raise_from_cause(e)
+                util.raise_(e, with_traceback=sys.exc_info[2])
 
     def refresh(
         self,
@@ -1711,8 +1716,10 @@ class Session(_SessionClassMethods):
         """
         try:
             state = attributes.instance_state(instance)
-        except exc.NO_STATE:
-            raise exc.UnmappedInstanceError(instance)
+        except exc.NO_STATE as err:
+            util.raise_(
+                exc.UnmappedInstanceError(instance), replace_context=err,
+            )
 
         self._expire_state(state, attribute_names)
 
@@ -1817,8 +1824,10 @@ class Session(_SessionClassMethods):
         """
         try:
             state = attributes.instance_state(instance)
-        except exc.NO_STATE:
-            raise exc.UnmappedInstanceError(instance)
+        except exc.NO_STATE as err:
+            util.raise_(
+                exc.UnmappedInstanceError(instance), replace_context=err,
+            )
         self._expire_state(state, attribute_names)
 
     def _expire_state(self, state, attribute_names):
@@ -1872,8 +1881,10 @@ class Session(_SessionClassMethods):
         """
         try:
             state = attributes.instance_state(instance)
-        except exc.NO_STATE:
-            raise exc.UnmappedInstanceError(instance)
+        except exc.NO_STATE as err:
+            util.raise_(
+                exc.UnmappedInstanceError(instance), replace_context=err,
+            )
         if state.session_id is not self.hash_key:
             raise sa_exc.InvalidRequestError(
                 "Instance %s is not present in this Session" % state_str(state)
@@ -2024,8 +2035,10 @@ class Session(_SessionClassMethods):
 
         try:
             state = attributes.instance_state(instance)
-        except exc.NO_STATE:
-            raise exc.UnmappedInstanceError(instance)
+        except exc.NO_STATE as err:
+            util.raise_(
+                exc.UnmappedInstanceError(instance), replace_context=err,
+            )
 
         self._save_or_update_state(state)
 
@@ -2059,8 +2072,10 @@ class Session(_SessionClassMethods):
 
         try:
             state = attributes.instance_state(instance)
-        except exc.NO_STATE:
-            raise exc.UnmappedInstanceError(instance)
+        except exc.NO_STATE as err:
+            util.raise_(
+                exc.UnmappedInstanceError(instance), replace_context=err,
+            )
 
         self._delete_impl(state, instance, head=True)
 
@@ -2490,8 +2505,10 @@ class Session(_SessionClassMethods):
         """
         try:
             state = attributes.instance_state(instance)
-        except exc.NO_STATE:
-            raise exc.UnmappedInstanceError(instance)
+        except exc.NO_STATE as err:
+            util.raise_(
+                exc.UnmappedInstanceError(instance), replace_context=err,
+            )
         return self._contains_state(state)
 
     def __iter__(self):
@@ -2586,8 +2603,11 @@ class Session(_SessionClassMethods):
             for o in objects:
                 try:
                     state = attributes.instance_state(o)
-                except exc.NO_STATE:
-                    raise exc.UnmappedInstanceError(o)
+
+                except exc.NO_STATE as err:
+                    util.raise_(
+                        exc.UnmappedInstanceError(o), replace_context=err,
+                    )
                 objset.add(state)
         else:
             objset = None
@@ -3450,8 +3470,10 @@ def object_session(instance):
 
     try:
         state = attributes.instance_state(instance)
-    except exc.NO_STATE:
-        raise exc.UnmappedInstanceError(instance)
+    except exc.NO_STATE as err:
+        util.raise_(
+            exc.UnmappedInstanceError(instance), replace_context=err,
+        )
     else:
         return _state_session(state)
 
index 0c72f3b37dcec3bceb5ac8059a1611e6ff8f909d..4f7d996d4f0871cecc16ba05e88ac55dc79fdd5c 100644 (file)
@@ -252,11 +252,14 @@ class Load(HasCacheKey, Generative, MapperOption):
                 # use getattr on the class to work around
                 # synonyms, hybrids, etc.
                 attr = getattr(ent.class_, attr)
-            except AttributeError:
+            except AttributeError as err:
                 if raiseerr:
-                    raise sa_exc.ArgumentError(
-                        'Can\'t find property named "%s" on '
-                        "%s in this Query." % (attr, ent)
+                    util.raise_(
+                        sa_exc.ArgumentError(
+                            'Can\'t find property named "%s" on '
+                            "%s in this Query." % (attr, ent)
+                        ),
+                        replace_context=err,
                     )
                 else:
                     return None
index 198e64f4f2a42989720a31441240c82adbe9bba9..ceaf54e5d332e1fe3436cd81ce9d0b0013afa2f9 100644 (file)
@@ -13,6 +13,7 @@ between instances based on join conditions.
 from . import attributes
 from . import exc
 from . import util as orm_util
+from .. import util
 
 
 def populate(
@@ -34,15 +35,15 @@ def populate(
             value = source.manager[prop.key].impl.get(
                 source, source_dict, attributes.PASSIVE_OFF
             )
-        except exc.UnmappedColumnError:
-            _raise_col_to_prop(False, source_mapper, l, dest_mapper, r)
+        except exc.UnmappedColumnError as err:
+            _raise_col_to_prop(False, source_mapper, l, dest_mapper, r, err)
 
         try:
             # inline of dest_mapper._set_state_attr_by_column
             prop = dest_mapper._columntoproperty[r]
             dest.manager[prop.key].impl.set(dest, dest_dict, value, None)
-        except exc.UnmappedColumnError:
-            _raise_col_to_prop(True, source_mapper, l, dest_mapper, r)
+        except exc.UnmappedColumnError as err:
+            _raise_col_to_prop(True, source_mapper, l, dest_mapper, r, err)
 
         # technically the "r.primary_key" check isn't
         # needed here, but we check for this condition to limit
@@ -64,8 +65,8 @@ def bulk_populate_inherit_keys(source_dict, source_mapper, synchronize_pairs):
         try:
             prop = source_mapper._columntoproperty[l]
             value = source_dict[prop.key]
-        except exc.UnmappedColumnError:
-            _raise_col_to_prop(False, source_mapper, l, source_mapper, r)
+        except exc.UnmappedColumnError as err:
+            _raise_col_to_prop(False, source_mapper, l, source_mapper, r, err)
 
         try:
             prop = source_mapper._columntoproperty[r]
@@ -88,8 +89,8 @@ def clear(dest, dest_mapper, synchronize_pairs):
             )
         try:
             dest_mapper._set_state_attr_by_column(dest, dest.dict, r, None)
-        except exc.UnmappedColumnError:
-            _raise_col_to_prop(True, None, l, dest_mapper, r)
+        except exc.UnmappedColumnError as err:
+            _raise_col_to_prop(True, None, l, dest_mapper, r, err)
 
 
 def update(source, source_mapper, dest, old_prefix, synchronize_pairs):
@@ -101,8 +102,8 @@ def update(source, source_mapper, dest, old_prefix, synchronize_pairs):
             value = source_mapper._get_state_attr_by_column(
                 source, source.dict, l, passive=attributes.PASSIVE_OFF
             )
-        except exc.UnmappedColumnError:
-            _raise_col_to_prop(False, source_mapper, l, None, r)
+        except exc.UnmappedColumnError as err:
+            _raise_col_to_prop(False, source_mapper, l, None, r, err)
         dest[r.key] = value
         dest[old_prefix + r.key] = oldvalue
 
@@ -113,8 +114,8 @@ def populate_dict(source, source_mapper, dict_, synchronize_pairs):
             value = source_mapper._get_state_attr_by_column(
                 source, source.dict, l, passive=attributes.PASSIVE_OFF
             )
-        except exc.UnmappedColumnError:
-            _raise_col_to_prop(False, source_mapper, l, None, r)
+        except exc.UnmappedColumnError as err:
+            _raise_col_to_prop(False, source_mapper, l, None, r, err)
 
         dict_[r.key] = value
 
@@ -127,8 +128,8 @@ def source_modified(uowcommit, source, source_mapper, synchronize_pairs):
     for l, r in synchronize_pairs:
         try:
             prop = source_mapper._columntoproperty[l]
-        except exc.UnmappedColumnError:
-            _raise_col_to_prop(False, source_mapper, l, None, r)
+        except exc.UnmappedColumnError as err:
+            _raise_col_to_prop(False, source_mapper, l, None, r, err)
         history = uowcommit.get_attribute_history(
             source, prop.key, attributes.PASSIVE_NO_INITIALIZE
         )
@@ -139,22 +140,28 @@ def source_modified(uowcommit, source, source_mapper, synchronize_pairs):
 
 
 def _raise_col_to_prop(
-    isdest, source_mapper, source_column, dest_mapper, dest_column
+    isdest, source_mapper, source_column, dest_mapper, dest_column, err
 ):
     if isdest:
-        raise exc.UnmappedColumnError(
-            "Can't execute sync rule for "
-            "destination column '%s'; mapper '%s' does not map "
-            "this column.  Try using an explicit `foreign_keys` "
-            "collection which does not include this column (or use "
-            "a viewonly=True relation)." % (dest_column, dest_mapper)
+        util.raise_(
+            exc.UnmappedColumnError(
+                "Can't execute sync rule for "
+                "destination column '%s'; mapper '%s' does not map "
+                "this column.  Try using an explicit `foreign_keys` "
+                "collection which does not include this column (or use "
+                "a viewonly=True relation)." % (dest_column, dest_mapper)
+            ),
+            replace_context=err,
         )
     else:
-        raise exc.UnmappedColumnError(
-            "Can't execute sync rule for "
-            "source column '%s'; mapper '%s' does not map this "
-            "column.  Try using an explicit `foreign_keys` "
-            "collection which does not include destination column "
-            "'%s' (or use a viewonly=True relation)."
-            % (source_column, source_mapper, dest_column)
+        util.raise_(
+            exc.UnmappedColumnError(
+                "Can't execute sync rule for "
+                "source column '%s'; mapper '%s' does not map this "
+                "column.  Try using an explicit `foreign_keys` "
+                "collection which does not include destination column "
+                "'%s' (or use a viewonly=True relation)."
+                % (source_column, source_mapper, dest_column)
+            ),
+            replace_context=err,
         )
index b53f0d7ddb05e08e2b2b83bd58447f8ae5709392..17d5ba15fde2c8fff7dd0d31261792207fefa593 100644 (file)
@@ -578,8 +578,8 @@ class _ConnectionRecord(object):
             self.connection = connection
             self.fresh = True
         except Exception as e:
-            pool.logger.debug("Error on connect(): %s", e)
-            raise
+            with util.safe_reraise():
+                pool.logger.debug("Error on connect(): %s", e)
         else:
             if first_connect_check:
                 pool.dispatch.first_connect.for_modify(
index 67f1564ec3c0e0957521c22cd872de39245b112d..8618d5e2aa60a1fb4ca4fa4f33a5443d4a1af6f2 100644 (file)
@@ -32,10 +32,13 @@ def str_to_datetime_processor_factory(regexp, type_):
         else:
             try:
                 m = rmatch(value)
-            except TypeError:
-                raise ValueError(
-                    "Couldn't parse %s string '%r' "
-                    "- value is not a string." % (type_.__name__, value)
+            except TypeError as err:
+                util.raise_(
+                    ValueError(
+                        "Couldn't parse %s string '%r' "
+                        "- value is not a string." % (type_.__name__, value)
+                    ),
+                    from_=err,
                 )
             if m is None:
                 raise ValueError(
index a7324c45fa7c053fdd61335f46302170f72681dd..2d336360f9e00e430d20fe6a042552fc759d39bf 100644 (file)
@@ -128,8 +128,8 @@ class _DialectArgView(util.collections_abc.MutableMapping):
     def _key(self, key):
         try:
             dialect, value_key = key.split("_", 1)
-        except ValueError:
-            raise KeyError(key)
+        except ValueError as err:
+            util.raise_(KeyError(key), replace_context=err)
         else:
             return dialect, value_key
 
@@ -138,17 +138,20 @@ class _DialectArgView(util.collections_abc.MutableMapping):
 
         try:
             opt = self.obj.dialect_options[dialect]
-        except exc.NoSuchModuleError:
-            raise KeyError(key)
+        except exc.NoSuchModuleError as err:
+            util.raise_(KeyError(key), replace_context=err)
         else:
             return opt[value_key]
 
     def __setitem__(self, key, value):
         try:
             dialect, value_key = self._key(key)
-        except KeyError:
-            raise exc.ArgumentError(
-                "Keys must be of the form <dialectname>_<argname>"
+        except KeyError as err:
+            util.raise_(
+                exc.ArgumentError(
+                    "Keys must be of the form <dialectname>_<argname>"
+                ),
+                replace_context=err,
             )
         else:
             self.obj.dialect_options[dialect][value_key] = value
@@ -634,17 +637,17 @@ class ColumnCollection(object):
     def __getitem__(self, key):
         try:
             return self._index[key]
-        except KeyError:
+        except KeyError as err:
             if isinstance(key, util.int_types):
-                raise IndexError(key)
+                util.raise_(IndexError(key), replace_context=err)
             else:
                 raise
 
     def __getattr__(self, key):
         try:
             return self._index[key]
-        except KeyError:
-            raise AttributeError(key)
+        except KeyError as err:
+            util.raise_(AttributeError(key), replace_context=err)
 
     def __contains__(self, key):
         if key not in self._index:
index b3bf4e93b99b320f37faef0dc47ad1012c2c102b..fc841bb4bea7d609b2900de50c0c2638d11394ba 100644 (file)
@@ -133,7 +133,13 @@ class RoleImpl(object):
         self._raise_for_expected(element, argname, resolved)
 
     def _raise_for_expected(
-        self, element, argname=None, resolved=None, advice=None, code=None
+        self,
+        element,
+        argname=None,
+        resolved=None,
+        advice=None,
+        code=None,
+        err=None,
     ):
         if argname:
             msg = "%s expected for argument %r; got %r." % (
@@ -147,7 +153,7 @@ class RoleImpl(object):
         if advice:
             msg += " " + advice
 
-        raise exc.ArgumentError(msg, code=code)
+        util.raise_(exc.ArgumentError(msg, code=code), replace_context=err)
 
 
 class _Deannotate(object):
@@ -201,16 +207,19 @@ class _ColumnCoercions(object):
 
 
 def _no_text_coercion(
-    element, argname=None, exc_cls=exc.ArgumentError, extra=None
+    element, argname=None, exc_cls=exc.ArgumentError, extra=None, err=None
 ):
-    raise exc_cls(
-        "%(extra)sTextual SQL expression %(expr)r %(argname)sshould be "
-        "explicitly declared as text(%(expr)r)"
-        % {
-            "expr": util.ellipses_string(element),
-            "argname": "for argument %s" % (argname,) if argname else "",
-            "extra": "%s " % extra if extra else "",
-        }
+    util.raise_(
+        exc_cls(
+            "%(extra)sTextual SQL expression %(expr)r %(argname)sshould be "
+            "explicitly declared as text(%(expr)r)"
+            % {
+                "expr": util.ellipses_string(element),
+                "argname": "for argument %s" % (argname,) if argname else "",
+                "extra": "%s " % extra if extra else "",
+            }
+        ),
+        replace_context=err,
     )
 
 
@@ -290,8 +299,8 @@ class ExpressionElementImpl(
                 return elements.BindParameter(
                     name, element, type_, unique=True
                 )
-            except exc.ArgumentError:
-                self._raise_for_expected(element)
+            except exc.ArgumentError as err:
+                self._raise_for_expected(element, err=err)
 
 
 class BinaryElementImpl(
@@ -302,8 +311,8 @@ class BinaryElementImpl(
     ):
         try:
             return expr._bind_param(operator, element, type_=bindparam_type)
-        except exc.ArgumentError:
-            self._raise_for_expected(element)
+        except exc.ArgumentError as err:
+            self._raise_for_expected(element, err=err)
 
     def _post_coercion(self, resolved, expr, **kw):
         if (
index 9c1f50ce13e2dd789c1f096dc053dd27e8e08377..d31cf67f88000e3deec9b226e2f74843254804c9 100644 (file)
@@ -1074,7 +1074,7 @@ class SQLCompiler(Compiled):
                 col = only_froms[element.element]
             else:
                 col = with_cols[element.element]
-        except KeyError:
+        except KeyError as err:
             coercions._no_text_coercion(
                 element.element,
                 extra=(
@@ -1082,6 +1082,7 @@ class SQLCompiler(Compiled):
                     "GROUP BY / DISTINCT etc."
                 ),
                 exc_cls=exc.CompileError,
+                err=err,
             )
         else:
             kwargs["render_label_as_label"] = col
@@ -1671,8 +1672,11 @@ class SQLCompiler(Compiled):
         else:
             try:
                 opstring = OPERATORS[operator_]
-            except KeyError:
-                raise exc.UnsupportedCompilationError(self, operator_)
+            except KeyError as err:
+                util.raise_(
+                    exc.UnsupportedCompilationError(self, operator_),
+                    replace_context=err,
+                )
             else:
                 return self._generate_generic_binary(
                     binary, opstring, from_linter=from_linter, **kw
@@ -3286,11 +3290,12 @@ class DDLCompiler(Compiled):
                 if column.primary_key:
                     first_pk = True
             except exc.CompileError as ce:
-                util.raise_from_cause(
+                util.raise_(
                     exc.CompileError(
                         util.u("(in table '%s', column '%s'): %s")
                         % (table.description, column.name, ce.args[0])
-                    )
+                    ),
+                    from_=ce,
                 )
 
         const = self.create_table_constraints(
index 31bcc34a400ef981f12e6a161a5673a83d37242c..5a2095604cf049cdfb9a730617fc6facf6c48bdd 100644 (file)
@@ -801,7 +801,7 @@ class SchemaDropper(DDLBase):
                 )
                 collection = [(t, ()) for t in unsorted_tables]
             else:
-                util.raise_from_cause(
+                util.raise_(
                     exc.CircularDependencyError(
                         err2.args[0],
                         err2.cycles,
@@ -818,7 +818,8 @@ class SchemaDropper(DDLBase):
                                 sorted([t.fullname for t in err2.cycles])
                             )
                         ),
-                    )
+                    ),
+                    from_=err2,
                 )
 
         seq_coll = [
index df690c383b861cb42b238480bec934b66e4c1568..d0babb1be0b154249c4ee1fba1fbed84b4835f84 100644 (file)
@@ -747,10 +747,13 @@ class ColumnElement(
     def comparator(self):
         try:
             comparator_factory = self.type.comparator_factory
-        except AttributeError:
-            raise TypeError(
-                "Object %r associated with '.type' attribute "
-                "is not a TypeEngine class or object" % self.type
+        except AttributeError as err:
+            util.raise_(
+                TypeError(
+                    "Object %r associated with '.type' attribute "
+                    "is not a TypeEngine class or object" % self.type
+                ),
+                replace_context=err,
             )
         else:
             return comparator_factory(self)
@@ -758,10 +761,17 @@ class ColumnElement(
     def __getattr__(self, key):
         try:
             return getattr(self.comparator, key)
-        except AttributeError:
-            raise AttributeError(
-                "Neither %r object nor %r object has an attribute %r"
-                % (type(self).__name__, type(self.comparator).__name__, key)
+        except AttributeError as err:
+            util.raise_(
+                AttributeError(
+                    "Neither %r object nor %r object has an attribute %r"
+                    % (
+                        type(self).__name__,
+                        type(self.comparator).__name__,
+                        key,
+                    )
+                ),
+                replace_context=err,
             )
 
     def operate(self, op, *other, **kwargs):
@@ -1742,10 +1752,13 @@ class TextClause(
                 # a unique/anonymous key in any case, so use the _orig_key
                 # so that a text() construct can support unique parameters
                 existing = new_params[bind._orig_key]
-            except KeyError:
-                raise exc.ArgumentError(
-                    "This text() construct doesn't define a "
-                    "bound parameter named %r" % bind._orig_key
+            except KeyError as err:
+                util.raise_(
+                    exc.ArgumentError(
+                        "This text() construct doesn't define a "
+                        "bound parameter named %r" % bind._orig_key
+                    ),
+                    replace_context=err,
                 )
             else:
                 new_params[existing._orig_key] = bind
@@ -1753,10 +1766,13 @@ class TextClause(
         for key, value in names_to_values.items():
             try:
                 existing = new_params[key]
-            except KeyError:
-                raise exc.ArgumentError(
-                    "This text() construct doesn't define a "
-                    "bound parameter named %r" % key
+            except KeyError as err:
+                util.raise_(
+                    exc.ArgumentError(
+                        "This text() construct doesn't define a "
+                        "bound parameter named %r" % key
+                    ),
+                    replace_context=err,
                 )
             else:
                 new_params[key] = existing._with_value(value)
@@ -3665,9 +3681,12 @@ class Over(ColumnElement):
         else:
             try:
                 lower = int(range_[0])
-            except ValueError:
-                raise exc.ArgumentError(
-                    "Integer or None expected for range value"
+            except ValueError as err:
+                util.raise_(
+                    exc.ArgumentError(
+                        "Integer or None expected for range value"
+                    ),
+                    replace_context=err,
                 )
             else:
                 if lower == 0:
@@ -3678,9 +3697,12 @@ class Over(ColumnElement):
         else:
             try:
                 upper = int(range_[1])
-            except ValueError:
-                raise exc.ArgumentError(
-                    "Integer or None expected for range value"
+            except ValueError as err:
+                util.raise_(
+                    exc.ArgumentError(
+                        "Integer or None expected for range value"
+                    ),
+                    replace_context=err,
                 )
             else:
                 if upper == 0:
index e6d3a6b059c25c7e2d24b059da3f0307e2a67916..5445a1bceabfb598598b36fb11ee62fc2ba9c659 100644 (file)
@@ -107,12 +107,13 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable):
             if item is not None:
                 try:
                     spwd = item._set_parent_with_dispatch
-                except AttributeError:
-                    util.raise_from_cause(
+                except AttributeError as err:
+                    util.raise_(
                         exc.ArgumentError(
                             "'SchemaItem' object, such as a 'Column' or a "
                             "'Constraint' expected, got %r" % item
-                        )
+                        ),
+                        replace_context=err,
                     )
                 else:
                     spwd(self)
@@ -1569,15 +1570,16 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
                 _proxies=[self],
                 *fk
             )
-        except TypeError:
-            util.raise_from_cause(
+        except TypeError as err:
+            util.raise_(
                 TypeError(
                     "Could not create a copy of this %r object.  "
                     "Ensure the class includes a _constructor() "
                     "attribute or method which accepts the "
                     "standard Column constructor arguments, or "
                     "references the Column class itself." % self.__class__
-                )
+                ),
+                from_=err,
             )
 
         c.table = selectable
@@ -3187,10 +3189,13 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
         try:
             ColumnCollectionConstraint._set_parent(self, table)
         except KeyError as ke:
-            raise exc.ArgumentError(
-                "Can't create ForeignKeyConstraint "
-                "on table '%s': no column "
-                "named '%s' is present." % (table.description, ke.args[0])
+            util.raise_(
+                exc.ArgumentError(
+                    "Can't create ForeignKeyConstraint "
+                    "on table '%s': no column "
+                    "named '%s' is present." % (table.description, ke.args[0])
+                ),
+                from_=ke,
             )
 
         for col, fk in zip(self.columns, self.elements):
index b8d88e160e19fc2957652b0d514ecded8c984208..b972c13be63c96f010889e72d15fe3010d6a6f82 100644 (file)
@@ -2620,10 +2620,13 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
             return None
         try:
             value = clause._limit_offset_value
-        except AttributeError:
-            raise exc.CompileError(
-                "This SELECT structure does not use a simple "
-                "integer value for %s" % attrname
+        except AttributeError as err:
+            util.raise_(
+                exc.CompileError(
+                    "This SELECT structure does not use a simple "
+                    "integer value for %s" % attrname
+                ),
+                replace_context=err,
             )
         else:
             return util.asint(value)
@@ -3489,10 +3492,13 @@ class Select(
 
         try:
             cols_present = bool(columns)
-        except TypeError:
-            raise exc.ArgumentError(
-                "columns argument to select() must "
-                "be a Python list or other iterable"
+        except TypeError as err:
+            util.raise_(
+                exc.ArgumentError(
+                    "columns argument to select() must "
+                    "be a Python list or other iterable"
+                ),
+                from_=err,
             )
 
         if cols_present:
index 22c80cc91ee6fda00dda3dd98da03fbf55b5bc95..e4a029a3e39b2f6a7c5de25fbe5259cb499c6e85 100644 (file)
@@ -1462,7 +1462,7 @@ class Enum(Emulated, String, SchemaType):
     def _db_value_for_elem(self, elem):
         try:
             return self._valid_lookup[elem]
-        except KeyError:
+        except KeyError as err:
             # for unknown string values, we return as is.  While we can
             # validate these if we wanted, that does not allow for lesser-used
             # end-user use cases, such as using a LIKE comparison with an enum,
@@ -1476,8 +1476,11 @@ class Enum(Emulated, String, SchemaType):
             ):
                 return elem
             else:
-                raise LookupError(
-                    '"%s" is not among the defined enum values' % elem
+                util.raise_(
+                    LookupError(
+                        '"%s" is not among the defined enum values' % elem
+                    ),
+                    replace_context=err,
                 )
 
     class Comparator(String.Comparator):
@@ -1496,9 +1499,12 @@ class Enum(Emulated, String, SchemaType):
     def _object_value_for_elem(self, elem):
         try:
             return self._object_lookup[elem]
-        except KeyError:
-            raise LookupError(
-                '"%s" is not among the defined enum values' % elem
+        except KeyError as err:
+            util.raise_(
+                LookupError(
+                    '"%s" is not among the defined enum values' % elem
+                ),
+                replace_context=err,
             )
 
     def __repr__(self):
index c6c860844a56e2dce8d4ec7075d27f8eb951f31d..739f9619549c91d25327a48f02a089caedacaba2 100644 (file)
@@ -479,9 +479,12 @@ class TypeEngine(Traversible):
         try:
             return dialect._type_memos[self]["literal"]
         except KeyError:
-            d = self._dialect_info(dialect)
-            d["literal"] = lp = d["impl"].literal_processor(dialect)
-            return lp
+            pass
+        # avoid KeyError context coming into literal_processor() function
+        # raises
+        d = self._dialect_info(dialect)
+        d["literal"] = lp = d["impl"].literal_processor(dialect)
+        return lp
 
     def _cached_bind_processor(self, dialect):
         """Return a dialect-specific bind processor for this type."""
@@ -489,9 +492,12 @@ class TypeEngine(Traversible):
         try:
             return dialect._type_memos[self]["bind"]
         except KeyError:
-            d = self._dialect_info(dialect)
-            d["bind"] = bp = d["impl"].bind_processor(dialect)
-            return bp
+            pass
+        # avoid KeyError context coming into bind_processor() function
+        # raises
+        d = self._dialect_info(dialect)
+        d["bind"] = bp = d["impl"].bind_processor(dialect)
+        return bp
 
     def _cached_result_processor(self, dialect, coltype):
         """Return a dialect-specific result processor for this type."""
@@ -499,21 +505,27 @@ class TypeEngine(Traversible):
         try:
             return dialect._type_memos[self][coltype]
         except KeyError:
-            d = self._dialect_info(dialect)
-            # key assumption: DBAPI type codes are
-            # constants.  Else this dictionary would
-            # grow unbounded.
-            d[coltype] = rp = d["impl"].result_processor(dialect, coltype)
-            return rp
+            pass
+        # avoid KeyError context coming into result_processor() function
+        # raises
+        d = self._dialect_info(dialect)
+        # key assumption: DBAPI type codes are
+        # constants.  Else this dictionary would
+        # grow unbounded.
+        d[coltype] = rp = d["impl"].result_processor(dialect, coltype)
+        return rp
 
     def _cached_custom_processor(self, dialect, key, fn):
         try:
             return dialect._type_memos[self][key]
         except KeyError:
-            d = self._dialect_info(dialect)
-            impl = d["impl"]
-            d[key] = result = fn(impl)
-            return result
+            pass
+        # avoid KeyError context coming into fn() function
+        # raises
+        d = self._dialect_info(dialect)
+        impl = d["impl"]
+        d[key] = result = fn(impl)
+        return result
 
     def _dialect_info(self, dialect):
         """Return a dialect-specific registry which
index 77e6b53a8ad9a6deaaf97cd820e058eed8b6eb3b..fda48c65743e6e78a80d163325f591fa4d952706 100644 (file)
@@ -62,9 +62,10 @@ def _generate_compiler_dispatch(cls):
         "def _compiler_dispatch(self, visitor, **kw):\n"
         "    try:\n"
         "        meth = visitor.visit_%(name)s\n"
-        "    except AttributeError:\n"
-        "        util.raise_from_cause(\n"
-        "            exc.UnsupportedCompilationError(visitor, cls))\n"
+        "    except AttributeError as err:\n"
+        "        util.raise_(\n"
+        "            exc.UnsupportedCompilationError(visitor, cls), \n"
+        "            replace_context=err)\n"
         "    else:\n"
         "        return meth(self, **kw)\n"
     ) % {"name": visit_name}
index 5829015798af0febb3075ee19c87ecf487c2cf3d..79b7f9eb3d6226ff6b43a71339dc789f2a1d97cd 100644 (file)
@@ -9,7 +9,9 @@
 from . import config  # noqa
 from . import mock  # noqa
 from .assertions import assert_raises  # noqa
+from .assertions import assert_raises_context_ok  # noqa
 from .assertions import assert_raises_message  # noqa
+from .assertions import assert_raises_message_context_ok  # noqa
 from .assertions import assert_raises_return  # noqa
 from .assertions import AssertsCompiledSQL  # noqa
 from .assertions import AssertsExecutionResults  # noqa
index f5325b0cb15cf015ebdaed1499d25a914592fed9..c97202516bf083a58b08b9e0955b7e890b9b012e 100644 (file)
@@ -9,6 +9,7 @@ from __future__ import absolute_import
 
 import contextlib
 import re
+import sys
 import warnings
 
 from . import assertsql
@@ -258,41 +259,80 @@ def eq_ignore_whitespace(a, b, msg=None):
     assert a == b, msg or "%r != %r" % (a, b)
 
 
+def _assert_proper_exception_context(exception):
+    """assert that any exception we're catching does not have a __context__
+    without a __cause__, and that __suppress_context__ is never set.
+
+    Python 3 will report nested as exceptions as "during the handling of
+    error X, error Y occurred". That's not what we want to do.  we want
+    these exceptions in a cause chain.
+
+    """
+
+    if not util.py3k:
+        return
+
+    if (
+        exception.__context__ is not exception.__cause__
+        and not exception.__suppress_context__
+    ):
+        assert False, (
+            "Exception %r was correctly raised but did not set a cause, "
+            "within context %r as its cause."
+            % (exception, exception.__context__)
+        )
+
+
 def assert_raises(except_cls, callable_, *args, **kw):
-    try:
-        callable_(*args, **kw)
-        success = False
-    except except_cls:
-        success = True
+    _assert_raises(except_cls, callable_, args, kw, check_context=True)
 
-    # assert outside the block so it works for AssertionError too !
-    assert success, "Callable did not raise an exception"
+
+def assert_raises_context_ok(except_cls, callable_, *args, **kw):
+    _assert_raises(
+        except_cls, callable_, args, kw,
+    )
 
 
 def assert_raises_return(except_cls, callable_, *args, **kw):
+    return _assert_raises(except_cls, callable_, args, kw, check_context=True)
+
+
+def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
+    _assert_raises(
+        except_cls, callable_, args, kwargs, msg=msg, check_context=True
+    )
+
+
+def assert_raises_message_context_ok(
+    except_cls, msg, callable_, *args, **kwargs
+):
+    _assert_raises(except_cls, callable_, args, kwargs, msg=msg)
+
+
+def _assert_raises(
+    except_cls, callable_, args, kwargs, msg=None, check_context=False
+):
     ret_err = None
+    if check_context:
+        are_we_already_in_a_traceback = sys.exc_info()[0]
     try:
-        callable_(*args, **kw)
+        callable_(*args, **kwargs)
         success = False
     except except_cls as err:
-        success = True
         ret_err = err
+        success = True
+        if msg is not None:
+            assert re.search(
+                msg, util.text_type(err), re.UNICODE
+            ), "%r !~ %s" % (msg, err,)
+        if check_context and not are_we_already_in_a_traceback:
+            _assert_proper_exception_context(err)
+        print(util.text_type(err).encode("utf-8"))
 
     # assert outside the block so it works for AssertionError too !
     assert success, "Callable did not raise an exception"
-    return ret_err
 
-
-def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
-    try:
-        callable_(*args, **kwargs)
-        assert False, "Callable did not raise an exception"
-    except except_cls as e:
-        assert re.search(msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (
-            msg,
-            e,
-        )
-        print(util.text_type(e).encode("utf-8"))
+    return ret_err
 
 
 class AssertsCompiledSQL(object):
index 0c05bf9e9b2390e81e064a4961605ba4dfe887ad..1a23ebf416ec9922690cc7850c6f1c9c8d092b70 100644 (file)
@@ -9,6 +9,7 @@
 import contextlib
 import operator
 import re
+import sys
 
 from . import config
 from .. import util
@@ -145,7 +146,7 @@ class compound(object):
                 )
                 break
         else:
-            util.raise_from_cause(ex)
+            util.raise_(ex, with_traceback=sys.exc_info()[2])
 
     def _expect_success(self, config, name="block"):
         if not self.fails:
index b0ceb802a453dfd8957ff07a38cd8efdf761f8fe..660a0e97668d2b7b4cf6e42a16d6d2c8baa4f725 100644 (file)
@@ -68,6 +68,7 @@ from .compat import py33  # noqa
 from .compat import py36  # noqa
 from .compat import py3k  # noqa
 from .compat import quote_plus  # noqa
+from .compat import raise_  # noqa
 from .compat import raise_from_cause  # noqa
 from .compat import reduce  # noqa
 from .compat import reraise  # noqa
index 8967955cd7b5e471508c831cb61b06b8186608ef..004b4687a638fd0ee08514b5258e3dfcc0a02cd7 100644 (file)
@@ -147,13 +147,42 @@ if py3k:
     def cmp(a, b):
         return (a > b) - (a < b)
 
-    def reraise(tp, value, tb=None, cause=None):
-        if cause is not None:
-            assert cause is not value, "Same cause emitted"
-            value.__cause__ = cause
-        if value.__traceback__ is not tb:
-            raise value.with_traceback(tb)
-        raise value
+    def raise_(
+        exception, with_traceback=None, replace_context=None, from_=False
+    ):
+        r"""implement "raise" with cause support.
+
+        :param exception: exception to raise
+        :param with_traceback: will call exception.with_traceback()
+        :param replace_context: an as-yet-unsupported feature.  This is
+         an exception object which we are "replacing", e.g., it's our
+         "cause" but we don't want it printed.    Basically just what
+         ``__suppress_context__`` does but we don't want to suppress
+         the enclosing context, if any.  So for now we make it the
+         cause.
+        :param from\_: the cause.  this actually sets the cause and doesn't
+         hope to hide it someday.
+
+        """
+        if with_traceback is not None:
+            exception = exception.with_traceback(with_traceback)
+
+        if from_ is not False:
+            exception.__cause__ = from_
+        elif replace_context is not None:
+            # no good solution here, we would like to have the exception
+            # have only the context of replace_context.__context__ so that the
+            # intermediary exception does not change, but we can't figure
+            # that out.
+            exception.__cause__ = replace_context
+
+        try:
+            raise exception
+        finally:
+            # credit to
+            # https://cosmicpercolator.com/2016/01/13/exception-leaks-in-python-2-and-3/
+            # as the __traceback__ object creates a cycle
+            del exception, replace_context, from_, with_traceback
 
     def u(s):
         return s
@@ -257,13 +286,13 @@ else:
         else:
             return text
 
-    # not as nice as that of Py3K, but at least preserves
-    # the code line where the issue occurred
     exec(
-        "def reraise(tp, value, tb=None, cause=None):\n"
-        "    if cause is not None:\n"
-        "        assert cause is not value, 'Same cause emitted'\n"
-        "    raise tp, value, tb\n"
+        "def raise_(exception, with_traceback=None, replace_context=None, "
+        "from_=False):\n"
+        "    if with_traceback:\n"
+        "        raise type(exception), exception, with_traceback\n"
+        "    else:\n"
+        "        raise exception\n"
     )
 
     TYPE_CHECKING = False
@@ -405,6 +434,8 @@ def nested(*managers):
 
 
 def raise_from_cause(exception, exc_info=None):
+    r"""legacy.  use raise\_()"""
+
     if exc_info is None:
         exc_info = sys.exc_info()
     exc_type, exc_value, exc_tb = exc_info
@@ -412,6 +443,12 @@ def raise_from_cause(exception, exc_info=None):
     reraise(type(exception), exception, tb=exc_tb, cause=cause)
 
 
+def reraise(tp, value, tb=None, cause=None):
+    r"""legacy.  use raise\_()"""
+
+    raise_(value, with_traceback=tb, from_=cause)
+
+
 def with_metaclass(meta, *bases):
     """Create a base class with a metaclass.
 
index 41a9698c7d16cebf31e8a2e21b0479ec62c5d842..09aa94bf2bcfd22d137861ebe44ea9bfe6853f88 100644 (file)
@@ -65,7 +65,9 @@ class safe_reraise(object):
             exc_type, exc_value, exc_tb = self._exc_info
             self._exc_info = None  # remove potential circular references
             if not self.warn_only:
-                compat.reraise(exc_type, exc_value, exc_tb)
+                compat.raise_(
+                    exc_value, with_traceback=exc_tb,
+                )
         else:
             if not compat.py3k and self._exc_info and self._exc_info[1]:
                 # emulate Py3K's behavior of telling us when an exception
@@ -76,7 +78,7 @@ class safe_reraise(object):
                     "is:\n %s %s\n" % (self._exc_info[0], self._exc_info[1])
                 )
             self._exc_info = None  # remove potential circular references
-            compat.reraise(type_, value, traceback)
+            compat.raise_(value, with_traceback=traceback)
 
 
 def string_or_unprintable(element):
index 55890cd064982a0761000c292531cc9dcdbc5880..8f84acde8a346f48965e8c53410c3c1aabaf71fb 100644 (file)
@@ -1124,6 +1124,20 @@ class CycleTest(_fixtures.FixtureTest):
 
         go()
 
+    def test_raise_from(self):
+        @assert_cycles()
+        def go():
+            try:
+                try:
+                    raise KeyError("foo")
+                except KeyError as ke:
+
+                    util.raise_(Exception("oops"), from_=ke)
+            except Exception as err:  # noqa
+                pass
+
+        go()
+
     def test_query_alias(self):
         User, Address = self.classes("User", "Address")
         configure_mappers()
index 87908f016a5c74e811cd05154e6435cde25fd275..73a1a8b6fe090d1f704ec784789ad59cee309884 100644 (file)
@@ -111,7 +111,7 @@ class ResultSetTest(fixtures.TestBase, AssertsExecutionResults):
             "some other column", Integer
         )
 
-        @profiling.function_call_count()
+        @profiling.function_call_count(variance=0.10)
         def go():
             c1 in row
 
index 48e464a01e5f4c93870df04e1eba5e16443e85ce..183e157e5e6997368ce98f039fbb6677d4b0435b 100644 (file)
@@ -3,7 +3,6 @@
 import copy
 import datetime
 import inspect
-import sys
 
 from sqlalchemy import exc
 from sqlalchemy import sql
@@ -2899,20 +2898,29 @@ class ReraiseTest(fixtures.TestBase):
         except MyException as err:
             is_(err.__cause__, None)
 
-    def test_reraise_disallow_same_cause(self):
+    def test_raise_from_cause_legacy(self):
         class MyException(Exception):
             pass
 
+        class MyOtherException(Exception):
+            pass
+
+        me = MyException("exc on")
+
         def go():
             try:
-                raise MyException("exc one")
-            except Exception as err:
-                type_, value, tb = sys.exc_info()
-                util.reraise(type_, err, tb, value)
+                raise me
+            except Exception:
+                util.raise_from_cause(MyOtherException("exc two"))
 
-        assert_raises_message(AssertionError, "Same cause emitted", go)
+        try:
+            go()
+            assert False
+        except MyOtherException as moe:
+            if testing.requires.python3.enabled:
+                is_(moe.__cause__, me)
 
-    def test_raise_from_cause(self):
+    def test_raise_from(self):
         class MyException(Exception):
             pass
 
@@ -2924,8 +2932,8 @@ class ReraiseTest(fixtures.TestBase):
         def go():
             try:
                 raise me
-            except Exception:
-                util.raise_from_cause(MyOtherException("exc two"))
+            except Exception as err:
+                util.raise_(MyOtherException("exc two"), from_=err)
 
         try:
             go()
index 5acd14177e1d76ead2ff83efa04a30dffe4ad186..cf262a5738f864f4536ca11b4dd13cbb8de137cc 100644 (file)
@@ -712,12 +712,13 @@ class ExecuteTest(fixtures.TestBase):
                     return super(MockCursor, self).execute(stmt, params, **kw)
 
         eng = engines.proxying_engine(cursor_cls=MockCursor)
-        assert_raises_message(
-            tsa.exc.SAWarning,
-            "Exception attempting to detect unicode returns",
-            eng.connect,
-        )
-        assert eng.dialect.returns_unicode_strings in (True, False)
+        with testing.expect_warnings(
+            "Exception attempting to detect unicode returns"
+        ):
+            eng.connect()
+
+        # because plain varchar passed, we don't know the correct answer
+        eq_(eng.dialect.returns_unicode_strings, "conditional")
         eng.dispose()
 
     def test_works_after_dispose(self):
index cfe20f5ec07a6891309538290c25d44eece90a61..72e0fa1865a53355ce9ca927ef016156a3b37acb 100644 (file)
@@ -10,6 +10,7 @@ from sqlalchemy import pool
 from sqlalchemy import select
 from sqlalchemy import testing
 from sqlalchemy.testing import assert_raises
+from sqlalchemy.testing import assert_raises_context_ok
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
@@ -1256,7 +1257,7 @@ class QueuePoolTest(PoolTestBase):
             eq_(p.checkedout(), 0)
             eq_(p._overflow, 0)
             dbapi.shutdown(True)
-            assert_raises(Exception, p.connect)
+            assert_raises_context_ok(Exception, p.connect)
             eq_(p._overflow, 0)
             eq_(p.checkedout(), 0)  # and not 1
 
index 205c1fb310402c31fbea66f9e33d3da2393afc50..000be1a7019b9ecb930d1bf6500a9a4ba5eaf9da 100644 (file)
@@ -14,6 +14,7 @@ from sqlalchemy import util
 from sqlalchemy.engine import url
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
+from sqlalchemy.testing import assert_raises_message_context_ok
 from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_warnings
@@ -255,7 +256,7 @@ class PrePingMockTest(fixtures.TestBase):
 
         self.dbapi.shutdown("execute", stop=True)
 
-        assert_raises_message(
+        assert_raises_message_context_ok(
             MockDisconnect, "database is stopped", pool.connect
         )
 
@@ -835,7 +836,7 @@ class CursorErrTest(fixtures.TestBase):
 
     def test_cursor_shutdown_in_initialize(self):
         db = self._fixture(True, True)
-        assert_raises_message(
+        assert_raises_message_context_ok(
             exc.SAWarning, "Exception attempting to detect", db.connect
         )
         eq_(
index 301614061e56f7df90243879147c880eb4167580..579f1aecec67b57824cf26b14eb51966c54c0a83 100644 (file)
@@ -24,6 +24,7 @@ from sqlalchemy.testing import eq_regex
 from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import in_
+from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_false
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
@@ -596,13 +597,10 @@ class ReflectionTest(fixtures.TestBase, ComparesTables):
         testing.db.dialect.ischema_names = {}
         try:
             m2 = MetaData(testing.db)
-            assert_raises(sa.exc.SAWarning, Table, "test", m2, autoload=True)
 
-            @testing.emits_warning("Did not recognize type")
-            def warns():
-                m3 = MetaData(testing.db)
-                t3 = Table("test", m3, autoload=True)
-                assert t3.c.foo.type.__class__ == sa.types.NullType
+            with testing.expect_warnings("Did not recognize type"):
+                t3 = Table("test", m2, autoload_with=testing.db)
+                is_(t3.c.foo.type.__class__, sa.types.NullType)
 
         finally:
             testing.db.dialect.ischema_names = ischema_names
index 3f4333750f34f59a8430f92b96e73ffcc6d9765b..8ef272a9ef13babdc954e2217208c6444bffb6ad 100644 (file)
@@ -4093,16 +4093,11 @@ class DialectKWArgTest(fixtures.TestBase):
 
     def test_unknown_dialect_warning(self):
         with self._fixture():
-            assert_raises_message(
-                exc.SAWarning,
+            with testing.expect_warnings(
                 "Can't validate argument 'unknown_y'; can't locate "
                 "any SQLAlchemy dialect named 'unknown'",
-                Index,
-                "a",
-                "b",
-                "c",
-                unknown_y=True,
-            )
+            ):
+                Index("a", "b", "c", unknown_y=True)
 
     def test_participating_bad_kw(self):
         with self._fixture():