]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Custom comparator classes used in conjunction with
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 2 Jan 2009 18:22:50 +0000 (18:22 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 2 Jan 2009 18:22:50 +0000 (18:22 +0000)
column_property(), relation() etc. can define
new comparison methods on the Comparator, which will
become available via __getattr__() on the
InstrumentedAttribute.   In the case of synonym()
or comparable_property(), attributes are resolved first
on the user-defined descriptor, then on the user-defined
comparator.

CHANGES
examples/postgis/postgis.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/properties.py
test/orm/mapper.py

diff --git a/CHANGES b/CHANGES
index 76a12e05fa6f2f7bb838fb04a5aa8c7aef91e06a..f2e007d0aee4de3470603f181526451980098b91 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -71,6 +71,15 @@ CHANGES
       next compile() call.  This issue occurs frequently
       when using declarative.
 
+    - Custom comparator classes used in conjunction with 
+      column_property(), relation() etc. can define 
+      new comparison methods on the Comparator, which will
+      become available via __getattr__() on the 
+      InstrumentedAttribute.   In the case of synonym()
+      or comparable_property(), attributes are resolved first
+      on the user-defined descriptor, then on the user-defined
+      comparator.
+      
     - Added ScopedSession.is_active accessor. [ticket:976]
     
     - Can pass mapped attributes and column objects as keys
index 841bce31cddd8b86e311fdb4d11fda5c9650bb27..c463cca26eb07937da95a82972e8570a6945d116 100644 (file)
@@ -123,21 +123,14 @@ class GisComparator(ColumnProperty.ColumnComparator):
     """Intercepts standard Column operators on mapped class attributes
     and overrides their behavior.
     
-    The PropComparator API currently does not allow "custom"
-    operators to be added, so only those operators which
-    already exist on Column can be overridden here.  Additional
-    GIS-specific operators can be implemented as standalone 
-    functions.
     
     """
     
     def __eq__(self, other):
         return self.__clause_element__().op('~=')(_to_postgis(other))
 
-def intersects(x, y):
-    """An example standalone GIS-specific comparison operator."""
-    
-    return _to_postgis(x).op('&&')(_to_postgis(y))
+    def intersects(self, other):
+        return self.__clause_element__().op('&&')(_to_postgis(other))
     
 class gis_element(object):
     """Represents a geometry value.
@@ -219,7 +212,7 @@ if __name__ == '__main__':
     assert r1 is r2 is r3
 
     # illustrate the "intersects" operator
-    print session.query(Road).filter(intersects(Road.road_geom, r1.road_geom)).all()
+    print session.query(Road).filter(Road.road_geom.intersects(r1.road_geom)).all()
 
     # illustrate usage of the "wkt" accessor. this requires a DB
     # execution to call the AsText() function so we keep this explicit.
index df607adf877ad22c5f54962c9c41b454b19c0607..2b2760208ac2f20390971a9bd257925ecd1d166c 100644 (file)
@@ -131,7 +131,17 @@ class QueryableAttribute(interfaces.PropComparator):
 
     def hasparent(self, state, optimistic=False):
         return self.impl.hasparent(state, optimistic=optimistic)
-
+    
+    def __getattr__(self, key):
+        try:
+            return getattr(self.comparator, key)
+        except AttributeError:
+            raise AttributeError('Neither %r object nor %r object has an attribute %r' % (
+                    type(self).__name__, 
+                    type(self.comparator).__name__, 
+                    key)
+            )
+        
     def __str__(self):
         return repr(self.parententity) + "." + self.property.key
 
@@ -195,8 +205,19 @@ def proxied_attribute_factory(descriptor):
             return descriptor.__delete__(instance)
 
         def __getattr__(self, attribute):
-            """Delegate __getattr__ to the original descriptor."""
-            return getattr(descriptor, attribute)
+            """Delegate __getattr__ to the original descriptor and/or comparator."""
+            
+            try:
+                return getattr(descriptor, attribute)
+            except AttributeError:
+                try:
+                    return getattr(self._comparator, attribute)
+                except AttributeError:
+                    raise AttributeError('Neither %r object nor %r object has an attribute %r' % (
+                            type(descriptor).__name__, 
+                            type(self._comparator).__name__, 
+                            attribute)
+                    )
 
         def _property(self):
             return self._parententity.get_property(self.key, resolve_synonyms=True)
index bf9bda366f12f6591e891efb5fc9a3b3595b375d..675b505e78d2e0c99935f7f722f3a4307717571c 100644 (file)
@@ -42,6 +42,7 @@ class ColumnProperty(StrategizedProperty):
         self.group = kwargs.pop('group', None)
         self.deferred = kwargs.pop('deferred', False)
         self.comparator_factory = kwargs.pop('comparator_factory', self.__class__.Comparator)
+        self.descriptor = kwargs.pop('descriptor', None)
         self.extension = kwargs.pop('extension', None)
         util.set_creation_order(self)
         if self.deferred:
@@ -206,6 +207,7 @@ class SynonymProperty(MapperProperty):
                     if obj is None:
                         return s
                     return getattr(obj, self.name)
+
             self.descriptor = SynonymProp()
 
         def comparator_callable(prop, mapper):
index 4e8412bd93697b963515c4ff6099428e1cd77cb9..5cab84175e13d04d33b6b8768fcc4d92c28503ae 100644 (file)
@@ -717,11 +717,22 @@ class MapperTest(_fixtures.FixtureTest):
     def test_comparable(self):
         class extendedproperty(property):
             attribute = 123
+            
+            def method1(self):
+                return "method1"
+            
             def __getitem__(self, key):
                 return 'value'
 
         class UCComparator(sa.orm.PropComparator):
             __hash__ = None
+            
+            def method1(self):
+                return "uccmethod1"
+                
+            def method2(self, other):
+                return "method2"
+                
             def __eq__(self, other):
                 cls = self.prop.parent.class_
                 col = getattr(cls, 'name')
@@ -754,6 +765,14 @@ class MapperTest(_fixtures.FixtureTest):
             assert hasattr(User, 'name')
             assert hasattr(User, 'uc_name')
 
+            eq_(User.uc_name.method1(), "method1")
+            eq_(User.uc_name.method2('x'), "method2")
+
+            self.assertRaisesMessage(
+                AttributeError, 
+                "Neither 'extendedproperty' object nor 'UCComparator' object has an attribute 'nonexistent'", 
+                getattr, User.uc_name, 'nonexistent')
+            
             # test compile
             assert not isinstance(User.uc_name == 'jack', bool)
             u = q.filter(User.uc_name=='JACK').one()
@@ -779,6 +798,30 @@ class MapperTest(_fixtures.FixtureTest):
             eq_(User.uc_name['key'], 'value')
             sess.rollback()
 
+    @testing.resolve_artifact_names
+    def test_comparable_column(self):
+        class MyComparator(sa.orm.properties.ColumnProperty.Comparator):
+            def __eq__(self, other):
+                # lower case comparison
+                return func.lower(self.__clause_element__()) == func.lower(other)
+                
+            def intersects(self, other):
+                # non-standard comparator
+                return self.__clause_element__().op('&=')(other)
+                
+        mapper(User, users, properties={
+            'name':sa.orm.column_property(users.c.name, comparator_factory=MyComparator)
+        })
+        
+        self.assertRaisesMessage(
+            AttributeError, 
+            "Neither 'InstrumentedAttribute' object nor 'MyComparator' object has an attribute 'nonexistent'", 
+            getattr, User.name, "nonexistent")
+
+        eq_(str((User.name == 'ed').compile(dialect=sa.engine.default.DefaultDialect())) , "lower(users.name) = lower(:lower_1)")
+        eq_(str((User.name.intersects('ed')).compile(dialect=sa.engine.default.DefaultDialect())), "users.name &= :name_1")
+        
+
     @testing.resolve_artifact_names
     def test_reconstructor(self):
         recon = []