from sqlalchemy.sql import visitors, operators, ColumnElement
from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency, object_mapper
from sqlalchemy.orm import session as sessionlib
-from sqlalchemy.orm.util import CascadeOptions
+from sqlalchemy.orm.util import CascadeOptions, PropertyAliasedClauses
from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty
from sqlalchemy.exceptions import ArgumentError
import weakref
return sql.and_(*clauses)
else:
return self.prop._optimized_compare(other)
+
+ def _join_and_criterion(self, criterion=None, **kwargs):
+ if self.prop._is_self_referential():
+ pac = PropertyAliasedClauses(self.prop,
+ self.prop.primaryjoin,
+ self.prop.secondaryjoin)
+ j = pac.primaryjoin
+ if pac.secondaryjoin:
+ j = j & pac.secondaryjoin
+ else:
+ j = self.prop.primaryjoin
+ if self.prop.secondaryjoin:
+ j = j & self.prop.secondaryjoin
- def any(self, criterion=None, **kwargs):
- if not self.prop.uselist:
- raise exceptions.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().")
- j = self.prop.primaryjoin
- if self.prop.secondaryjoin:
- j = j & self.prop.secondaryjoin
for k in kwargs:
crit = (getattr(self.prop.mapper.class_, k) == kwargs[k])
if criterion is None:
criterion = crit
else:
criterion = criterion & crit
+
+ if criterion and self.prop._is_self_referential():
+ criterion = pac.adapt_clause(criterion)
+
+ return j, criterion
+
+ def any(self, criterion=None, **kwargs):
+ if not self.prop.uselist:
+ raise exceptions.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().")
+ j, criterion = self._join_and_criterion(criterion, **kwargs)
+
return sql.exists([1], j & criterion)
def has(self, criterion=None, **kwargs):
if self.prop.uselist:
raise exceptions.InvalidRequestError("'has()' not implemented for collections. Use any().")
- j = self.prop.primaryjoin
- if self.prop.secondaryjoin:
- j = j & self.prop.secondaryjoin
- for k in kwargs:
- crit = (getattr(self.prop.mapper.class_, k) == kwargs[k])
- if criterion is None:
- criterion = crit
- else:
- criterion = criterion & crit
+ j, criterion = self._join_and_criterion(criterion, **kwargs)
+
return sql.exists([1], j & criterion)
def contains(self, other):
def __ne__(self, other):
if self.prop.uselist and not hasattr(other, '__iter__'):
raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object")
+
+ criterion = sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))])
+ j, criterion = self._join_and_criterion(criterion)
- j = self.prop.primaryjoin
- if self.prop.secondaryjoin:
- j = j & self.prop.secondaryjoin
- return ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))]))
+ return ~sql.exists([1], j & criterion)
def compare(self, op, value, value_is_parent=False):
if op == operators.eq:
assert [User(id=7)] == q.join(['open_orders', 'items'], aliased=True).filter(Item.id==4).join(['closed_orders', 'items'], aliased=True).filter(Item.id==3).all()
-class SelfReferentialJoinTest(ORMTest):
+class SelfReferentialTest(ORMTest):
+ keep_mappers = True
+ keep_data = True
+
def define_tables(self, metadata):
global nodes
nodes = Table('nodes', metadata,
Column('id', Integer, primary_key=True),
Column('parent_id', Integer, ForeignKey('nodes.id')),
Column('data', String(30)))
-
- def test_join(self):
+
+ def insert_data(self):
+ global Node
+
class Node(Base):
def append(self, node):
self.children.append(node)
n1.children[1].append(Node(data='n123'))
sess.save(n1)
sess.flush()
- sess.clear()
+ sess.close()
+
+ def test_join(self):
+ sess = create_session()
- # TODO: the aliasing of the join in query._join_to has to limit the aliasing
- # among local_side / remote_side (add local_side as an attribute on PropertyLoader)
- # also implement this idea in EagerLoader
node = sess.query(Node).join('children', aliased=True).filter_by(data='n122').first()
assert node.data=='n12'
join('parent', aliased=True, from_joinpoint=True).filter_by(data='n1').first()
assert node.data == 'n122'
+ def test_any(self):
+ sess = create_session()
+
+ self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n1')).all(), [])
+ self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n12')).all(), [Node(data='n1')])
+ self.assertEquals(sess.query(Node).filter(~Node.children.any()).all(), [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),])
+
+ def test_has(self):
+ sess = create_session()
+
+ self.assertEquals(sess.query(Node).filter(Node.parent.has(Node.data=='n12')).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')])
+ self.assertEquals(sess.query(Node).filter(Node.parent.has(Node.data=='n122')).all(), [])
+ self.assertEquals(sess.query(Node).filter(~Node.parent.has()).all(), [Node(data='n1')])
+
+ def test_contains(self):
+ sess = create_session()
+
+ n122 = sess.query(Node).filter(Node.data=='n122').one()
+ self.assertEquals(sess.query(Node).filter(Node.children.contains(n122)).all(), [Node(data='n12')])
+
+ n13 = sess.query(Node).filter(Node.data=='n13').one()
+ self.assertEquals(sess.query(Node).filter(Node.children.contains(n13)).all(), [Node(data='n1')])
+
+ def test_eq_ne(self):
+ sess = create_session()
+
+ n12 = sess.query(Node).filter(Node.data=='n12').one()
+ self.assertEquals(sess.query(Node).filter(Node.parent==n12).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')])
+
+ self.assertEquals(sess.query(Node).filter(Node.parent != n12).all(), [Node(data='n1'), Node(data='n11'), Node(data='n12'), Node(data='n13')])
+
class ExternalColumnsTest(QueryTest):
keep_mappers = False