"""SQLAlchemy ORM exceptions."""
-import sqlalchemy.exceptions as sa_exc
+import sqlalchemy as sa
-class ConcurrentModificationError(sa_exc.SQLAlchemyError):
+NO_STATE = (AttributeError, KeyError)
+"""Exception types that may be raised by instrumentation implementations."""
+
+class ConcurrentModificationError(sa.exc.SQLAlchemyError):
"""Rows have been modified outside of the unit of work."""
-class FlushError(sa_exc.SQLAlchemyError):
+class FlushError(sa.exc.SQLAlchemyError):
"""A invalid condition was detected during flush()."""
-class ObjectDeletedError(sa_exc.InvalidRequestError):
+class UnmappedError(sa.exc.InvalidRequestError):
+ """TODO"""
+
+
+class UnmappedInstanceError(UnmappedError):
+ """An mapping operation was requested for an unknown instance."""
+
+ def __init__(self, obj, entity_name=None, msg=None):
+ if not msg:
+ try:
+ mapper = sa.orm.class_mapper(type(obj), entity_name)
+ name = _safe_cls_name(type(obj))
+ msg = ("Class %r is mapped, but this instance lacks "
+ "instrumentation. Possible causes: instance created "
+ "before sqlalchemy.orm.mapper(%s) was called, or "
+ "instance was pickled/depickled without instrumentation"
+ "information." % (name, name))
+ except UnmappedClassError:
+ msg = _default_unmapped(type(obj), entity_name)
+ if isinstance(obj, type):
+ msg += (
+ '; was a class (%s) supplied where an instance was '
+ 'required?' % _safe_cls_name(obj))
+ UnmappedError.__init__(self, msg)
+
+
+class UnmappedClassError(UnmappedError):
+ """An mapping operation was requested for an unknown class."""
+
+ def __init__(self, cls, entity_name=None, msg=None):
+ if not msg:
+ msg = _default_unmapped(cls, entity_name)
+ UnmappedError.__init__(self, msg)
+
+
+class ObjectDeletedError(sa.exc.InvalidRequestError):
"""An refresh() operation failed to re-retrieve an object's row."""
-class UnmappedColumnError(sa_exc.InvalidRequestError):
+class UnmappedColumnError(sa.exc.InvalidRequestError):
"""Mapping operation was requested on an unknown column."""
-class NoResultFound(sa_exc.InvalidRequestError):
+
+class NoResultFound(sa.exc.InvalidRequestError):
"""A database result was required but none was found."""
-class MultipleResultsFound(sa_exc.InvalidRequestError):
+
+class MultipleResultsFound(sa.exc.InvalidRequestError):
"""A single database result was required but more than one were found."""
+
# Legacy compat until 0.6.
-sa_exc.ConcurrentModificationError = ConcurrentModificationError
-sa_exc.FlushError = FlushError
-sa_exc.UnmappedColumnError
+sa.exc.ConcurrentModificationError = ConcurrentModificationError
+sa.exc.FlushError = FlushError
+sa.exc.UnmappedColumnError
+
+def _safe_cls_name(cls):
+ try:
+ cls_name = '.'.join((cls.__module__, cls.__name__))
+ except AttributeError:
+ cls_name = getattr(cls, '__name__', None)
+ if cls_name is None:
+ cls_name = repr(cls)
+ return cls_name
+
+def _default_unmapped(cls, entity_name):
+ try:
+ mappers = sa.orm.attributes.manager_of_class(cls).mappers
+ except NO_STATE:
+ mappers = {}
+ except TypeError:
+ mappers = {}
+ name = _safe_cls_name(cls)
+
+ if not mappers and entity_name is None:
+ return "Class '%s' is not mapped" % name
+ else:
+ return "Class '%s' is not mapped with entity_name %r" % (
+ name, entity_name)
qualifies the underlying Mapper used to perform the query.
"""
+ return self.query(_class_to_mapper(class_, entity_name)).get(ident)
+
return self.query(class_, entity_name=entity_name).get(ident)
def load(self, class_, ident, entity_name=None):
qualifies the underlying ``Mapper`` used to perform the query.
"""
- return self.query(class_, entity_name=entity_name).load(ident)
+ return self.query(_class_to_mapper(class_, entity_name)).load(ident)
def refresh(self, instance, attribute_names=None):
"""Refresh the attributes on the given instance.
attribute names indicating a subset of attributes to be refreshed.
"""
- state = attributes.instance_state(instance)
+ try:
+ state = attributes.instance_state(instance)
+ except exc.NO_STATE:
+ raise exc.UnmappedInstanceError(instance)
self._validate_persistent(state)
if self.query(_object_mapper(instance))._get(
state.key, refresh_instance=state,
expired.
"""
- state = attributes.instance_state(instance)
+ try:
+ state = attributes.instance_state(instance)
+ except exc.NO_STATE:
+ raise exc.UnmappedInstanceError(instance)
self._validate_persistent(state)
if attribute_names:
_expire_state(state, attribute_names=attribute_names)
will be applied according to the *expunge* cascade rule.
"""
- state = attributes.instance_state(instance)
+ try:
+ state = attributes.instance_state(instance)
+ except exc.NO_STATE:
+ raise exc.UnmappedInstanceError(instance)
if state.session_id is not self.hash_key:
raise sa_exc.InvalidRequestError(
"Instance %s is not present in this Session" %
def _save_without_cascade(self, instance, entity_name=None):
"""Used by scoping.py to save on init without cascade."""
- state = _state_for_unsaved_instance(instance, entity_name)
+ state = _state_for_unsaved_instance(instance, entity_name, create=True)
self._save_impl(state)
def update(self, instance, entity_name=None):
instances if the relation is mapped with ``cascade="save-update"``.
"""
- state = attributes.instance_state(instance)
+ try:
+ state = attributes.instance_state(instance)
+ except exc.NO_STATE:
+ raise exc.UnmappedInstanceError(instance, entity_name)
self._update_impl(state)
self._cascade_save_or_update(state, entity_name)
update = util.pending_deprecation('0.5.x', "Use session.add()")(update)
The database delete operation occurs upon ``flush()``.
"""
- state = attributes.instance_state(instance)
+ try:
+ state = attributes.instance_state(instance)
+ except exc.NO_STATE:
+ raise exc.UnmappedInstanceError(instance)
self._delete_impl(state)
for state, m in _cascade_state_iterator('delete', state):
self._delete_impl(state, ignore_transient=True)
result of True.
"""
- return self._contains_state(attributes.instance_state(instance))
+ try:
+ state = attributes.instance_state(instance)
+ except exc.NO_STATE:
+ raise exc.UnmappedInstanceError(instance)
+ return self._contains_state(state)
def __iter__(self):
"""Iterate over all pending or persistent instances within this Session."""
# create the set of all objects we want to operate upon
if objects:
# specific list passed in
- objset = util.Set([attributes.instance_state(o) for o in objects])
+ objset = util.Set()
+ for o in objects:
+ try:
+ state = attributes.instance_state(o)
+ except exc.NO_STATE:
+ raise exc.UnmappedInstanceError(o)
+ objset.add(state)
else:
# or just everything
objset = util.Set(self.identity_map.all_states()).union(new)
should not be loaded in the course of performing this test.
"""
- for attr in attributes.manager_of_class(instance.__class__).attributes:
+ try:
+ state = attributes.instance_state(instance)
+ except exc.NO_STATE:
+ raise exc.UnmappedInstanceError(instance)
+ for attr in state.manager.attributes:
if not include_collections and hasattr(attr.impl, 'get_collection'):
continue
(added, unchanged, deleted) = attr.get_history(instance)
for (o, m) in mapper.cascade_iterator(cascade, state, **kwargs):
yield _state_for_unknown_persistence_instance(o, m.entity_name), m
-def _state_for_unsaved_instance(instance, entity_name):
+def _state_for_unsaved_instance(instance, entity_name, create=False):
manager = attributes.manager_of_class(instance.__class__)
if manager is None:
- raise "FIXME unmapped instance"
+ raise exc.UnmappedInstanceError(instance, entity_name)
if manager.has_state(instance):
state = manager.state_of(instance)
if state.key is not None:
raise sa_exc.InvalidRequestError(
"Instance '%s' is already persistent" %
mapperutil.state_str(state))
- else:
+ elif create:
state = manager.setup_instance(instance)
+ else:
+ raise exc.UnmappedInstanceError(instance, entity_name)
state.entity_name = entity_name
return state
def _state_for_unknown_persistence_instance(instance, entity_name):
- state = attributes.instance_state(instance)
+ try:
+ state = attributes.instance_state(instance)
+ except exc.NO_STATE:
+ raise exc.UnmappedInstanceError(instance, entity_name)
state.entity_name = entity_name
return state
from sqlalchemy import sql, util
from sqlalchemy.sql import expression, util as sql_util, operators
from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, PropComparator, MapperProperty
-from sqlalchemy.orm import attributes
+from sqlalchemy.orm import attributes, exc
all_cascades = util.FrozenSet(["delete", "delete-orphan", "all", "merge",
"expunge", "save-update", "refresh-expire", "none"])
except (KeyError, AttributeError):
if not raiseerror:
return None
- raise sa_exc.InvalidRequestError(
- "FIXME Instance %r with entity name '%s' has no mapper associated with it" %
- (object, entity_name))
+ raise exc.UnmappedInstanceError(object, entity_name)
if state.entity_name is not attributes.NO_ENTITY_NAME:
# Override the given entity name if the object is not transient.
entity_name = state.entity_name
except (KeyError, AttributeError):
if not raiseerror:
return
+ raise exc.UnmappedClassError(class_, entity_name)
raise sa_exc.InvalidRequestError(
"Class '%s' entity name '%s' has no mapper associated with it" %
(class_.__name__, entity_name))
return class_or_mapper._AliasedClass__mapper
elif isinstance(class_or_mapper, type):
return class_mapper(class_or_mapper, entity_name=entity_name, compile=compile)
- else:
+ elif hasattr(class_or_mapper, 'compile'):
if compile:
return class_or_mapper.compile()
else:
return class_or_mapper
+ else:
+ raise exc.UnmappedClassError(class_or_mapper, entity_name)
def has_identity(object):
state = attributes.instance_state(object)
import testenv; testenv.configure_for_tests()
import gc
+import inspect
import pickle
from sqlalchemy.orm import create_session, sessionmaker
from testlib import engines, sa, testing
assert len(list(sess)) == 1
+class SessionInterface(testing.TestBase):
+ """Bogus args to Session methods produce actionable exceptions."""
+
+ # TODO: expand with message body assertions.
+
+ _class_methods = set(('get', 'load'))
+
+ def _public_session_methods(self):
+ Session = sa.orm.session.Session
+
+ blacklist = set(('begin', 'query'))
+
+ ok = set()
+ for meth in Session.public_methods:
+ if meth in blacklist:
+ continue
+ spec = inspect.getargspec(getattr(Session, meth))
+ if len(spec[0]) > 1 or spec[1]:
+ ok.add(meth)
+ return ok
+
+ def _map_it(self, cls):
+ return mapper(cls, Table('t', sa.MetaData(),
+ Column('id', Integer, primary_key=True)))
+
+ def _test_instance_guards(self, user_arg):
+ watchdog = set()
+
+ def x_raises_(obj, method, *args, **kw):
+ watchdog.add(method)
+ callable_ = getattr(obj, method)
+ self.assertRaises(sa.orm.exc.UnmappedInstanceError,
+ callable_, *args, **kw)
+
+ def raises_(method, *args, **kw):
+ x_raises_(create_session(), method, *args, **kw)
+
+ raises_('__contains__', user_arg)
+
+ raises_('add', user_arg)
+
+ raises_('add_all', (user_arg,))
+
+ raises_('connection', instance=user_arg)
+
+ raises_('delete', user_arg)
+
+ raises_('execute', 'SELECT 1', instance=user_arg)
+
+ raises_('expire', user_arg)
+
+ raises_('expunge', user_arg)
+
+ # flush will no-op without something in the unit of work
+ def _():
+ class OK(object):
+ pass
+ self._map_it(OK)
+
+ s = create_session()
+ s.add(OK())
+ x_raises_(s, 'flush', (user_arg,))
+ _()
+
+ raises_('get_bind', instance=user_arg)
+
+ raises_('is_modified', user_arg)
+
+ raises_('merge', user_arg)
+
+ raises_('refresh', user_arg)
+
+ raises_('save', user_arg)
+
+ raises_('save_or_update', user_arg)
+
+ raises_('scalar', 'SELECT 1', instance=user_arg)
+
+ raises_('update', user_arg)
+
+ instance_methods = self._public_session_methods() - self._class_methods
+
+ eq_(watchdog, instance_methods,
+ watchdog.symmetric_difference(instance_methods))
+
+ def _test_class_guards(self, user_arg):
+ watchdog = set()
+
+ def raises_(method, *args, **kw):
+ watchdog.add(method)
+ callable_ = getattr(create_session(), method)
+ self.assertRaises(sa.orm.exc.UnmappedClassError,
+ callable_, *args, **kw)
+
+ raises_('get', user_arg, 1)
+
+ raises_('load', user_arg, 1)
+
+ eq_(watchdog, self._class_methods,
+ watchdog.symmetric_difference(self._class_methods))
+
+ def test_unmapped_instance(self):
+ class Unmapped(object):
+ pass
+
+ self._test_instance_guards(Unmapped())
+ self._test_class_guards(Unmapped)
+
+ def test_unmapped_primitives(self):
+ for prim in ('doh', 123, ('t', 'u', 'p', 'l', 'e')):
+ self._test_instance_guards(prim)
+ self._test_class_guards(prim)
+
+ def test_unmapped_class_for_instance(self):
+ class Unmapped(object):
+ pass
+
+ self._test_instance_guards(Unmapped)
+ self._test_class_guards(Unmapped)
+
+ def test_mapped_class_for_instance(self):
+ class Mapped(object):
+ pass
+ self._map_it(Mapped)
+
+ self._test_instance_guards(Mapped)
+ # no class guards- it would pass.
+
+ def test_missing_state(self):
+ class Mapped(object):
+ pass
+ early = Mapped()
+ self._map_it(Mapped)
+
+ self._test_instance_guards(early)
+ self._test_class_guards(early)
+
+
class TLTransactionTest(engine_base.AltEngineTest, _base.MappedTest):
def create_engine(self):
return engines.testing_engine(options=dict(strategy='threadlocal'))