From: Mike Bayer Date: Tue, 13 Sep 2022 15:00:46 +0000 (-0400) Subject: Add type awareness to evaluator X-Git-Tag: rel_2_0_0b1~64^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=14b634d7065446d146456eed006c4969a7972b1a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add type awareness to evaluator Fixed regression where using ORM update() with synchronize_session='fetch' would fail due to the use of evaluators that are now used to determine the in-Python value for expressions in the the SET clause when refreshing objects; if the evaluators make use of math operators against non-numeric values such as PostgreSQL JSONB, the non-evaluable condition would fail to be detected correctly. The evaluator now limits the use of math mutation operators to numeric types only, with the exception of "+" that continues to work for strings as well. SQLAlchemy 2.0 may alter this further by fetching the SET values completely rather than using evaluation. Fixes: #8507 Change-Id: Icf7120ccbf4266499df6bb3e05159c9f50971d69 --- diff --git a/doc/build/changelog/unreleased_14/8507.rst b/doc/build/changelog/unreleased_14/8507.rst new file mode 100644 index 0000000000..07944da75d --- /dev/null +++ b/doc/build/changelog/unreleased_14/8507.rst @@ -0,0 +1,13 @@ +.. change:: + :tags: bug, orm, regression + :tickets: 8507 + + Fixed regression where using ORM update() with synchronize_session='fetch' + would fail due to the use of evaluators that are now used to determine the + in-Python value for expressions in the the SET clause when refreshing + objects; if the evaluators make use of math operators against non-numeric + values such as PostgreSQL JSONB, the non-evaluable condition would fail to + be detected correctly. The evaluator now limits the use of math mutation + operators to numeric types only, with the exception of "+" that continues + to work for strings as well. SQLAlchemy 2.0 may alter this further by + fetching the SET values completely rather than using evaluation. diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py index 72936d1ab1..b3129afdd7 100644 --- a/lib/sqlalchemy/orm/evaluator.py +++ b/lib/sqlalchemy/orm/evaluator.py @@ -16,6 +16,8 @@ from .. import inspect from .. import util from ..sql import and_ from ..sql import operators +from ..sql.sqltypes import Integer +from ..sql.sqltypes import Numeric class UnevaluatableError(exc.InvalidRequestError): @@ -120,7 +122,7 @@ class EvaluatorCompiler: dispatch = f"visit_{clause.operator.__name__.rstrip('_')}_binary_op" meth = getattr(self, dispatch, None) if meth: - return meth(clause.operator, eval_left, eval_right) + return meth(clause.operator, eval_left, eval_right, clause) else: raise UnevaluatableError( f"Cannot evaluate {type(clause).__name__} with " @@ -165,9 +167,13 @@ class EvaluatorCompiler: return evaluate - def visit_custom_op_binary_op(self, operator, eval_left, eval_right): + def visit_custom_op_binary_op( + self, operator, eval_left, eval_right, clause + ): if operator.python_impl: - return self._straight_evaluate(operator, eval_left, eval_right) + return self._straight_evaluate( + operator, eval_left, eval_right, clause + ) else: raise UnevaluatableError( f"Custom operator {operator.opstring!r} can't be evaluated " @@ -175,19 +181,19 @@ class EvaluatorCompiler: "`.python_impl`." ) - def visit_is_binary_op(self, operator, eval_left, eval_right): + def visit_is_binary_op(self, operator, eval_left, eval_right, clause): def evaluate(obj): return eval_left(obj) == eval_right(obj) return evaluate - def visit_is_not_binary_op(self, operator, eval_left, eval_right): + def visit_is_not_binary_op(self, operator, eval_left, eval_right, clause): def evaluate(obj): return eval_left(obj) != eval_right(obj) return evaluate - def _straight_evaluate(self, operator, eval_left, eval_right): + def _straight_evaluate(self, operator, eval_left, eval_right, clause): def evaluate(obj): left_val = eval_left(obj) right_val = eval_right(obj) @@ -197,11 +203,25 @@ class EvaluatorCompiler: return evaluate - visit_add_binary_op = _straight_evaluate - visit_mul_binary_op = _straight_evaluate - visit_sub_binary_op = _straight_evaluate - visit_mod_binary_op = _straight_evaluate - visit_truediv_binary_op = _straight_evaluate + def _straight_evaluate_numeric_only( + self, operator, eval_left, eval_right, clause + ): + if clause.left.type._type_affinity not in ( + Numeric, + Integer, + ) or clause.right.type._type_affinity not in (Numeric, Integer): + raise UnevaluatableError( + f'Cannot evaluate math operator "{operator.__name__}" for ' + f"datatypes {clause.left.type}, {clause.right.type}" + ) + + return self._straight_evaluate(operator, eval_left, eval_right, clause) + + visit_add_binary_op = _straight_evaluate_numeric_only + visit_mul_binary_op = _straight_evaluate_numeric_only + visit_sub_binary_op = _straight_evaluate_numeric_only + visit_mod_binary_op = _straight_evaluate_numeric_only + visit_truediv_binary_op = _straight_evaluate_numeric_only visit_lt_binary_op = _straight_evaluate visit_le_binary_op = _straight_evaluate visit_ne_binary_op = _straight_evaluate @@ -209,33 +229,43 @@ class EvaluatorCompiler: visit_ge_binary_op = _straight_evaluate visit_eq_binary_op = _straight_evaluate - def visit_in_op_binary_op(self, operator, eval_left, eval_right): + def visit_in_op_binary_op(self, operator, eval_left, eval_right, clause): return self._straight_evaluate( lambda a, b: a in b if a is not _NO_OBJECT else None, eval_left, eval_right, + clause, ) - def visit_not_in_op_binary_op(self, operator, eval_left, eval_right): + def visit_not_in_op_binary_op( + self, operator, eval_left, eval_right, clause + ): return self._straight_evaluate( lambda a, b: a not in b if a is not _NO_OBJECT else None, eval_left, eval_right, + clause, ) - def visit_concat_op_binary_op(self, operator, eval_left, eval_right): + def visit_concat_op_binary_op( + self, operator, eval_left, eval_right, clause + ): return self._straight_evaluate( - lambda a, b: a + b, eval_left, eval_right + lambda a, b: a + b, eval_left, eval_right, clause ) - def visit_startswith_op_binary_op(self, operator, eval_left, eval_right): + def visit_startswith_op_binary_op( + self, operator, eval_left, eval_right, clause + ): return self._straight_evaluate( - lambda a, b: a.startswith(b), eval_left, eval_right + lambda a, b: a.startswith(b), eval_left, eval_right, clause ) - def visit_endswith_op_binary_op(self, operator, eval_left, eval_right): + def visit_endswith_op_binary_op( + self, operator, eval_left, eval_right, clause + ): return self._straight_evaluate( - lambda a, b: a.endswith(b), eval_left, eval_right + lambda a, b: a.endswith(b), eval_left, eval_right, clause ) def visit_unary(self, clause): diff --git a/test/orm/test_evaluator.py b/test/orm/test_evaluator.py index 104e47ae8f..ff40cd2015 100644 --- a/test/orm/test_evaluator.py +++ b/test/orm/test_evaluator.py @@ -5,6 +5,7 @@ from sqlalchemy import bindparam from sqlalchemy import ForeignKey from sqlalchemy import inspect from sqlalchemy import Integer +from sqlalchemy import JSON from sqlalchemy import not_ from sqlalchemy import or_ from sqlalchemy import String @@ -16,6 +17,7 @@ from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import relationship from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message +from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ @@ -53,6 +55,7 @@ class EvaluateTest(fixtures.MappedTest): Column("id", Integer, primary_key=True), Column("name", String(64)), Column("othername", String(64)), + Column("json", JSON), ) @classmethod @@ -343,6 +346,67 @@ class EvaluateTest(fixtures.MappedTest): ], ) + @testing.combinations( + (lambda User: User.id + 5, "id", 10, 15, None), + ( + # note this one uses concat_op, not operator.add + lambda User: User.name + " name", + "name", + "some value", + "some value name", + None, + ), + ( + lambda User: User.id + "name", + "id", + 10, + evaluator.UnevaluatableError, + r"Cannot evaluate math operator \"add\" for " + r"datatypes INTEGER, VARCHAR", + ), + ( + lambda User: User.json + 12, + "json", + {"foo": "bar"}, + evaluator.UnevaluatableError, + r"Cannot evaluate math operator \"add\" for " + r"datatypes JSON, INTEGER", + ), + ( + lambda User: User.json - 12, + "json", + {"foo": "bar"}, + evaluator.UnevaluatableError, + r"Cannot evaluate math operator \"sub\" for " + r"datatypes JSON, INTEGER", + ), + ( + lambda User: User.json - "foo", + "json", + {"foo": "bar"}, + evaluator.UnevaluatableError, + r"Cannot evaluate math operator \"sub\" for " + r"datatypes JSON, VARCHAR", + ), + ) + def test_math_op_type_exclusions( + self, expr, attrname, initial_value, expected, message + ): + """test #8507""" + + User = self.classes.User + + expr = testing.resolve_lambda(expr, User=User) + + if expected is evaluator.UnevaluatableError: + with expect_raises_message(evaluator.UnevaluatableError, message): + compiler.process(expr) + else: + obj = User(**{attrname: initial_value}) + + new_value = compiler.process(expr)(obj) + eq_(new_value, expected) + class M2OEvaluateTest(fixtures.DeclarativeMappedTest): @classmethod