]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
partial progress on adding prop.compare(), new behavior for prop ==
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 18 Jul 2007 20:07:25 +0000 (20:07 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 18 Jul 2007 20:07:25 +0000 (20:07 +0000)
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/sql.py
test/orm/query.py

index 0c727760c483790d38c31f22afe45b5ccee5fd1d..dfc025bf273570f672222883279eeebd1b24e42b 100644 (file)
@@ -14,7 +14,7 @@ import weakref
 PASSIVE_NORESULT = object()
 ATTR_WAS_SET = object()
 
-class InstrumentedAttribute(sql.Comparator):
+class InstrumentedAttribute(interfaces.PropComparator):
     """attribute access for instrumented classes."""
     
     def __init__(self, class_, manager, key, callable_, trackparent=False, extension=None, compare_function=None, mutable_scalars=False, comparator=None, **kwargs):
index c9c2a45b519f344b9599f1b45b1fb48679def639..f353575d90088a050784b6ebb020c6b8405213e6 100644 (file)
@@ -339,10 +339,18 @@ class MapperProperty(object):
 
 class PropComparator(sql.Comparator):
     """defines comparison operations for MapperProperty objects"""
+
+    def contains_op(a, b):
+        return a.contains(b)
+    contains_op = staticmethod(contains_op)
     
     def __init__(self, prop):
         self.prop = prop
 
+    def contains(self, other):
+        """return true if this collection contains other"""
+        return self.operate(PropComparator.contains_op, other)
+        
 class StrategizedProperty(MapperProperty):
     """A MapperProperty which uses selectable strategies to affect
     loading behavior.
index 9ab3cce2296196c79970f6586721e19507b4358e..7a3da1fdd15f68d5db76b206fe4c1741950337b4 100644 (file)
@@ -165,10 +165,38 @@ class PropertyLoader(StrategizedProperty):
         def __eq__(self, other):
             if other is None:
                 return ~sql.exists([1], self.prop.primaryjoin)
-            else:
+            elif self.prop.uselist:
+                if not hasattr(other, '__iter__'):
+                    raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object")
+                else:
+                    j = self.prop.primaryjoin
+                    if self.prop.secondaryjoin:
+                        j = j & self.prop.secondaryjoin
+                    clauses = []
+                    for o in other:
+                        clauses.append(
+                            sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(o))]))
+                        )
+                    return sql.and_(*clauses)
+            else:  
                 return self.prop._optimized_compare(other)
         
+        def contains(self, other):
+            if not self.prop.uselist:
+                raise exceptions.InvalidRequestError("'contains' not implemented for scalar attributes")
+            clause = self.prop._optimized_compare(other)
+
+            j = self.prop.primaryjoin
+            if self.prop.secondaryjoin:
+                j = j & self.prop.secondaryjoin
+
+            clause.negation_clause = ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))]))
+            return clause
+
         def __ne__(self, other):
+            if self.prop.uselist and not hasattr(other, '__iter__'):
+                raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object")
+                
             j = self.prop.primaryjoin
             if self.prop.secondaryjoin:
                 j = j & self.prop.secondaryjoin
index 1129629ca63098c1124c9799a85fda40c4331f60..8780ec522296dd547092f56da9d599708540ea5f 100644 (file)
@@ -368,7 +368,7 @@ def and_(*clauses):
     """
     if len(clauses) == 1:
         return clauses[0]
-    return ClauseList(operator='AND', *clauses)
+    return ClauseList(operator='AND', negate='OR', *clauses)
 
 def or_(*clauses):
     """Join a list of clauses together using the ``OR`` operator.
@@ -379,7 +379,7 @@ def or_(*clauses):
 
     if len(clauses) == 1:
         return clauses[0]
-    return ClauseList(operator='OR', *clauses)
+    return ClauseList(operator='OR', negate='AND', *clauses)
 
 def not_(clause):
     """Return a negation of the given clause, i.e. ``NOT(clause)``.
@@ -1131,7 +1131,10 @@ class ClauseElement(object):
         return self._negate()
 
     def _negate(self):
-        return _UnaryExpression(self.self_group(against="NOT"), operator="NOT", negate=None)
+        if hasattr(self, 'negation_clause'):
+            return self.negation_clause
+        else:
+            return _UnaryExpression(self.self_group(against="NOT"), operator="NOT", negate=None)
 
 
 class Comparator(object):
@@ -1907,6 +1910,7 @@ class ClauseList(ClauseElement):
         self.operator = kwargs.pop('operator', ',')
         self.group = kwargs.pop('group', True)
         self.group_contents = kwargs.pop('group_contents', True)
+        self.negate_operator = kwargs.pop('negate', None)
         for c in clauses:
             if c is None: 
                 continue
@@ -1928,6 +1932,14 @@ class ClauseList(ClauseElement):
     def _copy_internals(self):
         self.clauses = [clause._clone() for clause in self.clauses]
 
+    def _negate(self):
+        if hasattr(self, 'negation_clause'):
+            return self.negation_clause
+        elif self.negate_operator is None:
+            return super(ClauseList, self).negate()
+        else:
+            return ClauseList(operator=self.negate_operator, negate=self.operator, *(not_(c) for c in self.clauses))
+
     def get_children(self, **kwargs):
         return self.clauses
 
index 6956685c1f8be33f5dede45093bde62174b1bc76..ba2a768291fa52e19d283edee73feee33f1551f4 100644 (file)
@@ -204,18 +204,30 @@ class FilterTest(QueryTest):
         
         sess = create_session()
         address = sess.query(Address).get(3)
-        assert [User(id=8)] == sess.query(User).filter(User.addresses==address).all()
+        assert [User(id=8)] == sess.query(User).filter(User.addresses.contains(address)).all()
+
+        try:
+            sess.query(User).filter(User.addresses == address)
+            assert False
+        except exceptions.InvalidRequestError:
+            assert True
 
         assert [User(id=10)] == sess.query(User).filter(User.addresses==None).all()
 
-        assert [User(id=7), User(id=9), User(id=10)] == sess.query(User).filter(User.addresses!=address).all()
+        try:
+            assert [User(id=7), User(id=9), User(id=10)] == sess.query(User).filter(User.addresses!=address).all()
+            assert False
+        except exceptions.InvalidRequestError:
+            assert True
+            
+        #assert [User(id=7), User(id=9), User(id=10)] == sess.query(User).filter(User.addresses!=address).all()
         
     def test_contains_m2m(self):
         sess = create_session()
         item = sess.query(Item).get(3)
-        assert [Order(id=1), Order(id=2), Order(id=3)] == sess.query(Order).filter(Order.items==item).all()
+        assert [Order(id=1), Order(id=2), Order(id=3)] == sess.query(Order).filter(Order.items.contains(item)).all()
 
-        assert [Order(id=4), Order(id=5)] == sess.query(Order).filter(Order.items!=item).all()
+        assert [Order(id=4), Order(id=5)] == sess.query(Order).filter(~Order.items.contains(item)).all()
 
     def test_has(self):
         """test scalar comparison to an object instance"""