From: Mike Bayer Date: Wed, 18 Jul 2007 20:07:25 +0000 (+0000) Subject: partial progress on adding prop.compare(), new behavior for prop == X-Git-Tag: rel_0_4_6~81 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2a492db506b29dd2dd989b83b2324b65004d5a4f;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git partial progress on adding prop.compare(), new behavior for prop == --- diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 0c727760c4..dfc025bf27 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -14,7 +14,7 @@ import weakref PASSIVE_NORESULT = object() ATTR_WAS_SET = object() -class InstrumentedAttribute(sql.Comparator): +class InstrumentedAttribute(interfaces.PropComparator): """attribute access for instrumented classes.""" def __init__(self, class_, manager, key, callable_, trackparent=False, extension=None, compare_function=None, mutable_scalars=False, comparator=None, **kwargs): diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index c9c2a45b51..f353575d90 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -339,10 +339,18 @@ class MapperProperty(object): class PropComparator(sql.Comparator): """defines comparison operations for MapperProperty objects""" + + def contains_op(a, b): + return a.contains(b) + contains_op = staticmethod(contains_op) def __init__(self, prop): self.prop = prop + def contains(self, other): + """return true if this collection contains other""" + return self.operate(PropComparator.contains_op, other) + class StrategizedProperty(MapperProperty): """A MapperProperty which uses selectable strategies to affect loading behavior. diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 9ab3cce229..7a3da1fdd1 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -165,10 +165,38 @@ class PropertyLoader(StrategizedProperty): def __eq__(self, other): if other is None: return ~sql.exists([1], self.prop.primaryjoin) - else: + elif self.prop.uselist: + if not hasattr(other, '__iter__'): + raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object") + else: + j = self.prop.primaryjoin + if self.prop.secondaryjoin: + j = j & self.prop.secondaryjoin + clauses = [] + for o in other: + clauses.append( + sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(o))])) + ) + return sql.and_(*clauses) + else: return self.prop._optimized_compare(other) + def contains(self, other): + if not self.prop.uselist: + raise exceptions.InvalidRequestError("'contains' not implemented for scalar attributes") + clause = self.prop._optimized_compare(other) + + j = self.prop.primaryjoin + if self.prop.secondaryjoin: + j = j & self.prop.secondaryjoin + + clause.negation_clause = ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))])) + return clause + def __ne__(self, other): + if self.prop.uselist and not hasattr(other, '__iter__'): + raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object") + j = self.prop.primaryjoin if self.prop.secondaryjoin: j = j & self.prop.secondaryjoin diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 1129629ca6..8780ec5222 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -368,7 +368,7 @@ def and_(*clauses): """ if len(clauses) == 1: return clauses[0] - return ClauseList(operator='AND', *clauses) + return ClauseList(operator='AND', negate='OR', *clauses) def or_(*clauses): """Join a list of clauses together using the ``OR`` operator. @@ -379,7 +379,7 @@ def or_(*clauses): if len(clauses) == 1: return clauses[0] - return ClauseList(operator='OR', *clauses) + return ClauseList(operator='OR', negate='AND', *clauses) def not_(clause): """Return a negation of the given clause, i.e. ``NOT(clause)``. @@ -1131,7 +1131,10 @@ class ClauseElement(object): return self._negate() def _negate(self): - return _UnaryExpression(self.self_group(against="NOT"), operator="NOT", negate=None) + if hasattr(self, 'negation_clause'): + return self.negation_clause + else: + return _UnaryExpression(self.self_group(against="NOT"), operator="NOT", negate=None) class Comparator(object): @@ -1907,6 +1910,7 @@ class ClauseList(ClauseElement): self.operator = kwargs.pop('operator', ',') self.group = kwargs.pop('group', True) self.group_contents = kwargs.pop('group_contents', True) + self.negate_operator = kwargs.pop('negate', None) for c in clauses: if c is None: continue @@ -1928,6 +1932,14 @@ class ClauseList(ClauseElement): def _copy_internals(self): self.clauses = [clause._clone() for clause in self.clauses] + def _negate(self): + if hasattr(self, 'negation_clause'): + return self.negation_clause + elif self.negate_operator is None: + return super(ClauseList, self).negate() + else: + return ClauseList(operator=self.negate_operator, negate=self.operator, *(not_(c) for c in self.clauses)) + def get_children(self, **kwargs): return self.clauses diff --git a/test/orm/query.py b/test/orm/query.py index 6956685c1f..ba2a768291 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -204,18 +204,30 @@ class FilterTest(QueryTest): sess = create_session() address = sess.query(Address).get(3) - assert [User(id=8)] == sess.query(User).filter(User.addresses==address).all() + assert [User(id=8)] == sess.query(User).filter(User.addresses.contains(address)).all() + + try: + sess.query(User).filter(User.addresses == address) + assert False + except exceptions.InvalidRequestError: + assert True assert [User(id=10)] == sess.query(User).filter(User.addresses==None).all() - assert [User(id=7), User(id=9), User(id=10)] == sess.query(User).filter(User.addresses!=address).all() + try: + assert [User(id=7), User(id=9), User(id=10)] == sess.query(User).filter(User.addresses!=address).all() + assert False + except exceptions.InvalidRequestError: + assert True + + #assert [User(id=7), User(id=9), User(id=10)] == sess.query(User).filter(User.addresses!=address).all() def test_contains_m2m(self): sess = create_session() item = sess.query(Item).get(3) - assert [Order(id=1), Order(id=2), Order(id=3)] == sess.query(Order).filter(Order.items==item).all() + assert [Order(id=1), Order(id=2), Order(id=3)] == sess.query(Order).filter(Order.items.contains(item)).all() - assert [Order(id=4), Order(id=5)] == sess.query(Order).filter(Order.items!=item).all() + assert [Order(id=4), Order(id=5)] == sess.query(Order).filter(~Order.items.contains(item)).all() def test_has(self): """test scalar comparison to an object instance"""