--- /dev/null
+.. change::
+ :tags: feature, orm
+ :tickets: 3162
+
+ Added new parameter :paramref:`_sql.Operators.op.python_impl`, available
+ from :meth:`_sql.Operators.op` and also when using the
+ :class:`_sql.Operators.custom_op` constructor directly, which allows an
+ in-Python evaluation function to be provided along with the custom SQL
+ operator. This evaluation function becomes the implementation used when the
+ operator object is used given plain Python objects as operands on both
+ sides, and in particular is compatible with the
+ ``synchronize_session='evaluate'`` option used with
+ :ref:`orm_expression_update_delete`.
able to evaluate the expression in Python and will raise an error. If
this occurs, use the ``'fetch'`` strategy for the operation instead.
+ .. tip::
+
+ If a SQL expression makes use of custom operators using the
+ :meth:`_sql.Operators.op` or :class:`_sql.custom_op` feature, the
+ :paramref:`_sql.Operators.op.python_impl` parameter may be used to indicate
+ a Python function that will be used by the ``"evaluate"`` synchronization
+ strategy.
+
+ .. versionadded:: 2.0
+
.. warning::
The ``"evaluate"`` strategy should be avoided if an UPDATE operation is
import operator
+from .. import exc
from .. import inspect
from .. import util
from ..sql import and_
from ..sql import operators
-class UnevaluatableError(Exception):
+class UnevaluatableError(exc.InvalidRequestError):
pass
_NO_OBJECT = _NoObject()
-_straight_ops = set(
- getattr(operators, op)
- for op in (
- "add",
- "mul",
- "sub",
- "mod",
- "truediv",
- "lt",
- "le",
- "ne",
- "gt",
- "ge",
- "eq",
- )
-)
-
-_extended_ops = {
- operators.in_op: (lambda a, b: a in b if a is not _NO_OBJECT else None),
- operators.not_in_op: (
- lambda a, b: a not in b if a is not _NO_OBJECT else None
- ),
-}
-
-_notimplemented_ops = set(
- getattr(operators, op)
- for op in (
- "like_op",
- "not_like_op",
- "ilike_op",
- "not_ilike_op",
- "startswith_op",
- "between_op",
- "endswith_op",
- "concat_op",
- )
-)
-
class EvaluatorCompiler:
def __init__(self, target_cls=None):
self.target_cls = target_cls
- def process(self, *clauses):
- if len(clauses) > 1:
- clause = and_(*clauses)
- elif clauses:
- clause = clauses[0]
+ def process(self, clause, *clauses):
+ if clauses:
+ clause = and_(clause, *clauses)
- meth = getattr(self, "visit_%s" % clause.__visit_name__, None)
+ meth = getattr(self, f"visit_{clause.__visit_name__}", None)
if not meth:
raise UnevaluatableError(
- "Cannot evaluate %s" % type(clause).__name__
+ f"Cannot evaluate {type(clause).__name__}"
)
return meth(clause)
self.target_cls, parentmapper.class_
):
raise UnevaluatableError(
- "Can't evaluate criteria against alternate class %s"
- % parentmapper.class_
+ "Can't evaluate criteria against "
+ f"alternate class {parentmapper.class_}"
)
key = parentmapper._columntoproperty[clause].key
else:
and key in inspect(self.target_cls).column_attrs
):
util.warn(
- "Evaluating non-mapped column expression '%s' onto "
+ f"Evaluating non-mapped column expression '{clause}' onto "
"ORM instances; this is a deprecated use case. Please "
"make use of the actual mapped columns in ORM-evaluated "
- "UPDATE / DELETE expressions." % clause
+ "UPDATE / DELETE expressions."
)
else:
- raise UnevaluatableError("Cannot evaluate column: %s" % clause)
+ raise UnevaluatableError(f"Cannot evaluate column: {clause}")
get_corresponding_attr = operator.attrgetter(key)
return (
return self.visit_clauselist(clause)
def visit_clauselist(self, clause):
- evaluators = list(map(self.process, clause.clauses))
- if clause.operator is operators.or_:
+ evaluators = [self.process(clause) for clause in clause.clauses]
- def evaluate(obj):
- has_null = False
- for sub_evaluate in evaluators:
- value = sub_evaluate(obj)
- if value:
- return True
- has_null = has_null or value is None
- if has_null:
- return None
- return False
+ dispatch = (
+ f"visit_{clause.operator.__name__.rstrip('_')}_clauselist_op"
+ )
+ meth = getattr(self, dispatch, None)
+ if meth:
+ return meth(clause.operator, evaluators, clause)
+ else:
+ raise UnevaluatableError(
+ f"Cannot evaluate clauselist with operator {clause.operator}"
+ )
- elif clause.operator is operators.and_:
+ def visit_binary(self, clause):
+ eval_left = self.process(clause.left)
+ eval_right = self.process(clause.right)
- def evaluate(obj):
- for sub_evaluate in evaluators:
- value = sub_evaluate(obj)
- if not value:
- if value is None or value is _NO_OBJECT:
- return None
- return False
- return True
+ dispatch = f"visit_{clause.operator.__name__.rstrip('_')}_binary_op"
+ meth = getattr(self, dispatch, None)
+ if meth:
+ return meth(clause.operator, eval_left, eval_right)
+ else:
+ raise UnevaluatableError(
+ f"Cannot evaluate {type(clause).__name__} with "
+ f"operator {clause.operator}"
+ )
- elif clause.operator is operators.comma_op:
+ def visit_or_clauselist_op(self, operator, evaluators, clause):
+ def evaluate(obj):
+ has_null = False
+ for sub_evaluate in evaluators:
+ value = sub_evaluate(obj)
+ if value:
+ return True
+ has_null = has_null or value is None
+ if has_null:
+ return None
+ return False
- def evaluate(obj):
- values = []
- for sub_evaluate in evaluators:
- value = sub_evaluate(obj)
+ return evaluate
+
+ def visit_and_clauselist_op(self, operator, evaluators, clause):
+ def evaluate(obj):
+ for sub_evaluate in evaluators:
+ value = sub_evaluate(obj)
+ if not value:
if value is None or value is _NO_OBJECT:
return None
- values.append(value)
- return tuple(values)
+ return False
+ return True
+
+ return evaluate
+
+ def visit_comma_op_clauselist_op(self, operator, evaluators, clause):
+ def evaluate(obj):
+ values = []
+ for sub_evaluate in evaluators:
+ value = sub_evaluate(obj)
+ if value is None or value is _NO_OBJECT:
+ return None
+ values.append(value)
+ return tuple(values)
+
+ return evaluate
+ def visit_custom_op_binary_op(self, operator, eval_left, eval_right):
+ if operator.python_impl:
+ return self._straight_evaluate(operator, eval_left, eval_right)
else:
raise UnevaluatableError(
- "Cannot evaluate clauselist with operator %s" % clause.operator
+ f"Custom operator {operator.opstring!r} can't be evaluated "
+ "in Python unless it specifies a callable using "
+ "`.python_impl`."
)
+ def visit_is_binary_op(self, operator, eval_left, eval_right):
+ def evaluate(obj):
+ return eval_left(obj) == eval_right(obj)
+
return evaluate
- def visit_binary(self, clause):
- eval_left, eval_right = list(
- map(self.process, [clause.left, clause.right])
- )
- operator = clause.operator
- if operator is operators.is_:
+ def visit_is_not_binary_op(self, operator, eval_left, eval_right):
+ def evaluate(obj):
+ return eval_left(obj) != eval_right(obj)
- def evaluate(obj):
- return eval_left(obj) == eval_right(obj)
+ return evaluate
- elif operator is operators.is_not:
+ def _straight_evaluate(self, operator, eval_left, eval_right):
+ def evaluate(obj):
+ left_val = eval_left(obj)
+ right_val = eval_right(obj)
+ if left_val is None or right_val is None:
+ return None
+ return operator(eval_left(obj), eval_right(obj))
- def evaluate(obj):
- return eval_left(obj) != eval_right(obj)
-
- elif operator in _extended_ops:
+ return evaluate
- def evaluate(obj):
- left_val = eval_left(obj)
- right_val = eval_right(obj)
- if left_val is None or right_val is None:
- return None
+ visit_add_binary_op = _straight_evaluate
+ visit_mul_binary_op = _straight_evaluate
+ visit_sub_binary_op = _straight_evaluate
+ visit_mod_binary_op = _straight_evaluate
+ visit_truediv_binary_op = _straight_evaluate
+ visit_lt_binary_op = _straight_evaluate
+ visit_le_binary_op = _straight_evaluate
+ visit_ne_binary_op = _straight_evaluate
+ visit_gt_binary_op = _straight_evaluate
+ visit_ge_binary_op = _straight_evaluate
+ visit_eq_binary_op = _straight_evaluate
+
+ def visit_in_op_binary_op(self, operator, eval_left, eval_right):
+ return self._straight_evaluate(
+ lambda a, b: a in b if a is not _NO_OBJECT else None,
+ eval_left,
+ eval_right,
+ )
- return _extended_ops[operator](left_val, right_val)
+ def visit_not_in_op_binary_op(self, operator, eval_left, eval_right):
+ return self._straight_evaluate(
+ lambda a, b: a not in b if a is not _NO_OBJECT else None,
+ eval_left,
+ eval_right,
+ )
- elif operator in _straight_ops:
+ def visit_concat_op_binary_op(self, operator, eval_left, eval_right):
+ return self._straight_evaluate(
+ lambda a, b: a + b, eval_left, eval_right
+ )
- def evaluate(obj):
- left_val = eval_left(obj)
- right_val = eval_right(obj)
- if left_val is None or right_val is None:
- return None
- return operator(eval_left(obj), eval_right(obj))
+ def visit_startswith_op_binary_op(self, operator, eval_left, eval_right):
+ return self._straight_evaluate(
+ lambda a, b: a.startswith(b), eval_left, eval_right
+ )
- else:
- raise UnevaluatableError(
- "Cannot evaluate %s with operator %s"
- % (type(clause).__name__, clause.operator)
- )
- return evaluate
+ def visit_endswith_op_binary_op(self, operator, eval_left, eval_right):
+ return self._straight_evaluate(
+ lambda a, b: a.endswith(b), eval_left, eval_right
+ )
def visit_unary(self, clause):
eval_inner = self.process(clause.element)
return evaluate
raise UnevaluatableError(
- "Cannot evaluate %s with operator %s"
- % (type(clause).__name__, clause.operator)
+ f"Cannot evaluate {type(clause).__name__} "
+ f"with operator {clause.operator}"
)
def visit_bindparam(self, clause):
from operator import sub
from operator import truediv
+from .. import exc
from .. import util
return self.operate(inv)
def op(
- self, opstring, precedence=0, is_comparison=False, return_type=None
+ self,
+ opstring,
+ precedence=0,
+ is_comparison=False,
+ return_type=None,
+ python_impl=None,
):
"""Produce a generic operator function.
:class:`.Boolean`, and those that do not will be of the same
type as the left-hand operand.
+ :param python_impl: an optional Python function that can evaluate
+ two Python values in the same way as this operator works when
+ run on the database server. Useful for in-Python SQL expression
+ evaluation functions, such as for ORM hybrid attributes, and the
+ ORM "evaluator" used to match objects in a session after a multi-row
+ update or delete.
+
+ e.g.::
+
+ >>> expr = column('x').op('+', python_impl=lambda a, b: a + b)('y')
+
+ The operator for the above expression will also work for non-SQL
+ left and right objects::
+
+ >>> expr.operator(5, 10)
+ 15
+
+ .. versionadded:: 2.0
+
+
.. seealso::
:ref:`types_operators`
:ref:`relationship_custom_operator`
"""
- operator = custom_op(opstring, precedence, is_comparison, return_type)
+ operator = custom_op(
+ opstring,
+ precedence,
+ is_comparison,
+ return_type,
+ python_impl=python_impl,
+ )
def against(other):
return operator(self, other)
return against
- def bool_op(self, opstring, precedence=0):
+ def bool_op(self, opstring, precedence=0, python_impl=None):
"""Return a custom boolean operator.
This method is shorthand for calling
:meth:`.Operators.op`
"""
- return self.op(opstring, precedence=precedence, is_comparison=True)
+ return self.op(
+ opstring,
+ precedence=precedence,
+ is_comparison=True,
+ python_impl=python_impl,
+ )
def operate(self, op, *other, **kwargs):
r"""Operate on an argument.
"""
raise NotImplementedError(str(op))
+ __sa_operate__ = operate
+
def reverse_operate(self, op, other, **kwargs):
"""Reverse operate on an argument.
__name__ = "custom_op"
+ __slots__ = (
+ "opstring",
+ "precedence",
+ "is_comparison",
+ "natural_self_precedent",
+ "eager_grouping",
+ "return_type",
+ "python_impl",
+ )
+
def __init__(
self,
opstring,
return_type=None,
natural_self_precedent=False,
eager_grouping=False,
+ python_impl=None,
):
self.opstring = opstring
self.precedence = precedence
self.return_type = (
return_type._to_instance(return_type) if return_type else None
)
+ self.python_impl = python_impl
def __eq__(self, other):
return isinstance(other, custom_op) and other.opstring == self.opstring
return id(self)
def __call__(self, left, right, **kw):
- return left.operate(self, right, **kw)
+ if hasattr(left, "__sa_operate__"):
+ return left.operate(self, right, **kw)
+ elif self.python_impl:
+ return self.python_impl(left, right, **kw)
+ else:
+ raise exc.InvalidRequestError(
+ f"Custom operator {self.opstring!r} can't be used with "
+ "plain Python objects unless it includes the "
+ "'python_impl' parameter."
+ )
class ColumnOperators(Operators):
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session
from sqlalchemy.orm import synonym
+from sqlalchemy.sql import operators
from sqlalchemy.sql import update
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import AssertsCompiledSQL
["same_name", "id", "name"],
)
+ def test_custom_op(self, registry):
+ """test #3162"""
+
+ my_op = operators.custom_op(
+ "my_op", python_impl=lambda a, b: a + "_foo_" + b
+ )
+
+ @registry.mapped
+ class SomeClass:
+ __tablename__ = "sc"
+ id = Column(Integer, primary_key=True)
+ data = Column(String)
+
+ @hybrid.hybrid_property
+ def foo_data(self):
+ return my_op(self.data, "bar")
+
+ eq_(SomeClass(data="data").foo_data, "data_foo_bar")
+
+ self.assert_compile(SomeClass.foo_data, "sc.data my_op :data_1")
+
class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
from sqlalchemy import not_
from sqlalchemy import or_
from sqlalchemy import String
+from sqlalchemy import testing
from sqlalchemy import tuple_
+from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import evaluator
from sqlalchemy.orm import exc as orm_exc
from sqlalchemy.orm import relationship
from sqlalchemy.testing import expect_warnings
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
+from sqlalchemy.testing.assertions import expect_raises_message
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
],
)
+ @testing.combinations(
+ lambda User: User.name + "_foo" == "named_foo",
+ lambda User: User.name.startswith("nam"),
+ lambda User: User.name.endswith("named"),
+ )
+ def test_string_ops(self, expr):
+ User = self.classes.User
+
+ test_expr = testing.resolve_lambda(expr, User=User)
+ eval_eq(
+ test_expr,
+ testcases=[
+ (User(name="named"), True),
+ (User(name="othername"), False),
+ ],
+ )
+
def test_in(self):
User = self.classes.User
],
)
+ def test_mulitple_expressions(self):
+ User = self.classes.User
+
+ evaluator = compiler.process(User.id > 5, User.name == "ed")
+
+ is_(evaluator(User(id=7, name="ed")), True)
+ is_(evaluator(User(id=7, name="noted")), False)
+ is_(evaluator(User(id=4, name="ed")), False)
+
def test_in_tuples(self):
User = self.classes.User
],
)
+ def test_hybrids(self, registry):
+ @registry.mapped
+ class SomeClass:
+ __tablename__ = "sc"
+ id = Column(Integer, primary_key=True)
+ data = Column(String)
+
+ @hybrid_property
+ def foo_data(self):
+ return self.data + "_foo"
+
+ eval_eq(
+ SomeClass.foo_data == "somedata_foo",
+ testcases=[
+ (SomeClass(data="somedata"), True),
+ (SomeClass(data="otherdata"), False),
+ (SomeClass(data=None), None),
+ ],
+ )
+
+ def test_custom_op_no_impl(self):
+ """test #3162"""
+
+ User = self.classes.User
+
+ with expect_raises_message(
+ evaluator.UnevaluatableError,
+ r"Custom operator '\^\^' can't be evaluated in "
+ "Python unless it specifies",
+ ):
+ compiler.process(User.name.op("^^")("bar"))
+
+ def test_custom_op(self):
+ """test #3162"""
+
+ User = self.classes.User
+
+ eval_eq(
+ User.name.op("^^", python_impl=lambda a, b: a + "_foo_" + b)("bar")
+ == "name_foo_bar",
+ testcases=[
+ (User(name="name"), True),
+ (User(name="notname"), False),
+ ],
+ )
+
class M2OEvaluateTest(fixtures.DeclarativeMappedTest):
@classmethod
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import combinations
from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import expect_warnings
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
)
is_(expr.type, some_return_type)
+ def test_python_impl(self):
+ """test #3162"""
+ c = column("x")
+ c2 = column("y")
+ op1 = c.op("$", python_impl=lambda a, b: a > b)(c2).operator
+
+ is_(op1(3, 5), False)
+ is_(op1(5, 3), True)
+
+ def test_python_impl_not_present(self):
+ """test #3162"""
+ c = column("x")
+ c2 = column("y")
+ op1 = c.op("$")(c2).operator
+
+ with expect_raises_message(
+ exc.InvalidRequestError,
+ r"Custom operator '\$' can't be used with plain Python objects "
+ "unless it includes the 'python_impl' parameter.",
+ ):
+ op1(3, 5)
+
class TupleTypingTest(fixtures.TestBase):
def _assert_types(self, expr):