]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add type awareness to evaluator
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 13 Sep 2022 15:00:46 +0000 (11:00 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 13 Sep 2022 15:18:19 +0000 (11:18 -0400)
Fixed regression where using ORM update() with synchronize_session='fetch'
would fail due to the use of evaluators that are now used to determine the
in-Python value for expressions in the the SET clause when refreshing
objects; if the evaluators make use of math operators against non-numeric
values such as PostgreSQL JSONB, the non-evaluable condition would fail to
be detected correctly. The evaluator now limits the use of math mutation
operators to numeric types only, with the exception of "+" that continues
to work for strings as well. SQLAlchemy 2.0 may alter this further by
fetching the SET values completely rather than using evaluation.

Fixes: #8507
Change-Id: Icf7120ccbf4266499df6bb3e05159c9f50971d69

doc/build/changelog/unreleased_14/8507.rst [new file with mode: 0644]
lib/sqlalchemy/orm/evaluator.py
test/orm/test_evaluator.py

diff --git a/doc/build/changelog/unreleased_14/8507.rst b/doc/build/changelog/unreleased_14/8507.rst
new file mode 100644 (file)
index 0000000..07944da
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: bug, orm, regression
+    :tickets: 8507
+
+    Fixed regression where using ORM update() with synchronize_session='fetch'
+    would fail due to the use of evaluators that are now used to determine the
+    in-Python value for expressions in the the SET clause when refreshing
+    objects; if the evaluators make use of math operators against non-numeric
+    values such as PostgreSQL JSONB, the non-evaluable condition would fail to
+    be detected correctly. The evaluator now limits the use of math mutation
+    operators to numeric types only, with the exception of "+" that continues
+    to work for strings as well. SQLAlchemy 2.0 may alter this further by
+    fetching the SET values completely rather than using evaluation.
index 72936d1ab11371b381532cf8bd9a1ed6513a7273..b3129afdd7a41317955940e1e9d79f0886d1d527 100644 (file)
@@ -16,6 +16,8 @@ from .. import inspect
 from .. import util
 from ..sql import and_
 from ..sql import operators
+from ..sql.sqltypes import Integer
+from ..sql.sqltypes import Numeric
 
 
 class UnevaluatableError(exc.InvalidRequestError):
@@ -120,7 +122,7 @@ class EvaluatorCompiler:
         dispatch = f"visit_{clause.operator.__name__.rstrip('_')}_binary_op"
         meth = getattr(self, dispatch, None)
         if meth:
-            return meth(clause.operator, eval_left, eval_right)
+            return meth(clause.operator, eval_left, eval_right, clause)
         else:
             raise UnevaluatableError(
                 f"Cannot evaluate {type(clause).__name__} with "
@@ -165,9 +167,13 @@ class EvaluatorCompiler:
 
         return evaluate
 
-    def visit_custom_op_binary_op(self, operator, eval_left, eval_right):
+    def visit_custom_op_binary_op(
+        self, operator, eval_left, eval_right, clause
+    ):
         if operator.python_impl:
-            return self._straight_evaluate(operator, eval_left, eval_right)
+            return self._straight_evaluate(
+                operator, eval_left, eval_right, clause
+            )
         else:
             raise UnevaluatableError(
                 f"Custom operator {operator.opstring!r} can't be evaluated "
@@ -175,19 +181,19 @@ class EvaluatorCompiler:
                 "`.python_impl`."
             )
 
-    def visit_is_binary_op(self, operator, eval_left, eval_right):
+    def visit_is_binary_op(self, operator, eval_left, eval_right, clause):
         def evaluate(obj):
             return eval_left(obj) == eval_right(obj)
 
         return evaluate
 
-    def visit_is_not_binary_op(self, operator, eval_left, eval_right):
+    def visit_is_not_binary_op(self, operator, eval_left, eval_right, clause):
         def evaluate(obj):
             return eval_left(obj) != eval_right(obj)
 
         return evaluate
 
-    def _straight_evaluate(self, operator, eval_left, eval_right):
+    def _straight_evaluate(self, operator, eval_left, eval_right, clause):
         def evaluate(obj):
             left_val = eval_left(obj)
             right_val = eval_right(obj)
@@ -197,11 +203,25 @@ class EvaluatorCompiler:
 
         return evaluate
 
-    visit_add_binary_op = _straight_evaluate
-    visit_mul_binary_op = _straight_evaluate
-    visit_sub_binary_op = _straight_evaluate
-    visit_mod_binary_op = _straight_evaluate
-    visit_truediv_binary_op = _straight_evaluate
+    def _straight_evaluate_numeric_only(
+        self, operator, eval_left, eval_right, clause
+    ):
+        if clause.left.type._type_affinity not in (
+            Numeric,
+            Integer,
+        ) or clause.right.type._type_affinity not in (Numeric, Integer):
+            raise UnevaluatableError(
+                f'Cannot evaluate math operator "{operator.__name__}" for '
+                f"datatypes {clause.left.type}, {clause.right.type}"
+            )
+
+        return self._straight_evaluate(operator, eval_left, eval_right, clause)
+
+    visit_add_binary_op = _straight_evaluate_numeric_only
+    visit_mul_binary_op = _straight_evaluate_numeric_only
+    visit_sub_binary_op = _straight_evaluate_numeric_only
+    visit_mod_binary_op = _straight_evaluate_numeric_only
+    visit_truediv_binary_op = _straight_evaluate_numeric_only
     visit_lt_binary_op = _straight_evaluate
     visit_le_binary_op = _straight_evaluate
     visit_ne_binary_op = _straight_evaluate
@@ -209,33 +229,43 @@ class EvaluatorCompiler:
     visit_ge_binary_op = _straight_evaluate
     visit_eq_binary_op = _straight_evaluate
 
-    def visit_in_op_binary_op(self, operator, eval_left, eval_right):
+    def visit_in_op_binary_op(self, operator, eval_left, eval_right, clause):
         return self._straight_evaluate(
             lambda a, b: a in b if a is not _NO_OBJECT else None,
             eval_left,
             eval_right,
+            clause,
         )
 
-    def visit_not_in_op_binary_op(self, operator, eval_left, eval_right):
+    def visit_not_in_op_binary_op(
+        self, operator, eval_left, eval_right, clause
+    ):
         return self._straight_evaluate(
             lambda a, b: a not in b if a is not _NO_OBJECT else None,
             eval_left,
             eval_right,
+            clause,
         )
 
-    def visit_concat_op_binary_op(self, operator, eval_left, eval_right):
+    def visit_concat_op_binary_op(
+        self, operator, eval_left, eval_right, clause
+    ):
         return self._straight_evaluate(
-            lambda a, b: a + b, eval_left, eval_right
+            lambda a, b: a + b, eval_left, eval_right, clause
         )
 
-    def visit_startswith_op_binary_op(self, operator, eval_left, eval_right):
+    def visit_startswith_op_binary_op(
+        self, operator, eval_left, eval_right, clause
+    ):
         return self._straight_evaluate(
-            lambda a, b: a.startswith(b), eval_left, eval_right
+            lambda a, b: a.startswith(b), eval_left, eval_right, clause
         )
 
-    def visit_endswith_op_binary_op(self, operator, eval_left, eval_right):
+    def visit_endswith_op_binary_op(
+        self, operator, eval_left, eval_right, clause
+    ):
         return self._straight_evaluate(
-            lambda a, b: a.endswith(b), eval_left, eval_right
+            lambda a, b: a.endswith(b), eval_left, eval_right, clause
         )
 
     def visit_unary(self, clause):
index 104e47ae8f5544c8d61f34228ca92fb52b44bef3..ff40cd20155a68b4295615da63475a5fe5636e5c 100644 (file)
@@ -5,6 +5,7 @@ from sqlalchemy import bindparam
 from sqlalchemy import ForeignKey
 from sqlalchemy import inspect
 from sqlalchemy import Integer
+from sqlalchemy import JSON
 from sqlalchemy import not_
 from sqlalchemy import or_
 from sqlalchemy import String
@@ -16,6 +17,7 @@ from sqlalchemy.orm import exc as orm_exc
 from sqlalchemy.orm import relationship
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
+from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
@@ -53,6 +55,7 @@ class EvaluateTest(fixtures.MappedTest):
             Column("id", Integer, primary_key=True),
             Column("name", String(64)),
             Column("othername", String(64)),
+            Column("json", JSON),
         )
 
     @classmethod
@@ -343,6 +346,67 @@ class EvaluateTest(fixtures.MappedTest):
             ],
         )
 
+    @testing.combinations(
+        (lambda User: User.id + 5, "id", 10, 15, None),
+        (
+            # note this one uses concat_op, not operator.add
+            lambda User: User.name + " name",
+            "name",
+            "some value",
+            "some value name",
+            None,
+        ),
+        (
+            lambda User: User.id + "name",
+            "id",
+            10,
+            evaluator.UnevaluatableError,
+            r"Cannot evaluate math operator \"add\" for "
+            r"datatypes INTEGER, VARCHAR",
+        ),
+        (
+            lambda User: User.json + 12,
+            "json",
+            {"foo": "bar"},
+            evaluator.UnevaluatableError,
+            r"Cannot evaluate math operator \"add\" for "
+            r"datatypes JSON, INTEGER",
+        ),
+        (
+            lambda User: User.json - 12,
+            "json",
+            {"foo": "bar"},
+            evaluator.UnevaluatableError,
+            r"Cannot evaluate math operator \"sub\" for "
+            r"datatypes JSON, INTEGER",
+        ),
+        (
+            lambda User: User.json - "foo",
+            "json",
+            {"foo": "bar"},
+            evaluator.UnevaluatableError,
+            r"Cannot evaluate math operator \"sub\" for "
+            r"datatypes JSON, VARCHAR",
+        ),
+    )
+    def test_math_op_type_exclusions(
+        self, expr, attrname, initial_value, expected, message
+    ):
+        """test #8507"""
+
+        User = self.classes.User
+
+        expr = testing.resolve_lambda(expr, User=User)
+
+        if expected is evaluator.UnevaluatableError:
+            with expect_raises_message(evaluator.UnevaluatableError, message):
+                compiler.process(expr)
+        else:
+            obj = User(**{attrname: initial_value})
+
+            new_value = compiler.process(expr)(obj)
+            eq_(new_value, expected)
+
 
 class M2OEvaluateTest(fixtures.DeclarativeMappedTest):
     @classmethod