]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implement python_impl to custom_op for basic ORM evaluator extensibility
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 4 Jan 2022 19:04:15 +0000 (14:04 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 4 Jan 2022 21:40:35 +0000 (16:40 -0500)
Added new parameter :paramref:`_sql.Operators.op.python_impl`, available
from :meth:`_sql.Operators.op` and also when using the
:class:`_sql.Operators.custom_op` constructor directly, which allows an
in-Python evaluation function to be provided along with the custom SQL
operator. This evaluation function becomes the implementation used when the
operator object is used given plain Python objects as operands on both
sides, and in particular is compatible with the
``synchronize_session='evaluate'`` option used with
:ref:`orm_expression_update_delete`.

Fixes: #3162
Change-Id: If46ba6a0e303e2180a177ba418a8cafe9b42608e

doc/build/changelog/unreleased_20/3162.rst [new file with mode: 0644]
doc/build/orm/session_basics.rst
lib/sqlalchemy/orm/evaluator.py
lib/sqlalchemy/sql/operators.py
test/ext/test_hybrid.py
test/orm/test_evaluator.py
test/sql/test_operators.py

diff --git a/doc/build/changelog/unreleased_20/3162.rst b/doc/build/changelog/unreleased_20/3162.rst
new file mode 100644 (file)
index 0000000..2cc6e30
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: feature, orm
+    :tickets: 3162
+
+    Added new parameter :paramref:`_sql.Operators.op.python_impl`, available
+    from :meth:`_sql.Operators.op` and also when using the
+    :class:`_sql.Operators.custom_op` constructor directly, which allows an
+    in-Python evaluation function to be provided along with the custom SQL
+    operator. This evaluation function becomes the implementation used when the
+    operator object is used given plain Python objects as operands on both
+    sides, and in particular is compatible with the
+    ``synchronize_session='evaluate'`` option used with
+    :ref:`orm_expression_update_delete`.
index 6f818a439b77458b76f4cfc91900b93a2fbca0f0..9734b6eccb3afc1202596f28a58d5fcc8f3813ca 100644 (file)
@@ -615,6 +615,16 @@ values for ``synchronize_session`` are supported:
   able to evaluate the expression in Python and will raise an error.  If
   this occurs, use the ``'fetch'`` strategy for the operation instead.
 
+  .. tip::
+
+    If a SQL expression makes use of custom operators using the
+    :meth:`_sql.Operators.op` or :class:`_sql.custom_op` feature, the
+    :paramref:`_sql.Operators.op.python_impl` parameter may be used to indicate
+    a Python function that will be used by the ``"evaluate"`` synchronization
+    strategy.
+
+    .. versionadded:: 2.0
+
   .. warning::
 
     The ``"evaluate"`` strategy should be avoided if an UPDATE operation is
index 19e0be9d07a0b691174bd3601de53993d9822940..d8d88b805c4c8fb815c65b2e4e0b595b8dec8b6e 100644 (file)
@@ -7,13 +7,14 @@
 
 import operator
 
+from .. import exc
 from .. import inspect
 from .. import util
 from ..sql import and_
 from ..sql import operators
 
 
-class UnevaluatableError(Exception):
+class UnevaluatableError(exc.InvalidRequestError):
     pass
 
 
@@ -27,59 +28,19 @@ class _NoObject(operators.ColumnOperators):
 
 _NO_OBJECT = _NoObject()
 
-_straight_ops = set(
-    getattr(operators, op)
-    for op in (
-        "add",
-        "mul",
-        "sub",
-        "mod",
-        "truediv",
-        "lt",
-        "le",
-        "ne",
-        "gt",
-        "ge",
-        "eq",
-    )
-)
-
-_extended_ops = {
-    operators.in_op: (lambda a, b: a in b if a is not _NO_OBJECT else None),
-    operators.not_in_op: (
-        lambda a, b: a not in b if a is not _NO_OBJECT else None
-    ),
-}
-
-_notimplemented_ops = set(
-    getattr(operators, op)
-    for op in (
-        "like_op",
-        "not_like_op",
-        "ilike_op",
-        "not_ilike_op",
-        "startswith_op",
-        "between_op",
-        "endswith_op",
-        "concat_op",
-    )
-)
-
 
 class EvaluatorCompiler:
     def __init__(self, target_cls=None):
         self.target_cls = target_cls
 
-    def process(self, *clauses):
-        if len(clauses) > 1:
-            clause = and_(*clauses)
-        elif clauses:
-            clause = clauses[0]
+    def process(self, clause, *clauses):
+        if clauses:
+            clause = and_(clause, *clauses)
 
-        meth = getattr(self, "visit_%s" % clause.__visit_name__, None)
+        meth = getattr(self, f"visit_{clause.__visit_name__}", None)
         if not meth:
             raise UnevaluatableError(
-                "Cannot evaluate %s" % type(clause).__name__
+                f"Cannot evaluate {type(clause).__name__}"
             )
         return meth(clause)
 
@@ -102,8 +63,8 @@ class EvaluatorCompiler:
                 self.target_cls, parentmapper.class_
             ):
                 raise UnevaluatableError(
-                    "Can't evaluate criteria against alternate class %s"
-                    % parentmapper.class_
+                    "Can't evaluate criteria against "
+                    f"alternate class {parentmapper.class_}"
                 )
             key = parentmapper._columntoproperty[clause].key
         else:
@@ -113,13 +74,13 @@ class EvaluatorCompiler:
                 and key in inspect(self.target_cls).column_attrs
             ):
                 util.warn(
-                    "Evaluating non-mapped column expression '%s' onto "
+                    f"Evaluating non-mapped column expression '{clause}' onto "
                     "ORM instances; this is a deprecated use case.  Please "
                     "make use of the actual mapped columns in ORM-evaluated "
-                    "UPDATE / DELETE expressions." % clause
+                    "UPDATE / DELETE expressions."
                 )
             else:
-                raise UnevaluatableError("Cannot evaluate column: %s" % clause)
+                raise UnevaluatableError(f"Cannot evaluate column: {clause}")
 
         get_corresponding_attr = operator.attrgetter(key)
         return (
@@ -132,89 +93,143 @@ class EvaluatorCompiler:
         return self.visit_clauselist(clause)
 
     def visit_clauselist(self, clause):
-        evaluators = list(map(self.process, clause.clauses))
-        if clause.operator is operators.or_:
+        evaluators = [self.process(clause) for clause in clause.clauses]
 
-            def evaluate(obj):
-                has_null = False
-                for sub_evaluate in evaluators:
-                    value = sub_evaluate(obj)
-                    if value:
-                        return True
-                    has_null = has_null or value is None
-                if has_null:
-                    return None
-                return False
+        dispatch = (
+            f"visit_{clause.operator.__name__.rstrip('_')}_clauselist_op"
+        )
+        meth = getattr(self, dispatch, None)
+        if meth:
+            return meth(clause.operator, evaluators, clause)
+        else:
+            raise UnevaluatableError(
+                f"Cannot evaluate clauselist with operator {clause.operator}"
+            )
 
-        elif clause.operator is operators.and_:
+    def visit_binary(self, clause):
+        eval_left = self.process(clause.left)
+        eval_right = self.process(clause.right)
 
-            def evaluate(obj):
-                for sub_evaluate in evaluators:
-                    value = sub_evaluate(obj)
-                    if not value:
-                        if value is None or value is _NO_OBJECT:
-                            return None
-                        return False
-                return True
+        dispatch = f"visit_{clause.operator.__name__.rstrip('_')}_binary_op"
+        meth = getattr(self, dispatch, None)
+        if meth:
+            return meth(clause.operator, eval_left, eval_right)
+        else:
+            raise UnevaluatableError(
+                f"Cannot evaluate {type(clause).__name__} with "
+                f"operator {clause.operator}"
+            )
 
-        elif clause.operator is operators.comma_op:
+    def visit_or_clauselist_op(self, operator, evaluators, clause):
+        def evaluate(obj):
+            has_null = False
+            for sub_evaluate in evaluators:
+                value = sub_evaluate(obj)
+                if value:
+                    return True
+                has_null = has_null or value is None
+            if has_null:
+                return None
+            return False
 
-            def evaluate(obj):
-                values = []
-                for sub_evaluate in evaluators:
-                    value = sub_evaluate(obj)
+        return evaluate
+
+    def visit_and_clauselist_op(self, operator, evaluators, clause):
+        def evaluate(obj):
+            for sub_evaluate in evaluators:
+                value = sub_evaluate(obj)
+                if not value:
                     if value is None or value is _NO_OBJECT:
                         return None
-                    values.append(value)
-                return tuple(values)
+                    return False
+            return True
+
+        return evaluate
+
+    def visit_comma_op_clauselist_op(self, operator, evaluators, clause):
+        def evaluate(obj):
+            values = []
+            for sub_evaluate in evaluators:
+                value = sub_evaluate(obj)
+                if value is None or value is _NO_OBJECT:
+                    return None
+                values.append(value)
+            return tuple(values)
+
+        return evaluate
 
+    def visit_custom_op_binary_op(self, operator, eval_left, eval_right):
+        if operator.python_impl:
+            return self._straight_evaluate(operator, eval_left, eval_right)
         else:
             raise UnevaluatableError(
-                "Cannot evaluate clauselist with operator %s" % clause.operator
+                f"Custom operator {operator.opstring!r} can't be evaluated "
+                "in Python unless it specifies a callable using "
+                "`.python_impl`."
             )
 
+    def visit_is_binary_op(self, operator, eval_left, eval_right):
+        def evaluate(obj):
+            return eval_left(obj) == eval_right(obj)
+
         return evaluate
 
-    def visit_binary(self, clause):
-        eval_left, eval_right = list(
-            map(self.process, [clause.left, clause.right])
-        )
-        operator = clause.operator
-        if operator is operators.is_:
+    def visit_is_not_binary_op(self, operator, eval_left, eval_right):
+        def evaluate(obj):
+            return eval_left(obj) != eval_right(obj)
 
-            def evaluate(obj):
-                return eval_left(obj) == eval_right(obj)
+        return evaluate
 
-        elif operator is operators.is_not:
+    def _straight_evaluate(self, operator, eval_left, eval_right):
+        def evaluate(obj):
+            left_val = eval_left(obj)
+            right_val = eval_right(obj)
+            if left_val is None or right_val is None:
+                return None
+            return operator(eval_left(obj), eval_right(obj))
 
-            def evaluate(obj):
-                return eval_left(obj) != eval_right(obj)
-
-        elif operator in _extended_ops:
+        return evaluate
 
-            def evaluate(obj):
-                left_val = eval_left(obj)
-                right_val = eval_right(obj)
-                if left_val is None or right_val is None:
-                    return None
+    visit_add_binary_op = _straight_evaluate
+    visit_mul_binary_op = _straight_evaluate
+    visit_sub_binary_op = _straight_evaluate
+    visit_mod_binary_op = _straight_evaluate
+    visit_truediv_binary_op = _straight_evaluate
+    visit_lt_binary_op = _straight_evaluate
+    visit_le_binary_op = _straight_evaluate
+    visit_ne_binary_op = _straight_evaluate
+    visit_gt_binary_op = _straight_evaluate
+    visit_ge_binary_op = _straight_evaluate
+    visit_eq_binary_op = _straight_evaluate
+
+    def visit_in_op_binary_op(self, operator, eval_left, eval_right):
+        return self._straight_evaluate(
+            lambda a, b: a in b if a is not _NO_OBJECT else None,
+            eval_left,
+            eval_right,
+        )
 
-                return _extended_ops[operator](left_val, right_val)
+    def visit_not_in_op_binary_op(self, operator, eval_left, eval_right):
+        return self._straight_evaluate(
+            lambda a, b: a not in b if a is not _NO_OBJECT else None,
+            eval_left,
+            eval_right,
+        )
 
-        elif operator in _straight_ops:
+    def visit_concat_op_binary_op(self, operator, eval_left, eval_right):
+        return self._straight_evaluate(
+            lambda a, b: a + b, eval_left, eval_right
+        )
 
-            def evaluate(obj):
-                left_val = eval_left(obj)
-                right_val = eval_right(obj)
-                if left_val is None or right_val is None:
-                    return None
-                return operator(eval_left(obj), eval_right(obj))
+    def visit_startswith_op_binary_op(self, operator, eval_left, eval_right):
+        return self._straight_evaluate(
+            lambda a, b: a.startswith(b), eval_left, eval_right
+        )
 
-        else:
-            raise UnevaluatableError(
-                "Cannot evaluate %s with operator %s"
-                % (type(clause).__name__, clause.operator)
-            )
-        return evaluate
+    def visit_endswith_op_binary_op(self, operator, eval_left, eval_right):
+        return self._straight_evaluate(
+            lambda a, b: a.endswith(b), eval_left, eval_right
+        )
 
     def visit_unary(self, clause):
         eval_inner = self.process(clause.element)
@@ -228,8 +243,8 @@ class EvaluatorCompiler:
 
             return evaluate
         raise UnevaluatableError(
-            "Cannot evaluate %s with operator %s"
-            % (type(clause).__name__, clause.operator)
+            f"Cannot evaluate {type(clause).__name__} "
+            f"with operator {clause.operator}"
         )
 
     def visit_bindparam(self, clause):
index 74eb73e4606ac537d54279b4894851911860e6c5..8006d61453a91fedeefb7584de5492b0cb389176 100644 (file)
@@ -31,6 +31,7 @@ from operator import rshift
 from operator import sub
 from operator import truediv
 
+from .. import exc
 from .. import util
 
 
@@ -117,7 +118,12 @@ class Operators:
         return self.operate(inv)
 
     def op(
-        self, opstring, precedence=0, is_comparison=False, return_type=None
+        self,
+        opstring,
+        precedence=0,
+        is_comparison=False,
+        return_type=None,
+        python_impl=None,
     ):
         """Produce a generic operator function.
 
@@ -164,6 +170,26 @@ class Operators:
           :class:`.Boolean`, and those that do not will be of the same
           type as the left-hand operand.
 
+        :param python_impl: an optional Python function that can evaluate
+         two Python values in the same way as this operator works when
+         run on the database server.  Useful for in-Python SQL expression
+         evaluation functions, such as for ORM hybrid attributes, and the
+         ORM "evaluator" used to match objects in a session after a multi-row
+         update or delete.
+
+         e.g.::
+
+            >>> expr = column('x').op('+', python_impl=lambda a, b: a + b)('y')
+
+         The operator for the above expression will also work for non-SQL
+         left and right objects::
+
+            >>> expr.operator(5, 10)
+            15
+
+         .. versionadded:: 2.0
+
+
         .. seealso::
 
             :ref:`types_operators`
@@ -171,14 +197,20 @@ class Operators:
             :ref:`relationship_custom_operator`
 
         """
-        operator = custom_op(opstring, precedence, is_comparison, return_type)
+        operator = custom_op(
+            opstring,
+            precedence,
+            is_comparison,
+            return_type,
+            python_impl=python_impl,
+        )
 
         def against(other):
             return operator(self, other)
 
         return against
 
-    def bool_op(self, opstring, precedence=0):
+    def bool_op(self, opstring, precedence=0, python_impl=None):
         """Return a custom boolean operator.
 
         This method is shorthand for calling
@@ -191,7 +223,12 @@ class Operators:
             :meth:`.Operators.op`
 
         """
-        return self.op(opstring, precedence=precedence, is_comparison=True)
+        return self.op(
+            opstring,
+            precedence=precedence,
+            is_comparison=True,
+            python_impl=python_impl,
+        )
 
     def operate(self, op, *other, **kwargs):
         r"""Operate on an argument.
@@ -219,6 +256,8 @@ class Operators:
         """
         raise NotImplementedError(str(op))
 
+    __sa_operate__ = operate
+
     def reverse_operate(self, op, other, **kwargs):
         """Reverse operate on an argument.
 
@@ -256,6 +295,16 @@ class custom_op:
 
     __name__ = "custom_op"
 
+    __slots__ = (
+        "opstring",
+        "precedence",
+        "is_comparison",
+        "natural_self_precedent",
+        "eager_grouping",
+        "return_type",
+        "python_impl",
+    )
+
     def __init__(
         self,
         opstring,
@@ -264,6 +313,7 @@ class custom_op:
         return_type=None,
         natural_self_precedent=False,
         eager_grouping=False,
+        python_impl=None,
     ):
         self.opstring = opstring
         self.precedence = precedence
@@ -273,6 +323,7 @@ class custom_op:
         self.return_type = (
             return_type._to_instance(return_type) if return_type else None
         )
+        self.python_impl = python_impl
 
     def __eq__(self, other):
         return isinstance(other, custom_op) and other.opstring == self.opstring
@@ -281,7 +332,16 @@ class custom_op:
         return id(self)
 
     def __call__(self, left, right, **kw):
-        return left.operate(self, right, **kw)
+        if hasattr(left, "__sa_operate__"):
+            return left.operate(self, right, **kw)
+        elif self.python_impl:
+            return self.python_impl(left, right, **kw)
+        else:
+            raise exc.InvalidRequestError(
+                f"Custom operator {self.opstring!r} can't be used with "
+                "plain Python objects unless it includes the "
+                "'python_impl' parameter."
+            )
 
 
 class ColumnOperators(Operators):
index f3185909a9a6d31d26b3b41ef35c51224d8aaa24..de5f89b25274068fecf3bd0601d0a79bf54f01fe 100644 (file)
@@ -17,6 +17,7 @@ from sqlalchemy.orm import declarative_base
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import synonym
+from sqlalchemy.sql import operators
 from sqlalchemy.sql import update
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
@@ -162,6 +163,27 @@ class PropertyComparatorTest(fixtures.TestBase, AssertsCompiledSQL):
             ["same_name", "id", "name"],
         )
 
+    def test_custom_op(self, registry):
+        """test #3162"""
+
+        my_op = operators.custom_op(
+            "my_op", python_impl=lambda a, b: a + "_foo_" + b
+        )
+
+        @registry.mapped
+        class SomeClass:
+            __tablename__ = "sc"
+            id = Column(Integer, primary_key=True)
+            data = Column(String)
+
+            @hybrid.hybrid_property
+            def foo_data(self):
+                return my_op(self.data, "bar")
+
+        eq_(SomeClass(data="data").foo_data, "data_foo_bar")
+
+        self.assert_compile(SomeClass.foo_data, "sc.data my_op :data_1")
+
 
 class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = "default"
index 62acca582701ac0f595988a195a5ffaad970de07..33692505cbd8f978a719d51c021ce06feb69e3c2 100644 (file)
@@ -8,7 +8,9 @@ from sqlalchemy import Integer
 from sqlalchemy import not_
 from sqlalchemy import or_
 from sqlalchemy import String
+from sqlalchemy import testing
 from sqlalchemy import tuple_
+from sqlalchemy.ext.hybrid import hybrid_property
 from sqlalchemy.orm import evaluator
 from sqlalchemy.orm import exc as orm_exc
 from sqlalchemy.orm import relationship
@@ -17,6 +19,7 @@ from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
+from sqlalchemy.testing.assertions import expect_raises_message
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -200,6 +203,23 @@ class EvaluateTest(fixtures.MappedTest):
             ],
         )
 
+    @testing.combinations(
+        lambda User: User.name + "_foo" == "named_foo",
+        lambda User: User.name.startswith("nam"),
+        lambda User: User.name.endswith("named"),
+    )
+    def test_string_ops(self, expr):
+        User = self.classes.User
+
+        test_expr = testing.resolve_lambda(expr, User=User)
+        eval_eq(
+            test_expr,
+            testcases=[
+                (User(name="named"), True),
+                (User(name="othername"), False),
+            ],
+        )
+
     def test_in(self):
         User = self.classes.User
 
@@ -225,6 +245,15 @@ class EvaluateTest(fixtures.MappedTest):
             ],
         )
 
+    def test_mulitple_expressions(self):
+        User = self.classes.User
+
+        evaluator = compiler.process(User.id > 5, User.name == "ed")
+
+        is_(evaluator(User(id=7, name="ed")), True)
+        is_(evaluator(User(id=7, name="noted")), False)
+        is_(evaluator(User(id=4, name="ed")), False)
+
     def test_in_tuples(self):
         User = self.classes.User
 
@@ -268,6 +297,52 @@ class EvaluateTest(fixtures.MappedTest):
             ],
         )
 
+    def test_hybrids(self, registry):
+        @registry.mapped
+        class SomeClass:
+            __tablename__ = "sc"
+            id = Column(Integer, primary_key=True)
+            data = Column(String)
+
+            @hybrid_property
+            def foo_data(self):
+                return self.data + "_foo"
+
+        eval_eq(
+            SomeClass.foo_data == "somedata_foo",
+            testcases=[
+                (SomeClass(data="somedata"), True),
+                (SomeClass(data="otherdata"), False),
+                (SomeClass(data=None), None),
+            ],
+        )
+
+    def test_custom_op_no_impl(self):
+        """test #3162"""
+
+        User = self.classes.User
+
+        with expect_raises_message(
+            evaluator.UnevaluatableError,
+            r"Custom operator '\^\^' can't be evaluated in "
+            "Python unless it specifies",
+        ):
+            compiler.process(User.name.op("^^")("bar"))
+
+    def test_custom_op(self):
+        """test #3162"""
+
+        User = self.classes.User
+
+        eval_eq(
+            User.name.op("^^", python_impl=lambda a, b: a + "_foo_" + b)("bar")
+            == "name_foo_bar",
+            testcases=[
+                (User(name="name"), True),
+                (User(name="notname"), False),
+            ],
+        )
+
 
 class M2OEvaluateTest(fixtures.DeclarativeMappedTest):
     @classmethod
index 9e47f217f2ba72e0d4e06f7d94cf1d819d88bae0..2c77c39f3c0fbeb7dddfd31b0a1ee98e8cbb427b 100644 (file)
@@ -55,6 +55,7 @@ from sqlalchemy.sql.expression import union
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import combinations
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
@@ -3286,6 +3287,28 @@ class CustomOpTest(fixtures.TestBase):
         )
         is_(expr.type, some_return_type)
 
+    def test_python_impl(self):
+        """test #3162"""
+        c = column("x")
+        c2 = column("y")
+        op1 = c.op("$", python_impl=lambda a, b: a > b)(c2).operator
+
+        is_(op1(3, 5), False)
+        is_(op1(5, 3), True)
+
+    def test_python_impl_not_present(self):
+        """test #3162"""
+        c = column("x")
+        c2 = column("y")
+        op1 = c.op("$")(c2).operator
+
+        with expect_raises_message(
+            exc.InvalidRequestError,
+            r"Custom operator '\$' can't be used with plain Python objects "
+            "unless it includes the 'python_impl' parameter.",
+        ):
+            op1(3, 5)
+
 
 class TupleTypingTest(fixtures.TestBase):
     def _assert_types(self, expr):