]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- any(), has(), contains(), attribute level == and != now
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 Feb 2008 01:15:43 +0000 (01:15 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 Feb 2008 01:15:43 +0000 (01:15 +0000)
work properly with self-referential relations - the clause
inside the EXISTS is aliased on the "remote" side to
distinguish it from the parent table.

CHANGES
lib/sqlalchemy/orm/properties.py
test/orm/query.py

diff --git a/CHANGES b/CHANGES
index 339becd777dab9c163493d54549f9f5e0515de26..56b3599fb2af9a0eeaac477446fffc5bc6c7ba82 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -1,7 +1,14 @@
 =======
 CHANGES
 =======
-
+0.4.4
+------
+- orm
+    - any(), has(), contains(), attribute level == and != now
+      work properly with self-referential relations - the clause
+      inside the EXISTS is aliased on the "remote" side to
+      distinguish it from the parent table.
+      
 0.4.4
 ------
 
index d08dd712471bdeffbb36c03c4e4550a6c4945218..6339ec5750ae28784c72e223e95b57115fd5f750 100644 (file)
@@ -15,7 +15,7 @@ from sqlalchemy.sql.util import ClauseAdapter, ColumnsInClause
 from sqlalchemy.sql import visitors, operators, ColumnElement
 from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency, object_mapper
 from sqlalchemy.orm import session as sessionlib
-from sqlalchemy.orm.util import CascadeOptions
+from sqlalchemy.orm.util import CascadeOptions, PropertyAliasedClauses
 from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty
 from sqlalchemy.exceptions import ArgumentError
 import weakref
@@ -265,33 +265,44 @@ class PropertyLoader(StrategizedProperty):
                     return sql.and_(*clauses)
             else:
                 return self.prop._optimized_compare(other)
+        
+        def _join_and_criterion(self, criterion=None, **kwargs):
+            if self.prop._is_self_referential():
+                pac = PropertyAliasedClauses(self.prop,
+                        self.prop.primaryjoin,
+                        self.prop.secondaryjoin)
+                j = pac.primaryjoin
+                if pac.secondaryjoin:
+                    j = j & pac.secondaryjoin
+            else:
+                j = self.prop.primaryjoin
+                if self.prop.secondaryjoin:
+                    j = j & self.prop.secondaryjoin
 
-        def any(self, criterion=None, **kwargs):
-            if not self.prop.uselist:
-                raise exceptions.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().")
-            j = self.prop.primaryjoin
-            if self.prop.secondaryjoin:
-                j = j & self.prop.secondaryjoin
             for k in kwargs:
                 crit = (getattr(self.prop.mapper.class_, k) == kwargs[k])
                 if criterion is None:
                     criterion = crit
                 else:
                     criterion = criterion & crit
+
+            if criterion and self.prop._is_self_referential():
+                criterion = pac.adapt_clause(criterion)
+            
+            return j, criterion
+            
+        def any(self, criterion=None, **kwargs):
+            if not self.prop.uselist:
+                raise exceptions.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().")
+            j, criterion = self._join_and_criterion(criterion, **kwargs)
+
             return sql.exists([1], j & criterion)
 
         def has(self, criterion=None, **kwargs):
             if self.prop.uselist:
                 raise exceptions.InvalidRequestError("'has()' not implemented for collections.  Use any().")
-            j = self.prop.primaryjoin
-            if self.prop.secondaryjoin:
-                j = j & self.prop.secondaryjoin
-            for k in kwargs:
-                crit = (getattr(self.prop.mapper.class_, k) == kwargs[k])
-                if criterion is None:
-                    criterion = crit
-                else:
-                    criterion = criterion & crit
+            j, criterion = self._join_and_criterion(criterion, **kwargs)
+
             return sql.exists([1], j & criterion)
 
         def contains(self, other):
@@ -309,11 +320,11 @@ class PropertyLoader(StrategizedProperty):
         def __ne__(self, other):
             if self.prop.uselist and not hasattr(other, '__iter__'):
                 raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object")
+            
+            criterion = sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))])
+            j, criterion = self._join_and_criterion(criterion)
 
-            j = self.prop.primaryjoin
-            if self.prop.secondaryjoin:
-                j = j & self.prop.secondaryjoin
-            return ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))]))
+            return ~sql.exists([1], j & criterion)
 
     def compare(self, op, value, value_is_parent=False):
         if op == operators.eq:
index cfef709cfbba0537367314c0ee6e303b6292ee30..a2ac6cf4e09b3ead5f3480a70278761cdfdfa600 100644 (file)
@@ -1121,15 +1121,20 @@ class CustomJoinTest(QueryTest):
 
         assert [User(id=7)] == q.join(['open_orders', 'items'], aliased=True).filter(Item.id==4).join(['closed_orders', 'items'], aliased=True).filter(Item.id==3).all()
 
-class SelfReferentialJoinTest(ORMTest):
+class SelfReferentialTest(ORMTest):
+    keep_mappers = True
+    keep_data = True
+    
     def define_tables(self, metadata):
         global nodes
         nodes = Table('nodes', metadata,
             Column('id', Integer, primary_key=True),
             Column('parent_id', Integer, ForeignKey('nodes.id')),
             Column('data', String(30)))
-
-    def test_join(self):
+    
+    def insert_data(self):
+        global Node
+        
         class Node(Base):
             def append(self, node):
                 self.children.append(node)
@@ -1149,11 +1154,11 @@ class SelfReferentialJoinTest(ORMTest):
         n1.children[1].append(Node(data='n123'))
         sess.save(n1)
         sess.flush()
-        sess.clear()
+        sess.close()
+        
+    def test_join(self):
+        sess = create_session()
 
-        # TODO: the aliasing of the join in query._join_to has to limit the aliasing
-        # among local_side / remote_side (add local_side as an attribute on PropertyLoader)
-        # also implement this idea in EagerLoader
         node = sess.query(Node).join('children', aliased=True).filter_by(data='n122').first()
         assert node.data=='n12'
 
@@ -1164,6 +1169,37 @@ class SelfReferentialJoinTest(ORMTest):
             join('parent', aliased=True, from_joinpoint=True).filter_by(data='n1').first()
         assert node.data == 'n122'
 
+    def test_any(self):
+        sess = create_session()
+        
+        self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n1')).all(), [])
+        self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n12')).all(), [Node(data='n1')])
+        self.assertEquals(sess.query(Node).filter(~Node.children.any()).all(), [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),])
+
+    def test_has(self):
+        sess = create_session()
+        
+        self.assertEquals(sess.query(Node).filter(Node.parent.has(Node.data=='n12')).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')])
+        self.assertEquals(sess.query(Node).filter(Node.parent.has(Node.data=='n122')).all(), [])
+        self.assertEquals(sess.query(Node).filter(~Node.parent.has()).all(), [Node(data='n1')])
+    
+    def test_contains(self):
+        sess = create_session()
+        
+        n122 = sess.query(Node).filter(Node.data=='n122').one()
+        self.assertEquals(sess.query(Node).filter(Node.children.contains(n122)).all(), [Node(data='n12')])
+
+        n13 = sess.query(Node).filter(Node.data=='n13').one()
+        self.assertEquals(sess.query(Node).filter(Node.children.contains(n13)).all(), [Node(data='n1')])
+    
+    def test_eq_ne(self):
+        sess = create_session()
+        
+        n12 = sess.query(Node).filter(Node.data=='n12').one()
+        self.assertEquals(sess.query(Node).filter(Node.parent==n12).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')])
+        
+        self.assertEquals(sess.query(Node).filter(Node.parent != n12).all(), [Node(data='n1'), Node(data='n11'), Node(data='n12'), Node(data='n13')])
+        
 class ExternalColumnsTest(QueryTest):
     keep_mappers = False