From: Mike Bayer Date: Thu, 9 Feb 2023 20:36:38 +0000 (-0500) Subject: apply self_group to all elements of multi-expression X-Git-Tag: rel_2_0_3~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6971ba97247928c9a79f532001278d0e1d5845fc;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git apply self_group to all elements of multi-expression Fixed critical regression in SQL expression formulation in the 2.0 series due to :ticket:`7744` which improved support for SQL expressions that contained many elements against the same operator repeatedly; parenthesis grouping would be lost with expression elements beyond the first two elements. Fixes: #9271 Change-Id: Ib6ed5b71efe0f6816dab75bda622297fc89e3b49 --- diff --git a/doc/build/changelog/unreleased_20/9271.rst b/doc/build/changelog/unreleased_20/9271.rst new file mode 100644 index 0000000000..3efe0b156f --- /dev/null +++ b/doc/build/changelog/unreleased_20/9271.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, sql, regression + :tickets: 9271 + + Fixed critical regression in SQL expression formulation in the 2.0 series + due to :ticket:`7744` which improved support for SQL expressions that + contained many elements against the same operator repeatedly; parenthesis + grouping would be lost with expression elements beyond the first two + elements. + diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 37d53b30a4..4c2c7de3c4 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -547,7 +547,6 @@ class ClauseElement( optionaldict: Optional[Mapping[str, Any]], kwargs: Dict[str, Any], ) -> Self: - if optionaldict: kwargs.update(optionaldict) @@ -2780,7 +2779,6 @@ class OperatorExpression(ColumnElement[_T]): negate: Optional[OperatorType] = None, modifiers: Optional[Mapping[str, Any]] = None, ) -> OperatorExpression[_T]: - if operators.is_associative(op): assert ( negate is None @@ -2805,7 +2803,9 @@ class OperatorExpression(ColumnElement[_T]): if multi: return ExpressionClauseList._construct_for_list( - op, type_, *(left_flattened + right_flattened) + op, + type_, + *(left_flattened + right_flattened), ) return BinaryExpression( @@ -2886,7 +2886,12 @@ class ExpressionClauseList(OperatorExpression[_T]): ) -> ExpressionClauseList[_T]: self = cls.__new__(cls) self.group = group - self.clauses = clauses + if group: + self.clauses = tuple( + c.self_group(against=operator) for c in clauses + ) + else: + self.clauses = clauses self.operator = operator self.type = type_ return self @@ -2961,7 +2966,6 @@ class BooleanClauseList(ExpressionClauseList[bool]): *clauses: Any, **kw: Any, ) -> ColumnElement[Any]: - if initial_clause is _NoArg.NO_ARG: # no elements period. deprecated use case. return an empty # ClauseList construct that generates nothing unless it has @@ -3233,7 +3237,6 @@ class Case(ColumnElement[_T]): value: Optional[Any] = None, else_: Optional[Any] = None, ): - new_whens: Iterable[Any] = coercions._expression_collection_was_a_list( "whens", "case", whens ) @@ -4908,7 +4911,6 @@ class CollationClause(ColumnElement[str]): class _IdentifiedClause(Executable, ClauseElement): - __visit_name__ = "identified" def __init__(self, ident): @@ -5195,7 +5197,6 @@ class _anonymous_label(_truncated_label): enclosing_label: Optional[str] = None, sanitize_key: bool = False, ) -> _anonymous_label: - # need to escape chars that interfere with format # strings in any case, issue #8724 body = re.sub(r"[%\(\) \$]+", "_", body) diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index fd0cf66549..e7e51aa635 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -63,6 +63,7 @@ from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import is_not +from sqlalchemy.testing import resolve_lambda from sqlalchemy.testing.assertions import expect_deprecated from sqlalchemy.types import ARRAY from sqlalchemy.types import Boolean @@ -176,7 +177,6 @@ class DefaultColumnComparatorTest( argnames="op", ) def test_nonsensical_negations(self, op): - opstring = compiler.OPERATORS[op] self.assert_compile( select(~op(column("x"), column("q"))), @@ -184,7 +184,6 @@ class DefaultColumnComparatorTest( ) def test_null_true_false_is_sanity_checks(self): - d = default.DefaultDialect() d.supports_native_boolean = True @@ -376,9 +375,84 @@ class MultiElementExprTest(fixtures.TestBase, testing.AssertsCompiledSQL): expr = expr1 + expr2 self.assert_compile( - select(expr), "SELECT i1 + i2 + d1 || d2 AS anon_1" + select(expr), "SELECT i1 + i2 + (d1 || d2) AS anon_1" ) + @testing.combinations( + operators.add, + operators.mul, + argnames="op", + ) + @testing.combinations(True, False, argnames="reverse") + @testing.combinations(True, False, argnames="negate") + def test_parenthesized_exprs(self, op, reverse, negate): + t1 = table("t", column("q"), column("p")) + + inner = lambda: t1.c.q - t1.c.p # noqa E371 + expr = op(inner(), inner()) + + if reverse: + for i in range(8): + expr = op(inner(), expr) + else: + for i in range(8): + expr = op(expr, inner()) + + opstring = compiler.OPERATORS[op] + exprs = opstring.join("(t.q - t.p)" for i in range(10)) + + if negate: + self.assert_compile( + select(~expr), f"SELECT NOT ({exprs}) AS anon_1 FROM t" + ) + else: + self.assert_compile( + select(expr), f"SELECT {exprs} AS anon_1 FROM t" + ) + + @testing.combinations( + ( + lambda p, q: (1 - p) * (2 - q) + 10 * (3 - p) * (4 - q), + "(:p_1 - t.p) * (:q_1 - t.q) + " + ":param_1 * (:p_2 - t.p) * (:q_2 - t.q)", + ), + ( + lambda p, q: (1 - p) * (2 - q) * (3 - p) * (4 - q), + "(:p_1 - t.p) * (:q_1 - t.q) * " "(:p_2 - t.p) * (:q_2 - t.q)", + ), + ( + lambda p, q: ( + (1 + p + 5) + * (p * (q - 5) * (p + 8)) + * (q + (p - 3) + (q - 5) + (p - 9)) + * (4 + q + 9) + ), + "(:p_1 + t.p + :param_1) * " + "t.p * (t.q - :q_1) * (t.p + :p_2) * " + "(t.q + (t.p - :p_3) + (t.q - :q_2) + (t.p - :p_4)) * " + "(:q_3 + t.q + :param_2)", + ), + ( + lambda p, q: (1 // p) - (2 // q) - (3 // p) - (4 // q), + "((:p_1 / t.p - :q_1 / t.q) - :p_2 / t.p) - :q_2 / t.q", + ), + ( + lambda p, q: (1 + p) - (2 + q) - (3 + p) - (4 + q), + "(((:p_1 + t.p) - (:q_1 + t.q)) - (:p_2 + t.p)) - (:q_2 + t.q)", + ), + ( + lambda p, q: (1 + p) * 3 * (2 + q) * 4 * (3 + p) - (4 + q), + "(:p_1 + t.p) * :param_1 * (:q_1 + t.q) * " + ":param_2 * (:p_2 + t.p) - (:q_2 + t.q)", + ), + argnames="expr, expected", + ) + def test_other_exprs(self, expr, expected): + t = table("t", column("q", Integer), column("p", Integer)) + expr = resolve_lambda(expr, p=t.c.p, q=t.c.q) + + self.assert_compile(expr, expected) + @testing.combinations( operators.add, operators.and_, @@ -495,7 +569,6 @@ class MultiElementExprTest(fixtures.TestBase, testing.AssertsCompiledSQL): else f"SELECT {str_expr} AS anon_1 FROM t", ) else: - if reverse: str_expr = ( f"d0{opstring}(d1{opstring}(d2{opstring}" @@ -937,13 +1010,11 @@ class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_(col[("a", "b", "c")].type._type_affinity, JSON) def test_getindex_literal_integer(self): - col = Column("x", self.MyType()) self.assert_compile(col[5], "x -> :x_1", checkparams={"x_1": 5}) def test_getindex_literal_string(self): - col = Column("x", self.MyType()) self.assert_compile( @@ -951,7 +1022,6 @@ class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) def test_path_getindex_literal(self): - col = Column("x", self.MyType()) self.assert_compile( @@ -961,14 +1031,12 @@ class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) def test_getindex_sqlexpr(self): - col = Column("x", self.MyType()) col2 = Column("y", Integer()) self.assert_compile(col[col2], "x -> y", checkparams={}) def test_getindex_sqlexpr_right_grouping(self): - col = Column("x", self.MyType()) col2 = Column("y", Integer()) @@ -977,13 +1045,11 @@ class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) def test_getindex_sqlexpr_left_grouping(self): - col = Column("x", self.MyType()) self.assert_compile(col[8] != None, "(x -> :x_1) IS NOT NULL") # noqa def test_getindex_sqlexpr_both_grouping(self): - col = Column("x", self.MyType()) col2 = Column("y", Integer()) @@ -1101,7 +1167,6 @@ class ArrayIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_(col[5][6][7].type._type_affinity, Integer) def test_getindex_literal(self): - col = Column("x", self.MyType()) self.assert_compile(col[5], "x[:x_1]", checkparams={"x_1": 5}) @@ -1116,7 +1181,6 @@ class ArrayIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) def test_getindex_sqlexpr(self): - col = Column("x", self.MyType()) col2 = Column("y", Integer()) @@ -1127,7 +1191,6 @@ class ArrayIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) def test_getslice_literal(self): - col = Column("x", self.MyType()) self.assert_compile( @@ -1135,7 +1198,6 @@ class ArrayIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) def test_getslice_sqlexpr(self): - col = Column("x", self.MyType()) col2 = Column("y", Integer()) @@ -1144,13 +1206,11 @@ class ArrayIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) def test_getindex_literal_zeroind(self): - col = Column("x", self.MyType(zero_indexes=True)) self.assert_compile(col[5], "x[:x_1]", checkparams={"x_1": 6}) def test_getindex_sqlexpr_zeroind(self): - col = Column("x", self.MyType(zero_indexes=True)) col2 = Column("y", Integer()) @@ -1163,7 +1223,6 @@ class ArrayIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) def test_getslice_literal_zeroind(self): - col = Column("x", self.MyType(zero_indexes=True)) self.assert_compile( @@ -1171,7 +1230,6 @@ class ArrayIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) def test_getslice_sqlexpr_zeroind(self): - col = Column("x", self.MyType(zero_indexes=True)) col2 = Column("y", Integer()) @@ -2466,7 +2524,7 @@ class MathOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): id_="iaa", ) def test_math_op(self, py_op, sql_op): - for (lhs, rhs, res) in ( + for lhs, rhs, res in ( (5, self.table1.c.myid, ":myid_1 %s mytable.myid"), (5, literal(5), ":param_1 %s :param_2"), (self.table1.c.myid, "b", "mytable.myid %s :myid_1"), @@ -2567,7 +2625,7 @@ class ComparisonOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) def test_comparison_op(self, py_op, fwd_op, rev_op): dt = datetime.datetime(2012, 5, 10, 15, 27, 18) - for (lhs, rhs, l_sql, r_sql) in ( + for lhs, rhs, l_sql, r_sql in ( ("a", self.table1.c.myid, ":myid_1", "mytable.myid"), ("a", literal("b"), ":param_2", ":param_1"), # note swap! (self.table1.c.myid, "b", "mytable.myid", ":myid_1"), @@ -2584,7 +2642,6 @@ class ComparisonOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): (dt, literal("b"), ":param_2", ":param_1"), (literal("b"), dt, ":param_1", ":param_2"), ): - # the compiled clause should match either (e.g.): # 'a' < 'b' -or- 'b' > 'a'. compiled = str(py_op(lhs, rhs)) @@ -2663,7 +2720,7 @@ class NegationTest(fixtures.TestBase, testing.AssertsCompiledSQL): table1 = table("mytable", column("myid", Integer), column("name", String)) def test_negate_operators_1(self): - for (py_op, op) in ((operator.neg, "-"), (operator.inv, "NOT ")): + for py_op, op in ((operator.neg, "-"), (operator.inv, "NOT ")): for expr, expected in ( (self.table1.c.myid, "mytable.myid"), (literal("foo"), ":param_1"), @@ -4400,7 +4457,6 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) @testing.combinations("int", "array", argnames="datatype") def test_any_generic_null(self, datatype, expr, t_fixture): - col = t_fixture.c.data if datatype == "int" else t_fixture.c.arrval self.assert_compile(expr(col), "NULL = ANY (tab1.%s)" % col.name)