]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implemented many-to-one comparisons to None generate <column> IS NULL, with column...
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 12 Dec 2007 17:56:52 +0000 (17:56 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 12 Dec 2007 17:56:52 +0000 (17:56 +0000)
CHANGES
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/strategies.py
test/orm/query.py
test/testlib/fixtures.py

diff --git a/CHANGES b/CHANGES
index 38dec541b1fe0b49019fbb7abb579dabfda7f442..dcc6cbdb2fcd213b5f9392c6e3fc9bba5399301f 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -79,6 +79,10 @@ CHANGES
      statements as well.  Filter criterion, order bys, eager load
      clauses will be "aliased" against the given statement.
 
+   - query.filter(SomeClass.somechild == None), when comparing
+     a many-to-one property to None, properly generates "id IS NULL"
+     including that the NULL is on the right side.
+     
    - eagerload(), lazyload(), eagerload_all() take an optional 
      second class-or-mapper argument, which will select the mapper
      to apply the option towards.  This can select among other
index 4d41556a07527de4174017c62818a7b35a055386..b6d6cef638a6ffa3262b7addf8c84c9da1f103d1 100644 (file)
@@ -232,7 +232,10 @@ class PropertyLoader(StrategizedProperty):
     class Comparator(PropComparator):
         def __eq__(self, other):
             if other is None:
-                return ~sql.exists([1], self.prop.primaryjoin)
+                if self.prop.uselist:
+                    return ~sql.exists([1], self.prop.primaryjoin)
+                else:
+                    return self.prop._optimized_compare(None)
             elif self.prop.uselist:
                 if not hasattr(other, '__iter__'):
                     raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object.  Use contains().")
index d9390345e9a97b3c0325ee77966c7786ea0bd949..2ba9d6be1e2a83df56f7c7365caf964d2a6c124c 100644 (file)
@@ -8,7 +8,7 @@
 
 from sqlalchemy import sql, util, exceptions, logging
 from sqlalchemy.sql import util as sql_util
-from sqlalchemy.sql import visitors
+from sqlalchemy.sql import visitors, expression, operators
 from sqlalchemy.orm import mapper, attributes
 from sqlalchemy.orm.interfaces import LoaderStrategy, StrategizedOption, MapperOption, PropertyOption
 from sqlalchemy.orm import session as sessionlib
@@ -292,6 +292,9 @@ class LazyLoader(AbstractRelationLoader):
         self._register_attribute(self.parent.class_, callable_=lambda i: self.setup_loader(i))
 
     def lazy_clause(self, instance, reverse_direction=False):
+        if instance is None:
+            return self.lazy_none_clause(reverse_direction)
+            
         if not reverse_direction:
             (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.lazyreverse)
         else:
@@ -305,6 +308,26 @@ class LazyLoader(AbstractRelationLoader):
                 bindparam.value = mapper._get_committed_attr_by_column(instance, bind_to_col[bindparam.key])
         return visitors.traverse(criterion, clone=True, visit_bindparam=visit_bindparam)
     
+    def lazy_none_clause(self, reverse_direction=False):
+        if not reverse_direction:
+            (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.lazyreverse)
+        else:
+            (criterion, lazybinds, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
+        bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
+
+        def visit_binary(binary):
+            mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent
+            if isinstance(binary.left, expression._BindParamClause) and binary.left.key in bind_to_col:
+                # reverse order if the NULL is on the left side
+                binary.left = binary.right
+                binary.right = expression.null()
+                binary.operator = operators.is_
+            elif isinstance(binary.right, expression._BindParamClause) and binary.right.key in bind_to_col:
+                binary.right = expression.null()
+                binary.operator = operators.is_
+        
+        return visitors.traverse(criterion, clone=True, visit_binary=visit_binary)
+        
     def setup_loader(self, instance, options=None, path=None):
         if not mapper.has_mapper(instance):
             return None
index 5f85151f0aab95ccd793720110087c4ced640fc7..efd890e0e83cbf70e8d1db8a4593d77a61c02300 100644 (file)
@@ -376,6 +376,11 @@ class FilterTest(QueryTest):
 
         assert [Address(id=1), Address(id=5)] == sess.query(Address).filter(Address.user!=user).all()
 
+        # generates an IS NULL
+        assert [] == sess.query(Address).filter(Address.user == None).all()
+
+        assert [Order(id=5)] == sess.query(Order).filter(Order.address == None).all()
+
 class AggregateTest(QueryTest):
     def test_sum(self):
         sess = create_session()
index 4394780bb20f3cecde9f93fa06fa6b6b3233ca3b..2a4b457acd6fe8fc3953acf579af94cf0823d1eb 100644 (file)
@@ -150,7 +150,7 @@ def install_fixture_data():
         dict(id = 2, user_id = 9, description = 'order 2', isopen=0, address_id=4),
         dict(id = 3, user_id = 7, description = 'order 3', isopen=1, address_id=1),
         dict(id = 4, user_id = 9, description = 'order 4', isopen=1, address_id=4),
-        dict(id = 5, user_id = 7, description = 'order 5', isopen=0, address_id=1)
+        dict(id = 5, user_id = 7, description = 'order 5', isopen=0, address_id=None)
     )
     items.insert().execute(
         dict(id=1, description='item 1'),