From aaa85f707e312bbf21b21926b0901aa14e3f3856 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 20 Sep 2025 14:08:55 -0400 Subject: [PATCH] Use ARRAY type for any_(), all_() coercion Fixed issue where the :func:`_sql.any_` and :func:`_sql.all_` aggregation operators would not correctly coerce the datatype of the compared value, in those cases where the compared value were not a simple int/str etc., such as a Python ``Enum`` or other custom value. This would lead to execution time errors for these values. This issue is essentially the same as :ticket:`6515` which was for the now-legacy :meth:`.ARRAY.any` and :meth:`.ARRAY.all` methods. Fixes: #12874 Change-Id: I980894c23b9974bc84d584a1a4c5fae72dded6d3 --- doc/build/changelog/unreleased_20/12874.rst | 11 +++ lib/sqlalchemy/sql/elements.py | 35 +++++++++ test/dialect/postgresql/test_types.py | 21 +++++ test/sql/test_operators.py | 86 +++++++++++++++++++++ 4 files changed, 153 insertions(+) create mode 100644 doc/build/changelog/unreleased_20/12874.rst diff --git a/doc/build/changelog/unreleased_20/12874.rst b/doc/build/changelog/unreleased_20/12874.rst new file mode 100644 index 0000000000..2d802203ec --- /dev/null +++ b/doc/build/changelog/unreleased_20/12874.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, postgresql + :tickets: 12874 + + Fixed issue where the :func:`_sql.any_` and :func:`_sql.all_` aggregation + operators would not correctly coerce the datatype of the compared value, in + those cases where the compared value were not a simple int/str etc., such + as a Python ``Enum`` or other custom value. This would lead to execution + time errors for these values. This issue is essentially the same as + :ticket:`6515` which was for the now-legacy :meth:`.ARRAY.any` and + :meth:`.ARRAY.all` methods. diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index fbb2f8632b..e8a830b2b4 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -3925,6 +3925,8 @@ class CollectionAggregate(UnaryExpression[_T]): def _create_any( cls, expr: _ColumnExpressionArgument[_T] ) -> CollectionAggregate[bool]: + """create CollectionAggregate for the legacy + ARRAY.Comparator.any() method""" col_expr: ColumnElement[_T] = coercions.expect( roles.ExpressionElementRole, expr, @@ -3940,6 +3942,8 @@ class CollectionAggregate(UnaryExpression[_T]): def _create_all( cls, expr: _ColumnExpressionArgument[_T] ) -> CollectionAggregate[bool]: + """create CollectionAggregate for the legacy + ARRAY.Comparator.all() method""" col_expr: ColumnElement[_T] = coercions.expect( roles.ExpressionElementRole, expr, @@ -3951,6 +3955,37 @@ class CollectionAggregate(UnaryExpression[_T]): type_=type_api.BOOLEANTYPE, ) + @util.preload_module("sqlalchemy.sql.sqltypes") + def _bind_param( + self, + operator: operators.OperatorType, + obj: Any, + type_: Optional[TypeEngine[_T]] = None, + expanding: bool = False, + ) -> BindParameter[_T]: + """For new style any_(), all_(), ensure compared literal value + receives appropriate bound parameter type.""" + + # a CollectionAggregate is specific to ARRAY or int + # only. So for ARRAY case, make sure we use correct element type + sqltypes = util.preloaded.sql_sqltypes + if self.element.type._type_affinity is sqltypes.ARRAY: + compared_to_type = cast( + sqltypes.ARRAY[Any], self.element.type + ).item_type + else: + compared_to_type = self.element.type + + return BindParameter( + None, + obj, + _compared_to_operator=operator, + type_=type_, + _compared_to_type=compared_to_type, + unique=True, + expanding=expanding, + ) + # operate and reverse_operate are hardwired to # dispatch onto the type comparator directly, so that we can # ensure "reversed" behavior. diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index a3642003da..5cacf015ec 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -10,6 +10,7 @@ import re import uuid import sqlalchemy as sa +from sqlalchemy import all_ from sqlalchemy import any_ from sqlalchemy import ARRAY from sqlalchemy import cast @@ -3236,6 +3237,26 @@ class ArrayEnum(fixtures.TestBase): @testing.combinations("all", "any", argnames="fn") def test_any_all_roundtrip( self, array_of_enum_fixture, connection, array_cls, enum_cls, fn + ): + """test for #12874. originally from the legacy use case in #6515""" + + tbl, MyEnum = array_of_enum_fixture(array_cls, enum_cls) + + if fn == "all": + expr = MyEnum.b == all_(tbl.c.pyenum_col) + result = [([MyEnum.b],)] + elif fn == "any": + expr = MyEnum.b == any_(tbl.c.pyenum_col) + result = [([MyEnum.a, MyEnum.b],), ([MyEnum.b],)] + else: + assert False + sel = select(tbl.c.pyenum_col).where(expr).order_by(tbl.c.id) + eq_(connection.execute(sel).fetchall(), result) + + @_enum_combinations + @testing.combinations("all", "any", argnames="fn") + def test_any_all_legacy_roundtrip( + self, array_of_enum_fixture, connection, array_cls, enum_cls, fn ): """test #6515""" diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index 7ce305de01..51046d2458 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -1,5 +1,6 @@ import collections.abc as collections_abc import datetime +import enum import operator import pickle import re @@ -13,6 +14,7 @@ from sqlalchemy import bindparam from sqlalchemy import bitwise_not from sqlalchemy import desc from sqlalchemy import distinct +from sqlalchemy import Enum from sqlalchemy import exc from sqlalchemy import Float from sqlalchemy import Integer @@ -4834,6 +4836,12 @@ class InSelectableTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) +class MyEnum(enum.Enum): + ONE = enum.auto() + TWO = enum.auto() + THREE = enum.auto() + + class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default" @@ -4845,6 +4853,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): "tab1", m, Column("arrval", ARRAY(Integer)), + Column("arrenum", ARRAY(Enum(MyEnum))), + Column("arrstring", ARRAY(String)), Column("data", Integer), ) return t @@ -4877,6 +4887,82 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): ~expr(col), "NOT (NULL = ANY (tab1.%s))" % col.name ) + @testing.variation("operator", ["any", "all"]) + @testing.variation( + "datatype", ["int", "array", "arraystring", "arrayenum"] + ) + def test_what_type_is_any_all( + self, + datatype: testing.Variation, + t_fixture, + operator: testing.Variation, + ): + """test for #12874""" + + if datatype.int: + col = t_fixture.c.data + value = 5 + expected_type_affinity = Integer + elif datatype.array: + col = t_fixture.c.arrval + value = 25 + expected_type_affinity = Integer + elif datatype.arraystring: + col = t_fixture.c.arrstring + value = "a string" + expected_type_affinity = String + elif datatype.arrayenum: + col = t_fixture.c.arrenum + value = MyEnum.TWO + expected_type_affinity = Enum + else: + datatype.fail() + + if operator.any: + boolean_expr = value == any_(col) + elif operator.all: + boolean_expr = value == all_(col) + else: + operator.fail() + + # using isinstance so things work out for Enum which has type affinity + # of String + assert isinstance(boolean_expr.left.type, expected_type_affinity) + + @testing.variation("operator", ["any", "all"]) + @testing.variation("datatype", ["array", "arraystring", "arrayenum"]) + def test_what_type_is_legacy_any_all( + self, + datatype: testing.Variation, + t_fixture, + operator: testing.Variation, + ): + if datatype.array: + col = t_fixture.c.arrval + value = 25 + expected_type_affinity = Integer + elif datatype.arraystring: + col = t_fixture.c.arrstring + value = "a string" + expected_type_affinity = String + elif datatype.arrayenum: + col = t_fixture.c.arrenum + value = MyEnum.TWO + expected_type_affinity = Enum + else: + datatype.fail() + + if operator.any: + boolean_expr = col.any(value) + elif operator.all: + boolean_expr = col.all(value) + else: + operator.fail() + + # using isinstance so things work out for Enum which has type affinity + # of String + assert isinstance(boolean_expr.left.type, expected_type_affinity) + @testing.fixture( params=[ ("ANY", any_), -- 2.47.3