PASSIVE_NORESULT = object()
ATTR_WAS_SET = object()
-class InstrumentedAttribute(sql.Comparator):
+class InstrumentedAttribute(interfaces.PropComparator):
"""attribute access for instrumented classes."""
def __init__(self, class_, manager, key, callable_, trackparent=False, extension=None, compare_function=None, mutable_scalars=False, comparator=None, **kwargs):
class PropComparator(sql.Comparator):
"""defines comparison operations for MapperProperty objects"""
+
+ def contains_op(a, b):
+ return a.contains(b)
+ contains_op = staticmethod(contains_op)
def __init__(self, prop):
self.prop = prop
+ def contains(self, other):
+ """return true if this collection contains other"""
+ return self.operate(PropComparator.contains_op, other)
+
class StrategizedProperty(MapperProperty):
"""A MapperProperty which uses selectable strategies to affect
loading behavior.
def __eq__(self, other):
if other is None:
return ~sql.exists([1], self.prop.primaryjoin)
- else:
+ elif self.prop.uselist:
+ if not hasattr(other, '__iter__'):
+ raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object")
+ else:
+ j = self.prop.primaryjoin
+ if self.prop.secondaryjoin:
+ j = j & self.prop.secondaryjoin
+ clauses = []
+ for o in other:
+ clauses.append(
+ sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(o))]))
+ )
+ return sql.and_(*clauses)
+ else:
return self.prop._optimized_compare(other)
+ def contains(self, other):
+ if not self.prop.uselist:
+ raise exceptions.InvalidRequestError("'contains' not implemented for scalar attributes")
+ clause = self.prop._optimized_compare(other)
+
+ j = self.prop.primaryjoin
+ if self.prop.secondaryjoin:
+ j = j & self.prop.secondaryjoin
+
+ clause.negation_clause = ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))]))
+ return clause
+
def __ne__(self, other):
+ if self.prop.uselist and not hasattr(other, '__iter__'):
+ raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object")
+
j = self.prop.primaryjoin
if self.prop.secondaryjoin:
j = j & self.prop.secondaryjoin
"""
if len(clauses) == 1:
return clauses[0]
- return ClauseList(operator='AND', *clauses)
+ return ClauseList(operator='AND', negate='OR', *clauses)
def or_(*clauses):
"""Join a list of clauses together using the ``OR`` operator.
if len(clauses) == 1:
return clauses[0]
- return ClauseList(operator='OR', *clauses)
+ return ClauseList(operator='OR', negate='AND', *clauses)
def not_(clause):
"""Return a negation of the given clause, i.e. ``NOT(clause)``.
return self._negate()
def _negate(self):
- return _UnaryExpression(self.self_group(against="NOT"), operator="NOT", negate=None)
+ if hasattr(self, 'negation_clause'):
+ return self.negation_clause
+ else:
+ return _UnaryExpression(self.self_group(against="NOT"), operator="NOT", negate=None)
class Comparator(object):
self.operator = kwargs.pop('operator', ',')
self.group = kwargs.pop('group', True)
self.group_contents = kwargs.pop('group_contents', True)
+ self.negate_operator = kwargs.pop('negate', None)
for c in clauses:
if c is None:
continue
def _copy_internals(self):
self.clauses = [clause._clone() for clause in self.clauses]
+ def _negate(self):
+ if hasattr(self, 'negation_clause'):
+ return self.negation_clause
+ elif self.negate_operator is None:
+ return super(ClauseList, self).negate()
+ else:
+ return ClauseList(operator=self.negate_operator, negate=self.operator, *(not_(c) for c in self.clauses))
+
def get_children(self, **kwargs):
return self.clauses
sess = create_session()
address = sess.query(Address).get(3)
- assert [User(id=8)] == sess.query(User).filter(User.addresses==address).all()
+ assert [User(id=8)] == sess.query(User).filter(User.addresses.contains(address)).all()
+
+ try:
+ sess.query(User).filter(User.addresses == address)
+ assert False
+ except exceptions.InvalidRequestError:
+ assert True
assert [User(id=10)] == sess.query(User).filter(User.addresses==None).all()
- assert [User(id=7), User(id=9), User(id=10)] == sess.query(User).filter(User.addresses!=address).all()
+ try:
+ assert [User(id=7), User(id=9), User(id=10)] == sess.query(User).filter(User.addresses!=address).all()
+ assert False
+ except exceptions.InvalidRequestError:
+ assert True
+
+ #assert [User(id=7), User(id=9), User(id=10)] == sess.query(User).filter(User.addresses!=address).all()
def test_contains_m2m(self):
sess = create_session()
item = sess.query(Item).get(3)
- assert [Order(id=1), Order(id=2), Order(id=3)] == sess.query(Order).filter(Order.items==item).all()
+ assert [Order(id=1), Order(id=2), Order(id=3)] == sess.query(Order).filter(Order.items.contains(item)).all()
- assert [Order(id=4), Order(id=5)] == sess.query(Order).filter(Order.items!=item).all()
+ assert [Order(id=4), Order(id=5)] == sess.query(Order).filter(~Order.items.contains(item)).all()
def test_has(self):
"""test scalar comparison to an object instance"""