]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add type awareness to evaluator
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 13 Sep 2022 15:00:46 +0000 (11:00 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 15 Sep 2022 12:54:58 +0000 (08:54 -0400)
Fixed regression where using ORM update() with synchronize_session='fetch'
would fail due to the use of evaluators that are now used to determine the
in-Python value for expressions in the the SET clause when refreshing
objects; if the evaluators make use of math operators against non-numeric
values such as PostgreSQL JSONB, the non-evaluable condition would fail to
be detected correctly. The evaluator now limits the use of math mutation
operators to numeric types only, with the exception of "+" that continues
to work for strings as well. SQLAlchemy 2.0 may alter this further by
fetching the SET values completely rather than using evaluation.

For 1.4 this also adds "concat_op" as evaluable; 2.0 already has
more string operator support

Fixes: #8507
Change-Id: Icf7120ccbf4266499df6bb3e05159c9f50971d69
(cherry picked from commit 4ab1bc641c7d5833cf20d8ab9b38f5bfba37cfdd)

doc/build/changelog/unreleased_14/8507.rst [new file with mode: 0644]
lib/sqlalchemy/orm/evaluator.py
test/orm/test_evaluator.py

diff --git a/doc/build/changelog/unreleased_14/8507.rst b/doc/build/changelog/unreleased_14/8507.rst
new file mode 100644 (file)
index 0000000..07944da
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: bug, orm, regression
+    :tickets: 8507
+
+    Fixed regression where using ORM update() with synchronize_session='fetch'
+    would fail due to the use of evaluators that are now used to determine the
+    in-Python value for expressions in the the SET clause when refreshing
+    objects; if the evaluators make use of math operators against non-numeric
+    values such as PostgreSQL JSONB, the non-evaluable condition would fail to
+    be detected correctly. The evaluator now limits the use of math mutation
+    operators to numeric types only, with the exception of "+" that continues
+    to work for strings as well. SQLAlchemy 2.0 may alter this further by
+    fetching the SET values completely rather than using evaluation.
index dbbfba09f0118fd186394cdca4e53d69ea6ec7f6..f1d9ca5413d9cb2cd7b546203e4a1d65ed5f438d 100644 (file)
@@ -11,6 +11,8 @@ from .. import inspect
 from .. import util
 from ..sql import and_
 from ..sql import operators
+from ..sql.sqltypes import Integer
+from ..sql.sqltypes import Numeric
 
 
 class UnevaluatableError(Exception):
@@ -30,12 +32,6 @@ _NO_OBJECT = _NoObject()
 _straight_ops = set(
     getattr(operators, op)
     for op in (
-        "add",
-        "mul",
-        "sub",
-        "div",
-        "mod",
-        "truediv",
         "lt",
         "le",
         "ne",
@@ -45,6 +41,18 @@ _straight_ops = set(
     )
 )
 
+_math_only_straight_ops = set(
+    getattr(operators, op)
+    for op in (
+        "add",
+        "mul",
+        "sub",
+        "div",
+        "mod",
+        "truediv",
+    )
+)
+
 _extended_ops = {
     operators.in_op: (lambda a, b: a in b if a is not _NO_OBJECT else None),
     operators.not_in_op: (
@@ -62,7 +70,6 @@ _notimplemented_ops = set(
         "startswith_op",
         "between_op",
         "endswith_op",
-        "concat_op",
     )
 )
 
@@ -191,6 +198,11 @@ class EvaluatorCompiler(object):
             def evaluate(obj):
                 return eval_left(obj) != eval_right(obj)
 
+        elif operator is operators.concat_op:
+
+            def evaluate(obj):
+                return eval_left(obj) + eval_right(obj)
+
         elif operator in _extended_ops:
 
             def evaluate(obj):
@@ -201,6 +213,28 @@ class EvaluatorCompiler(object):
 
                 return _extended_ops[operator](left_val, right_val)
 
+        elif operator in _math_only_straight_ops:
+            if (
+                clause.left.type._type_affinity
+                not in (
+                    Numeric,
+                    Integer,
+                )
+                or clause.right.type._type_affinity not in (Numeric, Integer)
+            ):
+                raise UnevaluatableError(
+                    'Cannot evaluate math operator "%s" for '
+                    "datatypes %s, %s"
+                    % (operator.__name__, clause.left.type, clause.right.type)
+                )
+
+            def evaluate(obj):
+                left_val = eval_left(obj)
+                right_val = eval_right(obj)
+                if left_val is None or right_val is None:
+                    return None
+                return operator(eval_left(obj), eval_right(obj))
+
         elif operator in _straight_ops:
 
             def evaluate(obj):
index 62acca582701ac0f595988a195a5ffaad970de07..5902264e36ed1d57af6ffd56c314ae4cad3751fb 100644 (file)
@@ -5,15 +5,19 @@ from sqlalchemy import bindparam
 from sqlalchemy import ForeignKey
 from sqlalchemy import inspect
 from sqlalchemy import Integer
+from sqlalchemy import JSON
 from sqlalchemy import not_
 from sqlalchemy import or_
 from sqlalchemy import String
+from sqlalchemy import testing
 from sqlalchemy import tuple_
 from sqlalchemy.orm import evaluator
 from sqlalchemy.orm import exc as orm_exc
 from sqlalchemy.orm import relationship
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
@@ -50,6 +54,7 @@ class EvaluateTest(fixtures.MappedTest):
             Column("id", Integer, primary_key=True),
             Column("name", String(64)),
             Column("othername", String(64)),
+            Column("json", JSON),
         )
 
     @classmethod
@@ -200,6 +205,24 @@ class EvaluateTest(fixtures.MappedTest):
             ],
         )
 
+    @testing.combinations(
+        lambda User: User.name + "_foo" == "named_foo",
+        # not implemented in 1.4
+        # lambda User: User.name.startswith("nam"),
+        # lambda User: User.name.endswith("named"),
+    )
+    def test_string_ops(self, expr):
+        User = self.classes.User
+
+        test_expr = testing.resolve_lambda(expr, User=User)
+        eval_eq(
+            test_expr,
+            testcases=[
+                (User(name="named"), True),
+                (User(name="othername"), False),
+            ],
+        )
+
     def test_in(self):
         User = self.classes.User
 
@@ -268,6 +291,66 @@ class EvaluateTest(fixtures.MappedTest):
             ],
         )
 
+    @testing.combinations(
+        (lambda User: User.id + 5, "id", 10, 15, None),
+        (
+            lambda User: User.name + " name",
+            "name",
+            "some value",
+            "some value name",
+            None,
+        ),
+        (
+            lambda User: User.id + "name",
+            "id",
+            10,
+            evaluator.UnevaluatableError,
+            r"Cannot evaluate math operator \"add\" for "
+            r"datatypes INTEGER, VARCHAR",
+        ),
+        (
+            lambda User: User.json + 12,
+            "json",
+            {"foo": "bar"},
+            evaluator.UnevaluatableError,
+            r"Cannot evaluate math operator \"add\" for "
+            r"datatypes JSON, INTEGER",
+        ),
+        (
+            lambda User: User.json - 12,
+            "json",
+            {"foo": "bar"},
+            evaluator.UnevaluatableError,
+            r"Cannot evaluate math operator \"sub\" for "
+            r"datatypes JSON, INTEGER",
+        ),
+        (
+            lambda User: User.json - "foo",
+            "json",
+            {"foo": "bar"},
+            evaluator.UnevaluatableError,
+            r"Cannot evaluate math operator \"sub\" for "
+            r"datatypes JSON, VARCHAR",
+        ),
+    )
+    def test_math_op_type_exclusions(
+        self, expr, attrname, initial_value, expected, message
+    ):
+        """test #8507"""
+
+        User = self.classes.User
+
+        expr = testing.resolve_lambda(expr, User=User)
+
+        if expected is evaluator.UnevaluatableError:
+            with expect_raises_message(evaluator.UnevaluatableError, message):
+                compiler.process(expr)
+        else:
+            obj = User(**{attrname: initial_value})
+
+            new_value = compiler.process(expr)(obj)
+            eq_(new_value, expected)
+
 
 class M2OEvaluateTest(fixtures.DeclarativeMappedTest):
     @classmethod