From: Mike Bayer Date: Sat, 29 Feb 2020 19:40:45 +0000 (-0500) Subject: Ensure all nested exception throws have a cause X-Git-Tag: rel_1_4_0b1~493^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=57dc36a01b2b334a996f73f6a78b3bfbe4d9f2ec;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Ensure all nested exception throws have a cause Applied an explicit "cause" to most if not all internally raised exceptions that are raised from within an internal exception catch, to avoid misleading stacktraces that suggest an error within the handling of an exception. While it would be preferable to suppress the internally caught exception in the way that the ``__suppress_context__`` attribute would, there does not as yet seem to be a way to do this without suppressing an enclosing user constructed context, so for now it exposes the internally caught exception as the cause so that full information about the context of the error is maintained. Fixes: #4849 Change-Id: I55a86b29023675d9e5e49bc7edc5a2dc0bcd4751 --- diff --git a/doc/build/changelog/unreleased_13/4849.rst b/doc/build/changelog/unreleased_13/4849.rst new file mode 100644 index 0000000000..5a649dc331 --- /dev/null +++ b/doc/build/changelog/unreleased_13/4849.rst @@ -0,0 +1,13 @@ +.. change:: + :tags: bug, general, py3k + :tickets: 4849 + + Applied an explicit "cause" to most if not all internally raised exceptions + that are raised from within an internal exception catch, to avoid + misleading stacktraces that suggest an error within the handling of an + exception. While it would be preferable to suppress the internally caught + exception in the way that the ``__suppress_context__`` attribute would, + there does not as yet seem to be a way to do this without suppressing an + enclosing user constructed context, so for now it exposes the internally + caught exception as the cause so that full information about the context + of the error is maintained. diff --git a/lib/sqlalchemy/cextension/resultproxy.c b/lib/sqlalchemy/cextension/resultproxy.c index 3c44010b89..f622e6a288 100644 --- a/lib/sqlalchemy/cextension/resultproxy.c +++ b/lib/sqlalchemy/cextension/resultproxy.c @@ -288,7 +288,7 @@ BaseRow_getitem_by_object(BaseRow *self, PyObject *key, int asmapping) if (record == NULL) { record = PyObject_CallMethod(self->parent, "_key_fallback", - "O", key); + "OO", key, Py_None); if (record == NULL) return NULL; key_fallback = 1; // boolean to indicate record is a new reference diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index e0bf167931..6ea8cbcb81 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -2968,7 +2968,7 @@ class MySQLDialect(default.DefaultDialect): ).execute(st) except exc.DBAPIError as e: if self._extract_error_code(e.orig) == 1146: - raise exc.NoSuchTableError(full_name) + util.raise_(exc.NoSuchTableError(full_name), replace_context=e) else: raise row = self._compat_first(rp, charset=charset) @@ -2992,11 +2992,16 @@ class MySQLDialect(default.DefaultDialect): except exc.DBAPIError as e: code = self._extract_error_code(e.orig) if code == 1146: - raise exc.NoSuchTableError(full_name) + util.raise_( + exc.NoSuchTableError(full_name), replace_context=e + ) elif code == 1356: - raise exc.UnreflectableTableError( - "Table or view named %s could not be " - "reflected: %s" % (full_name, e) + util.raise_( + exc.UnreflectableTableError( + "Table or view named %s could not be " + "reflected: %s" % (full_name, e) + ), + replace_context=e, ) else: raise diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index 0b6afc337d..1b1c9b0ba8 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -763,11 +763,14 @@ class PGDialect_psycopg2(PGDialect): def set_isolation_level(self, connection, level): try: level = self._isolation_lookup[level.replace("_", " ")] - except KeyError: - raise exc.ArgumentError( - "Invalid value '%s' for isolation_level. " - "Valid isolation levels for %s are %s" - % (level, self.name, ", ".join(self._isolation_lookup)) + except KeyError as err: + util.raise_( + exc.ArgumentError( + "Invalid value '%s' for isolation_level. " + "Valid isolation levels for %s are %s" + % (level, self.name, ", ".join(self._isolation_lookup)) + ), + replace_context=err, ) connection.set_isolation_level(level) diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index d04b543cd1..b1a83bf921 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -997,9 +997,12 @@ class SQLiteCompiler(compiler.SQLCompiler): self.extract_map[extract.field], self.process(extract.expr, **kw), ) - except KeyError: - raise exc.CompileError( - "%s is not a valid extract argument." % extract.field + except KeyError as err: + util.raise_( + exc.CompileError( + "%s is not a valid extract argument." % extract.field + ), + replace_context=err, ) def limit_clause(self, select, **kw): @@ -1537,11 +1540,14 @@ class SQLiteDialect(default.DefaultDialect): def set_isolation_level(self, connection, level): try: isolation_level = self._isolation_lookup[level.replace("_", " ")] - except KeyError: - raise exc.ArgumentError( - "Invalid value '%s' for isolation_level. " - "Valid isolation levels for %s are %s" - % (level, self.name, ", ".join(self._isolation_lookup)) + except KeyError as err: + util.raise_( + exc.ArgumentError( + "Invalid value '%s' for isolation_level. " + "Valid isolation levels for %s are %s" + % (level, self.name, ", ".join(self._isolation_lookup)) + ), + replace_context=err, ) cursor = connection.cursor() cursor.execute("PRAGMA read_uncommitted = %d" % isolation_level) diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index ce6c2e9c67..449f386cea 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -996,8 +996,10 @@ class Connection(Connectable): return self._execute_text(object_, multiparams, params) try: meth = object_._execute_on_connection - except AttributeError: - raise exc.ObjectNotExecutableError(object_) + except AttributeError as err: + util.raise_( + exc.ObjectNotExecutableError(object_), replace_context=err + ) else: return meth(self, multiparams, params) @@ -1400,7 +1402,7 @@ class Connection(Connectable): invalidate_pool_on_disconnect = not is_exit_exception if self._reentrant_error: - util.raise_from_cause( + util.raise_( exc.DBAPIError.instance( statement, parameters, @@ -1412,7 +1414,8 @@ class Connection(Connectable): if context is not None else None, ), - exc_info, + with_traceback=exc_info[2], + from_=e, ) self._reentrant_error = True try: @@ -1502,11 +1505,13 @@ class Connection(Connectable): self._autorollback() if newraise: - util.raise_from_cause(newraise, exc_info) + util.raise_(newraise, with_traceback=exc_info[2], from_=e) elif should_wrap: - util.raise_from_cause(sqlalchemy_exception, exc_info) + util.raise_( + sqlalchemy_exception, with_traceback=exc_info[2], from_=e + ) else: - util.reraise(*exc_info) + util.raise_(exc_info[1], with_traceback=exc_info[2]) finally: del self._reentrant_error @@ -1573,11 +1578,13 @@ class Connection(Connectable): ) = ctx.is_disconnect if newraise: - util.raise_from_cause(newraise, exc_info) + util.raise_(newraise, with_traceback=exc_info[2], from_=e) elif should_wrap: - util.raise_from_cause(sqlalchemy_exception, exc_info) + util.raise_( + sqlalchemy_exception, with_traceback=exc_info[2], from_=e + ) else: - util.reraise(*exc_info) + util.raise_(exc_info[1], with_traceback=exc_info[2]) def _run_ddl_visitor(self, visitorcallable, element, **kwargs): """run a DDL visitor. @@ -2329,7 +2336,9 @@ class Engine(Connectable, log.Identified): e, dialect, self ) else: - util.reraise(*sys.exc_info()) + util.raise_( + sys.exc_info()[1], with_traceback=sys.exc_info()[2] + ) def raw_connection(self, _connection=None): """Return a "raw" DBAPI connection from the connection pool. diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 1a63c307bc..7db9eecaea 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -53,11 +53,11 @@ class ResultMetaData(object): def _has_key(self, key): return key in self._keymap - def _key_fallback(self, key): + def _key_fallback(self, key, err): if isinstance(key, int): - raise IndexError(key) + util.raise_(IndexError(key), replace_context=err) else: - raise KeyError(key) + util.raise_(KeyError(key), replace_context=err) class SimpleResultMetaData(ResultMetaData): @@ -546,11 +546,14 @@ class CursorResultMetaData(ResultMetaData): ) in self._colnames_from_description(context, cursor_description): yield idx, colname, sqltypes.NULLTYPE, coltype, None, untranslated - def _key_fallback(self, key, raiseerr=True): + def _key_fallback(self, key, err, raiseerr=True): if raiseerr: - raise exc.NoSuchColumnError( - "Could not locate column in row for column '%s'" - % util.string_or_unprintable(key) + util.raise_( + exc.NoSuchColumnError( + "Could not locate column in row for column '%s'" + % util.string_or_unprintable(key) + ), + replace_context=err, ) else: return None @@ -570,8 +573,8 @@ class CursorResultMetaData(ResultMetaData): def _getter(self, key, raiseerr=True): try: rec = self._keymap[key] - except KeyError: - rec = self._key_fallback(key, raiseerr) + except KeyError as ke: + rec = self._key_fallback(key, ke, raiseerr) if rec is None: return None @@ -598,8 +601,8 @@ class CursorResultMetaData(ResultMetaData): for key in keys: try: rec = self._keymap[key] - except KeyError: - rec = self._key_fallback(key, raiseerr) + except KeyError as ke: + rec = self._key_fallback(key, ke, raiseerr) if rec is None: return None @@ -656,9 +659,9 @@ class LegacyCursorResultMetaData(CursorResultMetaData): ) return True else: - return self._key_fallback(key, False) is not None + return self._key_fallback(key, None, False) is not None - def _key_fallback(self, key, raiseerr=True): + def _key_fallback(self, key, err, raiseerr=True): map_ = self._keymap result = None @@ -714,9 +717,12 @@ class LegacyCursorResultMetaData(CursorResultMetaData): ) if result is None: if raiseerr: - raise exc.NoSuchColumnError( - "Could not locate column in row for column '%s'" - % util.string_or_unprintable(key) + util.raise_( + exc.NoSuchColumnError( + "Could not locate column in row for column '%s'" + % util.string_or_unprintable(key) + ), + replace_context=err, ) else: return None @@ -736,7 +742,7 @@ class LegacyCursorResultMetaData(CursorResultMetaData): if key in self._keymap: return True else: - return self._key_fallback(key, False) is not None + return self._key_fallback(key, None, False) is not None class CursorFetchStrategy(object): @@ -807,9 +813,12 @@ class NoCursorDQLFetchStrategy(CursorFetchStrategy): def fetchall(self): return self._non_result([]) - def _non_result(self, default): + def _non_result(self, default, err=None): if self.closed: - raise exc.ResourceClosedError("This result object is closed.") + util.raise_( + exc.ResourceClosedError("This result object is closed."), + replace_context=err, + ) else: return default @@ -843,10 +852,13 @@ class NoCursorDMLFetchStrategy(CursorFetchStrategy): def fetchall(self): return self._non_result([]) - def _non_result(self, default): - raise exc.ResourceClosedError( - "This result object does not return rows. " - "It has been closed automatically." + def _non_result(self, default, err=None): + util.raise_( + exc.ResourceClosedError( + "This result object does not return rows. " + "It has been closed automatically." + ), + replace_context=err, ) @@ -1123,24 +1135,24 @@ class BaseResult(object): def _getter(self, key, raiseerr=True): try: getter = self._metadata._getter - except AttributeError: - return self.cursor_strategy._non_result(None) + except AttributeError as err: + return self.cursor_strategy._non_result(None, err) else: return getter(key, raiseerr) def _tuple_getter(self, key, raiseerr=True): try: getter = self._metadata._tuple_getter - except AttributeError: - return self.cursor_strategy._non_result(None) + except AttributeError as err: + return self.cursor_strategy._non_result(None, err) else: return getter(key, raiseerr) def _has_key(self, key): try: has_key = self._metadata._has_key - except AttributeError: - return self.cursor_strategy._non_result(None) + except AttributeError as err: + return self.cursor_strategy._non_result(None, err) else: return has_key(key) diff --git a/lib/sqlalchemy/engine/row.py b/lib/sqlalchemy/engine/row.py index 55d8c2249d..b58b350e25 100644 --- a/lib/sqlalchemy/engine/row.py +++ b/lib/sqlalchemy/engine/row.py @@ -84,8 +84,8 @@ except ImportError: def _subscript_impl(self, key, ismapping): try: rec = self._keymap[key] - except KeyError: - rec = self._parent._key_fallback(key) + except KeyError as ke: + rec = self._parent._key_fallback(key, ke) except TypeError: # the non-C version detects a slice using TypeError. # this is pretty inefficient for the slice use case @@ -119,7 +119,7 @@ except ImportError: try: return self._get_by_key_impl_mapping(name) except KeyError as e: - raise AttributeError(e.args[0]) + util.raise_(AttributeError(e.args[0]), replace_context=e) class Row(BaseRow, collections_abc.Sequence): diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index 41346fc4e0..f00b642dbd 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -244,6 +244,10 @@ class AssociationProxy(interfaces.InspectionAttrInfo): try: inst = class_.__dict__[self.key + "_inst"] except KeyError: + inst = None + + # avoid exception context + if inst is None: owner = self._calc_owner(class_) if owner is not None: inst = AssociationProxyInstance.for_proxy(self, owner, obj) @@ -358,9 +362,12 @@ class AssociationProxyInstance(object): # this was never asserted before but this should be made clear. if not isinstance(prop, orm.RelationshipProperty): - raise NotImplementedError( - "association proxy to a non-relationship " - "intermediary is not supported" + util.raise_( + NotImplementedError( + "association proxy to a non-relationship " + "intermediary is not supported" + ), + replace_context=None, ) target_class = prop.mapper.class_ @@ -1323,10 +1330,13 @@ class _AssociationDict(_AssociationCollection): try: for k, v in seq_or_map: self[k] = v - except ValueError: - raise ValueError( - "dictionary update sequence " - "requires 2-element tuples" + except ValueError as err: + util.raise_( + ValueError( + "dictionary update sequence " + "requires 2-element tuples" + ), + replace_context=err, ) for key, value in kw: diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py index cafe69093a..cf67387e43 100644 --- a/lib/sqlalchemy/ext/baked.py +++ b/lib/sqlalchemy/ext/baked.py @@ -504,9 +504,12 @@ class Result(object): """ try: ret = self.one_or_none() - except orm_exc.MultipleResultsFound: - raise orm_exc.MultipleResultsFound( - "Multiple rows were found for one()" + except orm_exc.MultipleResultsFound as err: + util.raise_( + orm_exc.MultipleResultsFound( + "Multiple rows were found for one()" + ), + replace_context=err, ) else: if ret is None: diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py index c27907cdcf..b8b6f8dc0d 100644 --- a/lib/sqlalchemy/ext/compiler.py +++ b/lib/sqlalchemy/ext/compiler.py @@ -398,6 +398,7 @@ Example usage:: """ from .. import exc +from .. import util from ..sql import sqltypes from ..sql import visitors @@ -422,10 +423,13 @@ def compiles(class_, *specs): def _wrap_existing_dispatch(element, compiler, **kw): try: return existing_dispatch(element, compiler, **kw) - except exc.UnsupportedCompilationError: - raise exc.CompileError( - "%s construct has no default " - "compilation handler." % type(element) + except exc.UnsupportedCompilationError as uce: + util.raise_( + exc.CompileError( + "%s construct has no default " + "compilation handler." % type(element) + ), + from_=uce, ) existing.specs["default"] = _wrap_existing_dispatch @@ -470,10 +474,13 @@ class _dispatcher(object): if not fn: try: fn = self.specs["default"] - except KeyError: - raise exc.CompileError( - "%s construct has no default " - "compilation handler." % type(element) + except KeyError as ke: + util.raise_( + exc.CompileError( + "%s construct has no default " + "compilation handler." % type(element) + ), + replace_context=ke, ) # if compilation includes add_to_result_map, collect add_to_result_map diff --git a/lib/sqlalchemy/ext/declarative/clsregistry.py b/lib/sqlalchemy/ext/declarative/clsregistry.py index 7ff30b807f..93e643cf5c 100644 --- a/lib/sqlalchemy/ext/declarative/clsregistry.py +++ b/lib/sqlalchemy/ext/declarative/clsregistry.py @@ -298,12 +298,15 @@ class _class_resolver(object): else: return x except NameError as n: - raise exc.InvalidRequestError( - "When initializing mapper %s, expression %r failed to " - "locate a name (%r). If this is a class name, consider " - "adding this relationship() to the %r class after " - "both dependent classes have been defined." - % (self.prop.parent, self.arg, n.args[0], self.cls) + util.raise_( + exc.InvalidRequestError( + "When initializing mapper %s, expression %r failed to " + "locate a name (%r). If this is a class name, consider " + "adding this relationship() to the %r class after " + "both dependent classes have been defined." + % (self.prop.parent, self.arg, n.args[0], self.cls) + ), + from_=n, ) diff --git a/lib/sqlalchemy/ext/indexable.py b/lib/sqlalchemy/ext/indexable.py index f2e0501bb3..6eb7e11850 100644 --- a/lib/sqlalchemy/ext/indexable.py +++ b/lib/sqlalchemy/ext/indexable.py @@ -223,7 +223,8 @@ The above query will render:: """ # noqa from __future__ import absolute_import -from sqlalchemy import inspect +from .. import inspect +from .. import util from ..ext.hybrid import hybrid_property from ..orm.attributes import flag_modified @@ -301,9 +302,9 @@ class index_property(hybrid_property): # noqa self.datatype = dict self.onebased = onebased - def _fget_default(self): + def _fget_default(self, err=None): if self.default == self._NO_DEFAULT_ARGUMENT: - raise AttributeError(self.attr_name) + util.raise_(AttributeError(self.attr_name), replace_context=err) else: return self.default @@ -314,8 +315,8 @@ class index_property(hybrid_property): # noqa return self._fget_default() try: value = column_value[self.index] - except (KeyError, IndexError): - return self._fget_default() + except (KeyError, IndexError) as err: + return self._fget_default(err) else: return value @@ -337,8 +338,8 @@ class index_property(hybrid_property): # noqa raise AttributeError(self.attr_name) try: del column_value[self.index] - except KeyError: - raise AttributeError(self.attr_name) + except KeyError as err: + util.raise_(AttributeError(self.attr_name), replace_context=err) else: setattr(instance, attr_name, column_value) flag_modified(instance, attr_name) diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 66a18da992..a959b0a400 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -231,16 +231,19 @@ class QueryableAttribute( def __getattr__(self, key): try: return getattr(self.comparator, key) - except AttributeError: - raise AttributeError( - "Neither %r object nor %r object associated with %s " - "has an attribute %r" - % ( - type(self).__name__, - type(self.comparator).__name__, - self, - key, - ) + except AttributeError as err: + util.raise_( + AttributeError( + "Neither %r object nor %r object associated with %s " + "has an attribute %r" + % ( + type(self).__name__, + type(self.comparator).__name__, + self, + key, + ) + ), + replace_context=err, ) def __str__(self): @@ -373,31 +376,39 @@ def create_proxied_attribute(descriptor): comparator.""" try: return getattr(descriptor, attribute) - except AttributeError: + except AttributeError as err: if attribute == "comparator": - raise AttributeError("comparator") + util.raise_( + AttributeError("comparator"), replace_context=err + ) try: # comparator itself might be unreachable comparator = self.comparator - except AttributeError: - raise AttributeError( - "Neither %r object nor unconfigured comparator " - "object associated with %s has an attribute %r" - % (type(descriptor).__name__, self, attribute) + except AttributeError as err2: + util.raise_( + AttributeError( + "Neither %r object nor unconfigured comparator " + "object associated with %s has an attribute %r" + % (type(descriptor).__name__, self, attribute) + ), + replace_context=err2, ) else: try: return getattr(comparator, attribute) - except AttributeError: - raise AttributeError( - "Neither %r object nor %r object " - "associated with %s has an attribute %r" - % ( - type(descriptor).__name__, - type(comparator).__name__, - self, - attribute, - ) + except AttributeError as err3: + util.raise_( + AttributeError( + "Neither %r object nor %r object " + "associated with %s has an attribute %r" + % ( + type(descriptor).__name__, + type(comparator).__name__, + self, + attribute, + ) + ), + replace_context=err3, ) Proxy.__name__ = type(descriptor).__name__ + "Proxy" @@ -713,12 +724,15 @@ class AttributeImpl(object): elif value is ATTR_WAS_SET: try: return dict_[key] - except KeyError: + except KeyError as err: # TODO: no test coverage here. - raise KeyError( - "Deferred loader for attribute " - "%r failed to populate " - "correctly" % key + util.raise_( + KeyError( + "Deferred loader for attribute " + "%r failed to populate " + "correctly" % key + ), + replace_context=err, ) elif value is not ATTR_EMPTY: return self.set_committed_value(state, dict_, value) diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index 571107a380..a31745aec2 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -387,9 +387,12 @@ def _entity_descriptor(entity, key): try: return getattr(entity, key) - except AttributeError: - raise sa_exc.InvalidRequestError( - "Entity '%s' has no property '%s'" % (description, key) + except AttributeError as err: + util.raise_( + sa_exc.InvalidRequestError( + "Entity '%s' has no property '%s'" % (description, key) + ), + replace_context=err, ) diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index f75c7d3bac..57c192a5d3 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -557,12 +557,15 @@ class StrategizedProperty(MapperProperty): try: return self._strategies[key] except KeyError: - cls = self._strategy_lookup(self, *key) - # this previously was setting self._strategies[cls], that's - # a bad idea; should use strategy key at all times because every - # strategy has multiple keys at this point - self._strategies[key] = strategy = cls(self, key) - return strategy + pass + + # run outside to prevent transfer of exception context + cls = self._strategy_lookup(self, *key) + # this previously was setting self._strategies[cls], that's + # a bad idea; should use strategy key at all times because every + # strategy has multiple keys at this point + self._strategies[key] = strategy = cls(self, key) + return strategy def setup(self, context, query_entity, path, adapter, **kwargs): loader = self._get_context_loader(context, path) diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 193980e6c3..d943ebb190 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -99,9 +99,9 @@ def instances(query, cursor, context): if not query._yield_per: break - except Exception as err: - cursor.close() - util.raise_from_cause(err) + except Exception: + with util.safe_reraise(): + cursor.close() @util.dependencies("sqlalchemy.orm.query") diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 0d87a9c406..91e3251e2c 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1483,11 +1483,14 @@ class Mapper(sql_base.HasCacheKey, InspectionAttr): # it to mapped ColumnProperty try: self.polymorphic_on = self._props[self.polymorphic_on] - except KeyError: - raise sa_exc.ArgumentError( - "Can't determine polymorphic_on " - "value '%s' - no attribute is " - "mapped to this name." % self.polymorphic_on + except KeyError as err: + util.raise_( + sa_exc.ArgumentError( + "Can't determine polymorphic_on " + "value '%s' - no attribute is " + "mapped to this name." % self.polymorphic_on + ), + replace_context=err, ) if self.polymorphic_on in self._columntoproperty: @@ -1987,9 +1990,12 @@ class Mapper(sql_base.HasCacheKey, InspectionAttr): try: return self._props[key] - except KeyError: - raise sa_exc.InvalidRequestError( - "Mapper '%s' has no property '%s'" % (self, key) + except KeyError as err: + util.raise_( + sa_exc.InvalidRequestError( + "Mapper '%s' has no property '%s'" % (self, key) + ), + replace_context=err, ) def get_property_by_column(self, column): diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 3b274a3893..46c84d4bda 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -1635,9 +1635,12 @@ def _sort_states(mapper, states): persistent, key=mapper._persistent_sortkey_fn ) except TypeError as err: - raise sa_exc.InvalidRequestError( - "Could not sort objects by primary key; primary key " - "values must be sortable in Python (was: %s)" % err + util.raise_( + sa_exc.InvalidRequestError( + "Could not sort objects by primary key; primary key " + "values must be sortable in Python (was: %s)" % err + ), + replace_context=err, ) return ( sorted(pending, key=operator.attrgetter("insert_order")) @@ -1681,10 +1684,13 @@ class BulkUD(object): def _factory(cls, lookup, synchronize_session, *arg): try: klass = lookup[synchronize_session] - except KeyError: - raise sa_exc.ArgumentError( - "Valid strategies for session synchronization " - "are %s" % (", ".join(sorted(repr(x) for x in lookup))) + except KeyError as err: + util.raise_( + sa_exc.ArgumentError( + "Valid strategies for session synchronization " + "are %s" % (", ".join(sorted(repr(x) for x in lookup))) + ), + replace_context=err, ) else: return klass(*arg) @@ -1768,10 +1774,13 @@ class BulkEvaluate(BulkUD): self._additional_evaluators(evaluator_compiler) except evaluator.UnevaluatableError as err: - raise sa_exc.InvalidRequestError( - 'Could not evaluate current criteria in Python: "%s". ' - "Specify 'fetch' or False for the " - "synchronize_session parameter." % err + util.raise_( + sa_exc.InvalidRequestError( + 'Could not evaluate current criteria in Python: "%s". ' + "Specify 'fetch' or False for the " + "synchronize_session parameter." % err + ), + from_=err, ) # TODO: detect when the where clause is a trivial primary key match diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index d237aa3bf2..e29e6eeeeb 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1019,15 +1019,18 @@ class Query(Generative): for prop in mapper._identity_key_props ) - except KeyError: - raise sa_exc.InvalidRequestError( - "Incorrect names of values in identifier to formulate " - "primary key for query.get(); primary key attribute names" - " are %s" - % ",".join( - "'%s'" % prop.key - for prop in mapper._identity_key_props - ) + except KeyError as err: + util.raise_( + sa_exc.InvalidRequestError( + "Incorrect names of values in identifier to formulate " + "primary key for query.get(); primary key attribute " + "names are %s" + % ",".join( + "'%s'" % prop.key + for prop in mapper._identity_key_props + ) + ), + replace_context=err, ) if ( @@ -3292,9 +3295,12 @@ class Query(Generative): """ try: ret = self.one_or_none() - except orm_exc.MultipleResultsFound: - raise orm_exc.MultipleResultsFound( - "Multiple rows were found for one()" + except orm_exc.MultipleResultsFound as err: + util.raise_( + orm_exc.MultipleResultsFound( + "Multiple rows were found for one()" + ), + replace_context=err, ) else: if ret is None: diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index b82a3d2712..2995baf5fc 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -2484,50 +2484,64 @@ class JoinCondition(object): a_subset=self.parent_local_selectable, consider_as_foreign_keys=consider_as_foreign_keys, ) - except sa_exc.NoForeignKeysError: + except sa_exc.NoForeignKeysError as nfe: if self.secondary is not None: - raise sa_exc.NoForeignKeysError( - "Could not determine join " - "condition between parent/child tables on " - "relationship %s - there are no foreign keys " - "linking these tables via secondary table '%s'. " - "Ensure that referencing columns are associated " - "with a ForeignKey or ForeignKeyConstraint, or " - "specify 'primaryjoin' and 'secondaryjoin' " - "expressions." % (self.prop, self.secondary) + util.raise_( + sa_exc.NoForeignKeysError( + "Could not determine join " + "condition between parent/child tables on " + "relationship %s - there are no foreign keys " + "linking these tables via secondary table '%s'. " + "Ensure that referencing columns are associated " + "with a ForeignKey or ForeignKeyConstraint, or " + "specify 'primaryjoin' and 'secondaryjoin' " + "expressions." % (self.prop, self.secondary) + ), + from_=nfe, ) else: - raise sa_exc.NoForeignKeysError( - "Could not determine join " - "condition between parent/child tables on " - "relationship %s - there are no foreign keys " - "linking these tables. " - "Ensure that referencing columns are associated " - "with a ForeignKey or ForeignKeyConstraint, or " - "specify a 'primaryjoin' expression." % self.prop + util.raise_( + sa_exc.NoForeignKeysError( + "Could not determine join " + "condition between parent/child tables on " + "relationship %s - there are no foreign keys " + "linking these tables. " + "Ensure that referencing columns are associated " + "with a ForeignKey or ForeignKeyConstraint, or " + "specify a 'primaryjoin' expression." % self.prop + ), + from_=nfe, ) - except sa_exc.AmbiguousForeignKeysError: + except sa_exc.AmbiguousForeignKeysError as afe: if self.secondary is not None: - raise sa_exc.AmbiguousForeignKeysError( - "Could not determine join " - "condition between parent/child tables on " - "relationship %s - there are multiple foreign key " - "paths linking the tables via secondary table '%s'. " - "Specify the 'foreign_keys' " - "argument, providing a list of those columns which " - "should be counted as containing a foreign key " - "reference from the secondary table to each of the " - "parent and child tables." % (self.prop, self.secondary) + util.raise_( + sa_exc.AmbiguousForeignKeysError( + "Could not determine join " + "condition between parent/child tables on " + "relationship %s - there are multiple foreign key " + "paths linking the tables via secondary table '%s'. " + "Specify the 'foreign_keys' " + "argument, providing a list of those columns which " + "should be counted as containing a foreign key " + "reference from the secondary table to each of the " + "parent and child tables." + % (self.prop, self.secondary) + ), + from_=afe, ) else: - raise sa_exc.AmbiguousForeignKeysError( - "Could not determine join " - "condition between parent/child tables on " - "relationship %s - there are multiple foreign key " - "paths linking the tables. Specify the " - "'foreign_keys' argument, providing a list of those " - "columns which should be counted as containing a " - "foreign key reference to the parent table." % self.prop + util.raise_( + sa_exc.AmbiguousForeignKeysError( + "Could not determine join " + "condition between parent/child tables on " + "relationship %s - there are multiple foreign key " + "paths linking the tables. Specify the " + "'foreign_keys' argument, providing a list of those " + "columns which should be counted as containing a " + "foreign key reference to the parent table." + % self.prop + ), + from_=afe, ) @property diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 0950339516..74e5464835 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -575,7 +575,7 @@ class SessionTransaction(object): self._parent._rollback_exception = sys.exc_info()[1] if rollback_err: - util.reraise(*rollback_err) + util.raise_(rollback_err[1], with_traceback=rollback_err[2]) sess.dispatch.after_soft_rollback(sess, self) @@ -1362,10 +1362,13 @@ class Session(_SessionClassMethods): def _add_bind(self, key, bind): try: insp = inspect(key) - except sa_exc.NoInspectionAvailable: + except sa_exc.NoInspectionAvailable as err: if not isinstance(key, type): - raise sa_exc.ArgumentError( - "Not an acceptable bind target: %s" % key + util.raise_( + sa_exc.ArgumentError( + "Not an acceptable bind target: %s" % key + ), + replace_context=err, ) else: self.__binds[key] = bind @@ -1515,9 +1518,11 @@ class Session(_SessionClassMethods): if mapper is not None: try: mapper = inspect(mapper) - except sa_exc.NoInspectionAvailable: + except sa_exc.NoInspectionAvailable as err: if isinstance(mapper, type): - raise exc.UnmappedClassError(mapper) + util.raise_( + exc.UnmappedClassError(mapper), replace_context=err, + ) else: raise @@ -1656,7 +1661,7 @@ class Session(_SessionClassMethods): "consider using a session.no_autoflush block if this " "flush is occurring prematurely" ) - util.raise_from_cause(e) + util.raise_(e, with_traceback=sys.exc_info[2]) def refresh( self, @@ -1711,8 +1716,10 @@ class Session(_SessionClassMethods): """ try: state = attributes.instance_state(instance) - except exc.NO_STATE: - raise exc.UnmappedInstanceError(instance) + except exc.NO_STATE as err: + util.raise_( + exc.UnmappedInstanceError(instance), replace_context=err, + ) self._expire_state(state, attribute_names) @@ -1817,8 +1824,10 @@ class Session(_SessionClassMethods): """ try: state = attributes.instance_state(instance) - except exc.NO_STATE: - raise exc.UnmappedInstanceError(instance) + except exc.NO_STATE as err: + util.raise_( + exc.UnmappedInstanceError(instance), replace_context=err, + ) self._expire_state(state, attribute_names) def _expire_state(self, state, attribute_names): @@ -1872,8 +1881,10 @@ class Session(_SessionClassMethods): """ try: state = attributes.instance_state(instance) - except exc.NO_STATE: - raise exc.UnmappedInstanceError(instance) + except exc.NO_STATE as err: + util.raise_( + exc.UnmappedInstanceError(instance), replace_context=err, + ) if state.session_id is not self.hash_key: raise sa_exc.InvalidRequestError( "Instance %s is not present in this Session" % state_str(state) @@ -2024,8 +2035,10 @@ class Session(_SessionClassMethods): try: state = attributes.instance_state(instance) - except exc.NO_STATE: - raise exc.UnmappedInstanceError(instance) + except exc.NO_STATE as err: + util.raise_( + exc.UnmappedInstanceError(instance), replace_context=err, + ) self._save_or_update_state(state) @@ -2059,8 +2072,10 @@ class Session(_SessionClassMethods): try: state = attributes.instance_state(instance) - except exc.NO_STATE: - raise exc.UnmappedInstanceError(instance) + except exc.NO_STATE as err: + util.raise_( + exc.UnmappedInstanceError(instance), replace_context=err, + ) self._delete_impl(state, instance, head=True) @@ -2490,8 +2505,10 @@ class Session(_SessionClassMethods): """ try: state = attributes.instance_state(instance) - except exc.NO_STATE: - raise exc.UnmappedInstanceError(instance) + except exc.NO_STATE as err: + util.raise_( + exc.UnmappedInstanceError(instance), replace_context=err, + ) return self._contains_state(state) def __iter__(self): @@ -2586,8 +2603,11 @@ class Session(_SessionClassMethods): for o in objects: try: state = attributes.instance_state(o) - except exc.NO_STATE: - raise exc.UnmappedInstanceError(o) + + except exc.NO_STATE as err: + util.raise_( + exc.UnmappedInstanceError(o), replace_context=err, + ) objset.add(state) else: objset = None @@ -3450,8 +3470,10 @@ def object_session(instance): try: state = attributes.instance_state(instance) - except exc.NO_STATE: - raise exc.UnmappedInstanceError(instance) + except exc.NO_STATE as err: + util.raise_( + exc.UnmappedInstanceError(instance), replace_context=err, + ) else: return _state_session(state) diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index 0c72f3b37d..4f7d996d4f 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -252,11 +252,14 @@ class Load(HasCacheKey, Generative, MapperOption): # use getattr on the class to work around # synonyms, hybrids, etc. attr = getattr(ent.class_, attr) - except AttributeError: + except AttributeError as err: if raiseerr: - raise sa_exc.ArgumentError( - 'Can\'t find property named "%s" on ' - "%s in this Query." % (attr, ent) + util.raise_( + sa_exc.ArgumentError( + 'Can\'t find property named "%s" on ' + "%s in this Query." % (attr, ent) + ), + replace_context=err, ) else: return None diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py index 198e64f4f2..ceaf54e5d3 100644 --- a/lib/sqlalchemy/orm/sync.py +++ b/lib/sqlalchemy/orm/sync.py @@ -13,6 +13,7 @@ between instances based on join conditions. from . import attributes from . import exc from . import util as orm_util +from .. import util def populate( @@ -34,15 +35,15 @@ def populate( value = source.manager[prop.key].impl.get( source, source_dict, attributes.PASSIVE_OFF ) - except exc.UnmappedColumnError: - _raise_col_to_prop(False, source_mapper, l, dest_mapper, r) + except exc.UnmappedColumnError as err: + _raise_col_to_prop(False, source_mapper, l, dest_mapper, r, err) try: # inline of dest_mapper._set_state_attr_by_column prop = dest_mapper._columntoproperty[r] dest.manager[prop.key].impl.set(dest, dest_dict, value, None) - except exc.UnmappedColumnError: - _raise_col_to_prop(True, source_mapper, l, dest_mapper, r) + except exc.UnmappedColumnError as err: + _raise_col_to_prop(True, source_mapper, l, dest_mapper, r, err) # technically the "r.primary_key" check isn't # needed here, but we check for this condition to limit @@ -64,8 +65,8 @@ def bulk_populate_inherit_keys(source_dict, source_mapper, synchronize_pairs): try: prop = source_mapper._columntoproperty[l] value = source_dict[prop.key] - except exc.UnmappedColumnError: - _raise_col_to_prop(False, source_mapper, l, source_mapper, r) + except exc.UnmappedColumnError as err: + _raise_col_to_prop(False, source_mapper, l, source_mapper, r, err) try: prop = source_mapper._columntoproperty[r] @@ -88,8 +89,8 @@ def clear(dest, dest_mapper, synchronize_pairs): ) try: dest_mapper._set_state_attr_by_column(dest, dest.dict, r, None) - except exc.UnmappedColumnError: - _raise_col_to_prop(True, None, l, dest_mapper, r) + except exc.UnmappedColumnError as err: + _raise_col_to_prop(True, None, l, dest_mapper, r, err) def update(source, source_mapper, dest, old_prefix, synchronize_pairs): @@ -101,8 +102,8 @@ def update(source, source_mapper, dest, old_prefix, synchronize_pairs): value = source_mapper._get_state_attr_by_column( source, source.dict, l, passive=attributes.PASSIVE_OFF ) - except exc.UnmappedColumnError: - _raise_col_to_prop(False, source_mapper, l, None, r) + except exc.UnmappedColumnError as err: + _raise_col_to_prop(False, source_mapper, l, None, r, err) dest[r.key] = value dest[old_prefix + r.key] = oldvalue @@ -113,8 +114,8 @@ def populate_dict(source, source_mapper, dict_, synchronize_pairs): value = source_mapper._get_state_attr_by_column( source, source.dict, l, passive=attributes.PASSIVE_OFF ) - except exc.UnmappedColumnError: - _raise_col_to_prop(False, source_mapper, l, None, r) + except exc.UnmappedColumnError as err: + _raise_col_to_prop(False, source_mapper, l, None, r, err) dict_[r.key] = value @@ -127,8 +128,8 @@ def source_modified(uowcommit, source, source_mapper, synchronize_pairs): for l, r in synchronize_pairs: try: prop = source_mapper._columntoproperty[l] - except exc.UnmappedColumnError: - _raise_col_to_prop(False, source_mapper, l, None, r) + except exc.UnmappedColumnError as err: + _raise_col_to_prop(False, source_mapper, l, None, r, err) history = uowcommit.get_attribute_history( source, prop.key, attributes.PASSIVE_NO_INITIALIZE ) @@ -139,22 +140,28 @@ def source_modified(uowcommit, source, source_mapper, synchronize_pairs): def _raise_col_to_prop( - isdest, source_mapper, source_column, dest_mapper, dest_column + isdest, source_mapper, source_column, dest_mapper, dest_column, err ): if isdest: - raise exc.UnmappedColumnError( - "Can't execute sync rule for " - "destination column '%s'; mapper '%s' does not map " - "this column. Try using an explicit `foreign_keys` " - "collection which does not include this column (or use " - "a viewonly=True relation)." % (dest_column, dest_mapper) + util.raise_( + exc.UnmappedColumnError( + "Can't execute sync rule for " + "destination column '%s'; mapper '%s' does not map " + "this column. Try using an explicit `foreign_keys` " + "collection which does not include this column (or use " + "a viewonly=True relation)." % (dest_column, dest_mapper) + ), + replace_context=err, ) else: - raise exc.UnmappedColumnError( - "Can't execute sync rule for " - "source column '%s'; mapper '%s' does not map this " - "column. Try using an explicit `foreign_keys` " - "collection which does not include destination column " - "'%s' (or use a viewonly=True relation)." - % (source_column, source_mapper, dest_column) + util.raise_( + exc.UnmappedColumnError( + "Can't execute sync rule for " + "source column '%s'; mapper '%s' does not map this " + "column. Try using an explicit `foreign_keys` " + "collection which does not include destination column " + "'%s' (or use a viewonly=True relation)." + % (source_column, source_mapper, dest_column) + ), + replace_context=err, ) diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py index b53f0d7ddb..17d5ba15fd 100644 --- a/lib/sqlalchemy/pool/base.py +++ b/lib/sqlalchemy/pool/base.py @@ -578,8 +578,8 @@ class _ConnectionRecord(object): self.connection = connection self.fresh = True except Exception as e: - pool.logger.debug("Error on connect(): %s", e) - raise + with util.safe_reraise(): + pool.logger.debug("Error on connect(): %s", e) else: if first_connect_check: pool.dispatch.first_connect.for_modify( diff --git a/lib/sqlalchemy/processors.py b/lib/sqlalchemy/processors.py index 67f1564ec3..8618d5e2aa 100644 --- a/lib/sqlalchemy/processors.py +++ b/lib/sqlalchemy/processors.py @@ -32,10 +32,13 @@ def str_to_datetime_processor_factory(regexp, type_): else: try: m = rmatch(value) - except TypeError: - raise ValueError( - "Couldn't parse %s string '%r' " - "- value is not a string." % (type_.__name__, value) + except TypeError as err: + util.raise_( + ValueError( + "Couldn't parse %s string '%r' " + "- value is not a string." % (type_.__name__, value) + ), + from_=err, ) if m is None: raise ValueError( diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index a7324c45fa..2d336360f9 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -128,8 +128,8 @@ class _DialectArgView(util.collections_abc.MutableMapping): def _key(self, key): try: dialect, value_key = key.split("_", 1) - except ValueError: - raise KeyError(key) + except ValueError as err: + util.raise_(KeyError(key), replace_context=err) else: return dialect, value_key @@ -138,17 +138,20 @@ class _DialectArgView(util.collections_abc.MutableMapping): try: opt = self.obj.dialect_options[dialect] - except exc.NoSuchModuleError: - raise KeyError(key) + except exc.NoSuchModuleError as err: + util.raise_(KeyError(key), replace_context=err) else: return opt[value_key] def __setitem__(self, key, value): try: dialect, value_key = self._key(key) - except KeyError: - raise exc.ArgumentError( - "Keys must be of the form _" + except KeyError as err: + util.raise_( + exc.ArgumentError( + "Keys must be of the form _" + ), + replace_context=err, ) else: self.obj.dialect_options[dialect][value_key] = value @@ -634,17 +637,17 @@ class ColumnCollection(object): def __getitem__(self, key): try: return self._index[key] - except KeyError: + except KeyError as err: if isinstance(key, util.int_types): - raise IndexError(key) + util.raise_(IndexError(key), replace_context=err) else: raise def __getattr__(self, key): try: return self._index[key] - except KeyError: - raise AttributeError(key) + except KeyError as err: + util.raise_(AttributeError(key), replace_context=err) def __contains__(self, key): if key not in self._index: diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index b3bf4e93b9..fc841bb4be 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -133,7 +133,13 @@ class RoleImpl(object): self._raise_for_expected(element, argname, resolved) def _raise_for_expected( - self, element, argname=None, resolved=None, advice=None, code=None + self, + element, + argname=None, + resolved=None, + advice=None, + code=None, + err=None, ): if argname: msg = "%s expected for argument %r; got %r." % ( @@ -147,7 +153,7 @@ class RoleImpl(object): if advice: msg += " " + advice - raise exc.ArgumentError(msg, code=code) + util.raise_(exc.ArgumentError(msg, code=code), replace_context=err) class _Deannotate(object): @@ -201,16 +207,19 @@ class _ColumnCoercions(object): def _no_text_coercion( - element, argname=None, exc_cls=exc.ArgumentError, extra=None + element, argname=None, exc_cls=exc.ArgumentError, extra=None, err=None ): - raise exc_cls( - "%(extra)sTextual SQL expression %(expr)r %(argname)sshould be " - "explicitly declared as text(%(expr)r)" - % { - "expr": util.ellipses_string(element), - "argname": "for argument %s" % (argname,) if argname else "", - "extra": "%s " % extra if extra else "", - } + util.raise_( + exc_cls( + "%(extra)sTextual SQL expression %(expr)r %(argname)sshould be " + "explicitly declared as text(%(expr)r)" + % { + "expr": util.ellipses_string(element), + "argname": "for argument %s" % (argname,) if argname else "", + "extra": "%s " % extra if extra else "", + } + ), + replace_context=err, ) @@ -290,8 +299,8 @@ class ExpressionElementImpl( return elements.BindParameter( name, element, type_, unique=True ) - except exc.ArgumentError: - self._raise_for_expected(element) + except exc.ArgumentError as err: + self._raise_for_expected(element, err=err) class BinaryElementImpl( @@ -302,8 +311,8 @@ class BinaryElementImpl( ): try: return expr._bind_param(operator, element, type_=bindparam_type) - except exc.ArgumentError: - self._raise_for_expected(element) + except exc.ArgumentError as err: + self._raise_for_expected(element, err=err) def _post_coercion(self, resolved, expr, **kw): if ( diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 9c1f50ce13..d31cf67f88 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1074,7 +1074,7 @@ class SQLCompiler(Compiled): col = only_froms[element.element] else: col = with_cols[element.element] - except KeyError: + except KeyError as err: coercions._no_text_coercion( element.element, extra=( @@ -1082,6 +1082,7 @@ class SQLCompiler(Compiled): "GROUP BY / DISTINCT etc." ), exc_cls=exc.CompileError, + err=err, ) else: kwargs["render_label_as_label"] = col @@ -1671,8 +1672,11 @@ class SQLCompiler(Compiled): else: try: opstring = OPERATORS[operator_] - except KeyError: - raise exc.UnsupportedCompilationError(self, operator_) + except KeyError as err: + util.raise_( + exc.UnsupportedCompilationError(self, operator_), + replace_context=err, + ) else: return self._generate_generic_binary( binary, opstring, from_linter=from_linter, **kw @@ -3286,11 +3290,12 @@ class DDLCompiler(Compiled): if column.primary_key: first_pk = True except exc.CompileError as ce: - util.raise_from_cause( + util.raise_( exc.CompileError( util.u("(in table '%s', column '%s'): %s") % (table.description, column.name, ce.args[0]) - ) + ), + from_=ce, ) const = self.create_table_constraints( diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 31bcc34a40..5a2095604c 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -801,7 +801,7 @@ class SchemaDropper(DDLBase): ) collection = [(t, ()) for t in unsorted_tables] else: - util.raise_from_cause( + util.raise_( exc.CircularDependencyError( err2.args[0], err2.cycles, @@ -818,7 +818,8 @@ class SchemaDropper(DDLBase): sorted([t.fullname for t in err2.cycles]) ) ), - ) + ), + from_=err2, ) seq_coll = [ diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index df690c383b..d0babb1be0 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -747,10 +747,13 @@ class ColumnElement( def comparator(self): try: comparator_factory = self.type.comparator_factory - except AttributeError: - raise TypeError( - "Object %r associated with '.type' attribute " - "is not a TypeEngine class or object" % self.type + except AttributeError as err: + util.raise_( + TypeError( + "Object %r associated with '.type' attribute " + "is not a TypeEngine class or object" % self.type + ), + replace_context=err, ) else: return comparator_factory(self) @@ -758,10 +761,17 @@ class ColumnElement( def __getattr__(self, key): try: return getattr(self.comparator, key) - except AttributeError: - raise AttributeError( - "Neither %r object nor %r object has an attribute %r" - % (type(self).__name__, type(self.comparator).__name__, key) + except AttributeError as err: + util.raise_( + AttributeError( + "Neither %r object nor %r object has an attribute %r" + % ( + type(self).__name__, + type(self.comparator).__name__, + key, + ) + ), + replace_context=err, ) def operate(self, op, *other, **kwargs): @@ -1742,10 +1752,13 @@ class TextClause( # a unique/anonymous key in any case, so use the _orig_key # so that a text() construct can support unique parameters existing = new_params[bind._orig_key] - except KeyError: - raise exc.ArgumentError( - "This text() construct doesn't define a " - "bound parameter named %r" % bind._orig_key + except KeyError as err: + util.raise_( + exc.ArgumentError( + "This text() construct doesn't define a " + "bound parameter named %r" % bind._orig_key + ), + replace_context=err, ) else: new_params[existing._orig_key] = bind @@ -1753,10 +1766,13 @@ class TextClause( for key, value in names_to_values.items(): try: existing = new_params[key] - except KeyError: - raise exc.ArgumentError( - "This text() construct doesn't define a " - "bound parameter named %r" % key + except KeyError as err: + util.raise_( + exc.ArgumentError( + "This text() construct doesn't define a " + "bound parameter named %r" % key + ), + replace_context=err, ) else: new_params[key] = existing._with_value(value) @@ -3665,9 +3681,12 @@ class Over(ColumnElement): else: try: lower = int(range_[0]) - except ValueError: - raise exc.ArgumentError( - "Integer or None expected for range value" + except ValueError as err: + util.raise_( + exc.ArgumentError( + "Integer or None expected for range value" + ), + replace_context=err, ) else: if lower == 0: @@ -3678,9 +3697,12 @@ class Over(ColumnElement): else: try: upper = int(range_[1]) - except ValueError: - raise exc.ArgumentError( - "Integer or None expected for range value" + except ValueError as err: + util.raise_( + exc.ArgumentError( + "Integer or None expected for range value" + ), + replace_context=err, ) else: if upper == 0: diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index e6d3a6b059..5445a1bcea 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -107,12 +107,13 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable): if item is not None: try: spwd = item._set_parent_with_dispatch - except AttributeError: - util.raise_from_cause( + except AttributeError as err: + util.raise_( exc.ArgumentError( "'SchemaItem' object, such as a 'Column' or a " "'Constraint' expected, got %r" % item - ) + ), + replace_context=err, ) else: spwd(self) @@ -1569,15 +1570,16 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): _proxies=[self], *fk ) - except TypeError: - util.raise_from_cause( + except TypeError as err: + util.raise_( TypeError( "Could not create a copy of this %r object. " "Ensure the class includes a _constructor() " "attribute or method which accepts the " "standard Column constructor arguments, or " "references the Column class itself." % self.__class__ - ) + ), + from_=err, ) c.table = selectable @@ -3187,10 +3189,13 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): try: ColumnCollectionConstraint._set_parent(self, table) except KeyError as ke: - raise exc.ArgumentError( - "Can't create ForeignKeyConstraint " - "on table '%s': no column " - "named '%s' is present." % (table.description, ke.args[0]) + util.raise_( + exc.ArgumentError( + "Can't create ForeignKeyConstraint " + "on table '%s': no column " + "named '%s' is present." % (table.description, ke.args[0]) + ), + from_=ke, ) for col, fk in zip(self.columns, self.elements): diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index b8d88e160e..b972c13be6 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -2620,10 +2620,13 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): return None try: value = clause._limit_offset_value - except AttributeError: - raise exc.CompileError( - "This SELECT structure does not use a simple " - "integer value for %s" % attrname + except AttributeError as err: + util.raise_( + exc.CompileError( + "This SELECT structure does not use a simple " + "integer value for %s" % attrname + ), + replace_context=err, ) else: return util.asint(value) @@ -3489,10 +3492,13 @@ class Select( try: cols_present = bool(columns) - except TypeError: - raise exc.ArgumentError( - "columns argument to select() must " - "be a Python list or other iterable" + except TypeError as err: + util.raise_( + exc.ArgumentError( + "columns argument to select() must " + "be a Python list or other iterable" + ), + from_=err, ) if cols_present: diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 22c80cc91e..e4a029a3e3 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -1462,7 +1462,7 @@ class Enum(Emulated, String, SchemaType): def _db_value_for_elem(self, elem): try: return self._valid_lookup[elem] - except KeyError: + except KeyError as err: # for unknown string values, we return as is. While we can # validate these if we wanted, that does not allow for lesser-used # end-user use cases, such as using a LIKE comparison with an enum, @@ -1476,8 +1476,11 @@ class Enum(Emulated, String, SchemaType): ): return elem else: - raise LookupError( - '"%s" is not among the defined enum values' % elem + util.raise_( + LookupError( + '"%s" is not among the defined enum values' % elem + ), + replace_context=err, ) class Comparator(String.Comparator): @@ -1496,9 +1499,12 @@ class Enum(Emulated, String, SchemaType): def _object_value_for_elem(self, elem): try: return self._object_lookup[elem] - except KeyError: - raise LookupError( - '"%s" is not among the defined enum values' % elem + except KeyError as err: + util.raise_( + LookupError( + '"%s" is not among the defined enum values' % elem + ), + replace_context=err, ) def __repr__(self): diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index c6c860844a..739f961954 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -479,9 +479,12 @@ class TypeEngine(Traversible): try: return dialect._type_memos[self]["literal"] except KeyError: - d = self._dialect_info(dialect) - d["literal"] = lp = d["impl"].literal_processor(dialect) - return lp + pass + # avoid KeyError context coming into literal_processor() function + # raises + d = self._dialect_info(dialect) + d["literal"] = lp = d["impl"].literal_processor(dialect) + return lp def _cached_bind_processor(self, dialect): """Return a dialect-specific bind processor for this type.""" @@ -489,9 +492,12 @@ class TypeEngine(Traversible): try: return dialect._type_memos[self]["bind"] except KeyError: - d = self._dialect_info(dialect) - d["bind"] = bp = d["impl"].bind_processor(dialect) - return bp + pass + # avoid KeyError context coming into bind_processor() function + # raises + d = self._dialect_info(dialect) + d["bind"] = bp = d["impl"].bind_processor(dialect) + return bp def _cached_result_processor(self, dialect, coltype): """Return a dialect-specific result processor for this type.""" @@ -499,21 +505,27 @@ class TypeEngine(Traversible): try: return dialect._type_memos[self][coltype] except KeyError: - d = self._dialect_info(dialect) - # key assumption: DBAPI type codes are - # constants. Else this dictionary would - # grow unbounded. - d[coltype] = rp = d["impl"].result_processor(dialect, coltype) - return rp + pass + # avoid KeyError context coming into result_processor() function + # raises + d = self._dialect_info(dialect) + # key assumption: DBAPI type codes are + # constants. Else this dictionary would + # grow unbounded. + d[coltype] = rp = d["impl"].result_processor(dialect, coltype) + return rp def _cached_custom_processor(self, dialect, key, fn): try: return dialect._type_memos[self][key] except KeyError: - d = self._dialect_info(dialect) - impl = d["impl"] - d[key] = result = fn(impl) - return result + pass + # avoid KeyError context coming into fn() function + # raises + d = self._dialect_info(dialect) + impl = d["impl"] + d[key] = result = fn(impl) + return result def _dialect_info(self, dialect): """Return a dialect-specific registry which diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 77e6b53a8a..fda48c6574 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -62,9 +62,10 @@ def _generate_compiler_dispatch(cls): "def _compiler_dispatch(self, visitor, **kw):\n" " try:\n" " meth = visitor.visit_%(name)s\n" - " except AttributeError:\n" - " util.raise_from_cause(\n" - " exc.UnsupportedCompilationError(visitor, cls))\n" + " except AttributeError as err:\n" + " util.raise_(\n" + " exc.UnsupportedCompilationError(visitor, cls), \n" + " replace_context=err)\n" " else:\n" " return meth(self, **kw)\n" ) % {"name": visit_name} diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index 5829015798..79b7f9eb3d 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -9,7 +9,9 @@ from . import config # noqa from . import mock # noqa from .assertions import assert_raises # noqa +from .assertions import assert_raises_context_ok # noqa from .assertions import assert_raises_message # noqa +from .assertions import assert_raises_message_context_ok # noqa from .assertions import assert_raises_return # noqa from .assertions import AssertsCompiledSQL # noqa from .assertions import AssertsExecutionResults # noqa diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index f5325b0cb1..c97202516b 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -9,6 +9,7 @@ from __future__ import absolute_import import contextlib import re +import sys import warnings from . import assertsql @@ -258,41 +259,80 @@ def eq_ignore_whitespace(a, b, msg=None): assert a == b, msg or "%r != %r" % (a, b) +def _assert_proper_exception_context(exception): + """assert that any exception we're catching does not have a __context__ + without a __cause__, and that __suppress_context__ is never set. + + Python 3 will report nested as exceptions as "during the handling of + error X, error Y occurred". That's not what we want to do. we want + these exceptions in a cause chain. + + """ + + if not util.py3k: + return + + if ( + exception.__context__ is not exception.__cause__ + and not exception.__suppress_context__ + ): + assert False, ( + "Exception %r was correctly raised but did not set a cause, " + "within context %r as its cause." + % (exception, exception.__context__) + ) + + def assert_raises(except_cls, callable_, *args, **kw): - try: - callable_(*args, **kw) - success = False - except except_cls: - success = True + _assert_raises(except_cls, callable_, args, kw, check_context=True) - # assert outside the block so it works for AssertionError too ! - assert success, "Callable did not raise an exception" + +def assert_raises_context_ok(except_cls, callable_, *args, **kw): + _assert_raises( + except_cls, callable_, args, kw, + ) def assert_raises_return(except_cls, callable_, *args, **kw): + return _assert_raises(except_cls, callable_, args, kw, check_context=True) + + +def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): + _assert_raises( + except_cls, callable_, args, kwargs, msg=msg, check_context=True + ) + + +def assert_raises_message_context_ok( + except_cls, msg, callable_, *args, **kwargs +): + _assert_raises(except_cls, callable_, args, kwargs, msg=msg) + + +def _assert_raises( + except_cls, callable_, args, kwargs, msg=None, check_context=False +): ret_err = None + if check_context: + are_we_already_in_a_traceback = sys.exc_info()[0] try: - callable_(*args, **kw) + callable_(*args, **kwargs) success = False except except_cls as err: - success = True ret_err = err + success = True + if msg is not None: + assert re.search( + msg, util.text_type(err), re.UNICODE + ), "%r !~ %s" % (msg, err,) + if check_context and not are_we_already_in_a_traceback: + _assert_proper_exception_context(err) + print(util.text_type(err).encode("utf-8")) # assert outside the block so it works for AssertionError too ! assert success, "Callable did not raise an exception" - return ret_err - -def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): - try: - callable_(*args, **kwargs) - assert False, "Callable did not raise an exception" - except except_cls as e: - assert re.search(msg, util.text_type(e), re.UNICODE), "%r !~ %s" % ( - msg, - e, - ) - print(util.text_type(e).encode("utf-8")) + return ret_err class AssertsCompiledSQL(object): diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py index 0c05bf9e9b..1a23ebf416 100644 --- a/lib/sqlalchemy/testing/exclusions.py +++ b/lib/sqlalchemy/testing/exclusions.py @@ -9,6 +9,7 @@ import contextlib import operator import re +import sys from . import config from .. import util @@ -145,7 +146,7 @@ class compound(object): ) break else: - util.raise_from_cause(ex) + util.raise_(ex, with_traceback=sys.exc_info()[2]) def _expect_success(self, config, name="block"): if not self.fails: diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index b0ceb802a4..660a0e9766 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -68,6 +68,7 @@ from .compat import py33 # noqa from .compat import py36 # noqa from .compat import py3k # noqa from .compat import quote_plus # noqa +from .compat import raise_ # noqa from .compat import raise_from_cause # noqa from .compat import reduce # noqa from .compat import reraise # noqa diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 8967955cd7..004b4687a6 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -147,13 +147,42 @@ if py3k: def cmp(a, b): return (a > b) - (a < b) - def reraise(tp, value, tb=None, cause=None): - if cause is not None: - assert cause is not value, "Same cause emitted" - value.__cause__ = cause - if value.__traceback__ is not tb: - raise value.with_traceback(tb) - raise value + def raise_( + exception, with_traceback=None, replace_context=None, from_=False + ): + r"""implement "raise" with cause support. + + :param exception: exception to raise + :param with_traceback: will call exception.with_traceback() + :param replace_context: an as-yet-unsupported feature. This is + an exception object which we are "replacing", e.g., it's our + "cause" but we don't want it printed. Basically just what + ``__suppress_context__`` does but we don't want to suppress + the enclosing context, if any. So for now we make it the + cause. + :param from\_: the cause. this actually sets the cause and doesn't + hope to hide it someday. + + """ + if with_traceback is not None: + exception = exception.with_traceback(with_traceback) + + if from_ is not False: + exception.__cause__ = from_ + elif replace_context is not None: + # no good solution here, we would like to have the exception + # have only the context of replace_context.__context__ so that the + # intermediary exception does not change, but we can't figure + # that out. + exception.__cause__ = replace_context + + try: + raise exception + finally: + # credit to + # https://cosmicpercolator.com/2016/01/13/exception-leaks-in-python-2-and-3/ + # as the __traceback__ object creates a cycle + del exception, replace_context, from_, with_traceback def u(s): return s @@ -257,13 +286,13 @@ else: else: return text - # not as nice as that of Py3K, but at least preserves - # the code line where the issue occurred exec( - "def reraise(tp, value, tb=None, cause=None):\n" - " if cause is not None:\n" - " assert cause is not value, 'Same cause emitted'\n" - " raise tp, value, tb\n" + "def raise_(exception, with_traceback=None, replace_context=None, " + "from_=False):\n" + " if with_traceback:\n" + " raise type(exception), exception, with_traceback\n" + " else:\n" + " raise exception\n" ) TYPE_CHECKING = False @@ -405,6 +434,8 @@ def nested(*managers): def raise_from_cause(exception, exc_info=None): + r"""legacy. use raise\_()""" + if exc_info is None: exc_info = sys.exc_info() exc_type, exc_value, exc_tb = exc_info @@ -412,6 +443,12 @@ def raise_from_cause(exception, exc_info=None): reraise(type(exception), exception, tb=exc_tb, cause=cause) +def reraise(tp, value, tb=None, cause=None): + r"""legacy. use raise\_()""" + + raise_(value, with_traceback=tb, from_=cause) + + def with_metaclass(meta, *bases): """Create a base class with a metaclass. diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 41a9698c7d..09aa94bf2b 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -65,7 +65,9 @@ class safe_reraise(object): exc_type, exc_value, exc_tb = self._exc_info self._exc_info = None # remove potential circular references if not self.warn_only: - compat.reraise(exc_type, exc_value, exc_tb) + compat.raise_( + exc_value, with_traceback=exc_tb, + ) else: if not compat.py3k and self._exc_info and self._exc_info[1]: # emulate Py3K's behavior of telling us when an exception @@ -76,7 +78,7 @@ class safe_reraise(object): "is:\n %s %s\n" % (self._exc_info[0], self._exc_info[1]) ) self._exc_info = None # remove potential circular references - compat.reraise(type_, value, traceback) + compat.raise_(value, with_traceback=traceback) def string_or_unprintable(element): diff --git a/test/aaa_profiling/test_memusage.py b/test/aaa_profiling/test_memusage.py index 55890cd064..8f84acde8a 100644 --- a/test/aaa_profiling/test_memusage.py +++ b/test/aaa_profiling/test_memusage.py @@ -1124,6 +1124,20 @@ class CycleTest(_fixtures.FixtureTest): go() + def test_raise_from(self): + @assert_cycles() + def go(): + try: + try: + raise KeyError("foo") + except KeyError as ke: + + util.raise_(Exception("oops"), from_=ke) + except Exception as err: # noqa + pass + + go() + def test_query_alias(self): User, Address = self.classes("User", "Address") configure_mappers() diff --git a/test/aaa_profiling/test_resultset.py b/test/aaa_profiling/test_resultset.py index 87908f016a..73a1a8b6fe 100644 --- a/test/aaa_profiling/test_resultset.py +++ b/test/aaa_profiling/test_resultset.py @@ -111,7 +111,7 @@ class ResultSetTest(fixtures.TestBase, AssertsExecutionResults): "some other column", Integer ) - @profiling.function_call_count() + @profiling.function_call_count(variance=0.10) def go(): c1 in row diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 48e464a01e..183e157e5e 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -3,7 +3,6 @@ import copy import datetime import inspect -import sys from sqlalchemy import exc from sqlalchemy import sql @@ -2899,20 +2898,29 @@ class ReraiseTest(fixtures.TestBase): except MyException as err: is_(err.__cause__, None) - def test_reraise_disallow_same_cause(self): + def test_raise_from_cause_legacy(self): class MyException(Exception): pass + class MyOtherException(Exception): + pass + + me = MyException("exc on") + def go(): try: - raise MyException("exc one") - except Exception as err: - type_, value, tb = sys.exc_info() - util.reraise(type_, err, tb, value) + raise me + except Exception: + util.raise_from_cause(MyOtherException("exc two")) - assert_raises_message(AssertionError, "Same cause emitted", go) + try: + go() + assert False + except MyOtherException as moe: + if testing.requires.python3.enabled: + is_(moe.__cause__, me) - def test_raise_from_cause(self): + def test_raise_from(self): class MyException(Exception): pass @@ -2924,8 +2932,8 @@ class ReraiseTest(fixtures.TestBase): def go(): try: raise me - except Exception: - util.raise_from_cause(MyOtherException("exc two")) + except Exception as err: + util.raise_(MyOtherException("exc two"), from_=err) try: go() diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 5acd14177e..cf262a5738 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -712,12 +712,13 @@ class ExecuteTest(fixtures.TestBase): return super(MockCursor, self).execute(stmt, params, **kw) eng = engines.proxying_engine(cursor_cls=MockCursor) - assert_raises_message( - tsa.exc.SAWarning, - "Exception attempting to detect unicode returns", - eng.connect, - ) - assert eng.dialect.returns_unicode_strings in (True, False) + with testing.expect_warnings( + "Exception attempting to detect unicode returns" + ): + eng.connect() + + # because plain varchar passed, we don't know the correct answer + eq_(eng.dialect.returns_unicode_strings, "conditional") eng.dispose() def test_works_after_dispose(self): diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py index cfe20f5ec0..72e0fa1865 100644 --- a/test/engine/test_pool.py +++ b/test/engine/test_pool.py @@ -10,6 +10,7 @@ from sqlalchemy import pool from sqlalchemy import select from sqlalchemy import testing from sqlalchemy.testing import assert_raises +from sqlalchemy.testing import assert_raises_context_ok from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures @@ -1256,7 +1257,7 @@ class QueuePoolTest(PoolTestBase): eq_(p.checkedout(), 0) eq_(p._overflow, 0) dbapi.shutdown(True) - assert_raises(Exception, p.connect) + assert_raises_context_ok(Exception, p.connect) eq_(p._overflow, 0) eq_(p.checkedout(), 0) # and not 1 diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py index 205c1fb310..000be1a701 100644 --- a/test/engine/test_reconnect.py +++ b/test/engine/test_reconnect.py @@ -14,6 +14,7 @@ from sqlalchemy import util from sqlalchemy.engine import url from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message +from sqlalchemy.testing import assert_raises_message_context_ok from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_warnings @@ -255,7 +256,7 @@ class PrePingMockTest(fixtures.TestBase): self.dbapi.shutdown("execute", stop=True) - assert_raises_message( + assert_raises_message_context_ok( MockDisconnect, "database is stopped", pool.connect ) @@ -835,7 +836,7 @@ class CursorErrTest(fixtures.TestBase): def test_cursor_shutdown_in_initialize(self): db = self._fixture(True, True) - assert_raises_message( + assert_raises_message_context_ok( exc.SAWarning, "Exception attempting to detect", db.connect ) eq_( diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py index 301614061e..579f1aecec 100644 --- a/test/engine/test_reflection.py +++ b/test/engine/test_reflection.py @@ -24,6 +24,7 @@ from sqlalchemy.testing import eq_regex from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import in_ +from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_true from sqlalchemy.testing import mock @@ -596,13 +597,10 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): testing.db.dialect.ischema_names = {} try: m2 = MetaData(testing.db) - assert_raises(sa.exc.SAWarning, Table, "test", m2, autoload=True) - @testing.emits_warning("Did not recognize type") - def warns(): - m3 = MetaData(testing.db) - t3 = Table("test", m3, autoload=True) - assert t3.c.foo.type.__class__ == sa.types.NullType + with testing.expect_warnings("Did not recognize type"): + t3 = Table("test", m2, autoload_with=testing.db) + is_(t3.c.foo.type.__class__, sa.types.NullType) finally: testing.db.dialect.ischema_names = ischema_names diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index 3f4333750f..8ef272a9ef 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -4093,16 +4093,11 @@ class DialectKWArgTest(fixtures.TestBase): def test_unknown_dialect_warning(self): with self._fixture(): - assert_raises_message( - exc.SAWarning, + with testing.expect_warnings( "Can't validate argument 'unknown_y'; can't locate " "any SQLAlchemy dialect named 'unknown'", - Index, - "a", - "b", - "c", - unknown_y=True, - ) + ): + Index("a", "b", "c", unknown_y=True) def test_participating_bad_kw(self): with self._fixture():