]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Centralized 'x is not mapped' reporting into sa.orm.exc.
authorJason Kirtland <jek@discorporate.us>
Wed, 21 May 2008 03:39:06 +0000 (03:39 +0000)
committerJason Kirtland <jek@discorporate.us>
Wed, 21 May 2008 03:39:06 +0000 (03:39 +0000)
- Guards are now present on all public Session methods and passing in an
  unmapped hoho anywhere yields helpful exception messages, going to some
  effort to provide hints for debugging situations that would otherwise seem
  hopeless, such as broken user instrumentation or half-pickles.

lib/sqlalchemy/orm/exc.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/util.py
test/orm/session.py

index ea7efd3fb858fab99f07e5cc4fa905404ee55ee6..d97a54621ec2fa2b1bbd1e43e8ba9391b18343e5 100644 (file)
@@ -6,31 +6,96 @@
 
 """SQLAlchemy ORM exceptions."""
 
-import sqlalchemy.exceptions as sa_exc
+import sqlalchemy as sa
 
 
-class ConcurrentModificationError(sa_exc.SQLAlchemyError):
+NO_STATE = (AttributeError, KeyError)
+"""Exception types that may be raised by instrumentation implementations."""
+
+class ConcurrentModificationError(sa.exc.SQLAlchemyError):
     """Rows have been modified outside of the unit of work."""
 
 
-class FlushError(sa_exc.SQLAlchemyError):
+class FlushError(sa.exc.SQLAlchemyError):
     """A invalid condition was detected during flush()."""
 
 
-class ObjectDeletedError(sa_exc.InvalidRequestError):
+class UnmappedError(sa.exc.InvalidRequestError):
+    """TODO"""
+
+
+class UnmappedInstanceError(UnmappedError):
+    """An mapping operation was requested for an unknown instance."""
+
+    def __init__(self, obj, entity_name=None, msg=None):
+        if not msg:
+            try:
+                mapper = sa.orm.class_mapper(type(obj), entity_name)
+                name = _safe_cls_name(type(obj))
+                msg = ("Class %r is mapped, but this instance lacks "
+                       "instrumentation.  Possible causes: instance created "
+                       "before sqlalchemy.orm.mapper(%s) was called, or "
+                       "instance was pickled/depickled without instrumentation"
+                       "information." % (name, name))
+            except UnmappedClassError:
+                msg = _default_unmapped(type(obj), entity_name)
+                if isinstance(obj, type):
+                    msg += (
+                        '; was a class (%s) supplied where an instance was '
+                        'required?' % _safe_cls_name(obj))
+        UnmappedError.__init__(self, msg)
+
+
+class UnmappedClassError(UnmappedError):
+    """An mapping operation was requested for an unknown class."""
+
+    def __init__(self, cls, entity_name=None, msg=None):
+        if not msg:
+            msg = _default_unmapped(cls, entity_name)
+        UnmappedError.__init__(self, msg)
+
+
+class ObjectDeletedError(sa.exc.InvalidRequestError):
     """An refresh() operation failed to re-retrieve an object's row."""
 
 
-class UnmappedColumnError(sa_exc.InvalidRequestError):
+class UnmappedColumnError(sa.exc.InvalidRequestError):
     """Mapping operation was requested on an unknown column."""
 
-class NoResultFound(sa_exc.InvalidRequestError):
+
+class NoResultFound(sa.exc.InvalidRequestError):
     """A database result was required but none was found."""
 
-class MultipleResultsFound(sa_exc.InvalidRequestError):
+
+class MultipleResultsFound(sa.exc.InvalidRequestError):
     """A single database result was required but more than one were found."""
 
+
 # Legacy compat until 0.6.
-sa_exc.ConcurrentModificationError = ConcurrentModificationError
-sa_exc.FlushError = FlushError
-sa_exc.UnmappedColumnError
+sa.exc.ConcurrentModificationError = ConcurrentModificationError
+sa.exc.FlushError = FlushError
+sa.exc.UnmappedColumnError
+
+def _safe_cls_name(cls):
+    try:
+        cls_name = '.'.join((cls.__module__, cls.__name__))
+    except AttributeError:
+        cls_name = getattr(cls, '__name__', None)
+        if cls_name is None:
+            cls_name = repr(cls)
+    return cls_name
+
+def _default_unmapped(cls, entity_name):
+    try:
+        mappers = sa.orm.attributes.manager_of_class(cls).mappers
+    except NO_STATE:
+        mappers = {}
+    except TypeError:
+        mappers = {}
+    name = _safe_cls_name(cls)
+
+    if not mappers and entity_name is None:
+        return "Class '%s' is not mapped" % name
+    else:
+        return "Class '%s' is not mapped with entity_name %r" % (
+            name, entity_name)
index e7020354c31335958d763a80b4aca0e1425c9742..ab622926342602400b89a087e90a978c92d34110 100644 (file)
@@ -917,6 +917,8 @@ class Session(object):
         qualifies the underlying Mapper used to perform the query.
 
         """
+        return self.query(_class_to_mapper(class_, entity_name)).get(ident)
+
         return self.query(class_, entity_name=entity_name).get(ident)
 
     def load(self, class_, ident, entity_name=None):
@@ -931,7 +933,7 @@ class Session(object):
         qualifies the underlying ``Mapper`` used to perform the query.
 
         """
-        return self.query(class_, entity_name=entity_name).load(ident)
+        return self.query(_class_to_mapper(class_, entity_name)).load(ident)
 
     def refresh(self, instance, attribute_names=None):
         """Refresh the attributes on the given instance.
@@ -950,7 +952,10 @@ class Session(object):
         attribute names indicating a subset of attributes to be refreshed.
 
         """
-        state = attributes.instance_state(instance)
+        try:
+            state = attributes.instance_state(instance)
+        except exc.NO_STATE:
+            raise exc.UnmappedInstanceError(instance)
         self._validate_persistent(state)
         if self.query(_object_mapper(instance))._get(
                 state.key, refresh_instance=state,
@@ -977,7 +982,10 @@ class Session(object):
         expired.
 
         """
-        state = attributes.instance_state(instance)
+        try:
+            state = attributes.instance_state(instance)
+        except exc.NO_STATE:
+            raise exc.UnmappedInstanceError(instance)
         self._validate_persistent(state)
         if attribute_names:
             _expire_state(state, attribute_names=attribute_names)
@@ -1009,7 +1017,10 @@ class Session(object):
         will be applied according to the *expunge* cascade rule.
 
         """
-        state = attributes.instance_state(instance)
+        try:
+            state = attributes.instance_state(instance)
+        except exc.NO_STATE:
+            raise exc.UnmappedInstanceError(instance)
         if state.session_id is not self.hash_key:
             raise sa_exc.InvalidRequestError(
                 "Instance %s is not present in this Session" %
@@ -1081,7 +1092,7 @@ class Session(object):
     def _save_without_cascade(self, instance, entity_name=None):
         """Used by scoping.py to save on init without cascade."""
 
-        state = _state_for_unsaved_instance(instance, entity_name)
+        state = _state_for_unsaved_instance(instance, entity_name, create=True)
         self._save_impl(state)
 
     def update(self, instance, entity_name=None):
@@ -1095,7 +1106,10 @@ class Session(object):
         instances if the relation is mapped with ``cascade="save-update"``.
 
         """
-        state = attributes.instance_state(instance)
+        try:
+            state = attributes.instance_state(instance)
+        except exc.NO_STATE:
+            raise exc.UnmappedInstanceError(instance, entity_name)
         self._update_impl(state)
         self._cascade_save_or_update(state, entity_name)
     update = util.pending_deprecation('0.5.x', "Use session.add()")(update)
@@ -1136,7 +1150,10 @@ class Session(object):
         The database delete operation occurs upon ``flush()``.
 
         """
-        state = attributes.instance_state(instance)
+        try:
+            state = attributes.instance_state(instance)
+        except exc.NO_STATE:
+            raise exc.UnmappedInstanceError(instance)
         self._delete_impl(state)
         for state, m in _cascade_state_iterator('delete', state):
             self._delete_impl(state, ignore_transient=True)
@@ -1304,7 +1321,11 @@ class Session(object):
         result of True.
 
         """
-        return self._contains_state(attributes.instance_state(instance))
+        try:
+            state = attributes.instance_state(instance)
+        except exc.NO_STATE:
+            raise exc.UnmappedInstanceError(instance)
+        return self._contains_state(state)
 
     def __iter__(self):
         """Iterate over all pending or persistent instances within this Session."""
@@ -1359,7 +1380,13 @@ class Session(object):
         # create the set of all objects we want to operate upon
         if objects:
             # specific list passed in
-            objset = util.Set([attributes.instance_state(o) for o in objects])
+            objset = util.Set()
+            for o in objects:
+                try:
+                    state = attributes.instance_state(o)
+                except exc.NO_STATE:
+                    raise exc.UnmappedInstanceError(o)
+                objset.add(state)
         else:
             # or just everything
             objset = util.Set(self.identity_map.all_states()).union(new)
@@ -1430,7 +1457,11 @@ class Session(object):
         should not be loaded in the course of performing this test.
 
         """
-        for attr in attributes.manager_of_class(instance.__class__).attributes:
+        try:
+            state = attributes.instance_state(instance)
+        except exc.NO_STATE:
+            raise exc.UnmappedInstanceError(instance)
+        for attr in state.manager.attributes:
             if not include_collections and hasattr(attr.impl, 'get_collection'):
                 continue
             (added, unchanged, deleted) = attr.get_history(instance)
@@ -1513,23 +1544,28 @@ def _cascade_unknown_state_iterator(cascade, state, **kwargs):
     for (o, m) in mapper.cascade_iterator(cascade, state, **kwargs):
         yield _state_for_unknown_persistence_instance(o, m.entity_name), m
 
-def _state_for_unsaved_instance(instance, entity_name):
+def _state_for_unsaved_instance(instance, entity_name, create=False):
     manager = attributes.manager_of_class(instance.__class__)
     if manager is None:
-        raise "FIXME unmapped instance"
+        raise exc.UnmappedInstanceError(instance, entity_name)
     if manager.has_state(instance):
         state = manager.state_of(instance)
         if state.key is not None:
             raise sa_exc.InvalidRequestError(
                 "Instance '%s' is already persistent" %
                 mapperutil.state_str(state))
-    else:
+    elif create:
         state = manager.setup_instance(instance)
+    else:
+        raise exc.UnmappedInstanceError(instance, entity_name)
     state.entity_name = entity_name
     return state
 
 def _state_for_unknown_persistence_instance(instance, entity_name):
-    state = attributes.instance_state(instance)
+    try:
+        state = attributes.instance_state(instance)
+    except exc.NO_STATE:
+        raise exc.UnmappedInstanceError(instance, entity_name)
     state.entity_name = entity_name
     return state
 
index 485d60cd1b2f78ec99b6b5311d1de3fa6433d22d..3b910c22ba5c7aa63569f4c8463c4a8bfd19909a 100644 (file)
@@ -10,7 +10,7 @@ import sqlalchemy.exceptions as sa_exc
 from sqlalchemy import sql, util
 from sqlalchemy.sql import expression, util as sql_util, operators
 from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, PropComparator, MapperProperty
-from sqlalchemy.orm import attributes
+from sqlalchemy.orm import attributes, exc
 
 all_cascades = util.FrozenSet(["delete", "delete-orphan", "all", "merge",
                          "expunge", "save-update", "refresh-expire", "none"])
@@ -457,9 +457,7 @@ def object_mapper(object, entity_name=None, raiseerror=True):
     except (KeyError, AttributeError):
         if not raiseerror:
             return None
-        raise sa_exc.InvalidRequestError(
-            "FIXME Instance %r with entity name '%s' has no mapper associated with it" %
-            (object, entity_name))
+        raise exc.UnmappedInstanceError(object, entity_name)
     if state.entity_name is not attributes.NO_ENTITY_NAME:
         # Override the given entity name if the object is not transient.
         entity_name = state.entity_name
@@ -482,6 +480,7 @@ def class_mapper(class_, entity_name=None, compile=True, raiseerror=True):
     except (KeyError, AttributeError):
         if not raiseerror:
             return
+        raise exc.UnmappedClassError(class_, entity_name)
         raise sa_exc.InvalidRequestError(
             "Class '%s' entity name '%s' has no mapper associated with it" %
             (class_.__name__, entity_name))
@@ -494,11 +493,13 @@ def _class_to_mapper(class_or_mapper, entity_name=None, compile=True):
         return class_or_mapper._AliasedClass__mapper
     elif isinstance(class_or_mapper, type):
         return class_mapper(class_or_mapper, entity_name=entity_name, compile=compile)
-    else:
+    elif hasattr(class_or_mapper, 'compile'):
         if compile:
             return class_or_mapper.compile()
         else:
             return class_or_mapper
+    else:
+        raise exc.UnmappedClassError(class_or_mapper, entity_name)
 
 def has_identity(object):
     state = attributes.instance_state(object)
index 6f6dfb6b8a9a126741042fab55bf9f694ea4d34b..9821221afa9f11996040380520cfb24403e86a87 100644 (file)
@@ -1,5 +1,6 @@
 import testenv; testenv.configure_for_tests()
 import gc
+import inspect
 import pickle
 from sqlalchemy.orm import create_session, sessionmaker
 from testlib import engines, sa, testing
@@ -980,6 +981,144 @@ class SessionTest(_fixtures.FixtureTest):
         assert len(list(sess)) == 1
 
 
+class SessionInterface(testing.TestBase):
+    """Bogus args to Session methods produce actionable exceptions."""
+
+    # TODO: expand with message body assertions.
+
+    _class_methods = set(('get', 'load'))
+
+    def _public_session_methods(self):
+        Session = sa.orm.session.Session
+
+        blacklist = set(('begin', 'query'))
+
+        ok = set()
+        for meth in Session.public_methods:
+            if meth in blacklist:
+                continue
+            spec = inspect.getargspec(getattr(Session, meth))
+            if len(spec[0]) > 1 or spec[1]:
+                ok.add(meth)
+        return ok
+
+    def _map_it(self, cls):
+        return mapper(cls, Table('t', sa.MetaData(),
+                                 Column('id', Integer, primary_key=True)))
+
+    def _test_instance_guards(self, user_arg):
+        watchdog = set()
+
+        def x_raises_(obj, method, *args, **kw):
+            watchdog.add(method)
+            callable_ = getattr(obj, method)
+            self.assertRaises(sa.orm.exc.UnmappedInstanceError,
+                              callable_, *args, **kw)
+
+        def raises_(method, *args, **kw):
+            x_raises_(create_session(), method, *args, **kw)
+
+        raises_('__contains__', user_arg)
+
+        raises_('add', user_arg)
+
+        raises_('add_all', (user_arg,))
+
+        raises_('connection', instance=user_arg)
+
+        raises_('delete', user_arg)
+
+        raises_('execute', 'SELECT 1', instance=user_arg)
+
+        raises_('expire', user_arg)
+
+        raises_('expunge', user_arg)
+
+        # flush will no-op without something in the unit of work
+        def _():
+            class OK(object):
+                pass
+            self._map_it(OK)
+
+            s = create_session()
+            s.add(OK())
+            x_raises_(s, 'flush', (user_arg,))
+        _()
+
+        raises_('get_bind', instance=user_arg)
+
+        raises_('is_modified', user_arg)
+
+        raises_('merge', user_arg)
+
+        raises_('refresh', user_arg)
+
+        raises_('save', user_arg)
+
+        raises_('save_or_update', user_arg)
+
+        raises_('scalar', 'SELECT 1', instance=user_arg)
+
+        raises_('update', user_arg)
+
+        instance_methods = self._public_session_methods() - self._class_methods
+
+        eq_(watchdog, instance_methods,
+            watchdog.symmetric_difference(instance_methods))
+
+    def _test_class_guards(self, user_arg):
+        watchdog = set()
+
+        def raises_(method, *args, **kw):
+            watchdog.add(method)
+            callable_ = getattr(create_session(), method)
+            self.assertRaises(sa.orm.exc.UnmappedClassError,
+                              callable_, *args, **kw)
+
+        raises_('get', user_arg, 1)
+
+        raises_('load', user_arg, 1)
+
+        eq_(watchdog, self._class_methods,
+            watchdog.symmetric_difference(self._class_methods))
+
+    def test_unmapped_instance(self):
+        class Unmapped(object):
+            pass
+
+        self._test_instance_guards(Unmapped())
+        self._test_class_guards(Unmapped)
+
+    def test_unmapped_primitives(self):
+        for prim in ('doh', 123, ('t', 'u', 'p', 'l', 'e')):
+            self._test_instance_guards(prim)
+            self._test_class_guards(prim)
+
+    def test_unmapped_class_for_instance(self):
+        class Unmapped(object):
+            pass
+
+        self._test_instance_guards(Unmapped)
+        self._test_class_guards(Unmapped)
+
+    def test_mapped_class_for_instance(self):
+        class Mapped(object):
+            pass
+        self._map_it(Mapped)
+
+        self._test_instance_guards(Mapped)
+        # no class guards- it would pass.
+
+    def test_missing_state(self):
+        class Mapped(object):
+            pass
+        early = Mapped()
+        self._map_it(Mapped)
+
+        self._test_instance_guards(early)
+        self._test_class_guards(early)
+
+
 class TLTransactionTest(engine_base.AltEngineTest, _base.MappedTest):
     def create_engine(self):
         return engines.testing_engine(options=dict(strategy='threadlocal'))