From: Mike Bayer Date: Thu, 4 Jun 2026 14:17:06 +0000 (-0400) Subject: Register func.any(), func.all(), func.some() as collection aggregates X-Git-Url: http://git.ipfire.org/gitweb/index.cgi?a=commitdiff_plain;h=b14415fa43dd676dd215fec0c18e756db43a70ff;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Register func.any(), func.all(), func.some() as collection aggregates Added CollectionAggregateFunction base class that sets _is_collection_aggregate = True, and registered any_, all_, some_ as subclasses so that func.any(), func.all(), and func.some() correctly prevent operator flipping on negation. Previously ~(col == func.any(arr)) would incorrectly compile to col != any(arr) instead of NOT (col = any(arr)), which has different semantics for collection aggregate comparison modifiers. Also extended the _construct_for_op guard to check both left and right operands for _is_collection_aggregate, since func.any(arr) can appear on either side of a comparison unlike the standalone any_() construct which auto-reverses operands. Fixes: #13343 Change-Id: Id4774938876dc7f1f38cf143dccfe3c8ddba464d --- diff --git a/doc/build/changelog/unreleased_21/13343.rst b/doc/build/changelog/unreleased_21/13343.rst new file mode 100644 index 0000000000..cae9b12458 --- /dev/null +++ b/doc/build/changelog/unreleased_21/13343.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, sql + :tickets: 13343 + + Fixed issue where negation of comparison expressions involving + ``func.any()``, ``func.all()``, and ``func.some()`` SQL functions would + incorrectly flip the comparison operator (e.g. ``=`` to ``!=``) rather + than wrapping the expression with ``NOT``. These functions are now + registered as collection aggregate functions that prevent operator + flipping on negation, consistent with the behavior of the standalone + :func:`_expression.any_` and :func:`_expression.all_` constructs. diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index acfeadced6..1e04d8b254 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -3190,7 +3190,7 @@ class OperatorExpression(ColumnElement[_T]): *(left_flattened + right_flattened), ) - if right._is_collection_aggregate: + if left._is_collection_aggregate or right._is_collection_aggregate: negate = None return BinaryExpression( diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 7b56720f20..37e079aa60 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -1046,9 +1046,17 @@ class _FunctionGenerator: @property def aggregate_strings(self) -> Type[_aggregate_strings_func]: ... + @property + def all(self) -> Type[_all__func[Any]]: # noqa: A001 + ... + @property def ansifunction(self) -> Type[_AnsiFunction_func[Any]]: ... + @property + def any(self) -> Type[_any__func[Any]]: # noqa: A001 + ... + # set ColumnElement[_T] as a separate overload, to appease # mypy which seems to not want to accept _T from # _ColumnExpressionArgument. Seems somewhat related to the covariant @@ -1307,6 +1315,9 @@ class _FunctionGenerator: @property def session_user(self) -> Type[_session_user_func]: ... + @property + def some(self) -> Type[_some_func[Any]]: ... + # set ColumnElement[_T] as a separate overload, to appease # mypy which seems to not want to accept _T from # _ColumnExpressionArgument. Seems somewhat related to the covariant @@ -1702,6 +1713,23 @@ class AnsiFunction(GenericFunction[_T]): GenericFunction.__init__(self, *args, **kwargs) +class CollectionAggregateFunction(GenericFunction[_T]): + """Define a function that acts as a collection aggregate modifier. + + Collection aggregate functions such as ``ANY``, ``ALL``, and ``SOME`` + modify the semantics of comparison operators, so negation of comparisons + involving these functions must use ``NOT`` rather than flipping the + comparison operator. + + .. versionadded:: 2.1 + + """ + + _is_collection_aggregate = True + _register = False + inherit_cache = True + + class ReturnTypeFromArgs(GenericFunction[_T]): """Define a function whose return type is bound to the type of its arguments. @@ -2032,6 +2060,50 @@ class array_agg(ReturnTypeFromArgs[Sequence[_T]]): super().__init__(*fn_args, **kwargs) +class any_(CollectionAggregateFunction[_T]): + """The SQL ANY() collection aggregate function. + + .. versionadded:: 2.1 + + .. seealso:: + + :func:`_expression.any_` - standalone ANY expression + + """ + + name = "any" + identifier = "any" + inherit_cache = True + + +class all_(CollectionAggregateFunction[_T]): + """The SQL ALL() collection aggregate function. + + .. versionadded:: 2.1 + + .. seealso:: + + :func:`_expression.all_` - standalone ALL expression + + """ + + name = "all" + identifier = "all" + inherit_cache = True + + +class some(CollectionAggregateFunction[_T]): + """The SQL SOME() collection aggregate function. + + SOME is a synonym for ANY in the SQL standard. + + .. versionadded:: 2.1 + + """ + + inherit_cache = True + + class OrderedSetAgg(GenericFunction[_T]): """Define a function where the return type is based on the sort expression type as defined by the expression passed to the @@ -2267,7 +2339,9 @@ class aggregate_strings(GenericFunction[str]): # name. See https://github.com/sqlalchemy/sqlalchemy/issues/13167 # START GENERATED FUNCTION ALIASES _aggregate_strings_func: TypeAlias = aggregate_strings +_all__func: TypeAlias = all_[_T] _AnsiFunction_func: TypeAlias = AnsiFunction[_T] +_any__func: TypeAlias = any_[_T] _array_agg_func: TypeAlias = array_agg[_T] _Cast_func: TypeAlias = Cast[_T] _char_length_func: TypeAlias = char_length @@ -2299,6 +2373,7 @@ _random_func: TypeAlias = random _rank_func: TypeAlias = rank _rollup_func: TypeAlias = rollup[_T] _session_user_func: TypeAlias = session_user +_some_func: TypeAlias = some[_T] _sum_func: TypeAlias = sum[_T] _sysdate_func: TypeAlias = sysdate _user_func: TypeAlias = user diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index b7969858ae..cbe1a10b1a 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -19,6 +19,7 @@ from sqlalchemy import JSON from sqlalchemy import literal from sqlalchemy import literal_column from sqlalchemy import MetaData +from sqlalchemy import not_ from sqlalchemy import Numeric from sqlalchemy import select from sqlalchemy import Sequence @@ -1259,6 +1260,77 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ) +class CollectionAggregateFunctionTest(fixtures.TestBase, AssertsCompiledSQL): + __dialect__ = "default" + + @testing.combinations( + ("any", func.any), + ("all", func.all), + ("some", func.some), + id_="ia", + ) + def test_is_collection_aggregate(self, fn): + c = column("x", Integer) + expr = fn(c) + is_(expr._is_collection_aggregate, True) + + @testing.combinations( + ("any", func.any, "any"), + ("all", func.all, "all"), + ("some", func.some, "some"), + id_="iaa", + ) + def test_negate_rhs(self, fn, name): + c = column("x", Integer) + arr = column("arr", Integer) + self.assert_compile( + ~(c == fn(arr)), + "NOT (x = %s(arr))" % name, + ) + + @testing.combinations( + ("any", func.any, "any"), + ("all", func.all, "all"), + ("some", func.some, "some"), + id_="iaa", + ) + def test_negate_lhs(self, fn, name): + c = column("x", Integer) + arr = column("arr", Integer) + self.assert_compile( + ~(fn(arr) == c), + "NOT (%s(arr) = x)" % name, + ) + + @testing.combinations( + ("any", func.any, "any"), + ("all", func.all, "all"), + ("some", func.some, "some"), + id_="iaa", + ) + def test_not_function(self, fn, name): + c = column("x", Integer) + arr = column("arr", Integer) + self.assert_compile( + not_(c == fn(arr)), + "NOT (x = %s(arr))" % name, + ) + + @testing.combinations( + ("any", func.any, "any"), + ("all", func.all, "all"), + ("some", func.some, "some"), + id_="iaa", + ) + def test_ne_not_affected(self, fn, name): + c = column("x", Integer) + arr = column("arr", Integer) + self.assert_compile( + c != fn(arr), + "x != %s(arr)" % name, + ) + + class ReturnTypeTest(AssertsCompiledSQL, fixtures.TestBase): def test_array_agg(self): expr = func.array_agg(column("data", Integer))