]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Use ARRAY type for any_(), all_() coercion
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 20 Sep 2025 18:08:55 +0000 (14:08 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 20 Sep 2025 18:24:48 +0000 (14:24 -0400)
Fixed issue where the :func:`_sql.any_` and :func:`_sql.all_` aggregation
operators would not correctly coerce the datatype of the compared value, in
those cases where the compared value were not a simple int/str etc., such
as a Python ``Enum`` or other custom value.   This would lead to execution
time errors for these values.  This issue is essentially the same as
:ticket:`6515` which was for the now-legacy :meth:`.ARRAY.any` and
:meth:`.ARRAY.all` methods.

Fixes: #12874
Change-Id: I980894c23b9974bc84d584a1a4c5fae72dded6d3

doc/build/changelog/unreleased_20/12874.rst [new file with mode: 0644]
lib/sqlalchemy/sql/elements.py
test/dialect/postgresql/test_types.py
test/sql/test_operators.py

diff --git a/doc/build/changelog/unreleased_20/12874.rst b/doc/build/changelog/unreleased_20/12874.rst
new file mode 100644 (file)
index 0000000..2d80220
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, postgresql
+    :tickets: 12874
+
+    Fixed issue where the :func:`_sql.any_` and :func:`_sql.all_` aggregation
+    operators would not correctly coerce the datatype of the compared value, in
+    those cases where the compared value were not a simple int/str etc., such
+    as a Python ``Enum`` or other custom value.   This would lead to execution
+    time errors for these values.  This issue is essentially the same as
+    :ticket:`6515` which was for the now-legacy :meth:`.ARRAY.any` and
+    :meth:`.ARRAY.all` methods.
index fbb2f8632b7d00f6d022932524e944ae3e254531..e8a830b2b46da053c428df530829415c019d938d 100644 (file)
@@ -3925,6 +3925,8 @@ class CollectionAggregate(UnaryExpression[_T]):
     def _create_any(
         cls, expr: _ColumnExpressionArgument[_T]
     ) -> CollectionAggregate[bool]:
+        """create CollectionAggregate for the legacy
+        ARRAY.Comparator.any() method"""
         col_expr: ColumnElement[_T] = coercions.expect(
             roles.ExpressionElementRole,
             expr,
@@ -3940,6 +3942,8 @@ class CollectionAggregate(UnaryExpression[_T]):
     def _create_all(
         cls, expr: _ColumnExpressionArgument[_T]
     ) -> CollectionAggregate[bool]:
+        """create CollectionAggregate for the legacy
+        ARRAY.Comparator.all() method"""
         col_expr: ColumnElement[_T] = coercions.expect(
             roles.ExpressionElementRole,
             expr,
@@ -3951,6 +3955,37 @@ class CollectionAggregate(UnaryExpression[_T]):
             type_=type_api.BOOLEANTYPE,
         )
 
+    @util.preload_module("sqlalchemy.sql.sqltypes")
+    def _bind_param(
+        self,
+        operator: operators.OperatorType,
+        obj: Any,
+        type_: Optional[TypeEngine[_T]] = None,
+        expanding: bool = False,
+    ) -> BindParameter[_T]:
+        """For new style any_(), all_(), ensure compared literal value
+        receives appropriate bound parameter type."""
+
+        # a CollectionAggregate is specific to ARRAY or int
+        # only.  So for ARRAY case, make sure we use correct element type
+        sqltypes = util.preloaded.sql_sqltypes
+        if self.element.type._type_affinity is sqltypes.ARRAY:
+            compared_to_type = cast(
+                sqltypes.ARRAY[Any], self.element.type
+            ).item_type
+        else:
+            compared_to_type = self.element.type
+
+        return BindParameter(
+            None,
+            obj,
+            _compared_to_operator=operator,
+            type_=type_,
+            _compared_to_type=compared_to_type,
+            unique=True,
+            expanding=expanding,
+        )
+
     # operate and reverse_operate are hardwired to
     # dispatch onto the type comparator directly, so that we can
     # ensure "reversed" behavior.
index a3642003da87f64735ade443d09aa6b40b89d9c4..5cacf015ec08296b7e889fe953dc24bfc4af8305 100644 (file)
@@ -10,6 +10,7 @@ import re
 import uuid
 
 import sqlalchemy as sa
+from sqlalchemy import all_
 from sqlalchemy import any_
 from sqlalchemy import ARRAY
 from sqlalchemy import cast
@@ -3236,6 +3237,26 @@ class ArrayEnum(fixtures.TestBase):
     @testing.combinations("all", "any", argnames="fn")
     def test_any_all_roundtrip(
         self, array_of_enum_fixture, connection, array_cls, enum_cls, fn
+    ):
+        """test for #12874. originally from the legacy use case in #6515"""
+
+        tbl, MyEnum = array_of_enum_fixture(array_cls, enum_cls)
+
+        if fn == "all":
+            expr = MyEnum.b == all_(tbl.c.pyenum_col)
+            result = [([MyEnum.b],)]
+        elif fn == "any":
+            expr = MyEnum.b == any_(tbl.c.pyenum_col)
+            result = [([MyEnum.a, MyEnum.b],), ([MyEnum.b],)]
+        else:
+            assert False
+        sel = select(tbl.c.pyenum_col).where(expr).order_by(tbl.c.id)
+        eq_(connection.execute(sel).fetchall(), result)
+
+    @_enum_combinations
+    @testing.combinations("all", "any", argnames="fn")
+    def test_any_all_legacy_roundtrip(
+        self, array_of_enum_fixture, connection, array_cls, enum_cls, fn
     ):
         """test #6515"""
 
index 7ce305de01efe35508e7848fef80030686645938..51046d2458394e8dd43780610f648b24be43ef7d 100644 (file)
@@ -1,5 +1,6 @@
 import collections.abc as collections_abc
 import datetime
+import enum
 import operator
 import pickle
 import re
@@ -13,6 +14,7 @@ from sqlalchemy import bindparam
 from sqlalchemy import bitwise_not
 from sqlalchemy import desc
 from sqlalchemy import distinct
+from sqlalchemy import Enum
 from sqlalchemy import exc
 from sqlalchemy import Float
 from sqlalchemy import Integer
@@ -4834,6 +4836,12 @@ class InSelectableTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         )
 
 
+class MyEnum(enum.Enum):
+    ONE = enum.auto()
+    TWO = enum.auto()
+    THREE = enum.auto()
+
+
 class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "default"
 
@@ -4845,6 +4853,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             "tab1",
             m,
             Column("arrval", ARRAY(Integer)),
+            Column("arrenum", ARRAY(Enum(MyEnum))),
+            Column("arrstring", ARRAY(String)),
             Column("data", Integer),
         )
         return t
@@ -4877,6 +4887,82 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             ~expr(col), "NOT (NULL = ANY (tab1.%s))" % col.name
         )
 
+    @testing.variation("operator", ["any", "all"])
+    @testing.variation(
+        "datatype", ["int", "array", "arraystring", "arrayenum"]
+    )
+    def test_what_type_is_any_all(
+        self,
+        datatype: testing.Variation,
+        t_fixture,
+        operator: testing.Variation,
+    ):
+        """test for #12874"""
+
+        if datatype.int:
+            col = t_fixture.c.data
+            value = 5
+            expected_type_affinity = Integer
+        elif datatype.array:
+            col = t_fixture.c.arrval
+            value = 25
+            expected_type_affinity = Integer
+        elif datatype.arraystring:
+            col = t_fixture.c.arrstring
+            value = "a string"
+            expected_type_affinity = String
+        elif datatype.arrayenum:
+            col = t_fixture.c.arrenum
+            value = MyEnum.TWO
+            expected_type_affinity = Enum
+        else:
+            datatype.fail()
+
+        if operator.any:
+            boolean_expr = value == any_(col)
+        elif operator.all:
+            boolean_expr = value == all_(col)
+        else:
+            operator.fail()
+
+        # using isinstance so things work out for Enum which has type affinity
+        # of String
+        assert isinstance(boolean_expr.left.type, expected_type_affinity)
+
+    @testing.variation("operator", ["any", "all"])
+    @testing.variation("datatype", ["array", "arraystring", "arrayenum"])
+    def test_what_type_is_legacy_any_all(
+        self,
+        datatype: testing.Variation,
+        t_fixture,
+        operator: testing.Variation,
+    ):
+        if datatype.array:
+            col = t_fixture.c.arrval
+            value = 25
+            expected_type_affinity = Integer
+        elif datatype.arraystring:
+            col = t_fixture.c.arrstring
+            value = "a string"
+            expected_type_affinity = String
+        elif datatype.arrayenum:
+            col = t_fixture.c.arrenum
+            value = MyEnum.TWO
+            expected_type_affinity = Enum
+        else:
+            datatype.fail()
+
+        if operator.any:
+            boolean_expr = col.any(value)
+        elif operator.all:
+            boolean_expr = col.all(value)
+        else:
+            operator.fail()
+
+        # using isinstance so things work out for Enum which has type affinity
+        # of String
+        assert isinstance(boolean_expr.left.type, expected_type_affinity)
+
     @testing.fixture(
         params=[
             ("ANY", any_),