]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fixed negated self-referential m2m contains(), [ticket:987]
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 4 Mar 2008 19:26:29 +0000 (19:26 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 4 Mar 2008 19:26:29 +0000 (19:26 +0000)
CHANGES
lib/sqlalchemy/orm/properties.py
test/orm/query.py

diff --git a/CHANGES b/CHANGES
index e53812101dfbcc507a83de20ae5ff22fad54162f..93dfd62b29fafdc46ce690583e0dba8d16f99617 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -24,10 +24,10 @@ CHANGES
       [ticket:986]
       
 - orm
-    - any(), has(), contains(), attribute level == and != now
-      work properly with self-referential relations - the clause
-      inside the EXISTS is aliased on the "remote" side to
-      distinguish it from the parent table.
+    - any(), has(), contains(), ~contains(), attribute level ==
+      and != now work properly with self-referential relations -
+      the clause inside the EXISTS is aliased on the "remote"
+      side to distinguish it from the parent table.
     
     - repaired behavior of == and != operators at the relation()
       level when compared against NULL for one-to-one 
index ee35b236b9d72456c47cd7e69c8f8b89b7d94e3e..ba5ef7d3649572542cabc8c53d5edbb93c0448ba 100644 (file)
@@ -334,12 +334,15 @@ class PropertyLoader(StrategizedProperty):
             clause = self.prop._optimized_compare(other)
 
             if self.prop.secondaryjoin:
-                j = self.prop.primaryjoin
-                j = j & self.prop.secondaryjoin
-                clause.negation_clause = ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))]))
+                clause.negation_clause = self._negated_contains_or_equals(other)
 
             return clause
 
+        def _negated_contains_or_equals(self, other):
+            criterion = sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))])
+            j, criterion, from_obj = self._join_and_criterion(criterion)
+            return ~sql.exists([1], j & criterion, from_obj=from_obj)
+            
         def __ne__(self, other):
             if other is None:
                 if self.prop.direction == sync.MANYTOONE:
@@ -351,11 +354,8 @@ class PropertyLoader(StrategizedProperty):
 
             if self.prop.uselist and not hasattr(other, '__iter__'):
                 raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object")
-            
-            criterion = sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))])
-            j, criterion, from_obj = self._join_and_criterion(criterion)
 
-            return ~sql.exists([1], j & criterion, from_obj=from_obj)
+            return self._negated_contains_or_equals(other)
 
     def compare(self, op, value, value_is_parent=False):
         if op == operators.eq:
index 835e0c976363f78e972ccaa20b8fe717b1f23267..2915c0f53da599b440c0c45c2b716b1ea966e82b 100644 (file)
@@ -1232,6 +1232,64 @@ class SelfReferentialTest(ORMTest):
         self.assertEquals(sess.query(Node).filter(Node.parent==n12).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')])
         
         self.assertEquals(sess.query(Node).filter(Node.parent != n12).all(), [Node(data='n1'), Node(data='n11'), Node(data='n12'), Node(data='n13')])
+
+class SelfReferentialM2MTest(ORMTest):
+    keep_mappers = True
+    keep_data = True
+    
+    def define_tables(self, metadata):
+        global nodes, node_to_nodes
+        nodes = Table('nodes', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('data', String(30)))
+            
+        node_to_nodes =Table('node_to_nodes', metadata,
+            Column('left_node_id', Integer, ForeignKey('nodes.id'),primary_key=True),
+            Column('right_node_id', Integer, ForeignKey('nodes.id'),primary_key=True),
+            )
+    
+    def insert_data(self):
+        global Node
+        
+        class Node(Base):
+            pass
+
+        mapper(Node, nodes, properties={
+            'children':relation(Node, lazy=True, secondary=node_to_nodes,
+                primaryjoin=nodes.c.id==node_to_nodes.c.left_node_id,
+                secondaryjoin=nodes.c.id==node_to_nodes.c.right_node_id,
+            )
+        })
+        sess = create_session()
+        n1 = Node(data='n1')
+        n2 = Node(data='n2')
+        n3 = Node(data='n3')
+        n4 = Node(data='n4')
+        n5 = Node(data='n5')
+        n6 = Node(data='n6')
+        n7 = Node(data='n7')
+        
+        n1.children = [n2, n3, n4]
+        n2.children = [n3, n6, n7]
+        n3.children = [n5, n4]
+
+        sess.save(n1)
+        sess.save(n2)
+        sess.save(n3)
+        sess.save(n4)
+        sess.flush()
+        sess.close()
+
+    def test_any(self):
+        sess = create_session()
+        self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n3')).all(), [Node(data='n1'), Node(data='n2')])
+
+    def test_contains(self):
+        sess = create_session()
+        n4 = sess.query(Node).filter_by(data='n4').one()
+
+        self.assertEquals(sess.query(Node).filter(Node.children.contains(n4)).order_by(Node.data).all(), [Node(data='n1'), Node(data='n3')])
+        self.assertEquals(sess.query(Node).filter(not_(Node.children.contains(n4))).order_by(Node.data).all(), [Node(data='n2'), Node(data='n4'), Node(data='n5'), Node(data='n6'), Node(data='n7')])
         
 class ExternalColumnsTest(QueryTest):
     keep_mappers = False