From 77c308367ffec3e8af9b5463b1c3bdd89640e8ac Mon Sep 17 00:00:00 2001 From: Ants Aasma Date: Thu, 29 May 2008 02:11:49 +0000 Subject: [PATCH] Preliminary implementation for the evaluation framework --- lib/sqlalchemy/orm/evaluator.py | 96 +++++++++++++++++++++++++++++++++ test/orm/evaluator.py | 94 ++++++++++++++++++++++++++++++++ 2 files changed, 190 insertions(+) create mode 100644 lib/sqlalchemy/orm/evaluator.py create mode 100644 test/orm/evaluator.py diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py new file mode 100644 index 0000000000..c4517e4947 --- /dev/null +++ b/lib/sqlalchemy/orm/evaluator.py @@ -0,0 +1,96 @@ +from sqlalchemy.sql import operators, functions +from sqlalchemy.sql import expression as sql +from sqlalchemy.util import Set +import operator + +class UnevaluatableError(Exception): + pass + +_straight_ops = Set([getattr(operators, op) for op in [ + 'add', 'mul', 'sub', 'div', 'mod', 'truediv', 'lt', 'le', 'ne', 'gt', 'ge', 'eq' +]]) + + +_notimplemented_ops = Set([getattr(operators, op) for op in [ + 'like_op', 'notlike_op', 'ilike_op', 'notilike_op', 'between_op', 'in_op', 'notin_op', + 'endswith_op', 'concat_op', +]]) + +class EvaluatorCompiler(object): + def process(self, clause): + meth = getattr(self, "visit_%s" % clause.__visit_name__, None) + if not meth: + raise UnevaluatableError("Cannot evaluate %s" % type(clause).__name__) + return meth(clause) + + def visit_grouping(self, clause): + return self.process(clause.element) + + def visit_null(self, clause): + return lambda obj: None + + def visit_column(self, clause): + if 'parententity' in clause._annotations: + key = clause._annotations['parententity']._get_col_to_prop(clause).key + else: + key = clause.key + get_corresponding_attr = operator.attrgetter(key) + return lambda obj: get_corresponding_attr(obj) + + def visit_clauselist(self, clause): + evaluators = map(self.process, clause.clauses) + if clause.operator is operators.or_: + def evaluate(obj): + has_null = False + for sub_evaluate in evaluators: + value = sub_evaluate(obj) + if value: + return True + has_null = has_null or value is None + if has_null: + return None + return False + if clause.operator is operators.and_: + def evaluate(obj): + for sub_evaluate in evaluators: + value = sub_evaluate(obj) + if not value: + if value is None: + return None + return False + return True + + return evaluate + + def visit_binary(self, clause): + eval_left,eval_right = map(self.process, [clause.left, clause.right]) + operator = clause.operator + if operator is operators.is_: + def evaluate(obj): + return eval_left(obj) == eval_right(obj) + if operator is operators.isnot: + def evaluate(obj): + return eval_left(obj) != eval_right(obj) + elif operator in _straight_ops: + def evaluate(obj): + left_val = eval_left(obj) + right_val = eval_right(obj) + if left_val is None or right_val is None: + return None + return operator(eval_left(obj), eval_right(obj)) + return evaluate + + def visit_unary(self, clause): + eval_inner = self.process(clause.element) + if clause.operator is operators.inv: + def evaluate(obj): + value = eval_inner(obj) + if value is None: + return None + return not value + return evaluate + raise UnevaluatableError("Cannot evaluate %s with operator %s" % (type(clause).__name__, clause.operator)) + + def visit_bindparam(self, clause): + val = clause.value + return lambda obj: val diff --git a/test/orm/evaluator.py b/test/orm/evaluator.py new file mode 100644 index 0000000000..5a1cec7673 --- /dev/null +++ b/test/orm/evaluator.py @@ -0,0 +1,94 @@ +"""Evluating SQL expressions on ORM objects""" +import testenv; testenv.configure_for_tests() +from testlib import sa, testing +from testlib.sa import Table, Column, String, Integer, select +from testlib.sa.orm import mapper, create_session +from testlib.testing import eq_ +from orm import _base + +from sqlalchemy import and_, or_, not_ +from sqlalchemy.orm import evaluator + +compiler = evaluator.EvaluatorCompiler() +def eval_eq(clause, testcases=None): + evaluator = compiler.process(clause) + def testeval(obj=None, expected_result=None): + assert evaluator(obj) == expected_result, "%s != %r for %s with %r" % (evaluator(obj), expected_result, clause, obj) + if testcases: + for an_obj,result in testcases: + testeval(an_obj, result) + return testeval + +class EvaluateTest(_base.MappedTest): + def define_tables(self, metadata): + Table('users', metadata, + Column('id', Integer, primary_key=True), + Column('name', String)) + + def setup_classes(self): + class User(_base.ComparableEntity): + pass + + @testing.resolve_artifact_names + def setup_mappers(self): + mapper(User, users) + + @testing.resolve_artifact_names + def test_compare_to_value(self): + eval_eq(User.name == 'foo', testcases=[ + (User(name='foo'), True), + (User(name='bar'), False), + (User(name=None), None), + ]) + + eval_eq(User.id < 5, testcases=[ + (User(id=3), True), + (User(id=5), False), + (User(id=None), None), + ]) + + @testing.resolve_artifact_names + def test_compare_to_none(self): + eval_eq(User.name == None, testcases=[ + (User(name='foo'), False), + (User(name=None), True), + ]) + + @testing.resolve_artifact_names + def test_boolean_ops(self): + eval_eq(and_(User.name == 'foo', User.id == 1), testcases=[ + (User(id=1, name='foo'), True), + (User(id=2, name='foo'), False), + (User(id=1, name='bar'), False), + (User(id=2, name='bar'), False), + (User(id=1, name=None), None), + ]) + + eval_eq(or_(User.name == 'foo', User.id == 1), testcases=[ + (User(id=1, name='foo'), True), + (User(id=2, name='foo'), True), + (User(id=1, name='bar'), True), + (User(id=2, name='bar'), False), + (User(id=1, name=None), True), + (User(id=2, name=None), None), + ]) + + eval_eq(not_(User.id == 1), testcases=[ + (User(id=1), False), + (User(id=2), True), + (User(id=None), None), + ]) + + @testing.resolve_artifact_names + def test_null_propagation(self): + eval_eq((User.name == 'foo') == (User.id == 1), testcases=[ + (User(id=1, name='foo'), True), + (User(id=2, name='foo'), False), + (User(id=1, name='bar'), False), + (User(id=2, name='bar'), True), + (User(id=None, name='foo'), None), + (User(id=None, name=None), None), + ]) + +if __name__ == '__main__': + testenv.main() -- 2.47.3