from .. import util
from ..sql import and_
from ..sql import operators
+from ..sql.sqltypes import Integer
+from ..sql.sqltypes import Numeric
class UnevaluatableError(exc.InvalidRequestError):
dispatch = f"visit_{clause.operator.__name__.rstrip('_')}_binary_op"
meth = getattr(self, dispatch, None)
if meth:
- return meth(clause.operator, eval_left, eval_right)
+ return meth(clause.operator, eval_left, eval_right, clause)
else:
raise UnevaluatableError(
f"Cannot evaluate {type(clause).__name__} with "
return evaluate
- def visit_custom_op_binary_op(self, operator, eval_left, eval_right):
+ def visit_custom_op_binary_op(
+ self, operator, eval_left, eval_right, clause
+ ):
if operator.python_impl:
- return self._straight_evaluate(operator, eval_left, eval_right)
+ return self._straight_evaluate(
+ operator, eval_left, eval_right, clause
+ )
else:
raise UnevaluatableError(
f"Custom operator {operator.opstring!r} can't be evaluated "
"`.python_impl`."
)
- def visit_is_binary_op(self, operator, eval_left, eval_right):
+ def visit_is_binary_op(self, operator, eval_left, eval_right, clause):
def evaluate(obj):
return eval_left(obj) == eval_right(obj)
return evaluate
- def visit_is_not_binary_op(self, operator, eval_left, eval_right):
+ def visit_is_not_binary_op(self, operator, eval_left, eval_right, clause):
def evaluate(obj):
return eval_left(obj) != eval_right(obj)
return evaluate
- def _straight_evaluate(self, operator, eval_left, eval_right):
+ def _straight_evaluate(self, operator, eval_left, eval_right, clause):
def evaluate(obj):
left_val = eval_left(obj)
right_val = eval_right(obj)
return evaluate
- visit_add_binary_op = _straight_evaluate
- visit_mul_binary_op = _straight_evaluate
- visit_sub_binary_op = _straight_evaluate
- visit_mod_binary_op = _straight_evaluate
- visit_truediv_binary_op = _straight_evaluate
+ def _straight_evaluate_numeric_only(
+ self, operator, eval_left, eval_right, clause
+ ):
+ if clause.left.type._type_affinity not in (
+ Numeric,
+ Integer,
+ ) or clause.right.type._type_affinity not in (Numeric, Integer):
+ raise UnevaluatableError(
+ f'Cannot evaluate math operator "{operator.__name__}" for '
+ f"datatypes {clause.left.type}, {clause.right.type}"
+ )
+
+ return self._straight_evaluate(operator, eval_left, eval_right, clause)
+
+ visit_add_binary_op = _straight_evaluate_numeric_only
+ visit_mul_binary_op = _straight_evaluate_numeric_only
+ visit_sub_binary_op = _straight_evaluate_numeric_only
+ visit_mod_binary_op = _straight_evaluate_numeric_only
+ visit_truediv_binary_op = _straight_evaluate_numeric_only
visit_lt_binary_op = _straight_evaluate
visit_le_binary_op = _straight_evaluate
visit_ne_binary_op = _straight_evaluate
visit_ge_binary_op = _straight_evaluate
visit_eq_binary_op = _straight_evaluate
- def visit_in_op_binary_op(self, operator, eval_left, eval_right):
+ def visit_in_op_binary_op(self, operator, eval_left, eval_right, clause):
return self._straight_evaluate(
lambda a, b: a in b if a is not _NO_OBJECT else None,
eval_left,
eval_right,
+ clause,
)
- def visit_not_in_op_binary_op(self, operator, eval_left, eval_right):
+ def visit_not_in_op_binary_op(
+ self, operator, eval_left, eval_right, clause
+ ):
return self._straight_evaluate(
lambda a, b: a not in b if a is not _NO_OBJECT else None,
eval_left,
eval_right,
+ clause,
)
- def visit_concat_op_binary_op(self, operator, eval_left, eval_right):
+ def visit_concat_op_binary_op(
+ self, operator, eval_left, eval_right, clause
+ ):
return self._straight_evaluate(
- lambda a, b: a + b, eval_left, eval_right
+ lambda a, b: a + b, eval_left, eval_right, clause
)
- def visit_startswith_op_binary_op(self, operator, eval_left, eval_right):
+ def visit_startswith_op_binary_op(
+ self, operator, eval_left, eval_right, clause
+ ):
return self._straight_evaluate(
- lambda a, b: a.startswith(b), eval_left, eval_right
+ lambda a, b: a.startswith(b), eval_left, eval_right, clause
)
- def visit_endswith_op_binary_op(self, operator, eval_left, eval_right):
+ def visit_endswith_op_binary_op(
+ self, operator, eval_left, eval_right, clause
+ ):
return self._straight_evaluate(
- lambda a, b: a.endswith(b), eval_left, eval_right
+ lambda a, b: a.endswith(b), eval_left, eval_right, clause
)
def visit_unary(self, clause):
from sqlalchemy import ForeignKey
from sqlalchemy import inspect
from sqlalchemy import Integer
+from sqlalchemy import JSON
from sqlalchemy import not_
from sqlalchemy import or_
from sqlalchemy import String
from sqlalchemy.orm import relationship
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
+from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_warnings
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
Column("id", Integer, primary_key=True),
Column("name", String(64)),
Column("othername", String(64)),
+ Column("json", JSON),
)
@classmethod
],
)
+ @testing.combinations(
+ (lambda User: User.id + 5, "id", 10, 15, None),
+ (
+ # note this one uses concat_op, not operator.add
+ lambda User: User.name + " name",
+ "name",
+ "some value",
+ "some value name",
+ None,
+ ),
+ (
+ lambda User: User.id + "name",
+ "id",
+ 10,
+ evaluator.UnevaluatableError,
+ r"Cannot evaluate math operator \"add\" for "
+ r"datatypes INTEGER, VARCHAR",
+ ),
+ (
+ lambda User: User.json + 12,
+ "json",
+ {"foo": "bar"},
+ evaluator.UnevaluatableError,
+ r"Cannot evaluate math operator \"add\" for "
+ r"datatypes JSON, INTEGER",
+ ),
+ (
+ lambda User: User.json - 12,
+ "json",
+ {"foo": "bar"},
+ evaluator.UnevaluatableError,
+ r"Cannot evaluate math operator \"sub\" for "
+ r"datatypes JSON, INTEGER",
+ ),
+ (
+ lambda User: User.json - "foo",
+ "json",
+ {"foo": "bar"},
+ evaluator.UnevaluatableError,
+ r"Cannot evaluate math operator \"sub\" for "
+ r"datatypes JSON, VARCHAR",
+ ),
+ )
+ def test_math_op_type_exclusions(
+ self, expr, attrname, initial_value, expected, message
+ ):
+ """test #8507"""
+
+ User = self.classes.User
+
+ expr = testing.resolve_lambda(expr, User=User)
+
+ if expected is evaluator.UnevaluatableError:
+ with expect_raises_message(evaluator.UnevaluatableError, message):
+ compiler.process(expr)
+ else:
+ obj = User(**{attrname: initial_value})
+
+ new_value = compiler.process(expr)(obj)
+ eq_(new_value, expected)
+
class M2OEvaluateTest(fixtures.DeclarativeMappedTest):
@classmethod