--- /dev/null
+from sqlalchemy.sql import operators, functions
+from sqlalchemy.sql import expression as sql
+from sqlalchemy.util import Set
+import operator
+
+class UnevaluatableError(Exception):
+ pass
+
+_straight_ops = Set([getattr(operators, op) for op in [
+ 'add', 'mul', 'sub', 'div', 'mod', 'truediv', 'lt', 'le', 'ne', 'gt', 'ge', 'eq'
+]])
+
+
+_notimplemented_ops = Set([getattr(operators, op) for op in [
+ 'like_op', 'notlike_op', 'ilike_op', 'notilike_op', 'between_op', 'in_op', 'notin_op',
+ 'endswith_op', 'concat_op',
+]])
+
+class EvaluatorCompiler(object):
+ def process(self, clause):
+ meth = getattr(self, "visit_%s" % clause.__visit_name__, None)
+ if not meth:
+ raise UnevaluatableError("Cannot evaluate %s" % type(clause).__name__)
+ return meth(clause)
+
+ def visit_grouping(self, clause):
+ return self.process(clause.element)
+
+ def visit_null(self, clause):
+ return lambda obj: None
+
+ def visit_column(self, clause):
+ if 'parententity' in clause._annotations:
+ key = clause._annotations['parententity']._get_col_to_prop(clause).key
+ else:
+ key = clause.key
+ get_corresponding_attr = operator.attrgetter(key)
+ return lambda obj: get_corresponding_attr(obj)
+
+ def visit_clauselist(self, clause):
+ evaluators = map(self.process, clause.clauses)
+ if clause.operator is operators.or_:
+ def evaluate(obj):
+ has_null = False
+ for sub_evaluate in evaluators:
+ value = sub_evaluate(obj)
+ if value:
+ return True
+ has_null = has_null or value is None
+ if has_null:
+ return None
+ return False
+ if clause.operator is operators.and_:
+ def evaluate(obj):
+ for sub_evaluate in evaluators:
+ value = sub_evaluate(obj)
+ if not value:
+ if value is None:
+ return None
+ return False
+ return True
+
+ return evaluate
+
+ def visit_binary(self, clause):
+ eval_left,eval_right = map(self.process, [clause.left, clause.right])
+ operator = clause.operator
+ if operator is operators.is_:
+ def evaluate(obj):
+ return eval_left(obj) == eval_right(obj)
+ if operator is operators.isnot:
+ def evaluate(obj):
+ return eval_left(obj) != eval_right(obj)
+ elif operator in _straight_ops:
+ def evaluate(obj):
+ left_val = eval_left(obj)
+ right_val = eval_right(obj)
+ if left_val is None or right_val is None:
+ return None
+ return operator(eval_left(obj), eval_right(obj))
+ return evaluate
+
+ def visit_unary(self, clause):
+ eval_inner = self.process(clause.element)
+ if clause.operator is operators.inv:
+ def evaluate(obj):
+ value = eval_inner(obj)
+ if value is None:
+ return None
+ return not value
+ return evaluate
+ raise UnevaluatableError("Cannot evaluate %s with operator %s" % (type(clause).__name__, clause.operator))
+
+ def visit_bindparam(self, clause):
+ val = clause.value
+ return lambda obj: val
--- /dev/null
+"""Evluating SQL expressions on ORM objects"""
+import testenv; testenv.configure_for_tests()
+from testlib import sa, testing
+from testlib.sa import Table, Column, String, Integer, select
+from testlib.sa.orm import mapper, create_session
+from testlib.testing import eq_
+from orm import _base
+
+from sqlalchemy import and_, or_, not_
+from sqlalchemy.orm import evaluator
+
+compiler = evaluator.EvaluatorCompiler()
+def eval_eq(clause, testcases=None):
+ evaluator = compiler.process(clause)
+ def testeval(obj=None, expected_result=None):
+ assert evaluator(obj) == expected_result, "%s != %r for %s with %r" % (evaluator(obj), expected_result, clause, obj)
+ if testcases:
+ for an_obj,result in testcases:
+ testeval(an_obj, result)
+ return testeval
+
+class EvaluateTest(_base.MappedTest):
+ def define_tables(self, metadata):
+ Table('users', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('name', String))
+
+ def setup_classes(self):
+ class User(_base.ComparableEntity):
+ pass
+
+ @testing.resolve_artifact_names
+ def setup_mappers(self):
+ mapper(User, users)
+
+ @testing.resolve_artifact_names
+ def test_compare_to_value(self):
+ eval_eq(User.name == 'foo', testcases=[
+ (User(name='foo'), True),
+ (User(name='bar'), False),
+ (User(name=None), None),
+ ])
+
+ eval_eq(User.id < 5, testcases=[
+ (User(id=3), True),
+ (User(id=5), False),
+ (User(id=None), None),
+ ])
+
+ @testing.resolve_artifact_names
+ def test_compare_to_none(self):
+ eval_eq(User.name == None, testcases=[
+ (User(name='foo'), False),
+ (User(name=None), True),
+ ])
+
+ @testing.resolve_artifact_names
+ def test_boolean_ops(self):
+ eval_eq(and_(User.name == 'foo', User.id == 1), testcases=[
+ (User(id=1, name='foo'), True),
+ (User(id=2, name='foo'), False),
+ (User(id=1, name='bar'), False),
+ (User(id=2, name='bar'), False),
+ (User(id=1, name=None), None),
+ ])
+
+ eval_eq(or_(User.name == 'foo', User.id == 1), testcases=[
+ (User(id=1, name='foo'), True),
+ (User(id=2, name='foo'), True),
+ (User(id=1, name='bar'), True),
+ (User(id=2, name='bar'), False),
+ (User(id=1, name=None), True),
+ (User(id=2, name=None), None),
+ ])
+
+ eval_eq(not_(User.id == 1), testcases=[
+ (User(id=1), False),
+ (User(id=2), True),
+ (User(id=None), None),
+ ])
+
+ @testing.resolve_artifact_names
+ def test_null_propagation(self):
+ eval_eq((User.name == 'foo') == (User.id == 1), testcases=[
+ (User(id=1, name='foo'), True),
+ (User(id=2, name='foo'), False),
+ (User(id=1, name='bar'), False),
+ (User(id=2, name='bar'), True),
+ (User(id=None, name='foo'), None),
+ (User(id=None, name=None), None),
+ ])
+
+if __name__ == '__main__':
+ testenv.main()