From: Mike Bayer Date: Tue, 4 Mar 2008 19:26:29 +0000 (+0000) Subject: fixed negated self-referential m2m contains(), [ticket:987] X-Git-Tag: rel_0_4_4~39 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=cef292c0429b292466a486bd80723b6728273467;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git fixed negated self-referential m2m contains(), [ticket:987] --- diff --git a/CHANGES b/CHANGES index e53812101d..93dfd62b29 100644 --- a/CHANGES +++ b/CHANGES @@ -24,10 +24,10 @@ CHANGES [ticket:986] - orm - - any(), has(), contains(), attribute level == and != now - work properly with self-referential relations - the clause - inside the EXISTS is aliased on the "remote" side to - distinguish it from the parent table. + - any(), has(), contains(), ~contains(), attribute level == + and != now work properly with self-referential relations - + the clause inside the EXISTS is aliased on the "remote" + side to distinguish it from the parent table. - repaired behavior of == and != operators at the relation() level when compared against NULL for one-to-one diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index ee35b236b9..ba5ef7d364 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -334,12 +334,15 @@ class PropertyLoader(StrategizedProperty): clause = self.prop._optimized_compare(other) if self.prop.secondaryjoin: - j = self.prop.primaryjoin - j = j & self.prop.secondaryjoin - clause.negation_clause = ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))])) + clause.negation_clause = self._negated_contains_or_equals(other) return clause + def _negated_contains_or_equals(self, other): + criterion = sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))]) + j, criterion, from_obj = self._join_and_criterion(criterion) + return ~sql.exists([1], j & criterion, from_obj=from_obj) + def __ne__(self, other): if other is None: if self.prop.direction == sync.MANYTOONE: @@ -351,11 +354,8 @@ class PropertyLoader(StrategizedProperty): if self.prop.uselist and not hasattr(other, '__iter__'): raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object") - - criterion = sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))]) - j, criterion, from_obj = self._join_and_criterion(criterion) - return ~sql.exists([1], j & criterion, from_obj=from_obj) + return self._negated_contains_or_equals(other) def compare(self, op, value, value_is_parent=False): if op == operators.eq: diff --git a/test/orm/query.py b/test/orm/query.py index 835e0c9763..2915c0f53d 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -1232,6 +1232,64 @@ class SelfReferentialTest(ORMTest): self.assertEquals(sess.query(Node).filter(Node.parent==n12).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')]) self.assertEquals(sess.query(Node).filter(Node.parent != n12).all(), [Node(data='n1'), Node(data='n11'), Node(data='n12'), Node(data='n13')]) + +class SelfReferentialM2MTest(ORMTest): + keep_mappers = True + keep_data = True + + def define_tables(self, metadata): + global nodes, node_to_nodes + nodes = Table('nodes', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30))) + + node_to_nodes =Table('node_to_nodes', metadata, + Column('left_node_id', Integer, ForeignKey('nodes.id'),primary_key=True), + Column('right_node_id', Integer, ForeignKey('nodes.id'),primary_key=True), + ) + + def insert_data(self): + global Node + + class Node(Base): + pass + + mapper(Node, nodes, properties={ + 'children':relation(Node, lazy=True, secondary=node_to_nodes, + primaryjoin=nodes.c.id==node_to_nodes.c.left_node_id, + secondaryjoin=nodes.c.id==node_to_nodes.c.right_node_id, + ) + }) + sess = create_session() + n1 = Node(data='n1') + n2 = Node(data='n2') + n3 = Node(data='n3') + n4 = Node(data='n4') + n5 = Node(data='n5') + n6 = Node(data='n6') + n7 = Node(data='n7') + + n1.children = [n2, n3, n4] + n2.children = [n3, n6, n7] + n3.children = [n5, n4] + + sess.save(n1) + sess.save(n2) + sess.save(n3) + sess.save(n4) + sess.flush() + sess.close() + + def test_any(self): + sess = create_session() + self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n3')).all(), [Node(data='n1'), Node(data='n2')]) + + def test_contains(self): + sess = create_session() + n4 = sess.query(Node).filter_by(data='n4').one() + + self.assertEquals(sess.query(Node).filter(Node.children.contains(n4)).order_by(Node.data).all(), [Node(data='n1'), Node(data='n3')]) + self.assertEquals(sess.query(Node).filter(not_(Node.children.contains(n4))).order_by(Node.data).all(), [Node(data='n2'), Node(data='n4'), Node(data='n5'), Node(data='n6'), Node(data='n7')]) class ExternalColumnsTest(QueryTest): keep_mappers = False