from .. import util
from ..sql import and_
from ..sql import operators
+from ..sql.sqltypes import Integer
+from ..sql.sqltypes import Numeric
class UnevaluatableError(Exception):
_straight_ops = set(
getattr(operators, op)
for op in (
- "add",
- "mul",
- "sub",
- "div",
- "mod",
- "truediv",
"lt",
"le",
"ne",
)
)
+_math_only_straight_ops = set(
+ getattr(operators, op)
+ for op in (
+ "add",
+ "mul",
+ "sub",
+ "div",
+ "mod",
+ "truediv",
+ )
+)
+
_extended_ops = {
operators.in_op: (lambda a, b: a in b if a is not _NO_OBJECT else None),
operators.not_in_op: (
"startswith_op",
"between_op",
"endswith_op",
- "concat_op",
)
)
def evaluate(obj):
return eval_left(obj) != eval_right(obj)
+ elif operator is operators.concat_op:
+
+ def evaluate(obj):
+ return eval_left(obj) + eval_right(obj)
+
elif operator in _extended_ops:
def evaluate(obj):
return _extended_ops[operator](left_val, right_val)
+ elif operator in _math_only_straight_ops:
+ if (
+ clause.left.type._type_affinity
+ not in (
+ Numeric,
+ Integer,
+ )
+ or clause.right.type._type_affinity not in (Numeric, Integer)
+ ):
+ raise UnevaluatableError(
+ 'Cannot evaluate math operator "%s" for '
+ "datatypes %s, %s"
+ % (operator.__name__, clause.left.type, clause.right.type)
+ )
+
+ def evaluate(obj):
+ left_val = eval_left(obj)
+ right_val = eval_right(obj)
+ if left_val is None or right_val is None:
+ return None
+ return operator(eval_left(obj), eval_right(obj))
+
elif operator in _straight_ops:
def evaluate(obj):
from sqlalchemy import ForeignKey
from sqlalchemy import inspect
from sqlalchemy import Integer
+from sqlalchemy import JSON
from sqlalchemy import not_
from sqlalchemy import or_
from sqlalchemy import String
+from sqlalchemy import testing
from sqlalchemy import tuple_
from sqlalchemy.orm import evaluator
from sqlalchemy.orm import exc as orm_exc
from sqlalchemy.orm import relationship
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import expect_warnings
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
Column("id", Integer, primary_key=True),
Column("name", String(64)),
Column("othername", String(64)),
+ Column("json", JSON),
)
@classmethod
],
)
+ @testing.combinations(
+ lambda User: User.name + "_foo" == "named_foo",
+ # not implemented in 1.4
+ # lambda User: User.name.startswith("nam"),
+ # lambda User: User.name.endswith("named"),
+ )
+ def test_string_ops(self, expr):
+ User = self.classes.User
+
+ test_expr = testing.resolve_lambda(expr, User=User)
+ eval_eq(
+ test_expr,
+ testcases=[
+ (User(name="named"), True),
+ (User(name="othername"), False),
+ ],
+ )
+
def test_in(self):
User = self.classes.User
],
)
+ @testing.combinations(
+ (lambda User: User.id + 5, "id", 10, 15, None),
+ (
+ lambda User: User.name + " name",
+ "name",
+ "some value",
+ "some value name",
+ None,
+ ),
+ (
+ lambda User: User.id + "name",
+ "id",
+ 10,
+ evaluator.UnevaluatableError,
+ r"Cannot evaluate math operator \"add\" for "
+ r"datatypes INTEGER, VARCHAR",
+ ),
+ (
+ lambda User: User.json + 12,
+ "json",
+ {"foo": "bar"},
+ evaluator.UnevaluatableError,
+ r"Cannot evaluate math operator \"add\" for "
+ r"datatypes JSON, INTEGER",
+ ),
+ (
+ lambda User: User.json - 12,
+ "json",
+ {"foo": "bar"},
+ evaluator.UnevaluatableError,
+ r"Cannot evaluate math operator \"sub\" for "
+ r"datatypes JSON, INTEGER",
+ ),
+ (
+ lambda User: User.json - "foo",
+ "json",
+ {"foo": "bar"},
+ evaluator.UnevaluatableError,
+ r"Cannot evaluate math operator \"sub\" for "
+ r"datatypes JSON, VARCHAR",
+ ),
+ )
+ def test_math_op_type_exclusions(
+ self, expr, attrname, initial_value, expected, message
+ ):
+ """test #8507"""
+
+ User = self.classes.User
+
+ expr = testing.resolve_lambda(expr, User=User)
+
+ if expected is evaluator.UnevaluatableError:
+ with expect_raises_message(evaluator.UnevaluatableError, message):
+ compiler.process(expr)
+ else:
+ obj = User(**{attrname: initial_value})
+
+ new_value = compiler.process(expr)(obj)
+ eq_(new_value, expected)
+
class M2OEvaluateTest(fixtures.DeclarativeMappedTest):
@classmethod