]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- 78 chars
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 8 Jul 2010 18:13:56 +0000 (14:13 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 8 Jul 2010 18:13:56 +0000 (14:13 -0400)
- Added "column_descriptions" accessor to Query,
returns a list of dictionaries containing
naming/typing information about the entities
the Query will return.  Can be helpful for
building GUIs on top of ORM queries.

CHANGES
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/util.py
test/orm/test_query.py

diff --git a/CHANGES b/CHANGES
index fb9603e524b91f41744e7b42fe99887e873118e0..83ed69118b6bedce1688773a40c2f7faa79abd7d 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -20,6 +20,12 @@ CHANGES
     cls._sa_class_manager.mapper now raise 
     UnmappedClassError().  [ticket:1142]
     
+  - Added "column_descriptions" accessor to Query,
+    returns a list of dictionaries containing
+    naming/typing information about the entities
+    the Query will return.  Can be helpful for 
+    building GUIs on top of ORM queries.
+    
 0.6.2
 =====
 - orm
index 2405d91c0b33305a63e6547a699386c443ed77ad..ab31736ed19522be3f83e91ad779f00195f5271e 100644 (file)
@@ -96,7 +96,8 @@ class QueryableAttribute(interfaces.PropComparator):
         """Construct an InstrumentedAttribute.
 
           comparator
-            a sql.Comparator to which class-level compare/math events will be sent
+            a sql.Comparator to which class-level compare/math events will be
+            sent
         """
         self.key = key
         self.impl = impl
@@ -104,7 +105,8 @@ class QueryableAttribute(interfaces.PropComparator):
         self.parententity = parententity
 
     def get_history(self, instance, **kwargs):
-        return self.impl.get_history(instance_state(instance), instance_dict(instance), **kwargs)
+        return self.impl.get_history(instance_state(instance),
+                                        instance_dict(instance), **kwargs)
 
     def __selectable__(self):
         # TODO: conditionally attach this method based on clause_element ?
@@ -148,7 +150,8 @@ class InstrumentedAttribute(QueryableAttribute):
     """Public-facing descriptor, placed in the mapped class dictionary."""
 
     def __set__(self, instance, value):
-        self.impl.set(instance_state(instance), instance_dict(instance), value, None)
+        self.impl.set(instance_state(instance), 
+                        instance_dict(instance), value, None)
 
     def __delete__(self, instance):
         self.impl.delete(instance_state(instance), instance_dict(instance))
@@ -156,7 +159,8 @@ class InstrumentedAttribute(QueryableAttribute):
     def __get__(self, instance, owner):
         if instance is None:
             return self
-        return self.impl.get(instance_state(instance), instance_dict(instance))
+        return self.impl.get(instance_state(instance),
+                                instance_dict(instance))
 
 class _ProxyImpl(object):
     accepts_scalar_loader = False
@@ -205,7 +209,8 @@ def proxied_attribute_factory(descriptor):
             return descriptor.__delete__(instance)
 
         def __getattr__(self, attribute):
-            """Delegate __getattr__ to the original descriptor and/or comparator."""
+            """Delegate __getattr__ to the original descriptor and/or
+            comparator."""
             
             try:
                 return getattr(descriptor, attribute)
@@ -214,10 +219,10 @@ def proxied_attribute_factory(descriptor):
                     return getattr(self._comparator, attribute)
                 except AttributeError:
                     raise AttributeError(
-                            'Neither %r object nor %r object has an attribute %r' % (
-                            type(descriptor).__name__, 
-                            type(self._comparator).__name__, 
-                            attribute)
+                    'Neither %r object nor %r object has an attribute %r' % (
+                    type(descriptor).__name__, 
+                    type(self._comparator).__name__, 
+                    attribute)
                     )
 
     Proxy.__name__ = type(descriptor).__name__ + 'Proxy'
@@ -450,7 +455,8 @@ class ScalarAttributeImpl(AttributeImpl):
             old = dict_.get(self.key, NO_VALUE)
 
         if self.extensions:
-            value = self.fire_replace_event(state, dict_, value, old, initiator)
+            value = self.fire_replace_event(state, dict_, 
+                                                value, old, initiator)
         state.modified_event(dict_, self, False, old)
         dict_[self.key] = value
 
@@ -469,8 +475,9 @@ class ScalarAttributeImpl(AttributeImpl):
 
 
 class MutableScalarAttributeImpl(ScalarAttributeImpl):
-    """represents a scalar value-holding InstrumentedAttribute, which can detect
-    changes within the value itself.
+    """represents a scalar value-holding InstrumentedAttribute, which can
+    detect changes within the value itself.
+
     """
 
     uses_objects = False
@@ -522,7 +529,8 @@ class MutableScalarAttributeImpl(ScalarAttributeImpl):
 
         if self.extensions:
             old = self.get(state, dict_)
-            value = self.fire_replace_event(state, dict_, value, old, initiator)
+            value = self.fire_replace_event(state, dict_, 
+                                            value, old, initiator)
 
         state.modified_event(dict_, self, True, NEVER_SET)
         dict_[self.key] = value
@@ -544,13 +552,13 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
                     trackparent=False, extension=None, copy_function=None,
                     compare_function=None, **kwargs):
         super(ScalarObjectAttributeImpl, self).__init__(
-                                                class_, 
-                                                key,
-                                                callable_, 
-                                                trackparent=trackparent, 
-                                                extension=extension,
-                                                compare_function=compare_function, 
-                                                **kwargs)
+                                            class_, 
+                                            key,
+                                            callable_, 
+                                            trackparent=trackparent, 
+                                            extension=extension,
+                                            compare_function=compare_function, 
+                                            **kwargs)
         if compare_function is None:
             self.is_equal = identity_equal
 
@@ -623,8 +631,8 @@ class CollectionAttributeImpl(AttributeImpl):
 
     InstrumentedCollectionAttribute holds an arbitrary, user-specified
     container object (defaulting to a list) and brokers access to the
-    CollectionAdapter, a "view" onto that object that presents consistent
-    bag semantics to the orm layer independent of the user data implementation.
+    CollectionAdapter, a "view" onto that object that presents consistent bag
+    semantics to the orm layer independent of the user data implementation.
 
     """
     accepts_scalar_loader = False
@@ -634,13 +642,13 @@ class CollectionAttributeImpl(AttributeImpl):
                     typecallable=None, trackparent=False, extension=None,
                     copy_function=None, compare_function=None, **kwargs):
         super(CollectionAttributeImpl, self).__init__(
-                                                class_, 
-                                                key, 
-                                                callable_, 
-                                                trackparent=trackparent,
-                                                extension=extension,
-                                                compare_function=compare_function, 
-                                                **kwargs)
+                                            class_, 
+                                            key, 
+                                            callable_, 
+                                            trackparent=trackparent,
+                                            extension=extension,
+                                            compare_function=compare_function, 
+                                            **kwargs)
 
         if copy_function is None:
             copy_function = self.__copy
@@ -661,7 +669,8 @@ class CollectionAttributeImpl(AttributeImpl):
         for ext in self.extensions:
             value = ext.append(state, value, initiator or self)
 
-        state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
+        state.modified_event(dict_, self, True, 
+                                NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
 
         if self.trackparent and value is not None:
             self.sethasparent(instance_state(value), True)
@@ -669,7 +678,8 @@ class CollectionAttributeImpl(AttributeImpl):
         return value
 
     def fire_pre_remove_event(self, state, dict_, initiator):
-        state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
+        state.modified_event(dict_, self, True, 
+                                NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
 
     def fire_remove_event(self, state, dict_, value, initiator):
         if self.trackparent and value is not None:
@@ -678,7 +688,8 @@ class CollectionAttributeImpl(AttributeImpl):
         for ext in self.extensions:
             ext.remove(state, value, initiator or self)
 
-        state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
+        state.modified_event(dict_, self, True, 
+                                NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
 
     def delete(self, state, dict_):
         if self.key not in dict_:
@@ -709,7 +720,8 @@ class CollectionAttributeImpl(AttributeImpl):
         collection = self.get_collection(state, dict_, passive=passive)
         if collection is PASSIVE_NO_RESULT:
             value = self.fire_append_event(state, dict_, value, initiator)
-            assert self.key not in dict_, "Collection was loaded during event handling."
+            assert self.key not in dict_, \
+                    "Collection was loaded during event handling."
             state.get_pending(self.key).append(value)
         else:
             collection.append_with_event(value, initiator)
@@ -721,7 +733,8 @@ class CollectionAttributeImpl(AttributeImpl):
         collection = self.get_collection(state, state.dict, passive=passive)
         if collection is PASSIVE_NO_RESULT:
             self.fire_remove_event(state, dict_, value, initiator)
-            assert self.key not in dict_, "Collection was loaded during event handling."
+            assert self.key not in dict_, \
+                    "Collection was loaded during event handling."
             state.get_pending(self.key).remove(value)
         else:
             collection.remove_with_event(value, initiator)
@@ -806,7 +819,8 @@ class CollectionAttributeImpl(AttributeImpl):
 
         return user_data
 
-    def get_collection(self, state, dict_, user_data=None, passive=PASSIVE_OFF):
+    def get_collection(self, state, dict_, 
+                            user_data=None, passive=PASSIVE_OFF):
         """Retrieve the CollectionAdapter associated with the given state.
 
         Creates a new CollectionAdapter if one does not exist.
@@ -840,7 +854,8 @@ class GenericBackrefExtension(interfaces.AttributeExtension):
         if oldchild is not None and oldchild is not PASSIVE_NO_RESULT:
             # With lazy=None, there's no guarantee that the full collection is
             # present when updating via a backref.
-            old_state, old_dict = instance_state(oldchild), instance_dict(oldchild)
+            old_state, old_dict = instance_state(oldchild),\
+                                    instance_dict(oldchild)
             impl = old_state.get_impl(self.key)
             try:
                 impl.remove(old_state, 
@@ -851,31 +866,37 @@ class GenericBackrefExtension(interfaces.AttributeExtension):
                 pass
                 
         if child is not None:
-            child_state, child_dict = instance_state(child), instance_dict(child)
+            child_state, child_dict = instance_state(child),\
+                                        instance_dict(child)
             child_state.get_impl(self.key).append(
                                             child_state, 
                                             child_dict, 
                                             state.obj(), 
-                                            initiator, passive=PASSIVE_NO_FETCH)
+                                            initiator, 
+                                            passive=PASSIVE_NO_FETCH)
         return child
 
     def append(self, state, child, initiator):
-        child_state, child_dict = instance_state(child), instance_dict(child)
+        child_state, child_dict = instance_state(child), \
+                                    instance_dict(child)
         child_state.get_impl(self.key).append(
                                             child_state, 
                                             child_dict, 
                                             state.obj(), 
-                                            initiator, passive=PASSIVE_NO_FETCH)
+                                            initiator, 
+                                            passive=PASSIVE_NO_FETCH)
         return child
 
     def remove(self, state, child, initiator):
         if child is not None:
-            child_state, child_dict = instance_state(child), instance_dict(child)
+            child_state, child_dict = instance_state(child),\
+                                        instance_dict(child)
             child_state.get_impl(self.key).remove(
                                             child_state, 
                                             child_dict, 
                                             state.obj(), 
-                                            initiator, passive=PASSIVE_NO_FETCH)
+                                            initiator,
+                                            passive=PASSIVE_NO_FETCH)
 
 
 class Events(object):
@@ -922,7 +943,7 @@ class ClassManager(dict):
         self.mutable_attributes = set()
         self.local_attrs = {}
         self.originals = {}
-        for base in class_.__mro__[-2:0:-1]:   # reverse, skipping 1st and last
+        for base in class_.__mro__[-2:0:-1]:  # reverse, skipping 1st and last
             if not isinstance(base, type):
                 continue
             cls_state = manager_of_class(base)
@@ -962,7 +983,8 @@ class ClassManager(dict):
             self.deferred_scalar_loader = deferred_scalar_loader
     
     def _subclass_manager(self, cls):
-        """Create a new ClassManager for a subclass of this ClassManager's class.
+        """Create a new ClassManager for a subclass of this ClassManager's
+        class.
         
         This is called automatically when attributes are instrumented so that
         the attributes can be propagated to subclasses against their own
@@ -1057,7 +1079,8 @@ class ClassManager(dict):
     def install_descriptor(self, key, inst):
         if key in (self.STATE_ATTR, self.MANAGER_ATTR):
             raise KeyError("%r: requested attribute name conflicts with "
-                           "instrumentation attribute of the same name." % key)
+                           "instrumentation attribute of the same name." %
+                           key)
         setattr(self.class_, key, inst)
 
     def uninstall_descriptor(self, key):
@@ -1066,7 +1089,8 @@ class ClassManager(dict):
     def install_member(self, key, implementation):
         if key in (self.STATE_ATTR, self.MANAGER_ATTR):
             raise KeyError("%r: requested attribute name conflicts with "
-                           "instrumentation attribute of the same name." % key)
+                           "instrumentation attribute of the same name." %
+                           key)
         self.originals.setdefault(key, getattr(self.class_, key, None))
         setattr(self.class_, key, implementation)
 
@@ -1101,11 +1125,13 @@ class ClassManager(dict):
 
     def new_instance(self, state=None):
         instance = self.class_.__new__(self.class_)
-        setattr(instance, self.STATE_ATTR, state or self._create_instance_state(instance))
+        setattr(instance, self.STATE_ATTR, 
+                    state or self._create_instance_state(instance))
         return instance
 
     def setup_instance(self, instance, state=None):
-        setattr(instance, self.STATE_ATTR, state or self._create_instance_state(instance))
+        setattr(instance, self.STATE_ATTR, 
+                    state or self._create_instance_state(instance))
     
     def teardown_instance(self, instance):
         delattr(instance, self.STATE_ATTR)
@@ -1208,7 +1234,8 @@ class _ClassInstrumentationAdapter(ClassManager):
         if delegate:
             return delegate(key, state, factory)
         else:
-            return ClassManager.initialize_collection(self, key, state, factory)
+            return ClassManager.initialize_collection(self, key, 
+                                                        state, factory)
 
     def new_instance(self, state=None):
         instance = self.class_.__new__(self.class_)
@@ -1391,7 +1418,8 @@ def register_attribute(class_, key, **kw):
     comparator = kw.pop('comparator', None)
     parententity = kw.pop('parententity', None)
     doc = kw.pop('doc', None)
-    register_descriptor(class_, key, proxy_property, comparator, parententity, doc=doc)
+    register_descriptor(class_, key, proxy_property, 
+                            comparator, parententity, doc=doc)
     if not proxy_property:
         register_attribute_impl(class_, key, **kw)
     
@@ -1433,7 +1461,8 @@ def register_descriptor(class_, key, proxy_property=None, comparator=None,
         proxy_type = proxied_attribute_factory(proxy_property)
         descriptor = proxy_type(key, proxy_property, comparator, parententity)
     else:
-        descriptor = InstrumentedAttribute(key, comparator=comparator, parententity=parententity)
+        descriptor = InstrumentedAttribute(key, comparator=comparator,
+                                            parententity=parententity)
     
     descriptor.__doc__ = doc
         
@@ -1452,7 +1481,8 @@ def init_collection(obj, key):
         for elem in values:
             collection_adapter.append_without_event(elem)
     
-    For an easier way to do the above, see :func:`~sqlalchemy.orm.attributes.set_committed_value`.
+    For an easier way to do the above, see
+     :func:`~sqlalchemy.orm.attributes.set_committed_value`.
     
     obj is an instrumented object instance.  An InstanceState
     is accepted directly for backwards compatibility but 
@@ -1528,14 +1558,15 @@ def del_attribute(instance, key):
     state.get_impl(key).delete(state, dict_)
 
 def is_instrumented(instance, key):
-    """Return True if the given attribute on the given instance is instrumented
-    by the attributes package.
+    """Return True if the given attribute on the given instance is
+    instrumented by the attributes package.
     
     This function may be used regardless of instrumentation
     applied directly to the class, i.e. no descriptors are required.
     
     """
-    return manager_of_class(instance.__class__).is_instrumented(key, search=True)
+    return manager_of_class(instance.__class__).\
+                        is_instrumented(key, search=True)
 
 class InstrumentationRegistry(object):
     """Private instrumentation registration singleton.
@@ -1590,11 +1621,12 @@ class InstrumentationRegistry(object):
         return manager
 
     def _collect_management_factories_for(self, cls):
-        """Return a collection of factories in play or specified for a hierarchy.
+        """Return a collection of factories in play or specified for a
+        hierarchy.
 
-        Traverses the entire inheritance graph of a cls and returns a collection
-        of instrumentation factories for those classes.  Factories are extracted
-        from active ClassManagers, if available, otherwise
+        Traverses the entire inheritance graph of a cls and returns a
+        collection of instrumentation factories for those classes. Factories
+        are extracted from active ClassManagers, if available, otherwise
         instrumentation_finders is consulted.
 
         """
@@ -1616,7 +1648,8 @@ class InstrumentationRegistry(object):
         return factories
 
     def manager_of_class(self, cls):
-        # this is only called when alternate instrumentation has been established
+        # this is only called when alternate instrumentation 
+        # has been established
         if cls is None:
             return None
         try:
@@ -1627,22 +1660,26 @@ class InstrumentationRegistry(object):
             return finder(cls)
 
     def state_of(self, instance):
-        # this is only called when alternate instrumentation has been established
+        # this is only called when alternate instrumentation 
+        # has been established
         if instance is None:
             raise AttributeError("None has no persistent state.")
         try:
             return self._state_finders[instance.__class__](instance)
         except KeyError:
-            raise AttributeError("%r is not instrumented" % instance.__class__)
+            raise AttributeError("%r is not instrumented" %
+                                    instance.__class__)
 
     def dict_of(self, instance):
-        # this is only called when alternate instrumentation has been established
+        # this is only called when alternate instrumentation 
+        # has been established
         if instance is None:
             raise AttributeError("None has no persistent state.")
         try:
             return self._dict_finders[instance.__class__](instance)
         except KeyError:
-            raise AttributeError("%r is not instrumented" % instance.__class__)
+            raise AttributeError("%r is not instrumented" %
+                                    instance.__class__)
         
     def unregister(self, class_):
         if class_ in self._manager_finders:
index b3588ae59da8c48db535b5cb8a7c7e4ec893b1f5..0afe622c14dec006ea67a35c1701148c0262f358 100644 (file)
@@ -112,10 +112,12 @@ class Query(object):
         for ent in entities:
             for entity in ent.entities:
                 if entity not in d:
-                    mapper, selectable, is_aliased_class = _entity_info(entity)
+                    mapper, selectable, is_aliased_class = \
+                                            _entity_info(entity)
                     if not is_aliased_class and mapper.with_polymorphic:
                         with_polymorphic = mapper._with_polymorphic_mappers
-                        if mapper.mapped_table not in self._polymorphic_adapters:
+                        if mapper.mapped_table not in \
+                                            self._polymorphic_adapters:
                             self.__mapper_loads_polymorphically_with(mapper, 
                                 sql_util.ColumnAdapter(
                                             selectable, 
@@ -1115,9 +1117,9 @@ class Query(object):
                 isinstance(keys[1], expression.ClauseElement) and \
                 not isinstance(keys[1], expression.FromClause):
             raise sa_exc.ArgumentError(
-                        "You appear to be passing a clause expression as the second "
-                        "argument to query.join().   Did you mean to use the form "
-                        "query.join((target, onclause))?  Note the tuple.")
+                "You appear to be passing a clause expression as the second "
+                "argument to query.join().   Did you mean to use the form "
+                "query.join((target, onclause))?  Note the tuple.")
             
         for arg1 in util.to_list(keys):
             if isinstance(arg1, tuple):
@@ -1567,7 +1569,49 @@ class Query(object):
                         querycontext.statement, params=self._params,
                         mapper=self._mapper_zero_or_none())
         return self.instances(result, querycontext)
-
+    
+    @property
+    def column_descriptions(self):
+        """Return metadata about the columns which would be 
+        returned by this :class:`Query`.
+        
+        Format is a list of dictionaries::
+            
+            user_alias = aliased(User, name='user2')
+            q = sess.query(User, User.id, user_alias)
+            
+            # this expression:
+            q.columns
+            
+            # would return:
+            [
+                {
+                    'name':'User',
+                    'type':User,
+                    'aliased':False,
+                },
+                {
+                    'name':'id',
+                    'type':Integer(),
+                    'aliased':False
+                },
+                {
+                    'name':'user2',
+                    'type':User,
+                    'aliased':True
+                }
+            ]
+            
+        """
+        return [
+            {
+                'name':ent._label_name,
+                'type':ent.type,
+                'aliased':getattr(ent, 'is_aliased_class', False),
+            }
+            for ent in self._entities
+        ]
+        
     def instances(self, cursor, __context=None):
         """Given a ResultProxy cursor as returned by connection.execute(),
         return an ORM result as an iterator.
@@ -2090,7 +2134,8 @@ class Query(object):
                 value_evaluators = {}
                 for key,value in values.iteritems():
                     key = expression._column_as_key(key)
-                    value_evaluators[key] = evaluator_compiler.process(expression._literal_as_binds(value))
+                    value_evaluators[key] = evaluator_compiler.process(
+                                        expression._literal_as_binds(value))
             except evaluator.UnevaluatableError:
                 raise sa_exc.InvalidRequestError(
                         "Could not evaluate current criteria in Python. "
@@ -2372,10 +2417,12 @@ class _MapperEntity(_QueryEntity):
         self.is_aliased_class = is_aliased_class
         if is_aliased_class:
             self.path_entity = self.entity = self.entity_zero = entity
+            self._label_name = self.entity._sa_label_name
         else:
             self.path_entity = mapper
             self.entity = self.entity_zero = mapper
-
+            self._label_name = self.mapper.class_.__name__
+            
     def set_with_polymorphic(self, query, cls_or_mappers, 
                                 selectable, discriminator):
         if cls_or_mappers is None:
@@ -2393,6 +2440,10 @@ class _MapperEntity(_QueryEntity):
             self.selectable = from_obj
             self.adapter = query._get_polymorphic_adapter(self, from_obj)
 
+    @property
+    def type(self):
+        return self.mapper.class_
+
     def corresponds_to(self, entity):
         if _is_aliased_class(entity) or self.is_aliased_class:
             return entity is self.path_entity
@@ -2456,13 +2507,8 @@ class _MapperEntity(_QueryEntity):
                                 polymorphic_discriminator=
                                     self._polymorphic_discriminator)
 
-        if self.is_aliased_class:
-            entname = self.entity._sa_label_name
-        else:
-            entname = self.mapper.class_.__name__
-        
-        return _instance, entname
-
+        return _instance, self._label_name
+    
     def setup_context(self, query, context):
         adapter = self._get_entity_clauses(query, context)
 
@@ -2509,12 +2555,12 @@ class _ColumnEntity(_QueryEntity):
     def __init__(self, query, column):
         if isinstance(column, basestring):
             column = sql.literal_column(column)
-            self._result_label = column.name
+            self._label_name = column.name
         elif isinstance(column, attributes.QueryableAttribute):
-            self._result_label = column.key
+            self._label_name = column.key
             column = column.__clause_element__()
         else:
-            self._result_label = getattr(column, 'key', None)
+            self._label_name = getattr(column, 'key', None)
 
         if not isinstance(column, expression.ColumnElement) and \
                             hasattr(column, '_select_iterable'):
@@ -2565,6 +2611,10 @@ class _ColumnEntity(_QueryEntity):
         else:
             self.entity_zero = None
     
+    @property
+    def type(self):
+        return self.column.type
+        
     def adapt_to_selectable(self, query, sel):
         _ColumnEntity(query, sel.corresponding_column(self.column))
         
@@ -2595,7 +2645,7 @@ class _ColumnEntity(_QueryEntity):
         def proc(row, result):
             return row[column]
 
-        return (proc, self._result_label)
+        return proc, self._label_name
 
     def setup_context(self, query, context):
         column = self._resolve_expr_against_query_aliases(
index a2e3c54331bf4bced68ea898e84caa9782522304..ef5413724b22a3b68ad19e0372b33132e6d4cc6d 100644 (file)
@@ -1,5 +1,6 @@
 # mapper/util.py
-# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
+# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer
+# mike_mp@zzzcomputing.com
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
@@ -7,8 +8,9 @@
 import sqlalchemy.exceptions as sa_exc
 from sqlalchemy import sql, util
 from sqlalchemy.sql import expression, util as sql_util, operators
-from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, PropComparator, \
-                                        MapperProperty, AttributeExtension
+from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE,\
+                                PropComparator, MapperProperty,\
+                                AttributeExtension
 from sqlalchemy.orm import attributes, exc
 
 mapperlib = None
@@ -95,7 +97,8 @@ def polymorphic_union(table_map, typecolname, aliasname='p_union'):
     for key in table_map.keys():
         table = table_map[key]
 
-        # mysql doesnt like selecting from a select; make it an alias of the select
+        # mysql doesnt like selecting from a select; 
+        # make it an alias of the select
         if isinstance(table, sql.Select):
             table = table.alias()
             table_map[key] = table
@@ -116,10 +119,10 @@ def polymorphic_union(table_map, typecolname, aliasname='p_union'):
     result = []
     for type, table in table_map.iteritems():
         if typecolname is not None:
-            result.append(sql.select([col(name, table) for name in colnames] +
-                             [sql.literal_column(
-                                sql_util._quote_ddl_expr(type)).label(typecolname)
-                              ],
+            result.append(
+                    sql.select([col(name, table) for name in colnames] +
+                    [sql.literal_column(sql_util._quote_ddl_expr(type)).
+                            label(typecolname)],
                              from_obj=[table]))
         else:
             result.append(sql.select([col(name, table) for name in colnames],
@@ -261,13 +264,16 @@ class ORMAdapter(sql_util.ColumnAdapter):
     and the AliasedClass if any is referenced.
 
     """
-    def __init__(self, entity, equivalents=None, chain_to=None, adapt_required=False):
+    def __init__(self, entity, equivalents=None, 
+                            chain_to=None, adapt_required=False):
         self.mapper, selectable, is_aliased_class = _entity_info(entity)
         if is_aliased_class:
             self.aliased_class = entity
         else:
             self.aliased_class = None
-        sql_util.ColumnAdapter.__init__(self, selectable, equivalents, chain_to, adapt_required=adapt_required)
+        sql_util.ColumnAdapter.__init__(self, selectable, 
+                                        equivalents, chain_to,
+                                        adapt_required=adapt_required)
 
     def replace(self, elem):
         entity = elem._annotations.get('parentmapper', None)
@@ -298,7 +304,8 @@ class AliasedClass(object):
         self.__target = self.__mapper.class_
         if alias is None:
             alias = self.__mapper._with_polymorphic_selectable.alias()
-        self.__adapter = sql_util.ClauseAdapter(alias, equivalents=self.__mapper._equivalent_columns)
+        self.__adapter = sql_util.ClauseAdapter(alias,
+                                equivalents=self.__mapper._equivalent_columns)
         self.__alias = alias
         # used to assign a name to the RowTuple object
         # returned by Query.
@@ -306,20 +313,29 @@ class AliasedClass(object):
         self.__name__ = 'AliasedClass_' + str(self.__target)
 
     def __getstate__(self):
-        return {'mapper':self.__mapper, 'alias':self.__alias, 'name':self._sa_label_name}
+        return {
+            'mapper':self.__mapper, 
+            'alias':self.__alias, 
+            'name':self._sa_label_name
+        }
 
     def __setstate__(self, state):
         self.__mapper = state['mapper']
         self.__target = self.__mapper.class_
         alias = state['alias']
-        self.__adapter = sql_util.ClauseAdapter(alias, equivalents=self.__mapper._equivalent_columns)
+        self.__adapter = sql_util.ClauseAdapter(alias,
+                                equivalents=self.__mapper._equivalent_columns)
         self.__alias = alias
         name = state['name']
         self._sa_label_name = name
         self.__name__ = 'AliasedClass_' + str(self.__target)
 
     def __adapt_element(self, elem):
-        return self.__adapter.traverse(elem)._annotate({'parententity': self, 'parentmapper':self.__mapper})
+        return self.__adapter.traverse(elem).\
+                    _annotate({
+                        'parententity': self, 
+                        'parentmapper':self.__mapper}
+                    )
 
     def __adapt_prop(self, prop):
         existing = getattr(self.__target, prop.key)
@@ -361,7 +377,8 @@ class AliasedClass(object):
             id(self), self.__target.__name__)
 
 def _orm_annotate(element, exclude=None):
-    """Deep copy the given ClauseElement, annotating each element with the "_orm_adapt" flag.
+    """Deep copy the given ClauseElement, annotating each element with the
+    "_orm_adapt" flag.
 
     Elements within the exclude collection will be cloned but not annotated.
 
@@ -375,7 +392,8 @@ class _ORMJoin(expression.Join):
 
     __visit_name__ = expression.Join.__visit_name__
 
-    def __init__(self, left, right, onclause=None, isouter=False, join_to_left=True):
+    def __init__(self, left, right, onclause=None, 
+                            isouter=False, join_to_left=True):
         adapt_from = None
 
         if hasattr(left, '_orm_mappers'):
@@ -408,7 +426,8 @@ class _ORMJoin(expression.Join):
                 prop = None
 
             if prop:
-                pj, sj, source, dest, secondary, target_adapter = prop._create_joins(
+                pj, sj, source, dest, \
+                secondary, target_adapter = prop._create_joins(
                                 source_selectable=adapt_from,
                                 dest_selectable=adapt_to,
                                 source_polymorphic=True,
@@ -451,10 +470,10 @@ def outerjoin(left, right, onclause=None, join_to_left=True):
     """Produce a left outer join between left and right clauses.
 
     In addition to the interface provided by
-    :func:`~sqlalchemy.sql.expression.outerjoin()`, left and right may be mapped
-    classes or AliasedClass instances. The onclause may be a
-    string name of a relationship(), or a class-bound descriptor
-    representing a relationship.
+    :func:`~sqlalchemy.sql.expression.outerjoin()`, left and right may be
+    mapped classes or AliasedClass instances. The onclause may be a string
+    name of a relationship(), or a class-bound descriptor representing a
+    relationship.
 
     """
     return _ORMJoin(left, right, onclause, True, join_to_left)
@@ -462,16 +481,15 @@ def outerjoin(left, right, onclause=None, join_to_left=True):
 def with_parent(instance, prop):
     """Return criterion which selects instances with a given parent.
 
-    instance
-      a parent instance, which should be persistent or detached.
+    :param instance: a parent instance, which should be persistent 
+      or detached.
 
-    property
-      a class-attached descriptor, MapperProperty or string property name
+    :param property: a class-attached descriptor, MapperProperty or 
+      string property name
       attached to the parent instance.
 
-    \**kwargs
-      all extra keyword arguments are propagated to the constructor of
-      Query.
+    :param \**kwargs: all extra keyword arguments are propagated 
+      to the constructor of Query.
 
     """
     if isinstance(prop, basestring):
@@ -529,21 +547,30 @@ def _entity_descriptor(entity, key):
             desc = getattr(entity, key)
             return desc, desc.property
         except AttributeError:
-            raise sa_exc.InvalidRequestError("Entity '%s' has no property '%s'" % (entity, key))
+            raise sa_exc.InvalidRequestError(
+                        "Entity '%s' has no property '%s'" % 
+                        (entity, key)
+                    )
             
     elif isinstance(entity, type):
         try:
             desc = attributes.manager_of_class(entity)[key]
             return desc, desc.property
         except KeyError:
-            raise sa_exc.InvalidRequestError("Entity '%s' has no property '%s'" % (entity, key))
+            raise sa_exc.InvalidRequestError(
+                        "Entity '%s' has no property '%s'" % 
+                        (entity, key)
+                    )
             
     else:
         try:
             desc = entity.class_manager[key]
             return desc, desc.property
         except KeyError:
-            raise sa_exc.InvalidRequestError("Entity '%s' has no property '%s'" % (entity, key))
+            raise sa_exc.InvalidRequestError(
+                        "Entity '%s' has no property '%s'" % 
+                        (entity, key)
+                    )
 
 def _orm_columns(entity):
     mapper, selectable, is_aliased_class = _entity_info(entity)
@@ -563,7 +590,8 @@ def _state_mapper(state):
     return state.manager.mapper
 
 def object_mapper(instance):
-    """Given an object, return the primary Mapper associated with the object instance.
+    """Given an object, return the primary Mapper associated with the object
+    instance.
 
     Raises UnmappedInstanceError if no mapping is configured.
 
index e8289e08c834cf7afa95b0f46265d5b469dbb32d..2933d1bc4aaf779fd9ac60a7f99d93b8f69ea32e 100644 (file)
@@ -13,7 +13,8 @@ import sqlalchemy as sa
 from sqlalchemy.test import testing, AssertsCompiledSQL, Column, engines
 
 from test.orm import _fixtures
-from test.orm._fixtures import keywords, addresses, Base, Keyword, FixtureTest, \
+from test.orm._fixtures import keywords, addresses, Base, \
+            Keyword, FixtureTest, \
            Dingaling, item_keywords, dingalings, User, items,\
            orders, Address, users, nodes, \
             order_items, Item, Order, Node, \
@@ -66,10 +67,60 @@ class RowTupleTest(QueryTest):
             'uname':users.c.name
         })
         
-        row  = create_session().query(User.id, User.uname).filter(User.id==7).first()
+        row  = create_session().\
+                    query(User.id, User.uname).\
+                    filter(User.id==7).first()
         assert row.id == 7
         assert row.uname == 'jack'
 
+    def test_column_metadata(self):
+        mapper(User, users)
+        mapper(Address, addresses)
+        sess = create_session()
+        user_alias = aliased(User)
+        address_alias = aliased(Address, name='aalias')
+        fn = func.count(User.id)
+        
+        for q, asserted in [
+            (
+                sess.query(User),
+                [{'name':'User', 'type':User, 'aliased':False}]
+            ),
+            (
+                sess.query(User.id, User),
+                [
+                    {'name':'id', 'type':users.c.id.type, 'aliased':False},
+                    {'name':'User', 'type':User, 'aliased':False}
+                ]
+            ),
+            (
+                sess.query(User.id, user_alias),
+                [
+                    {'name':'id', 'type':users.c.id.type, 'aliased':False},
+                    {'name':None, 'type':User, 'aliased':True}
+                ]
+            ),
+            (
+                sess.query(address_alias),
+                [
+                    {'name':'aalias', 'type':Address, 'aliased':True}
+                ]
+            ),
+            (
+                sess.query(User.name.label('uname'), fn),
+                [
+                    {'name':'uname', 'type':users.c.name.type,
+                                            'aliased':False},
+                    {'name':None, 'type':fn.type, 'aliased':False},
+                ]
+            )
+        ]:
+            eq_(
+                q.column_descriptions,
+                asserted
+            )
+        
+        
 class GetTest(QueryTest):
     def test_get(self):
         s = create_session()