]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
apply self_group to all elements of multi-expression
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 9 Feb 2023 20:36:38 +0000 (15:36 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 9 Feb 2023 21:45:14 +0000 (16:45 -0500)
Fixed critical regression in SQL expression formulation in the 2.0 series
due to :ticket:`7744` which improved support for SQL expressions that
contained many elements against the same operator repeatedly; parenthesis
grouping would be lost with expression elements beyond the first two
elements.

Fixes: #9271
Change-Id: Ib6ed5b71efe0f6816dab75bda622297fc89e3b49

doc/build/changelog/unreleased_20/9271.rst [new file with mode: 0644]
lib/sqlalchemy/sql/elements.py
test/sql/test_operators.py

diff --git a/doc/build/changelog/unreleased_20/9271.rst b/doc/build/changelog/unreleased_20/9271.rst
new file mode 100644 (file)
index 0000000..3efe0b1
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, sql, regression
+    :tickets: 9271
+
+    Fixed critical regression in SQL expression formulation in the 2.0 series
+    due to :ticket:`7744` which improved support for SQL expressions that
+    contained many elements against the same operator repeatedly; parenthesis
+    grouping would be lost with expression elements beyond the first two
+    elements.
+
index 37d53b30a429c8fc23658d80e769b6e141e876c2..4c2c7de3c4955f3c549755b6414e0eee868ebdf1 100644 (file)
@@ -547,7 +547,6 @@ class ClauseElement(
         optionaldict: Optional[Mapping[str, Any]],
         kwargs: Dict[str, Any],
     ) -> Self:
-
         if optionaldict:
             kwargs.update(optionaldict)
 
@@ -2780,7 +2779,6 @@ class OperatorExpression(ColumnElement[_T]):
         negate: Optional[OperatorType] = None,
         modifiers: Optional[Mapping[str, Any]] = None,
     ) -> OperatorExpression[_T]:
-
         if operators.is_associative(op):
             assert (
                 negate is None
@@ -2805,7 +2803,9 @@ class OperatorExpression(ColumnElement[_T]):
 
             if multi:
                 return ExpressionClauseList._construct_for_list(
-                    op, type_, *(left_flattened + right_flattened)
+                    op,
+                    type_,
+                    *(left_flattened + right_flattened),
                 )
 
         return BinaryExpression(
@@ -2886,7 +2886,12 @@ class ExpressionClauseList(OperatorExpression[_T]):
     ) -> ExpressionClauseList[_T]:
         self = cls.__new__(cls)
         self.group = group
-        self.clauses = clauses
+        if group:
+            self.clauses = tuple(
+                c.self_group(against=operator) for c in clauses
+            )
+        else:
+            self.clauses = clauses
         self.operator = operator
         self.type = type_
         return self
@@ -2961,7 +2966,6 @@ class BooleanClauseList(ExpressionClauseList[bool]):
         *clauses: Any,
         **kw: Any,
     ) -> ColumnElement[Any]:
-
         if initial_clause is _NoArg.NO_ARG:
             # no elements period.  deprecated use case.  return an empty
             # ClauseList construct that generates nothing unless it has
@@ -3233,7 +3237,6 @@ class Case(ColumnElement[_T]):
         value: Optional[Any] = None,
         else_: Optional[Any] = None,
     ):
-
         new_whens: Iterable[Any] = coercions._expression_collection_was_a_list(
             "whens", "case", whens
         )
@@ -4908,7 +4911,6 @@ class CollationClause(ColumnElement[str]):
 
 
 class _IdentifiedClause(Executable, ClauseElement):
-
     __visit_name__ = "identified"
 
     def __init__(self, ident):
@@ -5195,7 +5197,6 @@ class _anonymous_label(_truncated_label):
         enclosing_label: Optional[str] = None,
         sanitize_key: bool = False,
     ) -> _anonymous_label:
-
         # need to escape chars that interfere with format
         # strings in any case, issue #8724
         body = re.sub(r"[%\(\) \$]+", "_", body)
index fd0cf66549c6c92f84494143013df6e798b8ad57..e7e51aa635dec4bcf5220d71ab80366325264550 100644 (file)
@@ -63,6 +63,7 @@ from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_not
+from sqlalchemy.testing import resolve_lambda
 from sqlalchemy.testing.assertions import expect_deprecated
 from sqlalchemy.types import ARRAY
 from sqlalchemy.types import Boolean
@@ -176,7 +177,6 @@ class DefaultColumnComparatorTest(
         argnames="op",
     )
     def test_nonsensical_negations(self, op):
-
         opstring = compiler.OPERATORS[op]
         self.assert_compile(
             select(~op(column("x"), column("q"))),
@@ -184,7 +184,6 @@ class DefaultColumnComparatorTest(
         )
 
     def test_null_true_false_is_sanity_checks(self):
-
         d = default.DefaultDialect()
         d.supports_native_boolean = True
 
@@ -376,9 +375,84 @@ class MultiElementExprTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             expr = expr1 + expr2
 
             self.assert_compile(
-                select(expr), "SELECT i1 + i2 + d1 || d2 AS anon_1"
+                select(expr), "SELECT i1 + i2 + (d1 || d2) AS anon_1"
             )
 
+    @testing.combinations(
+        operators.add,
+        operators.mul,
+        argnames="op",
+    )
+    @testing.combinations(True, False, argnames="reverse")
+    @testing.combinations(True, False, argnames="negate")
+    def test_parenthesized_exprs(self, op, reverse, negate):
+        t1 = table("t", column("q"), column("p"))
+
+        inner = lambda: t1.c.q - t1.c.p  # noqa E371
+        expr = op(inner(), inner())
+
+        if reverse:
+            for i in range(8):
+                expr = op(inner(), expr)
+        else:
+            for i in range(8):
+                expr = op(expr, inner())
+
+        opstring = compiler.OPERATORS[op]
+        exprs = opstring.join("(t.q - t.p)" for i in range(10))
+
+        if negate:
+            self.assert_compile(
+                select(~expr), f"SELECT NOT ({exprs}) AS anon_1 FROM t"
+            )
+        else:
+            self.assert_compile(
+                select(expr), f"SELECT {exprs} AS anon_1 FROM t"
+            )
+
+    @testing.combinations(
+        (
+            lambda p, q: (1 - p) * (2 - q) + 10 * (3 - p) * (4 - q),
+            "(:p_1 - t.p) * (:q_1 - t.q) + "
+            ":param_1 * (:p_2 - t.p) * (:q_2 - t.q)",
+        ),
+        (
+            lambda p, q: (1 - p) * (2 - q) * (3 - p) * (4 - q),
+            "(:p_1 - t.p) * (:q_1 - t.q) * " "(:p_2 - t.p) * (:q_2 - t.q)",
+        ),
+        (
+            lambda p, q: (
+                (1 + p + 5)
+                * (p * (q - 5) * (p + 8))
+                * (q + (p - 3) + (q - 5) + (p - 9))
+                * (4 + q + 9)
+            ),
+            "(:p_1 + t.p + :param_1) * "
+            "t.p * (t.q - :q_1) * (t.p + :p_2) * "
+            "(t.q + (t.p - :p_3) + (t.q - :q_2) + (t.p - :p_4)) * "
+            "(:q_3 + t.q + :param_2)",
+        ),
+        (
+            lambda p, q: (1 // p) - (2 // q) - (3 // p) - (4 // q),
+            "((:p_1 / t.p - :q_1 / t.q) - :p_2 / t.p) - :q_2 / t.q",
+        ),
+        (
+            lambda p, q: (1 + p) - (2 + q) - (3 + p) - (4 + q),
+            "(((:p_1 + t.p) - (:q_1 + t.q)) - (:p_2 + t.p)) - (:q_2 + t.q)",
+        ),
+        (
+            lambda p, q: (1 + p) * 3 * (2 + q) * 4 * (3 + p) - (4 + q),
+            "(:p_1 + t.p) * :param_1 * (:q_1 + t.q) * "
+            ":param_2 * (:p_2 + t.p) - (:q_2 + t.q)",
+        ),
+        argnames="expr, expected",
+    )
+    def test_other_exprs(self, expr, expected):
+        t = table("t", column("q", Integer), column("p", Integer))
+        expr = resolve_lambda(expr, p=t.c.p, q=t.c.q)
+
+        self.assert_compile(expr, expected)
+
     @testing.combinations(
         operators.add,
         operators.and_,
@@ -495,7 +569,6 @@ class MultiElementExprTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                 else f"SELECT {str_expr} AS anon_1 FROM t",
             )
         else:
-
             if reverse:
                 str_expr = (
                     f"d0{opstring}(d1{opstring}(d2{opstring}"
@@ -937,13 +1010,11 @@ class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         is_(col[("a", "b", "c")].type._type_affinity, JSON)
 
     def test_getindex_literal_integer(self):
-
         col = Column("x", self.MyType())
 
         self.assert_compile(col[5], "x -> :x_1", checkparams={"x_1": 5})
 
     def test_getindex_literal_string(self):
-
         col = Column("x", self.MyType())
 
         self.assert_compile(
@@ -951,7 +1022,6 @@ class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         )
 
     def test_path_getindex_literal(self):
-
         col = Column("x", self.MyType())
 
         self.assert_compile(
@@ -961,14 +1031,12 @@ class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         )
 
     def test_getindex_sqlexpr(self):
-
         col = Column("x", self.MyType())
         col2 = Column("y", Integer())
 
         self.assert_compile(col[col2], "x -> y", checkparams={})
 
     def test_getindex_sqlexpr_right_grouping(self):
-
         col = Column("x", self.MyType())
         col2 = Column("y", Integer())
 
@@ -977,13 +1045,11 @@ class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         )
 
     def test_getindex_sqlexpr_left_grouping(self):
-
         col = Column("x", self.MyType())
 
         self.assert_compile(col[8] != None, "(x -> :x_1) IS NOT NULL")  # noqa
 
     def test_getindex_sqlexpr_both_grouping(self):
-
         col = Column("x", self.MyType())
         col2 = Column("y", Integer())
 
@@ -1101,7 +1167,6 @@ class ArrayIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         is_(col[5][6][7].type._type_affinity, Integer)
 
     def test_getindex_literal(self):
-
         col = Column("x", self.MyType())
 
         self.assert_compile(col[5], "x[:x_1]", checkparams={"x_1": 5})
@@ -1116,7 +1181,6 @@ class ArrayIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         )
 
     def test_getindex_sqlexpr(self):
-
         col = Column("x", self.MyType())
         col2 = Column("y", Integer())
 
@@ -1127,7 +1191,6 @@ class ArrayIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         )
 
     def test_getslice_literal(self):
-
         col = Column("x", self.MyType())
 
         self.assert_compile(
@@ -1135,7 +1198,6 @@ class ArrayIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         )
 
     def test_getslice_sqlexpr(self):
-
         col = Column("x", self.MyType())
         col2 = Column("y", Integer())
 
@@ -1144,13 +1206,11 @@ class ArrayIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         )
 
     def test_getindex_literal_zeroind(self):
-
         col = Column("x", self.MyType(zero_indexes=True))
 
         self.assert_compile(col[5], "x[:x_1]", checkparams={"x_1": 6})
 
     def test_getindex_sqlexpr_zeroind(self):
-
         col = Column("x", self.MyType(zero_indexes=True))
         col2 = Column("y", Integer())
 
@@ -1163,7 +1223,6 @@ class ArrayIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         )
 
     def test_getslice_literal_zeroind(self):
-
         col = Column("x", self.MyType(zero_indexes=True))
 
         self.assert_compile(
@@ -1171,7 +1230,6 @@ class ArrayIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         )
 
     def test_getslice_sqlexpr_zeroind(self):
-
         col = Column("x", self.MyType(zero_indexes=True))
         col2 = Column("y", Integer())
 
@@ -2466,7 +2524,7 @@ class MathOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         id_="iaa",
     )
     def test_math_op(self, py_op, sql_op):
-        for (lhs, rhs, res) in (
+        for lhs, rhs, res in (
             (5, self.table1.c.myid, ":myid_1 %s mytable.myid"),
             (5, literal(5), ":param_1 %s :param_2"),
             (self.table1.c.myid, "b", "mytable.myid %s :myid_1"),
@@ -2567,7 +2625,7 @@ class ComparisonOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     )
     def test_comparison_op(self, py_op, fwd_op, rev_op):
         dt = datetime.datetime(2012, 5, 10, 15, 27, 18)
-        for (lhs, rhs, l_sql, r_sql) in (
+        for lhs, rhs, l_sql, r_sql in (
             ("a", self.table1.c.myid, ":myid_1", "mytable.myid"),
             ("a", literal("b"), ":param_2", ":param_1"),  # note swap!
             (self.table1.c.myid, "b", "mytable.myid", ":myid_1"),
@@ -2584,7 +2642,6 @@ class ComparisonOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             (dt, literal("b"), ":param_2", ":param_1"),
             (literal("b"), dt, ":param_1", ":param_2"),
         ):
-
             # the compiled clause should match either (e.g.):
             # 'a' < 'b' -or- 'b' > 'a'.
             compiled = str(py_op(lhs, rhs))
@@ -2663,7 +2720,7 @@ class NegationTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     table1 = table("mytable", column("myid", Integer), column("name", String))
 
     def test_negate_operators_1(self):
-        for (py_op, op) in ((operator.neg, "-"), (operator.inv, "NOT ")):
+        for py_op, op in ((operator.neg, "-"), (operator.inv, "NOT ")):
             for expr, expected in (
                 (self.table1.c.myid, "mytable.myid"),
                 (literal("foo"), ":param_1"),
@@ -4400,7 +4457,6 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     )
     @testing.combinations("int", "array", argnames="datatype")
     def test_any_generic_null(self, datatype, expr, t_fixture):
-
         col = t_fixture.c.data if datatype == "int" else t_fixture.c.arrval
 
         self.assert_compile(expr(col), "NULL = ANY (tab1.%s)" % col.name)