From: Mike Bayer Date: Sun, 17 Feb 2008 01:15:43 +0000 (+0000) Subject: - any(), has(), contains(), attribute level == and != now X-Git-Tag: rel_0_4_4~69 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a3f67fecb27363c73f833cc72cefbff5e8754598;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - any(), has(), contains(), attribute level == and != now work properly with self-referential relations - the clause inside the EXISTS is aliased on the "remote" side to distinguish it from the parent table. --- diff --git a/CHANGES b/CHANGES index 339becd777..56b3599fb2 100644 --- a/CHANGES +++ b/CHANGES @@ -1,7 +1,14 @@ ======= CHANGES ======= - +0.4.4 +------ +- orm + - any(), has(), contains(), attribute level == and != now + work properly with self-referential relations - the clause + inside the EXISTS is aliased on the "remote" side to + distinguish it from the parent table. + 0.4.4 ------ diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index d08dd71247..6339ec5750 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -15,7 +15,7 @@ from sqlalchemy.sql.util import ClauseAdapter, ColumnsInClause from sqlalchemy.sql import visitors, operators, ColumnElement from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency, object_mapper from sqlalchemy.orm import session as sessionlib -from sqlalchemy.orm.util import CascadeOptions +from sqlalchemy.orm.util import CascadeOptions, PropertyAliasedClauses from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty from sqlalchemy.exceptions import ArgumentError import weakref @@ -265,33 +265,44 @@ class PropertyLoader(StrategizedProperty): return sql.and_(*clauses) else: return self.prop._optimized_compare(other) + + def _join_and_criterion(self, criterion=None, **kwargs): + if self.prop._is_self_referential(): + pac = PropertyAliasedClauses(self.prop, + self.prop.primaryjoin, + self.prop.secondaryjoin) + j = pac.primaryjoin + if pac.secondaryjoin: + j = j & pac.secondaryjoin + else: + j = self.prop.primaryjoin + if self.prop.secondaryjoin: + j = j & self.prop.secondaryjoin - def any(self, criterion=None, **kwargs): - if not self.prop.uselist: - raise exceptions.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().") - j = self.prop.primaryjoin - if self.prop.secondaryjoin: - j = j & self.prop.secondaryjoin for k in kwargs: crit = (getattr(self.prop.mapper.class_, k) == kwargs[k]) if criterion is None: criterion = crit else: criterion = criterion & crit + + if criterion and self.prop._is_self_referential(): + criterion = pac.adapt_clause(criterion) + + return j, criterion + + def any(self, criterion=None, **kwargs): + if not self.prop.uselist: + raise exceptions.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().") + j, criterion = self._join_and_criterion(criterion, **kwargs) + return sql.exists([1], j & criterion) def has(self, criterion=None, **kwargs): if self.prop.uselist: raise exceptions.InvalidRequestError("'has()' not implemented for collections. Use any().") - j = self.prop.primaryjoin - if self.prop.secondaryjoin: - j = j & self.prop.secondaryjoin - for k in kwargs: - crit = (getattr(self.prop.mapper.class_, k) == kwargs[k]) - if criterion is None: - criterion = crit - else: - criterion = criterion & crit + j, criterion = self._join_and_criterion(criterion, **kwargs) + return sql.exists([1], j & criterion) def contains(self, other): @@ -309,11 +320,11 @@ class PropertyLoader(StrategizedProperty): def __ne__(self, other): if self.prop.uselist and not hasattr(other, '__iter__'): raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object") + + criterion = sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))]) + j, criterion = self._join_and_criterion(criterion) - j = self.prop.primaryjoin - if self.prop.secondaryjoin: - j = j & self.prop.secondaryjoin - return ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))])) + return ~sql.exists([1], j & criterion) def compare(self, op, value, value_is_parent=False): if op == operators.eq: diff --git a/test/orm/query.py b/test/orm/query.py index cfef709cfb..a2ac6cf4e0 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -1121,15 +1121,20 @@ class CustomJoinTest(QueryTest): assert [User(id=7)] == q.join(['open_orders', 'items'], aliased=True).filter(Item.id==4).join(['closed_orders', 'items'], aliased=True).filter(Item.id==3).all() -class SelfReferentialJoinTest(ORMTest): +class SelfReferentialTest(ORMTest): + keep_mappers = True + keep_data = True + def define_tables(self, metadata): global nodes nodes = Table('nodes', metadata, Column('id', Integer, primary_key=True), Column('parent_id', Integer, ForeignKey('nodes.id')), Column('data', String(30))) - - def test_join(self): + + def insert_data(self): + global Node + class Node(Base): def append(self, node): self.children.append(node) @@ -1149,11 +1154,11 @@ class SelfReferentialJoinTest(ORMTest): n1.children[1].append(Node(data='n123')) sess.save(n1) sess.flush() - sess.clear() + sess.close() + + def test_join(self): + sess = create_session() - # TODO: the aliasing of the join in query._join_to has to limit the aliasing - # among local_side / remote_side (add local_side as an attribute on PropertyLoader) - # also implement this idea in EagerLoader node = sess.query(Node).join('children', aliased=True).filter_by(data='n122').first() assert node.data=='n12' @@ -1164,6 +1169,37 @@ class SelfReferentialJoinTest(ORMTest): join('parent', aliased=True, from_joinpoint=True).filter_by(data='n1').first() assert node.data == 'n122' + def test_any(self): + sess = create_session() + + self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n1')).all(), []) + self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n12')).all(), [Node(data='n1')]) + self.assertEquals(sess.query(Node).filter(~Node.children.any()).all(), [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),]) + + def test_has(self): + sess = create_session() + + self.assertEquals(sess.query(Node).filter(Node.parent.has(Node.data=='n12')).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')]) + self.assertEquals(sess.query(Node).filter(Node.parent.has(Node.data=='n122')).all(), []) + self.assertEquals(sess.query(Node).filter(~Node.parent.has()).all(), [Node(data='n1')]) + + def test_contains(self): + sess = create_session() + + n122 = sess.query(Node).filter(Node.data=='n122').one() + self.assertEquals(sess.query(Node).filter(Node.children.contains(n122)).all(), [Node(data='n12')]) + + n13 = sess.query(Node).filter(Node.data=='n13').one() + self.assertEquals(sess.query(Node).filter(Node.children.contains(n13)).all(), [Node(data='n1')]) + + def test_eq_ne(self): + sess = create_session() + + n12 = sess.query(Node).filter(Node.data=='n12').one() + self.assertEquals(sess.query(Node).filter(Node.parent==n12).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')]) + + self.assertEquals(sess.query(Node).filter(Node.parent != n12).all(), [Node(data='n1'), Node(data='n11'), Node(data='n12'), Node(data='n13')]) + class ExternalColumnsTest(QueryTest): keep_mappers = False