]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
refactor any_ / all_
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 2 Jan 2024 16:16:13 +0000 (11:16 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 2 Jan 2024 20:11:32 +0000 (15:11 -0500)
Improved compilation of :func:`_sql.any_` / :func:`_sql.all_` in the
context of a negation of boolean comparison, will now render ``NOT (expr)``
rather than reversing the equality operator to not equals, allowing
finer-grained control of negations for these non-typical operators.

Fixes: #10817
Change-Id: If0b324b1220ad3c7f053af91e8a61c81015f312a

doc/build/changelog/unreleased_20/10817.rst [new file with mode: 0644]
lib/sqlalchemy/sql/default_comparator.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/operators.py
lib/sqlalchemy/sql/sqltypes.py
test/sql/test_operators.py

diff --git a/doc/build/changelog/unreleased_20/10817.rst b/doc/build/changelog/unreleased_20/10817.rst
new file mode 100644 (file)
index 0000000..69634d0
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 10817
+
+    Improved compilation of :func:`_sql.any_` / :func:`_sql.all_` in the
+    context of a negation of boolean comparison, will now render ``NOT (expr)``
+    rather than reversing the equality operator to not equals, allowing
+    finer-grained control of negations for these non-typical operators.
index 939b14c5d4cf024f59a301d7e8dcc446407bd37f..072acafed304aadb9de455827d1538a65b6d8184 100644 (file)
@@ -56,7 +56,6 @@ def _boolean_compare(
     negate_op: Optional[OperatorType] = None,
     reverse: bool = False,
     _python_is_types: Tuple[Type[Any], ...] = (type(None), bool),
-    _any_all_expr: bool = False,
     result_type: Optional[TypeEngine[bool]] = None,
     **kwargs: Any,
 ) -> OperatorExpression[bool]:
@@ -90,7 +89,7 @@ def _boolean_compare(
                 negate=negate_op,
                 modifiers=kwargs,
             )
-        elif _any_all_expr:
+        elif expr._is_collection_aggregate:
             obj = coercions.expect(
                 roles.ConstExprRole, element=obj, operator=op, expr=expr
             )
index e6d7ad7da8d9b3ce44c2aaa57ed64db0160a1659..45eb8f3c55be13d47ae7550b672dff2a2b47b1b3 100644 (file)
@@ -803,6 +803,7 @@ class CompilerColumnElement(
     __slots__ = ()
 
     _propagate_attrs = util.EMPTY_DICT
+    _is_collection_aggregate = False
 
 
 # SQLCoreOperations should be suiting the ExpressionElementRole
@@ -1407,6 +1408,7 @@ class ColumnElement(
     _is_column_element = True
     _insert_sentinel: bool = False
     _omit_from_statements = False
+    _is_collection_aggregate = False
 
     foreign_keys: AbstractSet[ForeignKey] = frozenset()
 
@@ -2361,6 +2363,8 @@ class TextClause(
 
     _omit_from_statements = False
 
+    _is_collection_aggregate = False
+
     @property
     def _hide_froms(self) -> Iterable[FromClause]:
         return ()
@@ -2966,6 +2970,9 @@ class OperatorExpression(ColumnElement[_T]):
                     *(left_flattened + right_flattened),
                 )
 
+        if right._is_collection_aggregate:
+            negate = None
+
         return BinaryExpression(
             left, right, op, type_=type_, negate=negate, modifiers=modifiers
         )
@@ -3804,6 +3811,7 @@ class CollectionAggregate(UnaryExpression[_T]):
     """
 
     inherit_cache = True
+    _is_collection_aggregate = True
 
     @classmethod
     def _create_any(
@@ -3845,7 +3853,7 @@ class CollectionAggregate(UnaryExpression[_T]):
             raise exc.ArgumentError(
                 "Only comparison operators may be used with ANY/ALL"
             )
-        kwargs["reverse"] = kwargs["_any_all_expr"] = True
+        kwargs["reverse"] = True
         return self.comparator.operate(operators.mirror(op), *other, **kwargs)
 
     def reverse_operate(self, op, other, **kwargs):
@@ -4033,7 +4041,7 @@ class BinaryExpression(OperatorExpression[_T]):
                 modifiers=self.modifiers,
             )
         else:
-            return super()._negate()
+            return self.self_group()._negate()
 
 
 class Slice(ColumnElement[Any]):
index d91f7607063ae0af866eec5953d4e6d04263cd51..53fad3ea21176e1236dd784448e6176de4676e6d 100644 (file)
@@ -1819,10 +1819,10 @@ class ColumnOperators(Operators):
         See the documentation for :func:`_sql.any_` for examples.
 
         .. note:: be sure to not confuse the newer
-            :meth:`_sql.ColumnOperators.any_` method with its older
-            :class:`_types.ARRAY`-specific counterpart, the
-            :meth:`_types.ARRAY.Comparator.any` method, which a different
-            calling syntax and usage pattern.
+            :meth:`_sql.ColumnOperators.any_` method with the **legacy**
+            version of this method, the :meth:`_types.ARRAY.Comparator.any`
+            method that's specific to :class:`_types.ARRAY`, which uses a
+            different calling style.
 
         """
         return self.operate(any_op)
@@ -1834,10 +1834,10 @@ class ColumnOperators(Operators):
         See the documentation for :func:`_sql.all_` for examples.
 
         .. note:: be sure to not confuse the newer
-            :meth:`_sql.ColumnOperators.all_` method with its older
-            :class:`_types.ARRAY`-specific counterpart, the
-            :meth:`_types.ARRAY.Comparator.all` method, which a different
-            calling syntax and usage pattern.
+            :meth:`_sql.ColumnOperators.all_` method with the **legacy**
+            version of this method, the :meth:`_types.ARRAY.Comparator.all`
+            method that's specific to :class:`_types.ARRAY`, which uses a
+            different calling style.
 
         """
         return self.operate(all_op)
index 91e382de6944ca4b26e907ddd9b5e339a876c760..0963e8ed200244a2ce41b22a14a7cb8ba2c597f8 100644 (file)
@@ -2924,7 +2924,7 @@ class ARRAY(
         def any(self, other, operator=None):
             """Return ``other operator ANY (array)`` clause.
 
-            .. note:: This method is an :class:`_types.ARRAY` - specific
+            .. legacy:: This method is an :class:`_types.ARRAY` - specific
                 construct that is now superseded by the :func:`_sql.any_`
                 function, which features a different calling style. The
                 :func:`_sql.any_` function is also mirrored at the method level
@@ -2958,9 +2958,8 @@ class ARRAY(
 
             arr_type = self.type
 
-            # send plain BinaryExpression so that negate remains at None,
-            # leading to NOT expr for negation.
-            return elements.BinaryExpression(
+            return elements.CollectionAggregate._create_any(self.expr).operate(
+                operators.mirror(operator),
                 coercions.expect(
                     roles.BinaryElementRole,
                     element=other,
@@ -2968,19 +2967,17 @@ class ARRAY(
                     expr=self.expr,
                     bindparam_type=arr_type.item_type,
                 ),
-                elements.CollectionAggregate._create_any(self.expr),
-                operator,
             )
 
         @util.preload_module("sqlalchemy.sql.elements")
         def all(self, other, operator=None):
             """Return ``other operator ALL (array)`` clause.
 
-            .. note:: This method is an :class:`_types.ARRAY` - specific
-                construct that is now superseded by the :func:`_sql.any_`
+            .. legacy:: This method is an :class:`_types.ARRAY` - specific
+                construct that is now superseded by the :func:`_sql.all_`
                 function, which features a different calling style. The
-                :func:`_sql.any_` function is also mirrored at the method level
-                via the :meth:`_sql.ColumnOperators.any_` method.
+                :func:`_sql.all_` function is also mirrored at the method level
+                via the :meth:`_sql.ColumnOperators.all_` method.
 
             Usage of array-specific :meth:`_types.ARRAY.Comparator.all`
             is as follows::
@@ -3010,9 +3007,8 @@ class ARRAY(
 
             arr_type = self.type
 
-            # send plain BinaryExpression so that negate remains at None,
-            # leading to NOT expr for negation.
-            return elements.BinaryExpression(
+            return elements.CollectionAggregate._create_all(self.expr).operate(
+                operators.mirror(operator),
                 coercions.expect(
                     roles.BinaryElementRole,
                     element=other,
@@ -3020,8 +3016,6 @@ class ARRAY(
                     expr=self.expr,
                     bindparam_type=arr_type.item_type,
                 ),
-                elements.CollectionAggregate._create_all(self.expr),
-                operator,
             )
 
     comparator_factory = Comparator
index af51010c761a0169db4326ac419d8cc74d24fd33..7e61920aa292f6ae90da1c53fa19ff6a06d52191 100644 (file)
@@ -4540,7 +4540,7 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         )
         return t
 
-    @testing.combinations(
+    null_comparisons = testing.combinations(
         lambda col: any_(col) == None,
         lambda col: col.any_() == None,
         lambda col: any_(col) == null(),
@@ -4551,12 +4551,23 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         lambda col: None == col.any_(),
         argnames="expr",
     )
+
+    @null_comparisons
     @testing.combinations("int", "array", argnames="datatype")
     def test_any_generic_null(self, datatype, expr, t_fixture):
         col = t_fixture.c.data if datatype == "int" else t_fixture.c.arrval
 
         self.assert_compile(expr(col), "NULL = ANY (tab1.%s)" % col.name)
 
+    @null_comparisons
+    @testing.combinations("int", "array", argnames="datatype")
+    def test_any_generic_null_negate(self, datatype, expr, t_fixture):
+        col = t_fixture.c.data if datatype == "int" else t_fixture.c.arrval
+
+        self.assert_compile(
+            ~expr(col), "NOT (NULL = ANY (tab1.%s))" % col.name
+        )
+
     @testing.fixture(
         params=[
             ("ANY", any_),
@@ -4565,48 +4576,78 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             ("ALL", lambda x: x.all_()),
         ]
     )
-    def operator(self, request):
+    def any_all_operators(self, request):
         return request.param
 
+    # test legacy array any() / all().  these are superseded by the
+    # any_() / all_() versions
     @testing.fixture(
         params=[
             ("ANY", lambda x, *o: x.any(*o)),
             ("ALL", lambda x, *o: x.all(*o)),
         ]
     )
-    def array_op(self, request):
+    def legacy_any_all_operators(self, request):
         return request.param
 
-    def test_array(self, t_fixture, operator):
+    def test_array(self, t_fixture, any_all_operators):
         t = t_fixture
-        op, fn = operator
+        op, fn = any_all_operators
         self.assert_compile(
             5 == fn(t.c.arrval),
             f":param_1 = {op} (tab1.arrval)",
             checkparams={"param_1": 5},
         )
 
-    def test_comparator_array(self, t_fixture, operator):
+    def test_comparator_inline_negate(self, t_fixture, any_all_operators):
         t = t_fixture
-        op, fn = operator
+        op, fn = any_all_operators
+        self.assert_compile(
+            5 != fn(t.c.arrval),
+            f":param_1 != {op} (tab1.arrval)",
+            checkparams={"param_1": 5},
+        )
+
+    @testing.combinations(
+        (operator.eq, "="),
+        (operator.ne, "!="),
+        (operator.gt, ">"),
+        (operator.le, "<="),
+        argnames="operator,opstring",
+    )
+    def test_comparator_outer_negate(
+        self, t_fixture, any_all_operators, operator, opstring
+    ):
+        """test #10817"""
+        t = t_fixture
+        op, fn = any_all_operators
+        self.assert_compile(
+            ~(operator(5, fn(t.c.arrval))),
+            f"NOT (:param_1 {opstring} {op} (tab1.arrval))",
+            checkparams={"param_1": 5},
+        )
+
+    def test_comparator_array(self, t_fixture, any_all_operators):
+        t = t_fixture
+        op, fn = any_all_operators
         self.assert_compile(
             5 > fn(t.c.arrval),
             f":param_1 > {op} (tab1.arrval)",
             checkparams={"param_1": 5},
         )
 
-    def test_comparator_array_wexpr(self, t_fixture, operator):
+    def test_comparator_array_wexpr(self, t_fixture, any_all_operators):
         t = t_fixture
-        op, fn = operator
+        op, fn = any_all_operators
         self.assert_compile(
             t.c.data > fn(t.c.arrval),
             f"tab1.data > {op} (tab1.arrval)",
             checkparams={},
         )
 
-    def test_illegal_ops(self, t_fixture, operator):
+    def test_illegal_ops(self, t_fixture, any_all_operators):
         t = t_fixture
-        op, fn = operator
+        op, fn = any_all_operators
 
         assert_raises_message(
             exc.ArgumentError,
@@ -4622,10 +4663,10 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             t.c.data + fn(t.c.arrval), f"tab1.data + {op} (tab1.arrval)"
         )
 
-    def test_bindparam_coercion(self, t_fixture, array_op):
+    def test_bindparam_coercion(self, t_fixture, legacy_any_all_operators):
         """test #7979"""
         t = t_fixture
-        op, fn = array_op
+        op, fn = legacy_any_all_operators
 
         expr = fn(t.c.arrval, bindparam("param"))
         expected = f"%(param)s = {op} (tab1.arrval)"
@@ -4633,9 +4674,11 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
         self.assert_compile(expr, expected, dialect="postgresql")
 
-    def test_array_comparator_accessor(self, t_fixture, array_op):
+    def test_array_comparator_accessor(
+        self, t_fixture, legacy_any_all_operators
+    ):
         t = t_fixture
-        op, fn = array_op
+        op, fn = legacy_any_all_operators
 
         self.assert_compile(
             fn(t.c.arrval, 5, operator.gt),
@@ -4643,9 +4686,11 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             checkparams={"arrval_1": 5},
         )
 
-    def test_array_comparator_negate_accessor(self, t_fixture, array_op):
+    def test_array_comparator_negate_accessor(
+        self, t_fixture, legacy_any_all_operators
+    ):
         t = t_fixture
-        op, fn = array_op
+        op, fn = legacy_any_all_operators
 
         self.assert_compile(
             ~fn(t.c.arrval, 5, operator.gt),
@@ -4653,9 +4698,9 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             checkparams={"arrval_1": 5},
         )
 
-    def test_array_expression(self, t_fixture, operator):
+    def test_array_expression(self, t_fixture, any_all_operators):
         t = t_fixture
-        op, fn = operator
+        op, fn = any_all_operators
 
         self.assert_compile(
             5 == fn(t.c.arrval[5:6] + postgresql.array([3, 4])),
@@ -4671,9 +4716,9 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             dialect="postgresql",
         )
 
-    def test_subq(self, t_fixture, operator):
+    def test_subq(self, t_fixture, any_all_operators):
         t = t_fixture
-        op, fn = operator
+        op, fn = any_all_operators
 
         self.assert_compile(
             5 == fn(select(t.c.data).where(t.c.data < 10).scalar_subquery()),
@@ -4682,9 +4727,9 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             checkparams={"data_1": 10, "param_1": 5},
         )
 
-    def test_scalar_values(self, t_fixture, operator):
+    def test_scalar_values(self, t_fixture, any_all_operators):
         t = t_fixture
-        op, fn = operator
+        op, fn = any_all_operators
 
         self.assert_compile(
             5 == fn(values(t.c.data).data([(1,), (42,)]).scalar_values()),