]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added operator support to class-instrumented attributes. you can now
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 17 Jul 2007 04:25:09 +0000 (04:25 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 17 Jul 2007 04:25:09 +0000 (04:25 +0000)
filter() (or whatever) using <classname>.<attributename>==<whatever>.
for column based properties, all column operators work (i.e. ==, <, >,
like(), in_(), etc.).  For relation() and composite column properties,
==<instance>, !=<instance>, and ==<None> are implemented so far.
[ticket:643]

CHANGES
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql.py
test/orm/mapper.py
test/orm/query.py

diff --git a/CHANGES b/CHANGES
index d777d45dfcc352f057ce28c3bd196327fb45a127..e2d6010bf3abb70aae47ca71296c9a73bca1c9bd 100644 (file)
--- a/CHANGES
+++ b/CHANGES
           querying divergent criteria. ClauseElements at the front of
           filter_by() are removed (use filter()).
 
+    - added operator support to class-instrumented attributes. you can now
+      filter() (or whatever) using <classname>.<attributename>==<whatever>.
+      for column based properties, all column operators work (i.e. ==, <, >,
+      like(), in_(), etc.).  For relation() and composite column properties,
+      ==<instance>, !=<instance>, and ==<None> are implemented so far.
+      [ticket:643]
+      
     - added composite column properties. using the composite(cls, *columns)
       function inside of the "properties" dict, instances of cls will be
       created/mapped to a single attribute, comprised of the values
index 84d464e2047e66ad001212f7e6e883bdfa0b2ccf..ad9675f029be44d493fda6c4423746a2b9ab5cc4 100644 (file)
@@ -85,10 +85,10 @@ class InstrumentedAttribute(sql.Comparator):
         return self.comparator.compare_self()
         
     def operate(self, op, other):
-        return self.comparator.operate(op, other)
+        return op(self.comparator, other)
 
     def reverse_operate(self, op, other):
-        return self.comparator.reverse_operate(op, other)
+        return op(other, self.comparator)
         
     def hasparent(self, item, optimistic=False):
         """Return the boolean value of a `hasparent` flag attached to the given item.
index b0a2399f3b48adda49712154e26ec421a1cac5f4..c9c2a45b519f344b9599f1b45b1fb48679def639 100644 (file)
@@ -325,14 +325,17 @@ class MapperProperty(object):
 
         raise NotImplementedError()
 
-    def compare(self, value):
+    def compare(self, operator, value):
         """Return a compare operation for the columns represented by
         this ``MapperProperty`` to the given value, which may be a
-        column value or an instance.
+        column value or an instance.  'operator' is an operator from
+        the operators module, or from sql.Comparator.
+        
+        By default uses the PropComparator attached to this MapperProperty
+        under the attribute name "comparator".
         """
 
-        raise NotImplementedError()
-
+        return operator(self.comparator, value)
 
 class PropComparator(sql.Comparator):
     """defines comparison operations for MapperProperty objects"""
index a32354fcad96926b61d53b444766615b265d88ee..8a57b4a83ead56fa8ffc2e1acd3e39e9135695fe 100644 (file)
@@ -15,7 +15,7 @@ from sqlalchemy import sql, schema, util, exceptions, sql_util, logging
 from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency
 from sqlalchemy.orm import session as sessionlib
 from sqlalchemy.orm import util as mapperutil
-import sets, random
+import operator
 from sqlalchemy.orm.interfaces import *
 
 __all__ = ['ColumnProperty', 'CompositeProperty', 'PropertyLoader', 'BackRef']
@@ -56,9 +56,6 @@ class ColumnProperty(StrategizedProperty):
     def merge(self, session, source, dest, _recursive):
         setattr(dest, self.key, getattr(source, self.key, None))
 
-    def compare(self, value, op='=='):
-        return self.comparator == value
-
     def get_col_value(self, column, value):
         return value
 
@@ -84,7 +81,7 @@ class CompositeProperty(ColumnProperty):
     def __init__(self, class_, *columns, **kwargs):
         super(CompositeProperty, self).__init__(*columns, **kwargs)
         self.composite_class = class_
-        self.comparator = None
+        self.comparator = CompositeProperty.Comparator(self)
         
     def copy(self):
         return CompositeProperty(deferred=self.deferred, group=self.group, composite_class=self.composite_class, *self.columns)
@@ -101,19 +98,21 @@ class CompositeProperty(ColumnProperty):
             if a is column:
                 setattr(obj, b, value)
 
-    def compare(self, value, op='=='):
-        # TODO: build into operator framework
-        if op == '==':
-            return sql.and_([a==b for a, b in zip(self.columns, value.__colset__())])
-        elif op == '!=':
-            return sql.or_([a!=b for a, b in zip(self.columns, value.__colset__())])
-
     def get_col_value(self, column, value):
         for a, b in zip(self.columns, value.__colset__()):
             if a is column:
                 return b
 
-        
+    class Comparator(PropComparator):
+        def __eq__(self, other):
+            if other is None:
+                return sql.and_(*[a==None for a in self.prop.columns])
+            else:
+                return sql.and_(*[a==b for a, b in zip(self.prop.columns, other.__colset__())])
+
+        def __ne__(self, other):
+            return sql.or_(*[a!=b for a, b in zip(self.prop.columns, other.__colset__())])
+
 class PropertyLoader(StrategizedProperty):
     """Describes an object property that holds a single item or list
     of items that correspond to a related database table.
@@ -137,7 +136,7 @@ class PropertyLoader(StrategizedProperty):
         self.remote_side = util.to_set(remote_side)
         self.enable_typechecks = enable_typechecks
         self._parent_join_cache = {}
-        self.comparator = None
+        self.comparator = PropertyLoader.Comparator(self)
 
         if cascade is not None:
             self.cascade = mapperutil.CascadeOptions(cascade)
@@ -162,22 +161,40 @@ class PropertyLoader(StrategizedProperty):
             self.backref = backref
         self.is_backref = is_backref
 
-    def compare(self, value, value_is_parent=False, op='=='):
-        if op == '==':
-            # optimized operation for ==, uses a lazy clause.
-            (criterion, lazybinds, rev) = strategies.LazyLoader._create_lazy_clause(self, reverse_direction=not value_is_parent)
-            bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
-
-            class Visitor(sql.ClauseVisitor):
-                def visit_bindparam(s, bindparam):
-                    mapper = value_is_parent and self.parent or self.mapper
-                    bindparam.value = mapper.get_attr_by_column(value, bind_to_col[bindparam.key])
-            Visitor().traverse(criterion)
-            return criterion
+    class Comparator(PropComparator):
+        def __eq__(self, other):
+            if other is None:
+                return ~sql.exists([1], self.prop.primaryjoin)
+            else:
+                return self.prop._optimized_compare(other)
+        
+        def __ne__(self, other):
+            j = self.prop.primaryjoin
+            if self.prop.secondaryjoin:
+                j = j & self.prop.secondaryjoin
+            return ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))]))
+            
+    def compare(self, op, value, value_is_parent=False):
+        if op == operator.eq:
+            if value is None:
+                return ~sql.exists([1], self.prop.mapper.mapped_table, self.prop.primaryjoin)
+            else:
+                return self._optimized_compare(value, value_is_parent=value_is_parent)
         else:
-            # TODO: build expressions like these into operator framework
-            return sql.and_(*[x==y for (x, y) in zip(self.mapper.primary_key, self.mapper.primary_key_from_instance(value))])
-
+            return op(self.comparator, value)
+    
+    def _optimized_compare(self, value, value_is_parent=False):
+        # optimized operation for ==, uses a lazy clause.
+        (criterion, lazybinds, rev) = strategies.LazyLoader._create_lazy_clause(self, reverse_direction=not value_is_parent)
+        bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
+
+        class Visitor(sql.ClauseVisitor):
+            def visit_bindparam(s, bindparam):
+                mapper = value_is_parent and self.parent or self.mapper
+                bindparam.value = mapper.get_attr_by_column(value, bind_to_col[bindparam.key])
+        Visitor().traverse(criterion)
+        return criterion
+        
     private = property(lambda s:s.cascade.delete_orphan)
 
     def create_strategy(self):
index 3937149ee594a26eac97b65b15f054e7731cef37..0537ee258e3959b4b7f5d141418ff4f9d5b22c49 100644 (file)
@@ -7,6 +7,7 @@
 from sqlalchemy import sql, util, exceptions, sql_util, logging, schema
 from sqlalchemy.orm import mapper, class_mapper, object_mapper
 from sqlalchemy.orm.interfaces import OperationContext
+import operator
 
 __all__ = ['Query', 'QueryContext', 'SelectionContext']
 
@@ -120,7 +121,7 @@ class Query(object):
         mapper = object_mapper(instance)
         prop = mapper.get_property(property, resolve_synonyms=True)
         target = prop.mapper
-        criterion = prop.compare(instance, value_is_parent=True)
+        criterion = prop.compare(operator.eq, instance, value_is_parent=True)
         return Query(target, **kwargs).filter(criterion)
     query_from_parent = classmethod(query_from_parent)
         
@@ -149,7 +150,7 @@ class Query(object):
                 raise exceptions.InvalidRequestError("Could not locate a property which relates instances of class '%s' to instances of class '%s'" % (self.mapper.class_.__name__, instance.__class__.__name__))
         else:
             prop = mapper.get_property(property, resolve_synonyms=True)
-        return self.filter(prop.compare(instance, value_is_parent=True))
+        return self.filter(prop.compare(operator.eq, instance, value_is_parent=True))
 
     def add_entity(self, entity):
         """add a mapped entity to the list of result columns to be returned.
@@ -265,7 +266,7 @@ class Query(object):
 
         for key, value in kwargs.iteritems():
             prop = joinpoint.get_property(key, resolve_synonyms=True)
-            c = prop.compare(value)
+            c = prop.compare(operator.eq, value)
 
             if alias is not None:
                 sql_util.ClauseAdapter(alias).traverse(c)
@@ -1011,9 +1012,9 @@ class Query(object):
         for key, value in params.iteritems():
             (keys, prop) = self._locate_prop(key, start=start)
             if isinstance(prop, properties.PropertyLoader):
-                c = prop.compare(value) & self.join_via(keys[:-1])
+                c = prop.compare(operator.eq, value) & self.join_via(keys[:-1])
             else:
-                c = prop.compare(value) & self.join_via(keys)
+                c = prop.compare(operator.eq, value) & self.join_via(keys)
             if clause is None:
                 clause =  c
             else:
index e177f4194a2f8f11c605c6189ce933119f8b96ae..b6a843685caed6fa8ac46622f17f1f19af845b15 100644 (file)
@@ -870,8 +870,8 @@ class ClauseVisitor(object):
     """A class that knows how to traverse and visit
     ``ClauseElements``.
     
-    Each ``ClauseElement``'s accept_visitor() method will call a
-    corresponding visit_XXXX() method here. Traversal of a
+    Calls visit_XXX() methods dynamically generated for each particualr
+    ``ClauseElement`` subclass encountered.  Traversal of a
     hierarchy of ``ClauseElements`` is achieved via the
     ``traverse()`` method, which is passed the lead
     ``ClauseElement``.
@@ -885,6 +885,11 @@ class ClauseVisitor(object):
     (column_collections=False) or to return Schema-level items
     (schema_visitor=True).
     
+    ``ClauseVisitor`` also supports a simultaneous copy-and-traverse
+    operation, which will produce a copy of a given ``ClauseElement``
+    structure while at the same time allowing ``ClauseVisitor`` subclasses
+    to modify the new structure in-place.
+    
     """
     __traverse_options__ = {}
     
index eb0d110a16118400a840bb40b3c2ee3f703b20d5..e6c03161c2cf2ad409b159671049d0b7a4c5dab1 100644 (file)
@@ -844,12 +844,15 @@ class CompositeTypesTest(ORMTest):
             
         edges = Table('edges', metadata, 
             Column('id', Integer, primary_key=True),
-            Column('graph_id', Integer, ForeignKey('graphs.id'), nullable=False),
+            Column('graph_id', Integer, nullable=False),
+            Column('graph_version_id', Integer, nullable=False),
             Column('x1', Integer),
             Column('y1', Integer),
             Column('x2', Integer),
-            Column('y2', Integer))
-        
+            Column('y2', Integer),
+            ForeignKeyConstraint(['graph_id', 'graph_version_id'], ['graphs.id', 'graphs.version_id'])
+            )
+
     def test_basic(self):
         class Point(object):
             def __init__(self, x, y):
@@ -914,6 +917,15 @@ class CompositeTypesTest(ORMTest):
                 assert e1.end == e2.end
         self.assert_sql_count(testbase.db, go, 1)
         
+        # test comparison of CompositeProperties to their object instances
+        g = sess.query(Graph).get([1, 1])
+        assert sess.query(Edge).filter(Edge.start==Point(3, 4)).one() is g.edges[0]
+        
+        assert sess.query(Edge).filter(Edge.start!=Point(3, 4)).first() is g.edges[1]
+
+        assert sess.query(Edge).filter(Edge.start==None).all() == []
+        
+        
     def test_pk(self):
         """test using a composite type as a primary key"""
         
index c5c48f4bb463becbb14c059baf609d84a6aa0234..75885fb8ce09a8e6314c953beb2d2fa4791884af 100644 (file)
@@ -60,7 +60,7 @@ class QueryTest(testbase.ORMTest):
         
     def setup_mappers(self):
         mapper(User, users, properties={
-            'addresses':relation(Address),
+            'addresses':relation(Address, backref='user'),
             'orders':relation(Order, backref='user'), # o2m, m2o
         })
         mapper(Address, addresses)
@@ -196,7 +196,34 @@ class FilterTest(QueryTest):
     def test_onefilter(self):
         assert [User(id=8), User(id=9)] == create_session().query(User).filter(User.name.endswith('ed')).all()
 
+    def test_contains(self):
+        """test comparing a collection to an object instance."""
+        
+        sess = create_session()
+        address = sess.query(Address).get(3)
+        assert [User(id=8)] == sess.query(User).filter(User.addresses==address).all()
+
+        assert [User(id=10)] == sess.query(User).filter(User.addresses==None).all()
+
+        assert [User(id=7), User(id=9), User(id=10)] == sess.query(User).filter(User.addresses!=address).all()
+        
+    def test_contains_m2m(self):
+        sess = create_session()
+        item = sess.query(Item).get(3)
+        assert [Order(id=1), Order(id=2), Order(id=3)] == sess.query(Order).filter(Order.items==item).all()
 
+        assert [Order(id=4), Order(id=5)] == sess.query(Order).filter(Order.items!=item).all()
+
+    def test_has(self):
+        """test scalar comparison to an object instance"""
+        
+        sess = create_session()
+        user = sess.query(User).get(8)
+        assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).filter(Address.user==user).all()
+
+        assert [Address(id=1), Address(id=5)] == sess.query(Address).filter(Address.user!=user).all()
+
+        
 class CountTest(QueryTest):
     def test_basic(self):
         assert 4 == create_session().query(User).count()