]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- at long last have gotten the "proxy_property" keyword
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 6 Aug 2010 22:53:22 +0000 (18:53 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 6 Aug 2010 22:53:22 +0000 (18:53 -0400)
arg of register_descriptor to not be needed.   synonym, comparable, concreteinherited
props now supply a descriptor directly in the class dict, whose
__get__(None, cls) supplies a QueryableAttribute.   The basic idea is that
the hybrid prop can be used for this.   Payoff here is arguable, except that
hybrid can be at the base of future synonym/comparable operations.

doc/build/mappers.rst
lib/sqlalchemy/ext/hybrid.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/util.py
test/orm/test_mapper.py

index 97677aa080ebd5b7d593627f1f2d06169b6e8aa9..98dd3997a034e1a11c1827a75d4d555cb34986ff 100644 (file)
@@ -295,28 +295,39 @@ through the ``email`` descriptor and into the ``_email``
 mapped attribute, the class level ``EmailAddress.email``
 attribute does not have the usual expression semantics
 usable with :class:`.Query`. To provide
-these, we instead use the :func:`.synonym`
-function as follows::
+these, we instead use the :func:`.hybrid.property`
+decorator as follows::
 
-    mapper(EmailAddress, addresses_table, properties={
-        'email': synonym('_email', map_column=True)
-    })
+    from sqlalchemy.ext import hybrid
+
+    class EmailAddress(object):
+        
+        @hybrid.property
+        def email(self):
+            return self._email
+            
+        @email.setter
+        def email(self, email):
+            self._email = email
+    
+        @email.expression
+        def email(cls):
+            return cls._email
 
 The ``email`` attribute is now usable in the same way as any
 other mapped attribute, including filter expressions,
 get/set operations, etc.::
 
-    address = session.query(EmailAddress).filter(EmailAddress.email == 'some address').one()
+    address = session.query(EmailAddress).\\
+                filter(EmailAddress.email == 'some address').\\
+                one()
 
     address.email = 'some other address'
-    session.flush()
-
-    q = session.query(EmailAddress).filter_by(email='some other address')
+    session.commit()
 
-If the mapped class does not provide a property, the :func:`.synonym` construct will create a default getter/setter object automatically.
+    q = session.query(EmailAddress).\
+                filter_by(email='some other address')
 
-To use synonyms with :mod:`~sqlalchemy.ext.declarative`, see the section 
-:ref:`declarative_synonyms`.
 
 .. _custom_comparators:
 
index 2432ec9d749afb8ca15952d1ff943e346f4d9e43..3134db12deb884d364680a22a2ae427ea5cecef6 100644 (file)
@@ -62,11 +62,12 @@ or as the class itself::
     ### TODO ADD EXAMPLES HERE AND STUFF THIS ISN'T FINISHED ###
     
 """
+from sqlalchemy import util
 
 class method(object):
     def __init__(self, func, expr=None):
         self.func = func
-        self.expr = expr or fund
+        self.expr = expr or func
         
     def __get__(self, instance, owner):
         if instance is None:
@@ -78,16 +79,17 @@ class method(object):
         self.expr = expr
         return self
 
-class property(object):
+class property_(object):
     def __init__(self, fget, fset=None, fdel=None, expr=None):
         self.fget = fget
         self.fset = fset
         self.fdel = fdel
         self.expr = expr or fget
-        
+        util.update_wrapper(self, fget)
+
     def __get__(self, instance, owner):
         if instance is None:
-            return self.expr(owner)
+            return util.update_wrapper(self.expr(owner), self)
         else:
             return self.fget(instance)
             
@@ -109,3 +111,4 @@ class property(object):
         self.expr = expr
         return self
     
+
index f91ff51f2f7b489e34020dd87f02d6e0dcdc1b43..c21585cea631f0a8643fc527373026dc2504f0e2 100644 (file)
@@ -176,23 +176,30 @@ def proxied_attribute_factory(descriptor):
     behavior and getattr() to the given descriptor.
     """
 
-    class Proxy(InstrumentedAttribute):
+    class Proxy(QueryableAttribute):
         """A combination of InsturmentedAttribute and a regular descriptor."""
 
-        def __init__(self, key, descriptor, comparator, parententity):
+        def __init__(self, key, descriptor, comparator, adapter=None):
             self.key = key
             # maintain ProxiedAttribute.user_prop compatability.
             self.descriptor = self.user_prop = descriptor
             self._comparator = comparator
-            self._parententity = parententity
             self.impl = _ProxyImpl(key)
-
+            self.adapter = adapter
+            
         @util.memoized_property
         def comparator(self):
             if util.callable(self._comparator):
                 self._comparator = self._comparator()
+            if self.adapter:
+                self._comparator = self._comparator.adapted(self.adapter)
             return self._comparator
-
+        
+        def adapted(self, adapter):
+            return self.__class__(self.key, self.descriptor,
+                                       self._comparator,
+                                       adapter)
+        
         def __get__(self, instance, owner):
             """Delegate __get__ to the original descriptor."""
             if instance is None:
@@ -216,7 +223,7 @@ def proxied_attribute_factory(descriptor):
                 return getattr(descriptor, attribute)
             except AttributeError:
                 try:
-                    return getattr(self._comparator, attribute)
+                    return getattr(self.comparator, attribute)
                 except AttributeError:
                     raise AttributeError(
                     'Neither %r object nor %r object has an attribute %r' % (
@@ -1462,6 +1469,7 @@ def register_descriptor(class_, key, proxy_property=None, comparator=None,
     manager = manager_of_class(class_)
 
     if proxy_property:
+        raise NotImplementedError()
         proxy_type = proxied_attribute_factory(proxy_property)
         descriptor = proxy_type(key, proxy_property, comparator, parententity)
     else:
index 54674a7cc12aba2a0de17fba7f1d9d60c2336166..4e7702698f053e37bc268baa53f713a253d16e19 100644 (file)
@@ -1074,10 +1074,16 @@ class Mapper(object):
         return result
 
     def _is_userland_descriptor(self, obj):
+#        return not isinstance(obj, 
+#                    (MapperProperty, attributes.InstrumentedAttribute)) and \
+#                    hasattr(obj, '__get__')
+                    
         return not isinstance(obj, 
                     (MapperProperty, attributes.InstrumentedAttribute)) and \
-                    hasattr(obj, '__get__')
-
+                    hasattr(obj, '__get__') and not \
+                     isinstance(obj.__get__(None, obj),
+                                    attributes.QueryableAttribute)
+        
     def _should_exclude(self, name, assigned_name, local):
         """determine whether a particular property should be implicitly
         present on the class.
index 9cc03833a25dee63d01de0fea8d0699053ad00dc..50d95b77714eebffc907f59ac627fa31fe08c8a9 100644 (file)
@@ -237,13 +237,87 @@ class CompositeProperty(ColumnProperty):
     def __str__(self):
         return str(self.parent.class_.__name__) + "." + self.key
 
-class ConcreteInheritedProperty(MapperProperty):
+
+class DescriptorProperty(MapperProperty):
+    """:class:`MapperProperty` which proxies access to a 
+        user-defined descriptor."""
+
+    def set_parent(self, parent, init):
+        if self.descriptor is None:
+            desc = getattr(parent.class_, self.key, None)
+            if parent._is_userland_descriptor(desc):
+                self.descriptor = desc
+        self.parent = parent
+    
+    def instrument_class(self, mapper):
+        class_ = self.parent.class_
+        
+        from sqlalchemy.ext import hybrid
+
+        # hackety hack hack
+        class _ProxyImpl(object):
+            accepts_scalar_loader = False
+            expire_missing = True
+
+            def __init__(self, key):
+                self.key = key
+
+        if self.descriptor is None:
+            def fset(obj, value):
+                setattr(obj, self.name, value)
+            def fdel(obj):
+                delattr(obj, self.name)
+            def fget(obj):
+                return getattr(obj, self.name)
+            fget.__doc__ = self.doc
+
+            descriptor = hybrid.property_(
+                fget=fget,
+                fset=fset,
+                fdel=fdel,
+            )
+        elif isinstance(self.descriptor, property):
+            descriptor = hybrid.property_(
+                fget=self.descriptor.fget,
+                fset=self.descriptor.fset,
+                fdel=self.descriptor.fdel,
+            )
+        else:
+            descriptor = hybrid.property_(
+                fget=self.descriptor.__get__,
+                fset=self.descriptor.__set__,
+                fdel=self.descriptor.__delete__,
+            )
+
+        proxy_attr = attributes.\
+                        proxied_attribute_factory(self.descriptor 
+                                                    or descriptor)\
+                    (self.key, self.descriptor or descriptor,
+                        lambda: self._comparator_factory(mapper))
+        def get_comparator(owner):
+            return proxy_attr
+        descriptor.expr = get_comparator
+        
+        descriptor.impl = _ProxyImpl(self.key)
+        mapper.class_manager.instrument_attribute(self.key, descriptor)
+
+    def setup(self, context, entity, path, adapter, **kwargs):
+        pass
+
+    def create_row_processor(self, selectcontext, path, mapper, row, adapter):
+        return (None, None)
+
+    def merge(self, session, source_state, source_dict, 
+                dest_state, dest_dict, load, _recursive):
+        pass
+    
+class ConcreteInheritedProperty(DescriptorProperty):
     """A 'do nothing' :class:`MapperProperty` that disables 
     an attribute on a concrete subclass that is only present
     on the inherited mapper, not the concrete classes' mapper.
-    
+
     Cases where this occurs include:
-    
+
     * When the superclass mapper is mapped against a 
       "polymorphic union", which includes all attributes from 
       all subclasses.
@@ -251,22 +325,22 @@ class ConcreteInheritedProperty(MapperProperty):
       but not on the subclass mapper.  Concrete mappers require
       that relationship() is configured explicitly on each 
       subclass. 
-    
-    """
-    
-    extension = None
 
-    def setup(self, context, entity, path, adapter, **kwargs):
-        pass
-
-    def create_row_processor(self, selectcontext, path, mapper, row, adapter):
-        return (None, None)
+    """
 
-    def merge(self, session, source_state, source_dict, dest_state,
-                dest_dict, load, _recursive):
-        pass
+    extension = None
+    
+    def _comparator_factory(self, mapper):
+        comparator_callable = None
         
-    def instrument_class(self, mapper):
+        for m in self.parent.iterate_to_root():
+            p = m._props[self.key]
+            if not isinstance(p, ConcreteInheritedProperty):
+                comparator_callable = p.comparator_factory
+                break
+        return comparator_callable
+    
+    def __init__(self):
         def warn():
             raise AttributeError("Concrete %s does not implement "
                 "attribute %r at the instance level.  Add this "
@@ -279,27 +353,13 @@ class ConcreteInheritedProperty(MapperProperty):
             def __delete__(s, obj):
                 warn()
             def __get__(s, obj, owner):
+                if obj is None:
+                    return self.descriptor
                 warn()
-
-        comparator_callable = None
-        # TODO: put this process into a deferred callable?
-        for m in self.parent.iterate_to_root():
-            p = m._props[self.key]
-            if not isinstance(p, ConcreteInheritedProperty):
-                comparator_callable = p.comparator_factory
-                break
-
-        attributes.register_descriptor(
-            mapper.class_, 
-            self.key, 
-            comparator=comparator_callable(self, mapper), 
-            parententity=mapper,
-            property_=self,
-            proxy_property=NoninheritedConcreteProp()
-            )
-
-
-class SynonymProperty(MapperProperty):
+        self.descriptor = NoninheritedConcreteProp()
+        
+        
+class SynonymProperty(DescriptorProperty):
 
     extension = None
 
@@ -313,12 +373,16 @@ class SynonymProperty(MapperProperty):
         self.doc = doc or (descriptor and descriptor.__doc__) or None
         util.set_creation_order(self)
 
-    def setup(self, context, entity, path, adapter, **kwargs):
-        pass
+    def _comparator_factory(self, mapper):
+        class_ = self.parent.class_
+        prop = getattr(class_, self.name).property
+
+        if self.comparator_factory:
+            comp = self.comparator_factory(prop, mapper)
+        else:
+            comp = prop.comparator_factory(prop, mapper)
+        return comp
 
-    def create_row_processor(self, selectcontext, path, mapper, row, adapter):
-        return (None, None)
-    
     def set_parent(self, parent, init):
         if self.descriptor is None:
             desc = getattr(parent.class_, self.key, None)
@@ -349,50 +413,8 @@ class SynonymProperty(MapperProperty):
             p._mapped_by_synonym = self.key
     
         self.parent = parent
-    
-    def instrument_class(self, mapper):
-        class_ = self.parent.class_
-
-        if self.descriptor is None:
-            class SynonymProp(object):
-                def __set__(s, obj, value):
-                    setattr(obj, self.name, value)
-                def __delete__(s, obj):
-                    delattr(obj, self.name)
-                def __get__(s, obj, owner):
-                    if obj is None:
-                        return s
-                    return getattr(obj, self.name)
-
-            self.descriptor = SynonymProp()
-
-        def comparator_callable(prop, mapper):
-            def comparator():
-                prop = getattr(self.parent.class_,
-                                        self.name).property
-                if self.comparator_factory:
-                    return self.comparator_factory(prop, mapper)
-                else:
-                    return prop.comparator_factory(prop, mapper)
-            return comparator
-
-        attributes.register_descriptor(
-            mapper.class_, 
-            self.key, 
-            comparator=comparator_callable(self, mapper), 
-            parententity=mapper,
-            property_=self,
-            proxy_property=self.descriptor,
-            doc=self.doc
-            )
-
-    def merge(self, session, source_state, source_dict, dest_state,
-                dest_dict, load, _recursive):
-        pass
         
-log.class_logger(SynonymProperty)
-
-class ComparableProperty(MapperProperty):
+class ComparableProperty(DescriptorProperty):
     """Instruments a Python property for use in query expressions."""
 
     extension = None
@@ -410,28 +432,10 @@ class ComparableProperty(MapperProperty):
                 self.descriptor = desc
         self.parent = parent
 
-    def instrument_class(self, mapper):
-        """Set up a proxy to the unmanaged descriptor."""
+    def _comparator_factory(self, mapper):
+        return self.comparator_factory(self, mapper)
 
-        attributes.register_descriptor(
-            mapper.class_, 
-            self.key, 
-            comparator=self.comparator_factory(self, mapper), 
-            parententity=mapper,
-            property_=self,
-            proxy_property=self.descriptor,
-            doc=self.doc,
-            )
-
-    def setup(self, context, entity, path, adapter, **kwargs):
-        pass
-
-    def create_row_processor(self, selectcontext, path, mapper, row, adapter):
-        return (None, None)
 
-    def merge(self, session, source_state, source_dict, 
-                dest_state, dest_dict, load, _recursive):
-        pass
 
 
 class RelationshipProperty(StrategizedProperty):
index 8d468babe32168803eed23fa579d0bccfc247903..c05dfe1e4423a2bee901893a5a54160ea33354da 100644 (file)
@@ -864,7 +864,7 @@ class Query(object):
         """apply the given filtering criterion to the query and return 
         the newly resulting ``Query``."""
 
-        clauses = [_entity_descriptor(self._joinpoint_zero(), key)[0] == value
+        clauses = [_entity_descriptor(self._joinpoint_zero(), key) == value
             for key, value in kwargs.iteritems()]
 
         return self.filter(sql.and_(*clauses))
@@ -1158,7 +1158,7 @@ class Query(object):
             if isinstance(onclause, basestring):
                 left_entity = self._joinpoint_zero()
 
-                descriptor, prop = _entity_descriptor(left_entity, onclause)
+                descriptor = _entity_descriptor(left_entity, onclause)
                 onclause = descriptor
             
             # check for q.join(Class.propname, from_joinpoint=True)
@@ -1171,7 +1171,7 @@ class Query(object):
                                     _entity_info(self._joinpoint_zero())
                 if left_mapper is left_entity:
                     left_entity = self._joinpoint_zero()
-                    descriptor, prop = _entity_descriptor(left_entity,
+                    descriptor = _entity_descriptor(left_entity,
                                                             onclause.key)
                     onclause = descriptor
 
@@ -2579,7 +2579,10 @@ class _ColumnEntity(_QueryEntity):
         if isinstance(column, basestring):
             column = sql.literal_column(column)
             self._label_name = column.name
-        elif isinstance(column, attributes.QueryableAttribute):
+        elif isinstance(column, (
+                                    attributes.QueryableAttribute,
+                                    interfaces.PropComparator
+                                )):
             self._label_name = column.key
             column = column.__clause_element__()
         else:
index c9004990ab7a07814420d99c4e0ab9bb3c8f3073..49f572572c04e23a435ca71f6f401d59096d1136 100644 (file)
@@ -359,15 +359,17 @@ class AliasedClass(object):
         
         if isinstance(attr, attributes.QueryableAttribute):
             return self.__adapt_prop(attr.property)
-            
-        if hasattr(attr, 'func_code'):
+        elif hasattr(attr, 'func_code'):
             is_method = getattr(self.__target, key, None)
             if is_method and is_method.im_self is not None:
                 return util.types.MethodType(attr.im_func, self, self)
             else:
                 return None
         elif hasattr(attr, '__get__'):
-            return attr.__get__(None, self)
+            ret = attr.__get__(None, self)
+            if isinstance(ret, PropComparator):
+                return ret.adapted(self.__adapt_element)
+            return ret
         else:
             return attr
 
@@ -536,40 +538,22 @@ def _entity_info(entity, compile=True):
     return mapper, mapper._with_polymorphic_selectable, False
 
 def _entity_descriptor(entity, key):
-    """Return attribute/property information given an entity and string name.
-
-    Returns a 2-tuple representing InstrumentedAttribute/MapperProperty.
+    """Return a class attribute given an entity and string name.
+    
+    May return :class:`.InstrumentedAttribute` or user-defined
+    attribute.
 
     """
-    if isinstance(entity, AliasedClass):
-        try:
-            desc = getattr(entity, key)
-            return desc, desc.property
-        except AttributeError:
-            raise sa_exc.InvalidRequestError(
-                        "Entity '%s' has no property '%s'" % 
-                        (entity, key)
-                    )
-            
-    elif isinstance(entity, type):
-        try:
-            desc = attributes.manager_of_class(entity)[key]
-            return desc, desc.property
-        except KeyError:
-            raise sa_exc.InvalidRequestError(
-                        "Entity '%s' has no property '%s'" % 
-                        (entity, key)
-                    )
-            
-    else:
-        try:
-            desc = entity.class_manager[key]
-            return desc, desc.property
-        except KeyError:
-            raise sa_exc.InvalidRequestError(
-                        "Entity '%s' has no property '%s'" % 
-                        (entity, key)
-                    )
+    if not isinstance(entity, (AliasedClass, type)):
+        entity = entity.class_
+        
+    try:
+        return getattr(entity, key)
+    except AttributeError:
+        raise sa_exc.InvalidRequestError(
+                    "Entity '%s' has no property '%s'" % 
+                    (entity, key)
+                )
 
 def _orm_columns(entity):
     mapper, selectable, is_aliased_class = _entity_info(entity)
index e301b8d07ec47edfbc18928a2317bd7a7573253d..c9b86e8ee7b933b765ef596766efd17d79871ff3 100644 (file)
@@ -899,9 +899,10 @@ class MapperTest(_fixtures.FixtureTest):
                 args = (UCComparator, User.uc_name)
             else:
                 args = (UCComparator,)
-
             mapper(User, users, properties=dict(
                     uc_name = sa.orm.comparable_property(*args)))
+#            import pdb
+#            pdb.set_trace()
             return User
 
         for User in (map_(True), map_(False)):
@@ -1169,7 +1170,7 @@ class DocumentTest(testing.TestBase):
                                     backref=backref('foo',doc='foo relationship')
                                 ),
             'foober':column_property(t1.c.col3, doc='alternate data col'),
-            'hoho':synonym(t1.c.col4, doc="syn of col4")
+            'hoho':synonym("col4", doc="syn of col4")
         })
         mapper(Bar, t2)
         compile_mappers()
@@ -1554,7 +1555,7 @@ class ComparatorFactoryTest(_fixtures.FixtureTest, AssertsCompiledSQL):
                     User.name == 'ed', 
                     "foobar(users.name) = foobar(:foobar_1)",
                     dialect=default.DefaultDialect())
-
+        
         self.assert_compile(
                     aliased(User).name == 'ed', 
                     "foobar(users_1.name) = foobar(:foobar_1)",