]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Register func.any(), func.all(), func.some() as collection aggregates
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 4 Jun 2026 14:17:06 +0000 (10:17 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 4 Jun 2026 14:25:33 +0000 (10:25 -0400)
Added CollectionAggregateFunction base class that sets
_is_collection_aggregate = True, and registered any_, all_, some_
as subclasses so that func.any(), func.all(), and func.some() correctly
prevent operator flipping on negation. Previously ~(col == func.any(arr))
would incorrectly compile to col != any(arr) instead of
NOT (col = any(arr)), which has different semantics for collection
aggregate comparison modifiers.

Also extended the _construct_for_op guard to check both left and right
operands for _is_collection_aggregate, since func.any(arr) can appear
on either side of a comparison unlike the standalone any_() construct
which auto-reverses operands.

Fixes: #13343
Change-Id: Id4774938876dc7f1f38cf143dccfe3c8ddba464d

doc/build/changelog/unreleased_21/13343.rst [new file with mode: 0644]
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/functions.py
test/sql/test_functions.py

diff --git a/doc/build/changelog/unreleased_21/13343.rst b/doc/build/changelog/unreleased_21/13343.rst
new file mode 100644 (file)
index 0000000..cae9b12
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 13343
+
+    Fixed issue where negation of comparison expressions involving
+    ``func.any()``, ``func.all()``, and ``func.some()`` SQL functions would
+    incorrectly flip the comparison operator (e.g. ``=`` to ``!=``) rather
+    than wrapping the expression with ``NOT``. These functions are now
+    registered as collection aggregate functions that prevent operator
+    flipping on negation, consistent with the behavior of the standalone
+    :func:`_expression.any_` and :func:`_expression.all_` constructs.
index acfeadced6cd9378f8efa0e87868a4c3ad67a90f..1e04d8b254d99337e682613da3da4e74cfd043de 100644 (file)
@@ -3190,7 +3190,7 @@ class OperatorExpression(ColumnElement[_T]):
                     *(left_flattened + right_flattened),
                 )
 
-        if right._is_collection_aggregate:
+        if left._is_collection_aggregate or right._is_collection_aggregate:
             negate = None
 
         return BinaryExpression(
index 7b56720f207ab43334de8806724fe0c538b3b35e..37e079aa601f061fa84a835a1e6661b9387dcc9c 100644 (file)
@@ -1046,9 +1046,17 @@ class _FunctionGenerator:
         @property
         def aggregate_strings(self) -> Type[_aggregate_strings_func]: ...
 
+        @property
+        def all(self) -> Type[_all__func[Any]]:  # noqa: A001
+            ...
+
         @property
         def ansifunction(self) -> Type[_AnsiFunction_func[Any]]: ...
 
+        @property
+        def any(self) -> Type[_any__func[Any]]:  # noqa: A001
+            ...
+
         # set ColumnElement[_T] as a separate overload, to appease
         # mypy which seems to not want to accept _T from
         # _ColumnExpressionArgument. Seems somewhat related to the covariant
@@ -1307,6 +1315,9 @@ class _FunctionGenerator:
         @property
         def session_user(self) -> Type[_session_user_func]: ...
 
+        @property
+        def some(self) -> Type[_some_func[Any]]: ...
+
         # set ColumnElement[_T] as a separate overload, to appease
         # mypy which seems to not want to accept _T from
         # _ColumnExpressionArgument. Seems somewhat related to the covariant
@@ -1702,6 +1713,23 @@ class AnsiFunction(GenericFunction[_T]):
         GenericFunction.__init__(self, *args, **kwargs)
 
 
+class CollectionAggregateFunction(GenericFunction[_T]):
+    """Define a function that acts as a collection aggregate modifier.
+
+    Collection aggregate functions such as ``ANY``, ``ALL``, and ``SOME``
+    modify the semantics of comparison operators, so negation of comparisons
+    involving these functions must use ``NOT`` rather than flipping the
+    comparison operator.
+
+    .. versionadded:: 2.1
+
+    """
+
+    _is_collection_aggregate = True
+    _register = False
+    inherit_cache = True
+
+
 class ReturnTypeFromArgs(GenericFunction[_T]):
     """Define a function whose return type is bound to the type of its
     arguments.
@@ -2032,6 +2060,50 @@ class array_agg(ReturnTypeFromArgs[Sequence[_T]]):
         super().__init__(*fn_args, **kwargs)
 
 
+class any_(CollectionAggregateFunction[_T]):
+    """The SQL ANY() collection aggregate function.
+
+    .. versionadded:: 2.1
+
+    .. seealso::
+
+        :func:`_expression.any_` - standalone ANY expression
+
+    """
+
+    name = "any"
+    identifier = "any"
+    inherit_cache = True
+
+
+class all_(CollectionAggregateFunction[_T]):
+    """The SQL ALL() collection aggregate function.
+
+    .. versionadded:: 2.1
+
+    .. seealso::
+
+        :func:`_expression.all_` - standalone ALL expression
+
+    """
+
+    name = "all"
+    identifier = "all"
+    inherit_cache = True
+
+
+class some(CollectionAggregateFunction[_T]):
+    """The SQL SOME() collection aggregate function.
+
+    SOME is a synonym for ANY in the SQL standard.
+
+    .. versionadded:: 2.1
+
+    """
+
+    inherit_cache = True
+
+
 class OrderedSetAgg(GenericFunction[_T]):
     """Define a function where the return type is based on the sort
     expression type as defined by the expression passed to the
@@ -2267,7 +2339,9 @@ class aggregate_strings(GenericFunction[str]):
 # name. See https://github.com/sqlalchemy/sqlalchemy/issues/13167
 # START GENERATED FUNCTION ALIASES
 _aggregate_strings_func: TypeAlias = aggregate_strings
+_all__func: TypeAlias = all_[_T]
 _AnsiFunction_func: TypeAlias = AnsiFunction[_T]
+_any__func: TypeAlias = any_[_T]
 _array_agg_func: TypeAlias = array_agg[_T]
 _Cast_func: TypeAlias = Cast[_T]
 _char_length_func: TypeAlias = char_length
@@ -2299,6 +2373,7 @@ _random_func: TypeAlias = random
 _rank_func: TypeAlias = rank
 _rollup_func: TypeAlias = rollup[_T]
 _session_user_func: TypeAlias = session_user
+_some_func: TypeAlias = some[_T]
 _sum_func: TypeAlias = sum[_T]
 _sysdate_func: TypeAlias = sysdate
 _user_func: TypeAlias = user
index b7969858ae2ab4c1acd13eab751f327b05a5b6bd..cbe1a10b1aebc91542b9f494a55133a4e88c7409 100644 (file)
@@ -19,6 +19,7 @@ from sqlalchemy import JSON
 from sqlalchemy import literal
 from sqlalchemy import literal_column
 from sqlalchemy import MetaData
+from sqlalchemy import not_
 from sqlalchemy import Numeric
 from sqlalchemy import select
 from sqlalchemy import Sequence
@@ -1259,6 +1260,77 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         )
 
 
+class CollectionAggregateFunctionTest(fixtures.TestBase, AssertsCompiledSQL):
+    __dialect__ = "default"
+
+    @testing.combinations(
+        ("any", func.any),
+        ("all", func.all),
+        ("some", func.some),
+        id_="ia",
+    )
+    def test_is_collection_aggregate(self, fn):
+        c = column("x", Integer)
+        expr = fn(c)
+        is_(expr._is_collection_aggregate, True)
+
+    @testing.combinations(
+        ("any", func.any, "any"),
+        ("all", func.all, "all"),
+        ("some", func.some, "some"),
+        id_="iaa",
+    )
+    def test_negate_rhs(self, fn, name):
+        c = column("x", Integer)
+        arr = column("arr", Integer)
+        self.assert_compile(
+            ~(c == fn(arr)),
+            "NOT (x = %s(arr))" % name,
+        )
+
+    @testing.combinations(
+        ("any", func.any, "any"),
+        ("all", func.all, "all"),
+        ("some", func.some, "some"),
+        id_="iaa",
+    )
+    def test_negate_lhs(self, fn, name):
+        c = column("x", Integer)
+        arr = column("arr", Integer)
+        self.assert_compile(
+            ~(fn(arr) == c),
+            "NOT (%s(arr) = x)" % name,
+        )
+
+    @testing.combinations(
+        ("any", func.any, "any"),
+        ("all", func.all, "all"),
+        ("some", func.some, "some"),
+        id_="iaa",
+    )
+    def test_not_function(self, fn, name):
+        c = column("x", Integer)
+        arr = column("arr", Integer)
+        self.assert_compile(
+            not_(c == fn(arr)),
+            "NOT (x = %s(arr))" % name,
+        )
+
+    @testing.combinations(
+        ("any", func.any, "any"),
+        ("all", func.all, "all"),
+        ("some", func.some, "some"),
+        id_="iaa",
+    )
+    def test_ne_not_affected(self, fn, name):
+        c = column("x", Integer)
+        arr = column("arr", Integer)
+        self.assert_compile(
+            c != fn(arr),
+            "x != %s(arr)" % name,
+        )
+
+
 class ReturnTypeTest(AssertsCompiledSQL, fixtures.TestBase):
     def test_array_agg(self):
         expr = func.array_agg(column("data", Integer))