]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- repaired behavior of == and != operators at the relation()
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 3 Mar 2008 17:06:27 +0000 (17:06 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 3 Mar 2008 17:06:27 +0000 (17:06 +0000)
level when compared against NULL for one-to-one and other
relations [ticket:985]

CHANGES
lib/sqlalchemy/orm/properties.py
test/orm/query.py

diff --git a/CHANGES b/CHANGES
index 4faa8ee8665f902280f5e94e7984501de730bb6f..de8c168eaaf69ddad28b65ce82a4b96ade41dca7 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -26,6 +26,10 @@ CHANGES
       inside the EXISTS is aliased on the "remote" side to
       distinguish it from the parent table.
     
+    - repaired behavior of == and != operators at the relation()
+      level when compared against NULL for one-to-one 
+      relations [ticket:985]
+      
     - fixed bug whereby session.expire() attributes were not
       loading on an polymorphically-mapped instance mapped 
       by a select_table mapper.
index 74d4c04ca45bf62d821f2069c736f8336ae1f970..ee35b236b9d72456c47cd7e69c8f8b89b7d94e3e 100644 (file)
@@ -255,7 +255,7 @@ class PropertyLoader(StrategizedProperty):
             
         def __eq__(self, other):
             if other is None:
-                if self.prop.uselist:
+                if self.prop.direction == sync.ONETOMANY:
                     return ~sql.exists([1], self.prop.primaryjoin)
                 else:
                     return self.prop._optimized_compare(None)
@@ -341,6 +341,14 @@ class PropertyLoader(StrategizedProperty):
             return clause
 
         def __ne__(self, other):
+            if other is None:
+                if self.prop.direction == sync.MANYTOONE:
+                    return sql.or_(*[x!=None for x in self.prop.foreign_keys])
+                elif self.prop.uselist:
+                    return self.any()
+                else:
+                    return self.has()
+
             if self.prop.uselist and not hasattr(other, '__iter__'):
                 raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object")
             
index 62bb99a3239a1496221edd05bd729d84b13d9887..835e0c976363f78e972ccaa20b8fe717b1f23267 100644 (file)
@@ -18,7 +18,10 @@ class QueryTest(FixtureTest):
             'addresses':relation(Address, backref='user'),
             'orders':relation(Order, backref='user'), # o2m, m2o
         })
-        mapper(Address, addresses)
+        mapper(Address, addresses, properties={
+            'dingaling':relation(Dingaling, uselist=False, backref="address")  #o2o
+        })
+        mapper(Dingaling, dingalings)
         mapper(Order, orders, properties={
             'items':relation(Item, secondary=order_items, order_by=items.c.id),  #m2m
             'address':relation(Address),  # m2o
@@ -372,6 +375,10 @@ class FilterTest(QueryTest):
 
         assert [Order(id=5)] == sess.query(Order).filter(Order.address == None).all()
 
+        # o2o
+        dingaling = sess.query(Dingaling).get(2)
+        assert [Address(id=5)] == sess.query(Address).filter(Address.dingaling==dingaling).all()
+
     def test_filter_by(self):
         sess = create_session()
         user = sess.query(User).get(8)
@@ -382,7 +389,22 @@ class FilterTest(QueryTest):
 
         # one to many generates WHERE NOT EXISTS
         assert [User(name='chuck')] == sess.query(User).filter_by(addresses = None).all()
-
+    
+    def test_none_comparison(self):
+        sess = create_session()
+        
+        # o2o
+        self.assertEquals([Address(id=1), Address(id=3), Address(id=4)], sess.query(Address).filter(Address.dingaling==None).all())
+        self.assertEquals([Address(id=2), Address(id=5)], sess.query(Address).filter(Address.dingaling != None).all())
+        
+        # m2o
+        self.assertEquals([Order(id=5)], sess.query(Order).filter(Order.address==None).all())
+        self.assertEquals([Order(id=1), Order(id=2), Order(id=3), Order(id=4)], sess.query(Order).filter(Order.address!=None).all())
+        
+        # o2m
+        self.assertEquals([User(id=10)], sess.query(User).filter(User.addresses==None).all())
+        self.assertEquals([User(id=7),User(id=8),User(id=9)], sess.query(User).filter(User.addresses!=None).all())
+        
 class AggregateTest(QueryTest):
     def test_sum(self):
         sess = create_session()